# The dojo

## Setup training device

In [None]:
# !git init
# !git remote add origin https://github.com/andreasalstrup/chrome-vision.git
# !git pull origin <branch>
# !git reset --hard FETCH_HEAD

In [None]:
!nvidia-smi

In [None]:
import torch
from torch import device, nn
import torchvision
import torchvision.io as io

torch.manual_seed(42)

print(f'PyTorch version: {torch.__version__}\ntorchvision version: {torchvision.__version__}')

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f'Using device: {device}')

ok = io.read_image("data/E45Vejle_1011.jpg")

# Hyperparameters

In [None]:
import torchvision.models as models

# Data loading
BATCH_SIZE = 32
IMAGE_RESIZE = 64

# Model
OUT_FEATURES = 128
QUEUE_SIZE = 65536
MOMENTUM = 0.9
SOFTMAX_TEMP = 0.07

# Encoder
ENCODER = models.resnet50

# Optimizer
OPTIMIZER = torch.optim.Adam
LEARNING_RATE = 0.001
BETAS = (0.9, 0.999)
EPS = 1e-08
WEIGHT_DECAY = 1e-5

# Loss function
LOSS_FN = nn.CrossEntropyLoss()

# Training loop
EPOCHS = 200

## Data loading

### Combine train data

In [None]:
from utils.merge_dir import MergeDir

src_dir = 'data/leftImg8bit/train'
dst_dir = 'data/leftImg8bit/train_combined'
dst_index = 'data/leftImg8bit/indices/trainIndex/combined.csv'

MergeDir(src_dir, dst_dir, dst_index)

### Using chromecut to prepare datasets

In [None]:
from dataset import chromeCutter

chromeCutter("data/leftImg8bit/indices/trainIndex/combined.csv","data/leftImg8bit/train_combined/", "combinedCut", "train")
#chromeCutter("data/leftImg8bit/indices/trainIndex/bremen.csv","data/leftImg8bit/train/bremen/", "bremenCut", "train")
#chromeCutter("data/leftImg8bit/indices/testIndex/berlin.csv","data/leftImg8bit/test/berlin/", "berlinCut", "test")

### Custom dataset

In [None]:
from torchvision import transforms
from utils.transforms import ContrastiveTransform

transform_MoCoV1 = ContrastiveTransform(
                        transforms.Compose([
                            transforms.ToPILImage(),
                            transforms.Resize((IMAGE_RESIZE, IMAGE_RESIZE)),
                            transforms.RandomResizedCrop(IMAGE_RESIZE, scale=(0.2, 1.0)), # 224 -> 64 
                            transforms.RandomGrayscale(p=0.2),
                            transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
                            transforms.RandomHorizontalFlip(),
                            transforms.ToTensor(),
                            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                            ])
                        )

In [None]:
from torch.utils.data import DataLoader
from dataset import CustomImageDataset

train_data = CustomImageDataset("data/leftImg8bit/indices/trainIndex/combinedCut.csv","data/leftImg8bit/train/cut/combinedCut", transform=transform_MoCoV1)
#train_data = CustomImageDataset("data/leftImg8bit/indices/trainIndex/bremenCut.csv","data/leftImg8bit/train/cut/bremenCut", transform=transform_MoCoV1)
#test_data = CustomImageDataset("data/leftImg8bit/indices/testIndex/berlinCut.csv","data/leftImg8bit/test/cut/berlinCut", transform=ContrastiveTransform(transform_MoCoV1))

train_dataloader = DataLoader(train_data,
                              batch_size=BATCH_SIZE,
                              shuffle=True,
                              drop_last=True)

test_dataloader = DataLoader(train_data,
                             batch_size=BATCH_SIZE,
                             shuffle=False)

print(f'Len of train dataloader: {len(train_dataloader)} batches of {BATCH_SIZE}')
print(f'Len of test dataloader: {len(test_dataloader)} batches of {BATCH_SIZE}')

##### Check out what's inside the training dataloader

In [None]:
train_features_batch = next(iter(train_dataloader))

print(F"Train features query_image shape: {train_features_batch[0].shape}")
print(F"Train features key_image shape: {train_features_batch[1].shape}")

## Model

#### MoCo

In [None]:
from model.chrome_vision import ChromeMoCo

model = ChromeMoCo(base_encoder=ENCODER,
                  feature_dim=OUT_FEATURES,
                  queue_size=QUEUE_SIZE,
                  momentum=MOMENTUM,
                  softmax_temp=SOFTMAX_TEMP).to(device)

optimizer = OPTIMIZER(params=model.parameters(), lr=LEARNING_RATE, betas=BETAS, eps=EPS, weight_decay=WEIGHT_DECAY)

### Training loop - MoCo

In [None]:
import os, os.path as path
from timeit import default_timer as timer
from tqdm.auto import tqdm
from model.evaluation import train_step, test_step # use torchmetrics.utilities.data.select_topk
from model.utilis import print_train_time, accuracy_top_k, saveModel, saveCheckpoint, loadCheckpoint
from utils.show_training import ShowTraining

showTraining = ShowTraining()
time_start = timer()

checkpoint_epoch = -1
CHECKPOINT = 'model/models/checkpoint.pt'
CHECKPOINT_BACKUP = 'model/models/checkpoint_backup.pt'

if path.exists(CHECKPOINT):
    (model, optimizer, checkpoint_epoch, showTraining, train_dataloader) = loadCheckpoint(CHECKPOINT, model, optimizer)

for epoch in tqdm(range(EPOCHS)):
    if epoch <= checkpoint_epoch:
        continue
    
    saveCheckpoint(CHECKPOINT_BACKUP, model, optimizer, epoch - 1, showTraining, train_dataloader)
    
    print(f'\n\tEpoch: {epoch}\n')

    (loss, top1, top5) = train_step(model=model,
                                    data_loader=train_dataloader,
                                    loss_fn=LOSS_FN,
                                    optimizer=optimizer,
                                    accuracy_fn=accuracy_top_k,
                                    device=device)
    
    # test_step(model=model,
    #            data_loader=test_dataloader,
    #            loss_fn=loss_fn,
    #            accuracy_fn=accuracy_top_k,
    #            device=device)
    
    showTraining.appendData(loss, top1, top5)
    showTraining.draw_curve(epoch)
    saveCheckpoint(CHECKPOINT, model, optimizer, epoch, showTraining, train_dataloader)

NAME = f"{model.__class__.__name__}_BatchSize{BATCH_SIZE}_LR{LEARNING_RATE}_ImageSize{IMAGE_RESIZE}_Epochs{EPOCHS}"
saveModel("model/models", f"{NAME}.pt", model)
showTraining.saveFig(f"model/models/{NAME}.png")

if path.exists(CHECKPOINT):
    os.remove(CHECKPOINT)
    os.remove(CHECKPOINT_BACKUP)

# Print time taken
time_end = timer()
total_train = print_train_time(time_start, time_end, str(next(model.parameters()).device))

### Calculate model results on test dataset

In [None]:
model_results = test_step(model=model,
                           data_loader=test_dataloader,
                           loss_fn=LOSS_FN,
                           accuracy_fn=accuracy_top_k,
                           device=device)

model_results