In [112]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from copy import deepcopy
from torch.nn.utils.rnn import pad_sequence
from utils.reparam_module import ReparamModule
import numpy as np
import random
from tqdm import tqdm

# Enviroment

In [113]:
seed = 2024

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True

# Dataset

In [114]:
path = './toys-pair.pth'
data = torch.load(path)
num_item = 11925
SOS = 11925
EOS = 11926

In [115]:
source, target = [], []
source_seqlen, target_seqlen = [], []
for _ in data:
    s, t = _
    source_seqlen.append(len(s))
    target_seqlen.append(len(t))
    s = torch.tensor([SOS] + s + [EOS])
    t = torch.tensor([SOS] + t + [EOS])
    source.append(s)
    target.append(t)
source = pad_sequence(source, batch_first=True, padding_value=0)
target = pad_sequence(target, batch_first=True, padding_value=0)
if target.shape[1] < 20:
    target = torch.cat([target, torch.zeros(target.shape[0], 20 - target.shape[1], dtype=torch.int)], dim=-1)
source_seqlen = torch.tensor(source_seqlen)
target_seqlen = torch.tensor(target_seqlen)

In [116]:
from torch.utils.data import Dataset, DataLoader
class MyDataset(Dataset):
    def __init__(
            self,
            source,
            target,
            source_seqlen,
            target_seqlen,
        ) -> None:
        super().__init__()
        self.source = source.to('cuda')
        self.target = target.to('cuda')
        self.source_seq_len = source_seqlen.to('cuda')
        self.target_seq_len = target_seqlen.to('cuda')

    def __len__(self):
        return len(self.source)

    def __getitem__(self, index):
        src = self.source[index]
        tgt = self.target[index]
        src_len = self.source_seq_len[index]
        tgt_len = self.target_seq_len[index]
        return src, tgt, src_len, tgt_len

# Model

In [117]:
from utils import normal_initialization

class Generator(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        # self.item_embedding = nn.Embedding(num_item + 2, 64, padding_idx=0)
        self.transformer = nn.Transformer(
            d_model=64,
            nhead=2,
            num_encoder_layers=2,
            num_decoder_layers=2,
            dim_feedforward=256,
            dropout=0.5,
            activation='gelu',
            layer_norm_eps=1e-12,
            batch_first=True,
        )
        self.dropout = nn.Dropout(0.5)
        self.position_embedding = torch.nn.Embedding(50, 64)
        self.device = 'cuda'
        self.apply(normal_initialization)
        self.load_pretrained()

    def load_pretrained(self):
        path = 'saved/SASRec8/amazon-toys-seq-noise-50/2024-01-24-16-37-41-603118.ckpt'
        saved = torch.load(path, map_location='cpu')
        pretrained = saved['parameters']['item_embedding.weight']
        pretrained = torch.cat([
            pretrained,
            nn.init.normal_(torch.zeros(2, 64), std=0.02)
        ])
        self.item_embedding = nn.Embedding.from_pretrained(pretrained, padding_idx=0, freeze=False)

    def forward(self, src, tgt, src_mask, tgt_mask,
            src_padding_mask,
            tgt_padding_mask,
            memory_key_padding_mask
        ):
        position_ids = torch.arange(src.size(1), dtype=torch.long, device=self.device)
        position_ids = position_ids.reshape(1, -1)
        src_position_embedding = self.position_embedding(position_ids)
        src_emb = self.dropout(self.item_embedding(src) + src_position_embedding)

        position_ids = torch.arange(tgt.size(1), dtype=torch.long, device=self.device)
        position_ids = position_ids.reshape(1, -1)
        tgt_position_embedding = self.position_embedding(position_ids)
        tgt_emb = self.dropout(self.item_embedding(tgt) + tgt_position_embedding)

        outs = self.transformer(
            src_emb, tgt_emb, src_mask, tgt_mask, None,
            src_padding_mask, tgt_padding_mask, memory_key_padding_mask
        )

        return outs @ self.item_embedding.weight.T
    
    def encode(self, src, src_mask):
        position_ids = torch.arange(src.size(1), dtype=torch.long, device=self.device)
        position_ids = position_ids.reshape(1, -1)
        src_position_embedding = self.position_embedding(position_ids)
        src_emb = self.dropout(self.item_embedding(src) + src_position_embedding)

        return self.transformer.encoder(src_emb, src_mask)

    def decode(self, tgt, memory, tgt_mask):
        position_ids = torch.arange(tgt.size(1), dtype=torch.long, device=self.device)
        position_ids = position_ids.reshape(1, -1)
        tgt_position_embedding = self.position_embedding(position_ids)
        tgt_emb = self.dropout(self.item_embedding(tgt) + tgt_position_embedding)
        
        return self.transformer.decoder(tgt_emb, memory, tgt_mask)

def generate_square_subsequent_mask(sz):
    mask = (torch.triu(torch.ones((sz, sz), device='cuda')) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, -100000).masked_fill(mask == 1, float(0.0))
    return mask

def create_mask(src, tgt):
    src_seq_len = src.shape[1]
    tgt_seq_len = tgt.shape[1]

    tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
    src_mask = torch.zeros((src_seq_len, src_seq_len),device='cuda').type(torch.bool)

    src_padding_mask = (src == 0)
    tgt_padding_mask = (tgt == 0)
    return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask


# Train

In [118]:
model = Generator().to('cuda')
loss_fn = nn.CrossEntropyLoss(ignore_index=0)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.98), eps=1e-9)
dataset = MyDataset(source, target, source_seqlen, target_seqlen)

In [119]:
def train_epoch(model, optimizer):
    model.train()
    losses = 0
    train_dataloader = DataLoader(dataset, batch_size=256, shuffle=True)
    for src, tgt, src_len, tgt_len in tqdm(train_dataloader):
        tgt_input = tgt[:, :-1]
        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)
        logits = model(src, tgt_input, src_mask, tgt_mask,src_padding_mask, tgt_padding_mask, src_padding_mask)
        optimizer.zero_grad()
        tgt_out = tgt[:, 1:]
        loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
        loss.backward()

        optimizer.step()
        losses += loss.item()
    return losses / len(list(train_dataloader))

In [120]:
from timeit import default_timer as timer
NUM_EPOCHS = 50

for epoch in range(1, NUM_EPOCHS+1):
    start_time = timer()
    train_loss = train_epoch(model, optimizer)
    end_time = timer()
    print((f"Epoch: {epoch}, Train loss: {train_loss:.3f}, "f"Epoch time = {(end_time - start_time):.3f}s"))

100%|██████████| 127/127 [00:37<00:00,  3.37it/s]


Epoch: 1, Train loss: 5.834, Epoch time = 38.203s


100%|██████████| 127/127 [00:37<00:00,  3.39it/s]


Epoch: 2, Train loss: 3.619, Epoch time = 37.899s


100%|██████████| 127/127 [00:37<00:00,  3.38it/s]


Epoch: 3, Train loss: 3.227, Epoch time = 38.082s


100%|██████████| 127/127 [00:37<00:00,  3.40it/s]


Epoch: 4, Train loss: 2.995, Epoch time = 37.876s


100%|██████████| 127/127 [00:37<00:00,  3.38it/s]


Epoch: 5, Train loss: 2.824, Epoch time = 38.029s


100%|██████████| 127/127 [00:37<00:00,  3.41it/s]


Epoch: 6, Train loss: 2.693, Epoch time = 37.755s


100%|██████████| 127/127 [00:37<00:00,  3.41it/s]


Epoch: 7, Train loss: 2.588, Epoch time = 37.688s


100%|██████████| 127/127 [00:38<00:00,  3.32it/s]


Epoch: 8, Train loss: 2.508, Epoch time = 38.713s


100%|██████████| 127/127 [00:37<00:00,  3.37it/s]


Epoch: 9, Train loss: 2.449, Epoch time = 38.136s


100%|██████████| 127/127 [00:37<00:00,  3.40it/s]


Epoch: 10, Train loss: 2.385, Epoch time = 37.802s


100%|██████████| 127/127 [00:37<00:00,  3.42it/s]


Epoch: 11, Train loss: 2.338, Epoch time = 37.736s


100%|██████████| 127/127 [00:37<00:00,  3.41it/s]


Epoch: 12, Train loss: 2.296, Epoch time = 37.681s


100%|██████████| 127/127 [00:37<00:00,  3.41it/s]


Epoch: 13, Train loss: 2.262, Epoch time = 37.670s


100%|██████████| 127/127 [00:37<00:00,  3.39it/s]


Epoch: 14, Train loss: 2.238, Epoch time = 37.924s


100%|██████████| 127/127 [00:37<00:00,  3.37it/s]


Epoch: 15, Train loss: 2.214, Epoch time = 38.392s


100%|██████████| 127/127 [00:38<00:00,  3.33it/s]


Epoch: 16, Train loss: 2.189, Epoch time = 38.661s


100%|██████████| 127/127 [00:38<00:00,  3.34it/s]


Epoch: 17, Train loss: 2.174, Epoch time = 38.719s


100%|██████████| 127/127 [00:37<00:00,  3.42it/s]


Epoch: 18, Train loss: 2.154, Epoch time = 37.613s


100%|██████████| 127/127 [00:30<00:00,  4.13it/s]


Epoch: 19, Train loss: 2.135, Epoch time = 31.181s


100%|██████████| 127/127 [00:30<00:00,  4.21it/s]


Epoch: 20, Train loss: 2.125, Epoch time = 30.626s


100%|██████████| 127/127 [00:30<00:00,  4.19it/s]


Epoch: 21, Train loss: 2.100, Epoch time = 30.788s


100%|██████████| 127/127 [00:29<00:00,  4.25it/s]


Epoch: 22, Train loss: 2.098, Epoch time = 30.379s


100%|██████████| 127/127 [00:30<00:00,  4.20it/s]


Epoch: 23, Train loss: 2.089, Epoch time = 30.715s


100%|██████████| 127/127 [00:30<00:00,  4.22it/s]


Epoch: 24, Train loss: 2.075, Epoch time = 30.554s


100%|██████████| 127/127 [00:29<00:00,  4.26it/s]


Epoch: 25, Train loss: 2.066, Epoch time = 30.492s


100%|██████████| 127/127 [00:30<00:00,  4.21it/s]


Epoch: 26, Train loss: 2.053, Epoch time = 30.885s


100%|██████████| 127/127 [00:30<00:00,  4.10it/s]


Epoch: 27, Train loss: 2.049, Epoch time = 31.430s


100%|██████████| 127/127 [00:30<00:00,  4.15it/s]


Epoch: 28, Train loss: 2.038, Epoch time = 31.048s


100%|██████████| 127/127 [00:30<00:00,  4.12it/s]


Epoch: 29, Train loss: 2.036, Epoch time = 31.298s


100%|██████████| 127/127 [00:30<00:00,  4.14it/s]


Epoch: 30, Train loss: 2.030, Epoch time = 31.127s


100%|██████████| 127/127 [00:30<00:00,  4.15it/s]


Epoch: 31, Train loss: 2.018, Epoch time = 31.038s


100%|██████████| 127/127 [00:31<00:00,  4.09it/s]


Epoch: 32, Train loss: 2.015, Epoch time = 31.491s


100%|██████████| 127/127 [00:30<00:00,  4.14it/s]


Epoch: 33, Train loss: 2.009, Epoch time = 31.133s


100%|██████████| 127/127 [00:30<00:00,  4.13it/s]


Epoch: 34, Train loss: 2.004, Epoch time = 31.207s


100%|██████████| 127/127 [00:30<00:00,  4.18it/s]


Epoch: 35, Train loss: 2.002, Epoch time = 30.838s


100%|██████████| 127/127 [00:30<00:00,  4.16it/s]


Epoch: 36, Train loss: 2.001, Epoch time = 31.129s


100%|██████████| 127/127 [00:30<00:00,  4.12it/s]


Epoch: 37, Train loss: 1.984, Epoch time = 31.265s


100%|██████████| 127/127 [00:30<00:00,  4.16it/s]


Epoch: 38, Train loss: 1.982, Epoch time = 31.026s


100%|██████████| 127/127 [00:30<00:00,  4.19it/s]


Epoch: 39, Train loss: 1.977, Epoch time = 30.997s


100%|██████████| 127/127 [00:30<00:00,  4.18it/s]


Epoch: 40, Train loss: 1.972, Epoch time = 30.824s


100%|██████████| 127/127 [00:30<00:00,  4.14it/s]


Epoch: 41, Train loss: 1.970, Epoch time = 31.330s


100%|██████████| 127/127 [00:30<00:00,  4.11it/s]


Epoch: 42, Train loss: 1.958, Epoch time = 31.369s


100%|██████████| 127/127 [00:30<00:00,  4.16it/s]


Epoch: 43, Train loss: 1.961, Epoch time = 31.026s


100%|██████████| 127/127 [00:30<00:00,  4.16it/s]


Epoch: 44, Train loss: 1.959, Epoch time = 31.185s


100%|██████████| 127/127 [00:31<00:00,  4.09it/s]


Epoch: 45, Train loss: 1.952, Epoch time = 31.501s


100%|██████████| 127/127 [00:30<00:00,  4.15it/s]


Epoch: 46, Train loss: 1.952, Epoch time = 31.068s


100%|██████████| 127/127 [00:30<00:00,  4.20it/s]


Epoch: 47, Train loss: 1.945, Epoch time = 30.732s


100%|██████████| 127/127 [00:30<00:00,  4.21it/s]


Epoch: 48, Train loss: 1.944, Epoch time = 30.836s


100%|██████████| 127/127 [00:31<00:00,  4.09it/s]


Epoch: 49, Train loss: 1.942, Epoch time = 31.554s


100%|██████████| 127/127 [00:30<00:00,  4.19it/s]


Epoch: 50, Train loss: 1.942, Epoch time = 30.778s


In [121]:
torch.save(model, 'translator-toy.pth')

In [130]:
def greedy_decode(model, src, src_mask, max_len, start_symbol):
    src = src.to('cuda')
    src_mask = src_mask.to('cuda')

    memory = model.encode(src, src_mask)
    ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to('cuda')
    for i in range(max_len-1):
        memory = memory.to('cuda')
        tgt_mask = (generate_square_subsequent_mask(ys.size(1))
                    .type(torch.bool)).to('cuda')
        out = model.decode(ys, memory, tgt_mask)
        prob = out[:, -1] @ model.item_embedding.weight.T
        _, next_word = torch.max(prob, dim=-1)
        next_word = next_word.item()

        ys = torch.cat([ys,
                        torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=1)
        if next_word == EOS:
            break
    return ys

def translate(model: torch.nn.Module, src):
    model.eval()
    src = src.reshape(1, -1)
    num_tokens = src.shape[1]
    src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)
    tgt_tokens = greedy_decode(
        model, src, src_mask, max_len=25, start_symbol=SOS).flatten()
    return tgt_tokens
idx = 27
rst = translate(model, torch.tensor([1,2,3,4,5] + 45 * [0], device='cuda'))
print(rst)
print(target[idx])

tensor([11925,  4644,  4644, 11926], device='cuda:0')
tensor([11925,  7170,  7173, 11926,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0])


In [100]:
rst

tensor([11925,     1,  2296, 11926], device='cuda:0')

In [101]:
source[7]

tensor([11925,  3030,  1664,  4662,  7233,     1,  2296,  3372, 11926,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0])