In [None]:
from pprint import pprint
from tqdm import tqdm
from random import randint, choices

import numpy as np
import pandas as pd
import time
import math

from torch import nn
from torch.cuda.amp import GradScaler
from torch.utils.data import Dataset, DataLoader
import torch as T
import torch.nn.functional as F

from pytorch_metric_learning import losses

from RecommendationTransformer import RecommendationTransformer

In [None]:
class OutfitDataset(Dataset):
    def __init__(self, recDataset: pd.DataFrame, contextDataset: pd.DataFrame, EmbeddingDataset: pd.DataFrame):
        self.recs = recDataset.to_numpy()
        self.contexts = contextDataset.to_numpy()
        self.embeddings = EmbeddingDataset
        self.embeddings.loc['pad'] = pd.Series({'embedding': np.zeros_like(self.embeddings.iloc[0]['embedding'])})
        self.embeddings.loc['start'] = pd.Series({'embedding': np.zeros_like(self.embeddings.iloc[0]['embedding'])})

        self.len = self.rec.shape[0]
        self.seq_len = self.contexts.shape[1] + self.rec.shape[1] + 1
        self.noneType = type(None)

    def __len__(self):
        return self.len

    def __getitem__(self, i):
        rec = self.recs[i]
        rec = rec[rec!=None]
        np.random.shuffle(rec)

        context = self.contexts[i]
        context = context[context!=None]
        np.random.shuffle(context)

        pad_len = self.seq_len - len(context) - len(rec)

        src         = [ *self.embeddings.loc[[ *context, 'start',  *rec,  *(pad_len*['pad'])]]['embedding'] ]
        tkn_mask    = [ *[ 0 for _ in context], *[1], *[ 2 for _ in rec], *(pad_len*[3]) ]
        pad_mask    = [ *[ 0. for _ in context], *[0.],  *[ 0. for _ in rec], *(pad_len*[1.]) ]
        src         = np.array(src, dtype=np.float32)
        tkn_mask    = np.array(tkn_mask, dtype=np.int32)
        pad_mask    = np.array(pad_mask, dtype=np.bool8)
        tgt_done    = np.array(pad_mask, dtype=np.float32)

        return src[:-1], tkn_mask[:-1], pad_mask[:-1], src[1:], tgt_done[1:]

In [None]:
device = T.device('cuda' if T.cuda.is_available() else 'cpu')

recDataset = pd.read_feather('path_to_my_dataset')
contextDataset = pd.read_feather('path_to_my_dataset')
EmbeddingDataset = pd.read_feather('path_to_my_dataset')

dataset = OutfitDataset(recDataset, contextDataset, EmbeddingDataset)
train_dataloader = DataLoader(dataset, batch_size=256, shuffle=True)
RecModel = RecommendationTransformer().cuda()

In [None]:
src, tkn_mask, pad_mask, target, tgt_done = next(iter(train_dataloader))
print(src.dtype, tkn_mask.dtype, pad_mask.dtype, target.dtype, tgt_done.dtype)
print(src.shape, tkn_mask.shape, pad_mask.shape, target.shape, tgt_done.shape)
print(src[0, :, :1])
print(tkn_mask[0, :])
print(pad_mask[0, :])
print(target[0, :, :1])
print(tgt_done[0, :])

In [None]:
num_batches = len(train_dataloader)
criterion1 = losses.NTXentLoss(temperature=0.07)
# criterion1 = nn.TripletMarginWithDistanceLoss(distance_function=lambda x, y: 1.0 - F.cosine_similarity(x, y), margin=0.0)
criterion2 = nn.BCEWithLogitsLoss()
lr = 0.0001 # learning rate
epochs = 300
optimizer = T.optim.AdamW(RecModel.parameters(), lr=lr)
scaler = GradScaler()
scheduler = T.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.0001, total_steps=epochs * num_batches)

In [None]:
def train(model: nn.Module, epoch=None) -> None:
    model.train()  # turn on train mode
    total_loss1 = 0.
    total_loss2 = 0.
    log_interval = 1
    start_time = time.time()
    src_mask = nn.Transformer.generate_square_subsequent_mask(dataset.seq_len).to(device)

    num_batches = len(train_dataloader)
    pbar = tqdm(enumerate(train_dataloader), total=math.ceil(dataset.len/train_dataloader.batch_size))
    for i, (src, tkn_mask, pad_mask, target, tgt_done) in pbar:
        optimizer.zero_grad()

        src         = src.to(device)
        tkn_mask    = tkn_mask.to(device)
        pad_mask    = pad_mask.to(device)
        target      = target.to(device)
        tgt_done    = tgt_done.to(device)

        with T.autocast(device_type='cuda', dtype=T.float16):
            output, is_done = model(
                src,
                tkn_mask,
                src_mask=src_mask,
                src_key_padding_mask=pad_mask
            )
            
            embeddings = T.cat((output, target))
            indices = T.arange(0, output.size(0), device=output.device)
            labels = T.cat((indices, indices))

            loss1 = criterion1(embeddings.view(embeddings.shape[0], -1), labels)
            loss2 = criterion2(is_done, tgt_done)
            loss = loss1 + loss2

        scaler.scale(loss).backward()

        # Unscales the gradients of optimizer's assigned params in-place
        scaler.unscale_(optimizer)

        # Since the gradients of optimizer's assigned params are unscaled, clips as usual:
        T.nn.utils.clip_grad_norm_(model.parameters(), 0.5)

        # optimizer's gradients are already unscaled, so scaler.step does not unscale them,
        # although it still skips optimizer.step() if the gradients contain infs or NaNs.
        scaler.step(optimizer)

        # Updates the scale for next iteration.
        scaler.update()

        scheduler.step()

        total_loss1 += loss1.item()
        total_loss2 += loss2.item()
        lr = scheduler.get_last_lr()[0]
        ms_per_batch = (time.time() - start_time) * 1000 / log_interval
        cur_loss1 = total_loss1 / log_interval
        cur_loss2 = total_loss2 / log_interval
        pbar.set_postfix_str(f'| epoch {epoch:3d} | {i:5d}/{num_batches:5d} batches | '
            f'lr {lr} | ms/batch {ms_per_batch:5.2f} | '
            f'loss1 {cur_loss1} | '
            f'loss1 {cur_loss2} | ')
        total_loss1 = 0
        total_loss2 = 0
        start_time = time.time()

In [None]:
# RecModel.load_state_dict(T.load('./model_params.pt')) # load model states
# optimizer.load_state_dict(T.load('./optim_params.pt')) # load model states

for epoch in range(1, epochs + 1):
    train(RecModel, epoch)
    T.save(RecModel.state_dict(), './model_params.pt')
    T.save(optimizer.state_dict(), './optim_params.pt')