In [36]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from copy import deepcopy
from torch.nn.utils.rnn import pad_sequence
from utils.reparam_module import ReparamModule
import numpy as np
import random
from tqdm import tqdm

# Enviroment

In [37]:
seed = 2024

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True

# Dataset

In [38]:
dataset_name = 'toy'
full_dataset_name = 'amazon-toys'
num_item_dict = {
    'toy': 11925,
    'sport': 18358,
}
path = f'./{dataset_name}-pair.pth'
data = torch.load(path)
num_item = num_item_dict[dataset_name]
SOS = num_item
EOS = num_item + 1

In [39]:
source, target = [], []
source_seqlen, target_seqlen = [], []
for _ in data:
    s, t = _
    source_seqlen.append(len(s))
    target_seqlen.append(len(t))
    s = torch.tensor([SOS] + s + [EOS])
    t = torch.tensor([SOS] + t + [EOS])
    source.append(s)
    target.append(t)
source = pad_sequence(source, batch_first=True, padding_value=0)
target = pad_sequence(target, batch_first=True, padding_value=0)
if target.shape[1] < 20:
    target = torch.cat([target, torch.zeros(target.shape[0], 20 - target.shape[1], dtype=torch.int)], dim=-1)
source_seqlen = torch.tensor(source_seqlen)
target_seqlen = torch.tensor(target_seqlen)

In [40]:
from torch.utils.data import Dataset, DataLoader
class MyDataset(Dataset):
    def __init__(
            self,
            source,
            target,
            source_seqlen,
            target_seqlen,
        ) -> None:
        super().__init__()
        self.source = source.to('cuda')
        self.target = target.to('cuda')
        self.source_seq_len = source_seqlen.to('cuda')
        self.target_seq_len = target_seqlen.to('cuda')

    def __len__(self):
        return len(self.source)

    def __getitem__(self, index):
        src = self.source[index]
        tgt = self.target[index]
        src_len = self.source_seq_len[index]
        tgt_len = self.target_seq_len[index]
        return src, tgt, src_len, tgt_len

# Model

In [41]:
from utils import normal_initialization

class Generator(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.item_embedding = nn.Embedding(num_item + 2, 64, padding_idx=0)
        self.item_embedding_decoder = nn.Embedding(num_item + 2, 64, padding_idx=0)
        self.transformer = nn.Transformer(
            d_model=64,
            nhead=2,
            num_encoder_layers=2,
            num_decoder_layers=2,
            dim_feedforward=256,
            dropout=0.5,
            activation='gelu',
            layer_norm_eps=1e-12,
            batch_first=True,
        )
        self.dropout = nn.Dropout(0.5)
        self.position_embedding = torch.nn.Embedding(50, 64)
        self.device = 'cuda'
        self.apply(normal_initialization)
        self.load_pretrained()

    def load_pretrained(self):
        # path = 'saved/SASRec8/amazon-toys-seq-noise-50/2024-01-24-16-37-41-603118.ckpt'
        path = 'saved/SASRec7/amazon-toys-seq-noise-50/2024-01-24-17-16-57-368371.ckpt'
        saved = torch.load(path, map_location='cpu')
        pretrained = saved['parameters']['item_embedding.weight']
        pretrained = torch.cat([
            pretrained,
            nn.init.normal_(torch.zeros(2, 64), std=0.02)
        ])
        self.item_embedding = nn.Embedding.from_pretrained(pretrained, padding_idx=0, freeze=False)

    def condition_mask(self, logits, src):
        mask = torch.zeros_like(logits, device=logits.device, dtype=torch.bool)
        mask = mask.scatter(-1, src.unsqueeze(-2).repeat(1, mask.shape[1], 1), 1)
        logits = torch.masked_fill(logits, ~mask, -torch.inf)
        return logits

    def forward(self, src, tgt, src_mask, tgt_mask,
            src_padding_mask,
            tgt_padding_mask,
            memory_key_padding_mask,
            src_seqlen,
            tgt_seqlen,
        ):
        position_ids = torch.arange(src.size(1), dtype=torch.long, device=self.device)
        position_ids = position_ids.reshape(1, -1)
        src_position_embedding = self.position_embedding(position_ids)
        src_emb = self.dropout(self.item_embedding(src) + src_position_embedding)

        position_ids = torch.arange(tgt.size(1), dtype=torch.long, device=self.device)
        position_ids = position_ids.reshape(1, -1)
        tgt_position_embedding = self.position_embedding(position_ids)
        tgt_emb = self.dropout(self.item_embedding(tgt) + tgt_position_embedding)

        outs = self.transformer(
            src_emb, tgt_emb, src_mask, tgt_mask, None,
            src_padding_mask, tgt_padding_mask, memory_key_padding_mask
        )
        logits = outs @ self.item_embedding_decoder.weight.T
        logits = self.condition_mask(logits, src)

        return logits
    
    def encode(self, src, src_mask):
        position_ids = torch.arange(src.size(1), dtype=torch.long, device=self.device)
        position_ids = position_ids.reshape(1, -1)
        src_position_embedding = self.position_embedding(position_ids)
        src_emb = self.dropout(self.item_embedding(src) + src_position_embedding)

        return self.transformer.encoder(src_emb, src_mask)

    def decode(self, tgt, memory, tgt_mask):
        position_ids = torch.arange(tgt.size(1), dtype=torch.long, device=self.device)
        position_ids = position_ids.reshape(1, -1)
        tgt_position_embedding = self.position_embedding(position_ids)
        tgt_emb = self.dropout(self.item_embedding(tgt) + tgt_position_embedding)
        
        return self.transformer.decoder(tgt_emb, memory, tgt_mask)

def generate_square_subsequent_mask(sz):
    mask = (torch.triu(torch.ones((sz, sz), device='cuda')) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, -100000).masked_fill(mask == 1, float(0.0))
    return mask

def create_mask(src, tgt):
    src_seq_len = src.shape[1]
    tgt_seq_len = tgt.shape[1]

    tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
    src_mask = torch.zeros((src_seq_len, src_seq_len),device='cuda').type(torch.bool)

    src_padding_mask = (src == 0)
    tgt_padding_mask = (tgt == 0)
    return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask


# Train

In [42]:
model = Generator().to('cuda')
loss_fn = nn.CrossEntropyLoss(ignore_index=0)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.98), eps=1e-9)
dataset = MyDataset(source, target, source_seqlen, target_seqlen)

In [43]:
def train_epoch(model, optimizer):
    model.train()
    losses = 0
    train_dataloader = DataLoader(dataset, batch_size=256, shuffle=True)
    for src, tgt, src_seqlen, tgt_seqlen in tqdm(train_dataloader):
        tgt_input = tgt[:, :-1]
        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)
        logits = model(
            src, tgt_input, src_mask, tgt_mask,
            src_padding_mask, tgt_padding_mask, src_padding_mask,
            src_seqlen, tgt_seqlen)
        optimizer.zero_grad()
        tgt_out = tgt[:, 1:]
        loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
        loss.backward()

        optimizer.step()
        losses += loss.item()
    return losses / len(list(train_dataloader))

In [45]:
from timeit import default_timer as timer
NUM_EPOCHS = 10

for epoch in range(1, NUM_EPOCHS+1):
    start_time = timer()
    train_loss = train_epoch(model, optimizer)
    end_time = timer()
    print((f"Epoch: {epoch}, Train loss: {train_loss:.3f}, "f"Epoch time = {(end_time - start_time):.3f}s"))

100%|██████████| 153/153 [00:16<00:00,  9.08it/s]


Epoch: 1, Train loss: 1.096, Epoch time = 17.350s


100%|██████████| 153/153 [00:16<00:00,  9.11it/s]


Epoch: 2, Train loss: 1.080, Epoch time = 17.295s


100%|██████████| 153/153 [00:16<00:00,  9.22it/s]


Epoch: 3, Train loss: 1.067, Epoch time = 17.102s


100%|██████████| 153/153 [00:16<00:00,  9.50it/s]


Epoch: 4, Train loss: 1.054, Epoch time = 16.505s


100%|██████████| 153/153 [00:16<00:00,  9.38it/s]


Epoch: 5, Train loss: 1.041, Epoch time = 16.678s


100%|██████████| 153/153 [00:16<00:00,  9.26it/s]


Epoch: 6, Train loss: 1.033, Epoch time = 16.962s


100%|██████████| 153/153 [00:16<00:00,  9.36it/s]


Epoch: 7, Train loss: 1.023, Epoch time = 16.808s


100%|██████████| 153/153 [00:16<00:00,  9.54it/s]


Epoch: 8, Train loss: 1.017, Epoch time = 16.621s


100%|██████████| 153/153 [00:16<00:00,  9.56it/s]


Epoch: 9, Train loss: 1.010, Epoch time = 16.423s


100%|██████████| 153/153 [00:16<00:00,  9.51it/s]


Epoch: 10, Train loss: 1.005, Epoch time = 16.460s


In [46]:
torch.save(model.state_dict(), f'translator-{dataset_name}.pth')

# Inference

In [47]:
def greedy_decode(model, src, src_mask, max_len, start_symbol):
    src = src.to('cuda')
    src_mask = src_mask.to('cuda')

    memory = model.encode(src, src_mask)
    ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to('cuda')
    for i in range(max_len-1):
        memory = memory.to('cuda')
        tgt_mask = (generate_square_subsequent_mask(ys.size(1))
                    .type(torch.bool)).to('cuda')
        out = model.decode(ys, memory, tgt_mask)
        prob = out[:, -1] @ model.item_embedding_decoder.weight.T
        _, next_word = torch.max(prob, dim=-1)
        next_word = next_word.item()

        ys = torch.cat([ys,
                        torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=1)
        if next_word == EOS:
            break
    return ys

def translate(model: torch.nn.Module, src):
    model.eval()
    src = src.reshape(1, -1)
    num_tokens = src.shape[1]
    src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)
    tgt_tokens = greedy_decode(
        model, src, src_mask, max_len=25, start_symbol=SOS).flatten()
    return tgt_tokens



In [48]:
def preprocess(seq):
    return torch.tensor([SOS] + seq + [EOS], device='cuda')
original_data = torch.load(f'./dataset/{full_dataset_name}-noise-50/{dataset_name}/train_ori.pth')
seqlist = [_[1][:_[3]] + [_[2][_[3] - 1]] for _ in original_data]
seqlist = [preprocess(_) for _ in seqlist]
# seqlist = pad_sequence(seqlist, batch_first=True).to('cuda')

In [49]:

filtered_sequences = []
for seq in tqdm(seqlist[:100]):
    rst = translate(model, seq)
    filtered_sequences.append(rst)

100%|██████████| 100/100 [00:18<00:00,  5.46it/s]


In [50]:
filtered_sequences[:5]

[tensor([11925,  7753, 10158, 11472, 11926], device='cuda:0'),
 tensor([11925,     1,  2296,  2297, 11926], device='cuda:0'),
 tensor([11925,  3030,  1664,  4662,  7233,     1,  2296,  3372, 11926],
        device='cuda:0'),
 tensor([11925,     1,  1074,  4012, 11926], device='cuda:0'),
 tensor([11925,     1,  2002,  6552,  6906,  7235,  9849,   281,  1919,  3652,
          5364, 11926], device='cuda:0')]

In [51]:
seqlist[:5]

[tensor([11925,  7753, 10158, 11472, 11926], device='cuda:0'),
 tensor([11925,     1,  2296,  2297, 11926], device='cuda:0'),
 tensor([11925,  3030,  1664,  4662,  7233,     1,  2296,  3372, 11926],
        device='cuda:0'),
 tensor([11925,     1,  1074,  4012, 11926], device='cuda:0'),
 tensor([11925,     1,  2002,  6552,  6906,  7235,  9849,   281,  1919,  3652,
          5364, 11926], device='cuda:0')]

In [17]:
cnt = 0
for a, b in zip(filtered_sequences, seqlist[:100]):
    if (a == b).all():
        cnt += 1
print(cnt)

100


# Check inference

In [96]:
interval = list(range(0, 20001, 5000))
rst_list = []
for i in range(len(interval) - 1):
    rst_list.append(torch.load(f'f-seq-{dataset_name}-{interval[i]}-{interval[i + 1]}.pth'))
rst = []
for _ in rst_list:
    rst += _

In [100]:
cnt = 0
for a, b in zip(rst, seqlist):
    if (a == b).all():
        cnt += 1

In [43]:
0.999 ** 1000

0.36769542477096373