# The dojo

## Setup training device

In [None]:
# !git init
# !git remote add origin https://github.com/andreasalstrup/chrome-vision.git
# !git pull origin <branch>
# !git reset --hard FETCH_HEAD

In [None]:
!nvidia-smi

In [None]:
import torch
from torch import device, nn
import torchvision
import torchvision.io as io

#torch.manual_seed(42)

print(f'PyTorch version: {torch.__version__}\ntorchvision version: {torchvision.__version__}')

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f'Using device: {device}')

ok = io.read_image("data/E45Vejle_1011.jpg")

# Hyperparameters

In [None]:
import torchvision.models as models

# Data loading
BATCH_SIZE = 256
IMAGE_RESIZE = 64

# Model
OUT_FEATURES = 128
QUEUE_SIZE = 65536
MOMENTUM = 0.9
SOFTMAX_TEMP = 0.07

# Encoder
ENCODER = models.resnet50

# Optimizer
OPTIMIZER = torch.optim.Adam
LEARNING_RATE = 0.001
ADJUST_LEARNING_RATE = False
BETAS = (0.9, 0.999)
EPS = 1e-08
WEIGHT_DECAY = 1e-5

# Loss function
LOSS_FN = nn.CrossEntropyLoss()

# Training loop
EPOCHS = 400

## Data loading

### Combine train data

In [None]:
from chrome_utils.merge_dir import MergeDir

src_dir = 'data/leftImg8bit/train'
dst_dir = 'data/leftImg8bit/train_combined'
dst_index = 'data/leftImg8bit/indices/trainIndex/combined.csv'

MergeDir(src_dir, dst_dir, dst_index)

### Combine test data

In [None]:
from chrome_utils.merge_dir import MergeDir

src_dir = 'data/leftImg8bit/test'
dst_dir = 'data/leftImg8bit/test_combined'
dst_index = 'data/leftImg8bit/indices/testIndex/combined.csv'

MergeDir(src_dir, dst_dir, dst_index)

### Using chromecut to prepare datasets

In [None]:
from chrome_cut import ChromeCut
cutter = ChromeCut()
cutter.CutImagesInFolder(annotations_file="data/leftImg8bit/indices/trainIndex/combined.csv",
                         img_dir="data/leftImg8bit/train_combined/", 
                         name="trainCombinedCut", 
                         new_img_dir="data/leftImg8bit/train/cut/trainCombinedCut",
                         new_annotations_file_location="data/leftImg8bit/indices/trainIndex")

cutter.CutImagesInFolder(annotations_file="data/leftImg8bit/indices/testIndex/combined.csv",
                         img_dir="data/leftImg8bit/test_combined/", 
                         name="testCombinedCut",
                         new_img_dir= "data/leftImg8bit/test/cut/testCombinedCut",
                         new_annotations_file_location= "data/leftImg8bit/indices/testIndex")

### Custom dataset

In [None]:
from torchvision import transforms
from chrome_utils.transforms import ContrastiveTransform

transform_MoCoV1 = ContrastiveTransform(
                        transforms.Compose([
                            transforms.ToPILImage(),
                            transforms.Resize((IMAGE_RESIZE, IMAGE_RESIZE)),
                            transforms.RandomResizedCrop(IMAGE_RESIZE, scale=(0.2, 1.0)), # 224 -> 64 
                            transforms.RandomGrayscale(p=0.2),
                            transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
                            transforms.RandomHorizontalFlip(),
                            transforms.ToTensor(),
                            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                            ])
                        )

In [None]:
from torch.utils.data import DataLoader
from custom_image_dataset import CustomImageDataset

TRAIN_DATA = CustomImageDataset("data/leftImg8bit/indices/trainIndex/trainCombinedCut.csv","data/leftImg8bit/train/cut/trainCombinedCut", transform=transform_MoCoV1)
TEST_DATA = CustomImageDataset("data/leftImg8bit/indices/testIndex/testCombinedCut.csv","data/leftImg8bit/test/cut/testCombinedCut", transform=transform_MoCoV1)

TRAIN_DATALOADER = DataLoader(TRAIN_DATA,
                              batch_size=BATCH_SIZE,
                              shuffle=True,
                              drop_last=True)

TEST_DATALOADER = DataLoader(TEST_DATA,
                             batch_size=BATCH_SIZE,
                             shuffle=False,
                             drop_last=True)

print(f'Len of train dataloader: {len(TRAIN_DATALOADER)} batches of {BATCH_SIZE}')
print(f'Len of test dataloader: {len(TEST_DATALOADER)} batches of {BATCH_SIZE}')

##### Check out what's inside the training dataloader

In [None]:
train_features_batch = next(iter(TRAIN_DATALOADER))

print(F"Train features query_image shape: {train_features_batch[0].shape}")
print(F"Train features key_image shape: {train_features_batch[1].shape}")

## Model

#### ChromeMoCo

In [None]:
from model.chrome_moco import ChromeMoCo

model = ChromeMoCo(base_encoder=ENCODER,
                  feature_dim=OUT_FEATURES,
                  queue_size=QUEUE_SIZE,
                  momentum=MOMENTUM,
                  softmax_temp=SOFTMAX_TEMP).to(device)

optimizer = OPTIMIZER(params=model.parameters(), lr=LEARNING_RATE, betas=BETAS, eps=EPS, weight_decay=WEIGHT_DECAY)

#### ChromeMoCoV2

In [None]:
from model.chrome_moco_v2 import ChromeMoCoV2

OUT_FEATURES = (1000, 10)

model = ChromeMoCoV2(base_encoder=ENCODER,
                  feature_dim=OUT_FEATURES,
                  queue_size=QUEUE_SIZE,
                  momentum=MOMENTUM,
                  softmax_temp=SOFTMAX_TEMP).to(device)

optimizer = OPTIMIZER(params=model.parameters(), lr=LEARNING_RATE, betas=BETAS, eps=EPS, weight_decay=WEIGHT_DECAY)

### Training loop - MoCo

In [None]:
import os, os.path as path
from timeit import default_timer as timer
from tqdm.auto import tqdm
from model.evaluation import train_step, test_step, adjust_learning_rate # use torchmetrics.utilities.data.select_topk
from chrome_utils.model_utils import print_train_time, accuracy_top_k, saveModel, saveCheckpoint, loadCheckpoint
from chrome_utils.show_progress import ShowProgress

# Setup progress curves
show_training = ShowProgress('Train')
show_testing = ShowProgress('Test')
time_start = timer()

# Setup/Load checkpoint
checkpoint_epoch = -1
CHECKPOINT = 'model/models/checkpoint.pt'
CHECKPOINT_BACKUP = 'model/models/checkpoint_backup.pt'

if path.exists(CHECKPOINT):
    (model, optimizer, checkpoint_epoch, show_training, show_testing, TRAIN_DATALOADER, TEST_DATALOADER) = loadCheckpoint(CHECKPOINT, model, optimizer)

# Train loop
for epoch in tqdm(range(EPOCHS)):

    print(f'\n\tEpoch: {epoch}\n')

    # Skip epochs already trained
    if epoch <= checkpoint_epoch:
        continue
    
    saveCheckpoint(CHECKPOINT_BACKUP,
                   model,
                   optimizer,
                   epoch - 1,
                   show_training,
                   show_testing,
                   TRAIN_DATALOADER,
                   TEST_DATALOADER)
    
    if ADJUST_LEARNING_RATE:
        current_learning_rate = optimizer.param_groups[0]['lr']
        optimizer = adjust_learning_rate(optimizer, epoch, EPOCHS, current_learning_rate)
        print(optimizer.param_groups[0]['lr'])
    
    # Train
    (train_loss, train_top1, train_top5) = train_step(model=model,
                                                     data_loader=TRAIN_DATALOADER,
                                                     loss_fn=LOSS_FN,
                                                     optimizer=optimizer,
                                                     accuracy_fn=accuracy_top_k,
                                                     device=device)
    # Draw train curve
    print(f'Train loss: {train_loss:.5f} | Train acc1: {train_top1:.2f}% | Train acc5: {train_top5:.2f}%')
    show_training.appendData(train_loss, train_top1, train_top5)
    show_training.draw_curve(epoch)
    
    # Test
    if epoch % 10 == 0:
        (test_loss, test_top1, test_top5) = test_step(model=model,
                                                    data_loader=TEST_DATALOADER,
                                                    loss_fn=LOSS_FN,
                                                    accuracy_fn=accuracy_top_k,
                                                    device=device)
        
        # Draw test curve
        print(f'Test loss: {test_loss:.5f} | Test acc1: {test_top1:.2f}% | Test acc5: {test_top5:.2f}%')
        show_testing.appendData(test_loss, test_top1, test_top5)
        show_testing.draw_curve(epoch)
    
    saveCheckpoint(CHECKPOINT,
                   model,
                   optimizer,
                   epoch,
                   show_training,
                   show_testing,
                   TRAIN_DATALOADER,
                   TEST_DATALOADER)

# Save model and curves
NAME = f"{model.__class__.__name__}_BatchSize{BATCH_SIZE}_OutFeat{OUT_FEATURES}_LR{LEARNING_RATE}_Adj{ADJUST_LEARNING_RATE}_ImageSize{IMAGE_RESIZE}_Epochs{EPOCHS}"
saveModel("model/models", f"{NAME}.pt", model)
show_training.saveFig(f"model/models/{NAME}_train.png")
show_testing.saveFig(f"model/models/{NAME}_test.png")

# Remove checkpoints
if path.exists(CHECKPOINT):
    os.remove(CHECKPOINT)
    os.remove(CHECKPOINT_BACKUP)

# Print time taken
time_end = timer()
total_train = print_train_time(time_start, time_end, str(next(model.parameters()).device))

# Evaluation

### Combine validation data

In [None]:
from chrome_utils.merge_dir import MergeDir

src_dir = 'data/leftImg8bit/val'
dst_dir = 'data/leftImg8bit/val_combined'
dst_index = 'data/leftImg8bit/indices/valIndex/combined.csv'

MergeDir(src_dir, dst_dir, dst_index)

### Using chromecut to prepare datasets

In [None]:
from chrome_cut import ChromeCut
cutter = ChromeCut()
cutter.CutImagesInFolder(annotations_file="data/leftImg8bit/indices/valIndex/combined.csv",
                         img_dir="data/leftImg8bit/val_combined/", 
                         name="valCombinedCut", 
                         new_img_dir="data/leftImg8bit/val/cut/valCombinedCut",
                         new_annotations_file_location="data/leftImg8bit/indices/valIndex")

### Load dataset

In [None]:
from torch.utils.data import DataLoader
from custom_image_dataset import CustomImageDataset

VAL_DATA = CustomImageDataset("data/leftImg8bit/indices/valIndex/valCombinedCut.csv","data/leftImg8bit/val/cut/valCombinedCut", transform=transform_MoCoV1)

VAL_DATALOADER = DataLoader(VAL_DATA,
                              batch_size=BATCH_SIZE,
                              shuffle=False,
                              drop_last=True)

### Evaluate models

In [None]:
import os
import matplotlib.pyplot as plt
from model.chrome_moco import ChromeMoCo
from model.evaluation import evaluate_models
from chrome_utils.model_utils import accuracy_top_k

model = ChromeMoCo(base_encoder=ENCODER,
                  feature_dim=OUT_FEATURES,
                  queue_size=QUEUE_SIZE,
                  momentum=MOMENTUM,
                  softmax_temp=SOFTMAX_TEMP).to(device)

# Evaluate models in directory
DIR = "model/models/"
NAME = "eval.png"
df = evaluate_models(model=model,
                     models_dir=DIR,
                     data_loader=VAL_DATALOADER,
                     loss_fn=LOSS_FN,
                     accuracy_fn=accuracy_top_k,
                     device=device)

# Plot results
df.set_index('model_name', inplace=True)
ax = df.plot(kind='barh', figsize=(10, 8))
ax.set_xlabel('Accuracy / Loss')

for i, (acc1, acc5, loss) in enumerate(zip(df['model_acc1'], df['model_acc5'], df['model_loss'])):
    acc1_width = ax.get_xlim()[1] * (acc1 / 100)
    acc5_width = ax.get_xlim()[1] * (acc5 / 100)
    loss_width = ax.get_xlim()[1] * (loss / 100)
    acc1 = round(acc1, 2)
    ax.text(acc1_width - 5, i, f'{round(acc1, 2)}%', color='white', fontweight='bold')
    ax.text(acc5_width - 5, i + 0.15, f'{round(acc5, 2)}%', color='white', fontweight='bold')
    ax.text(loss_width + 1 , i - 0.2, f'{round(loss, 2)}', color='black', fontweight='bold')

plt.savefig(os.path.join(DIR, NAME), bbox_inches='tight')

### Example of how to run ChromeVision

In [None]:
import chromevision as cVision
import torch
import cv2

#The model used for object detection
model.load_state_dict(torch.load("model/models/ChromeMoCo_BatchSize256_OutFeat10_LR0.001_AdjFalse_ImageSize64_Epochs200.pt", map_location=torch.device(device)))
chromeModel = cVision.Chromevision(model, ChromeCut())                
identifiedImage = chromeModel.identify("data/leftImg8bit/val/munster/munster_000039_000019_leftImg8bit.png")

cv2.imshow('Result', identifiedImage)
cv2.waitKey(0)
cv2.destroyAllWindows()