In [None]:
import os
import pathlib

import numpy as np
import pandas as pd

import torch
import torchvision

from torch.utils.data import DataLoader, Subset
from torchvision.transforms import Resize, Lambda, Compose

import deep_fashion
import backbones, heads, models

import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline

from tqdm.notebook import tqdm

import utils

from time import time

In [1]:
from datetime import datetime

In [2]:
now = 

In [4]:
now

'12-07-2023--19:06:12'

In [None]:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"

device_idxs = [3, 4, 5, 6]
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(idx) for idx in device_idxs])

first_device = torch.device("cuda:0")
device = torch.device("cuda")

In [None]:
utils.print_memory_usage(device_idxs)

with open("log.txt", 'w') as log_file:
    print(torch.cuda.memory_summary(), file=log_file)

---

In [None]:
training_metadata = {}

In [None]:
ctsrbm_image_transform = torchvision.models.ResNet50_Weights.DEFAULT.transforms()
ctsrbm_image_transform.antialias = True

ctsrbm_dataset_dir = os.path.join(pathlib.Path.home(), "data", "DeepFashion", "Consumer-to-shop Clothes Retrieval Benchmark")

ctsrbm_dataset = deep_fashion.ConsToShopClothRetrBM(ctsrbm_dataset_dir, ctsrbm_image_transform)

cutdown_ratio = 0.02

ctsrbm_train_dataset = Subset(ctsrbm_dataset, utils.cutdown_list(ctsrbm_dataset.get_split_mask_idxs("train"), cutdown_ratio))
ctsrbm_test_dataset = Subset(ctsrbm_dataset, utils.cutdown_list(ctsrbm_dataset.get_split_mask_idxs("test"), cutdown_ratio))
ctsrbm_val_dataset = Subset(ctsrbm_dataset, utils.cutdown_list(ctsrbm_dataset.get_split_mask_idxs("val"), cutdown_ratio))

batch_size = 256
num_workers = 16

ctsrbm_train_loader = DataLoader(ctsrbm_train_dataset, batch_size=batch_size, num_workers=num_workers)
ctsrbm_test_loader = DataLoader(ctsrbm_test_dataset, batch_size=batch_size, num_workers=num_workers)
ctsrbm_val_loader = DataLoader(ctsrbm_val_dataset, batch_size=batch_size, num_workers=num_workers)

In [None]:
backbone = backbones.ResNet50Backbone()
model = models.RetModel(backbone, 1024).to(first_device)
model = torch.nn.DataParallel(model, device_ids=list(range(len(device_idxs))))

optimizer = torch.optim.Adam(
    model.parameters(),
    lr=1e-3,
)

scheduler = torch.optim.lr_scheduler.ExponentialLR(
    optimizer,
    gamma=0.95
)

scaler = torch.cuda.amp.GradScaler()

In [None]:
def save_checkpoint(checkpoint_filename, model, optimizer, scheduler, train_losses, val_losses, training_metadata):

    checkpoint = {
        "model_state_dict": model.module.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "scheduler_state_dict": scheduler.state_dict(),
        "train_losses": train_losses,
        "val_losses": val_losses,
        "training_metadata": training_metadata
        }

    torch.save(checkpoint, checkpoint_filename)


def load_checkpoint(checkpoint_filename):
    
    checkpoint = torch.load(checkpoint_filename)

    # Loading model

    backbone = backbones.ResNet50Backbone()

    model = models.RetModel(backbone, 1024).to(first_device)
    model = torch.nn.DataParallel(model, device_ids=list(range(len(device_idxs))))
    model.module.load_state_dict(checkpoint["model_state_dict"])
    
    # Loading optimizer

    optimizer = torch.optim.Adam(
        model.parameters(),
        lr=1e-3,
    )
    
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    
    # Loading scheduler

    scheduler = torch.optim.lr_scheduler.ExponentialLR(
        optimizer,
        gamma=0.95
    )

    scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
    
    # Loading other parameters

    train_losses = checkpoint["train_losses"]
    val_losses = checkpoint["val_losses"]
    training_metadata = checkpoint["training_metadata"]

    return model, optimizer, scheduler, train_losses, val_losses, training_metadata

In [None]:
def train_epoch(with_tqdm=True):

    #total_data_points = []
    #time_diffs = []

    model.train()
    total_loss = 0

    loader_gen = ctsrbm_train_loader
    if with_tqdm: loader_gen = tqdm(loader_gen)

    #time_start = time() 

    for train_batch in loader_gen:

        #time_diff = time() - time_start
        #time_diffs.append(time_diff)

        #data_points = train_batch[0].size(dim=0)
        #total_data_points.append(data_points)

        #print("Batch start")
        #utils.print_memory_usage(device_idxs)

        anc_imgs = train_batch[0]
        pos_imgs = train_batch[1]
        neg_imgs = train_batch[2]

        #print("Loaded data")
        #utils.print_memory_usage(device_idxs)

        with torch.cuda.amp.autocast():

            anc_emb = model(anc_imgs)
            pos_emb = model(pos_imgs)
            neg_emb = model(neg_imgs)

            #print("Model evaluated")
            utils.print_memory_usage(device_idxs)
        
            triplet_loss = torch.nn.TripletMarginLoss()
            loss = triplet_loss(anc_emb, pos_emb, neg_emb)

        total_loss += loss.item()

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

    return total_loss


def val_epoch(with_tqdm=True):

    #total_data_points = []
    #time_diffs = []

    model.eval()
    total_loss = 0

    with torch.no_grad():

        loader_gen = ctsrbm_val_loader
        if with_tqdm: loader_gen = tqdm(loader_gen)

        #time_start = time() 

        for val_batch in loader_gen:

            #time_diff = time() - time_start
            #time_diffs.append(time_diff)

            #data_points = val_batch[0].size(dim=0)
            #total_data_points.append(data_points)

            #print("Batch start")
            #utils.print_memory_usage(device_idxs)

            anc_imgs = val_batch[0]
            pos_imgs = val_batch[1]
            neg_imgs = val_batch[2]

            #print("Loaded data")
            #utils.print_memory_usage(device_idxs)

            with torch.cuda.amp.autocast():

                anc_emb = model(anc_imgs)
                pos_emb = model(pos_imgs)
                neg_emb = model(neg_imgs)

                #print("Model evaluated")
                utils.print_memory_usage(device_idxs)

                triplet_loss = torch.nn.TripletMarginLoss()
                loss = triplet_loss(anc_emb, pos_emb, neg_emb)

            total_loss += loss.item()

    return total_loss

---

In [None]:
train_losses = []
val_losses = []

current_epoch = 0

In [None]:
model.module.freeze_backbone()
early_stopper = utils.EarlyStopper(patience=5)
max_epoch = current_epoch + 10

In [None]:
while current_epoch < max_epoch:

    current_epoch += 1

    print("Epoch {:d}".format(current_epoch))

    train_loss = train_epoch()
    val_loss = val_epoch()

    mean_train_loss = train_loss / len(ctsrbm_train_loader)
    mean_val_loss = val_loss / len(ctsrbm_val_loader)

    train_losses.append(mean_train_loss)
    val_losses.append(mean_val_loss)

    if early_stopper.early_stop(mean_val_loss):
        break

training_metadata["stage_1_epochs"] = current_epoch

plt.plot(train_time_diffs, np.cumsum(train_data_points))
plt.title("{:d} workers".format(num_workers))

plt.xlabel("Time (s)")
plt.ylabel("Data points")

plt.grid()

plt.savefig("{:d}_workers_{:d}_batch_stage1.png".format(num_workers, batch_size))

In [None]:
checkpoint_dir = os.path.join(pathlib.Path.home(), "data", "checkpoints", "fashion_retrieval")
checkpoint_filename = "resnet50_ret_stage1.pth"
checkpoint_full_filename = os.path.join(checkpoint_dir, checkpoint_filename)

save_checkpoint(checkpoint_full_filename, model, optimizer, scheduler, train_losses, val_losses, training_metadata)

In [None]:
checkpoint_dir = os.path.join(pathlib.Path.home(), "data", "checkpoints", "fashion_retrieval")
checkpoint_filename = "resnet50_ret_stage1.pth"
checkpoint_full_filename = os.path.join(checkpoint_dir, checkpoint_filename)

model, optimizer, scheduler, train_losses, val_losses, training_metadata = load_checkpoint(checkpoint_full_filename)
current_epoch = len(train_losses)

In [None]:
model.module.unfreeze_backbone()
early_stopper = utils.EarlyStopper(patience=5)
max_epoch = current_epoch + 30

In [None]:
torch.cuda.empty_cache()

In [None]:
utils.print_memory_usage(device_idxs)

In [None]:
while current_epoch < max_epoch:

    current_epoch += 1

    print("Epoch {:d}".format(current_epoch))

    train_loss = train_epoch()
    val_loss = val_epoch()

    mean_train_loss = train_loss / len(ctsrbm_train_loader)
    mean_val_loss = val_loss / len(ctsrbm_val_loader)

    train_losses.append(mean_train_loss)
    val_losses.append(mean_val_loss)

    if early_stopper.early_stop(mean_val_loss):
        break

training_metadata["stage_2_epochs"] = current_epoch - training_metadata["stage_1_epochs"]

In [None]:
checkpoint_dir = os.path.join(pathlib.Path.home(), "data", "checkpoints", "fashion_retrieval")
checkpoint_filename = "resnet50_ret_stage2.pth"
checkpoint_full_filename = os.path.join(checkpoint_dir, checkpoint_filename)

save_checkpoint(checkpoint_full_filename, model, optimizer, scheduler, train_losses, val_losses, training_metadata)

In [None]:
plt.figure(figsize=(12, 6))

plt.plot(range(1, len(train_losses) + 1), train_losses, label="train", marker=".")
plt.plot(range(1, len(val_losses) + 1), val_losses, label="val", marker=".")

plt.axvline(6.5, ymin=0.02, ymax=0.98, color="black", linestyle="--")

plt.xlabel("Epochs")
plt.ylabel("Loss")

plt.yscale("log")

plt.title("Loss - ResNet50 - Ret")
plt.legend()
plt.grid()

plt.show()