In [1]:
import os

import torch
import torchvision

from torch.utils.data import DataLoader, Subset
from torchvision.transforms import Resize, Lambda, Compose

import deep_fashion
import backbones, heads, models

import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline

---

In [2]:
ctsrbm_image_transform = torchvision.models.ResNet50_Weights.DEFAULT.transforms()
ctsrbm_dataset_dir = os.path.join("..", "fashion_datasets_raw", "DeepFashion", "ConsumerToShop_Retrieval_BM")

ctsrbm_dataset = deep_fashion.ConsToShopClothRetrBM(ctsrbm_dataset_dir, ctsrbm_image_transform)

ctsrbm_train_dataset = Subset(ctsrbm_dataset, ctsrbm_dataset.get_split_mask_idxs("train"))
ctsrbm_test_dataset = Subset(ctsrbm_dataset, ctsrbm_dataset.get_split_mask_idxs("test"))
ctsrbm_val_dataset = Subset(ctsrbm_dataset, ctsrbm_dataset.get_split_mask_idxs("val"))

ctsrbm_train_loader = DataLoader(ctsrbm_train_dataset, batch_size=256, num_workers=4)
ctsrbm_test_loader = DataLoader(ctsrbm_test_dataset, batch_size=256, num_workers=4)
ctsrbm_val_loader = DataLoader(ctsrbm_val_dataset, batch_size=256, num_workers=4)

---

In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [10]:
backbone = backbones.ResNet50Backbone()
model = models.RetModel(backbone, 1024)
model = model.to(device)

optimizer = torch.optim.Adam(model.parameters())
scheduler = torch.optim.lr_scheduler.StepLR()

In [None]:
def train_epoch():

    model.train()

    num_batches = 0
    total_loss = 0

    for train_batch in ctsrbm_train_loader:

        anc_imgs = train_batch[0].to(device)
        pos_imgs = train_batch[1].to(device)
        neg_imgs = train_batch[2].to(device)

        anc_emb = model(anc_imgs)
        pos_emb = model(pos_imgs)
        neg_emb = model(neg_imgs)

        triplet_loss = torch.nn.TripletMarginLoss()
        loss = triplet_loss(anc_emb, pos_emb, neg_emb)

        num_batches += 1
        total_loss += loss.item()

        loss.backward()
        optimizer.step()

    return num_batches, total_loss


def val_epoch():

    model.eval()

    num_batches = 0
    total_loss = 0

    with torch.no_grad():

        for val_batch in ctsrbm_val_loader:

            anc_imgs = val_batch[0].to(device)
            pos_imgs = val_batch[1].to(device)
            neg_imgs = val_batch[2].to(device)

            anc_emb = model(anc_imgs)
            pos_emb = model(pos_imgs)
            neg_emb = model(neg_imgs)

            triplet_loss = torch.nn.TripletMarginLoss()
            loss = triplet_loss(anc_emb, pos_emb, neg_emb)

            num_batches += 1
            total_loss += loss.item()

    return num_batches, total_loss


class EarlyStopper:
    
    def __init__(self, patience=1, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.min_val_loss = float("inf")

    def early_stop(self, val_loss):
        if val_loss < self.min_val_loss:
            self.min_val_loss = val_loss
            self.counter = 0
        elif val_loss > (self.min_val_loss + self.min_delta):
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False

In [None]:
mean_train_loss_list = []
mean_val_loss_list = []

###

model.freeze_backbone()

early_stopper = EarlyStopper()

num_epochs = 10
for epoch_num in range(1, num_epochs + 1):

    num_train_batches, train_loss = train_epoch()
    num_val_batches, val_loss = val_epoch()

    mean_train_loss = train_loss / num_train_batches
    mean_val_loss = val_loss / num_val_batches

    mean_train_loss_list.append(mean_train_loss)
    mean_val_loss_list.append(mean_val_loss)

    if early_stopper.early_stop(mean_val_loss):
        break

###

model.unfreeze_backbone()

early_stopper = EarlyStopper()

num_epochs = 30
for epoch_num in range(1, num_epochs + 1):

    num_train_batches, train_loss = train_epoch()
    num_val_batches, val_loss = val_epoch()

    mean_train_loss = train_loss / num_train_batches
    mean_val_loss = val_loss / num_val_batches

    mean_train_loss_list.append(mean_train_loss)
    mean_val_loss_list.append(mean_val_loss)

    if early_stopper.early_stop(mean_val_loss):
        break

###

checkpoint_dirname = None #TODO
checkpoint_filename = os.path.join(checkpoints_dirname, "checkpoint_{:d}.pth".format(current_epoch_num))

#print("Saving checkpoint to \"{:s}\"".format(checkpoint_filename))

checkpoint = { 
    "model_state_dict": model.state_dict(),
    "optimizer_state_dict": optimizer.state_dict(),
    "scheduler_state_dict": scheduler.state_dict(),
    "mean_train_loss_list": mean_train_loss_list,
    "mean_val_loss_list": mean_val_loss_list
    }

torch.save(checkpoint, checkpoint_filename)