In [None]:
!python3 --version

In [None]:
!python3 -m pip --version

In [None]:
!python3 -m pip install -r requirements/dev.txt

In [None]:
!ls data/UGallery -sh

In [None]:
# ugallery_data_utils.py
import numpy as np


def load_embeddings(embedding_path, embedding_shape=(13297, 2048)):
    data = np.load(embedding_path, allow_pickle=True)
    # Generate indexes and contiguous embedding
    embedding = np.zeros(shape=embedding_shape)
    artwork_id2index = dict()
    artwork_index2id = dict()
    for i, (artwork_id_hash, artwork_embedding) in enumerate(data):
        assert artwork_id_hash not in artwork_id2index
        artwork_id2index[artwork_id_hash] = i
        assert i not in artwork_index2id 
        artwork_index2id[i] = artwork_id_hash
        assert not np.any(embedding[i])
        embedding[i] = artwork_embedding
    assert not np.all(embedding == 0)
    assert artwork_id2index
    assert artwork_index2id
    return embedding, artwork_id2index, artwork_index2id

# PyTorch DataLoaders

In [None]:
import json
import random
import time
from collections import Counter, defaultdict
from copy import deepcopy
from os.path import join

import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import BatchSampler, RandomSampler, SequentialSampler
from torchvision import transforms

In [None]:
from tqdm.notebook import tqdm

## Custom Dataset (1): Loading data

In [None]:
# Load embeddings from files
embeddings_path = join("data", "UGallery", "ugallery_resnet50_embeddings.npy")
loaded_data = load_embeddings(embeddings_path)
embeddings, artwork_id2index, artwork_index2id = loaded_data
print(f"embeddings shape: {embeddings.shape}")

## Custom Dataset (1): Custom Dataset and transform

In [None]:
class ToTensor:
    """Convert ndarrays in sample dict to Tensors."""

    def __call__(self, sample):
        return {
            k: torch.from_numpy(v).float()
            for k, v in sample.items()
        }

In [None]:
class UGalleryDataset(Dataset):
    # TODO(Antonio): Options for training, validation and testing. Training
    # and validation are stored as csv files, but testing is a json file.
    # Based on torchvision.Dataset maybe

    def __init__(self, csv_file, embedding, transform=None):
        # Dataframe
        self.triples = pd.read_csv(csv_file)
        profile_to_list = lambda p: p[1:-1].replace("'", "").split(", ")
        self.triples["profile"] = self.triples["profile"].map(profile_to_list)
        # Caching profile sizes
        self.profile_sizes = tuple(self.triples["profile"].map(len))
        self.embedding = embedding
        self.transform = transform
        self.__ready = False
        
    def prepare(self, id2index=None):
        if self.__ready:
            raise Exception("Dataset was already prepared")
        if id2index:
            self.__apply_mapping(id2index)
        self.triples = self.triples.to_numpy()
        self.__ready = True
        print("Dataset is ready")
        
    def __apply_mapping(self, id2index):
        def map_id2index(element):
            if type(element) is list:
                return [id2index[e] for e in element]
            else:
                return id2index[element]
        self.triples = self.triples.applymap(map_id2index)

    def __len__(self):
        return len(self.triples)

    def __getitem__(self, idx):
        profile = self.embedding[self.triples[idx, 0], :]
        pi = self.embedding[self.triples[idx, 1]]
        ni = self.embedding[self.triples[idx, 2]]

        sample = {
            "profile": profile,
            "pi": pi,
            "ni": ni,
        }

        if self.transform:
            sample = self.transform(sample)

        return sample

## Custom Dataset (3): Custom BatchSampler

In [None]:
class SameProfileSizeBatchSampler(BatchSampler):

    def __init__(self, sampler, batch_size, bump_rate=0.05, drop_last=False):
        self.sampler = sampler
        assert hasattr(self.sampler.data_source, "profile_sizes")
        self.batch_size = batch_size
        self.drop_last = drop_last
        self.bump_rate = bump_rate

    def __iter__(self):
        batch_queue = defaultdict(list)
        profile_sizes = self.sampler.data_source.profile_sizes
        for idx in self.sampler:
            p_size = profile_sizes[idx]
            batch_queue[p_size].append(idx)
            if len(batch_queue[p_size]) == self.batch_size:
                batch, batch_queue[p_size] = batch_queue[p_size][:], []
                yield batch
                if random.random() < self.bump_rate and not self.drop_last:
                    possible_keys = [k for k, v in batch_queue.items() if v]
                    if possible_keys:
                        bumped_key = random.choice(possible_keys)
                        batch, batch_queue[bumped_key] = batch_queue[bumped_key][:], []
                        yield batch
        if not self.drop_last:
            for k in random.sample(list(batch_queue.keys()), len(batch_queue)):
                if batch_queue[k]:
                    yield batch_queue[k]

    def __len__(self):
        if self.drop_last:
            return len(self.sampler) // self.batch_size
        else:
            counter = Counter(self.sampler.data_source.profile_sizes)
            n_samples = 0
            for k, v in counter.items():
                n_samples += (v + self.batch_size - 1) // self.batch_size
            return n_samples

## Custom Dataset (4): collate_fn function (merge batches)

In [None]:
def merge_samples(data):
    elem = data[0]
    batch = dict()
    for key, value in elem.items():
        out = None
        if torch.utils.data.get_worker_info() is not None:
            numel = value.numel() * len(data)
            storage = value.storage()._new_shared(numel)
            out = value.new(storage)
        else:
            out = torch.zeros(len(data), *value.size())
        batch[key] = torch.cat([b[key] for b in data], out=out).view(-1, *value.size())
    target = torch.ones(len(data), 1, 1)
    return batch, target

## Training (1): Model definition

In [None]:
class CuratorNet(nn.Module):
    
    def __init__(self):
        super().__init__()

        # Common section
        self.selu_common1 = nn.Linear(2048, 200)
        self.selu_common2 = nn.Linear(200, 200)
        
        # Profile section
        self.maxpool = nn.AdaptiveMaxPool2d((1, 200))
        self.avgpool = nn.AdaptiveAvgPool2d((1, 200))
        self.selu_pu1 = nn.Linear(200 + 200, 300)
        self.selu_pu2 = nn.Linear(300, 300)
        self.selu_pu3 = nn.Linear(300, 200)
        
        # Random weight initialization
        self.reset_parameters()
                
    def forward(self, profile, pi, ni):
        # Positive item
        pi = F.selu(self.selu_common1(pi))
        pi = F.selu(self.selu_common2(pi))
        
        # Negative item
        ni = F.selu(self.selu_common1(ni))
        ni = F.selu(self.selu_common2(ni))
        
        # User profile
        profile = F.selu(self.selu_common1(profile))
        profile = F.selu(self.selu_common2(profile))
        profile = torch.cat((self.maxpool(profile), self.avgpool(profile)), dim=-1)
        profile = F.selu(self.selu_pu1(profile))
        profile = F.selu(self.selu_pu2(profile))
        profile = F.selu(self.selu_pu3(profile))
        
        # x_ui > x_uj
        x_ui = torch.bmm(profile, pi.unsqueeze(-1))  # .squeeze(2)
        x_uj = torch.bmm(profile, ni.unsqueeze(-1))  # .squeeze(2)
        
        return x_ui - x_uj
    
    def recommend(self, profile, items, grad_enabled=False):
        with torch.set_grad_enabled(grad_enabled):
            # User profile
            profile = F.selu(self.selu_common1(profile))
            profile = F.selu(self.selu_common2(profile))
            profile = torch.cat((self.maxpool(profile), self.avgpool(profile)), dim=-1)
            profile = F.selu(self.selu_pu1(profile))
            profile = F.selu(self.selu_pu2(profile))
            profile = F.selu(self.selu_pu3(profile))

            # Items
            items = F.selu(self.selu_common1(items))
            items = F.selu(self.selu_common2(items))

            # x_ui
            x_ui = torch.bmm(profile, items.transpose(-1, -2)).squeeze()

            return x_ui

    def reset_parameters(self):
        # Common section
        nn.init.xavier_uniform_(self.selu_common1.weight)
        nn.init.xavier_uniform_(self.selu_common2.weight)
        # Profile section
        nn.init.xavier_uniform_(self.selu_pu1.weight)
        nn.init.xavier_uniform_(self.selu_pu2.weight)
        nn.init.xavier_uniform_(self.selu_pu3.weight)


## Training (2): Include LR scheduling

In [None]:
SETTINGS = {
    "dataloader:batch_size": 4096 * 3, # * 2,
    "dataloader:num_workers": 4, # 4,
    "training:num_epochs": 150, # 300 is ideal
    "optimizer:lr": 0.0001,  #  * 2, Had it like that the first time I think
    "optimizer:weight_decay": 0.001,
    "scheduler:factor": 0.6,
}
# double learning rate if you double batch size.

In [None]:
summary_writer_name = "CuratorNet_UGallery"
for k, v in SETTINGS.items():
    print(k, v)
    summary_writer_name = summary_writer_name + f"_{k.split(':')[1]}={v}"
print(summary_writer_name)

In [None]:
# Model
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = CuratorNet()
model = model.to(device)

In [None]:
# Training criteria
optimizer = optim.Adam(model.parameters(), lr=SETTINGS["optimizer:lr"], weight_decay=SETTINGS["optimizer:weight_decay"])
criterion = nn.BCEWithLogitsLoss()
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="max", factor=SETTINGS["scheduler:factor"], patience=2, verbose=True)

In [None]:
# Dataloaders (training)
training_dataset = UGalleryDataset(
    csv_file="data/UGallery/train_public.csv",
    embedding=embeddings,
    transform=transforms.Compose([
        ToTensor(),
    ]))
training_dataset.prepare(artwork_id2index)
training_sampler = RandomSampler(training_dataset)
training_batch_sampler = SameProfileSizeBatchSampler(sampler=training_sampler, batch_size=SETTINGS["dataloader:batch_size"])
training_dataloader = DataLoader(training_dataset, collate_fn=merge_samples, batch_sampler=training_batch_sampler, num_workers=SETTINGS["dataloader:num_workers"], pin_memory=True)

In [None]:
# Dataloaders (validation)
validation_dataset = UGalleryDataset(
    csv_file="data/UGallery/validation_public.csv",
    embedding=embeddings,
    transform=transforms.Compose([
        ToTensor(),
    ]))
validation_dataset.prepare(artwork_id2index)
validation_sampler = SequentialSampler(validation_dataset)
validation_batch_sampler = SameProfileSizeBatchSampler(sampler=validation_sampler, bump_rate=0.0, batch_size=SETTINGS["dataloader:batch_size"])
validation_dataloader = DataLoader(validation_dataset, collate_fn=merge_samples, batch_sampler=validation_batch_sampler, num_workers=SETTINGS["dataloader:num_workers"], pin_memory=True)

In [None]:
from torch.utils.tensorboard import SummaryWriter


writer = SummaryWriter(f"runs/{summary_writer_name}", flush_secs=20)


def train_model(model, device, criterion, optimizer, scheduler, dataloaders, num_epochs=1, experiment_name=None):
    model = model.to(device)
    start = time.time()
    best_model_wts = deepcopy(model.state_dict())
    best_validation_acc = 0.0
    # Checkpoint
    checkpoint_filename = f"{model.__class__.__name__}_{time.strftime('%Y-%m-%d-%H-%M-%S')}"
    checkpoint_filepath = join("checkpoints", checkpoint_filename)
    print(f"Checkpoints stored at {checkpoint_filepath}")
    
    for epoch in range(1, num_epochs + 1):
        print(f"Epoch {epoch}/{num_epochs}")
        
        # Each epoch has a training and validation phase
        for phase in ["train", "validation"]:
            if phase == "train":
                model.train()
            else:
                model.eval()
            
            running_loss = 0.0
            running_acc = 0
            running_x = 0
                
            # Iterate over data
            for i_batch, (batch, target) in enumerate(tqdm(dataloaders[phase], desc=f"Epoch {epoch} ({phase})")):
                batch = {
                    k: v.to(device)
                    for k, v in batch.items()
                }
                
                target = target.to(device)
                
                # Restart params gradients
                optimizer.zero_grad()
                
                # Forward pass
                with torch.set_grad_enabled(phase == "train"):
                    output = model(**batch)
                    loss = criterion(output, target)
                    # Backward pass
                    if phase == "train":
                        loss.backward()
                        optimizer.step()
                
                # Statistics
                running_loss += loss.item() * output.size(0)
                running_acc += (output.cpu().detach().numpy() > 0).sum()
                running_x += output.size(0)
                
                if i_batch % 40 == 39:
                    writer.add_scalar(
                        f"{phase} loss",
                        running_loss / running_x,
                        (epoch - 1) * len(dataloaders[phase]) + i_batch,
                    )
                    writer.close()
            
            # Logging
            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_acc / len(dataloaders[phase].dataset)
            print(f"{phase.title()} loss: {epoch_loss}")
            print(f"{phase.title()} acc = {100 * epoch_acc}%")
            
            # Deepcopy if model is good
            if phase == "validation" and epoch_acc > best_validation_acc:
                print(f"New best model with ~{round(100 * epoch_acc, 4)}% acc ({epoch_acc})")
                best_validation_acc = epoch_acc
                best_epoch = scheduler.last_epoch
                best_model_wts = deepcopy(model.state_dict())
                torch.save({
                    "epoch": best_epoch,
                    "model_state_dict": model.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "scheduler_state_dict": scheduler.state_dict(),
                    "validation_accuracy": best_validation_acc,    
                }, checkpoint_filepath)
                print(f"Saved model at {checkpoint_filepath}")
            
            # Scheduler step if necessary
            if phase == "validation":
                print(f"Scheduler: {scheduler.num_bad_epochs} bad epoch(s) (patience={scheduler.patience})")
                scheduler.step(epoch_acc)

        print()
    
    elapsed = time.time() - start
    print(f"Training completed in {elapsed // 60:.0f}m {elapsed % 60:.0f}s")
    print(f"Best validation accuracy: ~{round(100 * best_validation_acc, 4)}%")
    
    # Load best model weights
    model.load_state_dict(best_model_wts)
    return model, best_validation_acc, best_epoch

## Training (4): Cleverer... saving best model

In [None]:
model, validation_accuracy, best_epoch = train_model(
    model, device,
    criterion, optimizer, scheduler,
    {"train": training_dataloader, "validation": validation_dataloader},
    num_epochs=SETTINGS["training:num_epochs"],
)

In [None]:
# https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-loading-a-general-checkpoint-for-inference-and-or-resuming-training
checkpoint_filename = f"{model.__class__.__name__}_{time.strftime('%Y-%m-%d-%H-%M-%S')}"
torch.save({
    "best_epoch": best_epoch,
    "epoch": scheduler.last_epoch,
    "model_state_dict": model.state_dict(),
    "optimizer_state_dict": optimizer.state_dict(),
    "scheduler_state_dict": scheduler.state_dict(),
    "validation_accuracy": validation_accuracy,    
}, join("checkpoints", checkpoint_filename))
print("Saved model")