In [1]:
import os
import pathlib

import torch
import torchvision

from torch.utils.data import DataLoader, Subset
from torchvision.transforms import Resize, Lambda, Compose

import deep_fashion
import backbones, heads, models

import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline

from tqdm.notebook import tqdm

import utils

from pynvml import *
nvmlInit()

In [2]:
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
device = torch.device("cuda:3")
torch.cuda.empty_cache()

In [3]:
def print_memory_usage(idx):
    
    h = nvmlDeviceGetHandleByIndex(idx)
    info = nvmlDeviceGetMemoryInfo(h)
    
    total_mem_str = utils.sprint_fancy_num_bytes(info.total)
    used_mem_str = utils.sprint_fancy_num_bytes(info.used)
    use_perc = info.used / info.total * 100

    print("Memory usage: {:s}/{:s} ({:2.2f}%)".format(
        used_mem_str,
        total_mem_str,
        use_perc
    ))

In [4]:
print_memory_usage(3)

Memory usage: 93.750MiB/11.000GiB (0.83%)


---

In [5]:
ctsrbm_image_transform = torchvision.models.ResNet50_Weights.DEFAULT.transforms()
ctsrbm_image_transform.antialias = True

ctsrbm_dataset_dir = os.path.join(pathlib.Path.home(), "data", "DeepFashion", "Consumer-to-shop Clothes Retrieval Benchmark")

ctsrbm_dataset = deep_fashion.ConsToShopClothRetrBM(ctsrbm_dataset_dir, ctsrbm_image_transform)

cutdown_ratio = 0.02

ctsrbm_train_dataset = Subset(ctsrbm_dataset, utils.cutdown_list(ctsrbm_dataset.get_split_mask_idxs("train"), cutdown_ratio))
ctsrbm_test_dataset = Subset(ctsrbm_dataset, utils.cutdown_list(ctsrbm_dataset.get_split_mask_idxs("test"), cutdown_ratio))
ctsrbm_val_dataset = Subset(ctsrbm_dataset, utils.cutdown_list(ctsrbm_dataset.get_split_mask_idxs("val"), cutdown_ratio))

ctsrbm_train_loader = DataLoader(ctsrbm_train_dataset, batch_size=128, num_workers=4)
ctsrbm_test_loader = DataLoader(ctsrbm_test_dataset, batch_size=128, num_workers=4)
ctsrbm_val_loader = DataLoader(ctsrbm_val_dataset, batch_size=128, num_workers=4)

In [6]:
backbone = backbones.ResNet50Backbone()
model = models.RetModel(backbone, 1024).to(device)

optimizer = torch.optim.Adam(
    model.parameters(),
    lr=1e-3,
)

scheduler = torch.optim.lr_scheduler.ExponentialLR(
    optimizer,
    gamma=0.95
)

In [7]:
def save_checkpoint(checkpoint_filename, model, optimizer, scheduler, train_losses, val_losses):

    checkpoint = {
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "scheduler_state_dict": scheduler.state_dict(),
        "train_losses": train_losses,
        "val_losses": val_losses
        }

    torch.save(checkpoint, checkpoint_filename)


def load_checkpoint(checkpoint_filename):
    
    checkpoint = torch.load(checkpoint_filename)

    # Loading model

    backbone = backbones.ResNet50Backbone()

    model = models.RetModel(backbone, 1024).to(device)
    model.load_state_dict(checkpoint["model_state_dict"])
    
    # Loading optimizer

    optimizer = torch.optim.Adam(
        model.parameters(),
        lr=1e-3,
    )
    
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    
    # Loading scheduler

    scheduler = torch.optim.lr_scheduler.ExponentialLR(
        optimizer,
        gamma=0.95
    )

    scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
    
    # Loading other parameters

    train_losses = checkpoint["train_losses"]
    val_losses = checkpoint["val_losses"]

    return model, optimizer, scheduler, train_losses, val_losses

In [8]:
def train_epoch(with_tqdm=True):

    model.train()
    total_loss = 0

    loader_gen = ctsrbm_train_loader
    if with_tqdm: loader_gen = tqdm(loader_gen)

    for train_batch in loader_gen:

        print("Batch start")
        print_memory_usage(3)

        anc_imgs = train_batch[0].to(device)
        pos_imgs = train_batch[1].to(device)
        neg_imgs = train_batch[2].to(device)

        print("Loaded data")
        print_memory_usage(3)

        anc_emb = model(anc_imgs)
        del anc_imgs
        pos_emb = model(pos_imgs)
        del pos_imgs
        neg_emb = model(neg_imgs)
        del neg_imgs

        print("Model evaluated")
        print_memory_usage(3)

        #del anc_imgs
        #del pos_imgs
        #del neg_imgs

        print("Deleted data")
        print_memory_usage(3)
        
        triplet_loss = torch.nn.TripletMarginLoss()
        loss = triplet_loss(anc_emb, pos_emb, neg_emb)

        total_loss += loss.item()

        loss.backward()
        optimizer.step()

    return total_loss


def val_epoch(with_tqdm=True):

    model.eval()
    total_loss = 0

    with torch.no_grad():

        loader_gen = ctsrbm_val_loader
        if with_tqdm: loader_gen = tqdm(loader_gen)

        for val_batch in loader_gen:

            print("Batch start")
            print_memory_usage(3)

            anc_imgs = val_batch[0].to(device)
            pos_imgs = val_batch[1].to(device)
            neg_imgs = val_batch[2].to(device)

            print("Loaded data")
            print_memory_usage(3)

            anc_emb = model(anc_imgs)
            pos_emb = model(pos_imgs)
            neg_emb = model(neg_imgs)

            print("Model evaluated")
            print_memory_usage(3)

            del anc_imgs
            del pos_imgs
            del neg_imgs

            print("Deleted data")
            print_memory_usage(3)

            triplet_loss = torch.nn.TripletMarginLoss()
            loss = triplet_loss(anc_emb, pos_emb, neg_emb)

            total_loss += loss.item()

    return total_loss

---

In [9]:
train_losses = []
val_losses = []

current_epoch = 0

In [10]:
model.freeze_backbone()
early_stopper = utils.EarlyStopper(patience=5)
max_epoch = current_epoch + 10

while current_epoch < max_epoch:

    current_epoch += 1

    print("Epoch {:d}".format(current_epoch))

    train_loss = train_epoch()
    val_loss = val_epoch()

    mean_train_loss = train_loss / len(ctsrbm_train_loader)
    mean_val_loss = val_loss / len(ctsrbm_val_loader)

    train_losses.append(mean_train_loss)
    val_losses.append(mean_val_loss)

    if early_stopper.early_stop(mean_val_loss):
        break

checkpoint_dir = os.path.join(pathlib.Path.home(), "data", "checkpoints", "fashion_retrieval")
checkpoint_filename = "resnet50_ret_stage1.pth"
checkpoint_full_filename = os.path.join(checkpoint_dir, checkpoint_filename)

save_checkpoint(checkpoint_full_filename, model, optimizer, scheduler, train_losses, val_losses)

In [11]:
checkpoint_dir = os.path.join(pathlib.Path.home(), "data", "checkpoints", "fashion_retrieval")
checkpoint_filename = "resnet50_ret_stage1.pth"
checkpoint_full_filename = os.path.join(checkpoint_dir, checkpoint_filename)

model, optimizer, scheduler, train_losses, val_losses = load_checkpoint(checkpoint_full_filename)
current_epoch = len(train_losses)

In [12]:
torch.cuda.empty_cache()
print_memory_usage(3)

Memory usage: 549.875MiB/11.000GiB (4.88%)


In [13]:
model.unfreeze_backbone()
early_stopper = utils.EarlyStopper(patience=5)
max_epoch = current_epoch + 30

In [14]:
while current_epoch < max_epoch:

    current_epoch += 1

    print("Epoch {:d}".format(current_epoch))

    train_loss = train_epoch()
    val_loss = val_epoch()

    mean_train_loss = train_loss / len(ctsrbm_train_loader)
    mean_val_loss = val_loss / len(ctsrbm_val_loader)

    train_losses.append(mean_train_loss)
    val_losses.append(mean_val_loss)

    if early_stopper.early_stop(mean_val_loss):
        break

Epoch 7


  0%|          | 0/16 [00:00<?, ?it/s]

Batch start
Memory usage: 549.875MiB/11.000GiB (4.88%)
Loaded data
Memory usage: 771.875MiB/11.000GiB (6.85%)


OutOfMemoryError: CUDA out of memory. Tried to allocate 50.00 MiB (GPU 3; 10.91 GiB total capacity; 10.55 GiB already allocated; 26.12 MiB free; 10.72 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [None]:
checkpoint_dir = os.path.join(pathlib.Path.home(), "data", "checkpoints", "fashion_retrieval")
checkpoint_filename = "resnet50_ret_stage2.pth"
checkpoint_full_filename = os.path.join(checkpoint_dir, checkpoint_filename)

save_checkpoint(checkpoint_full_filename, model, optimizer, scheduler, train_losses, val_losses)