In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from copy import deepcopy
from torch.nn.utils.rnn import pad_sequence
from utils.reparam_module import ReparamModule
import numpy as np
import random
from tqdm import tqdm

# Enviroment

In [4]:
seed = 2024

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True

# Dataset

In [13]:
dataset_name = 'sport'
full_dataset_name = 'amazon-sport'
num_item_dict = {
    'toy': 11925,
    'sport': 18358,
}
path = f'./{dataset_name}-pair.pth'
data = torch.load(path)
num_item = num_item_dict[dataset_name]
SOS = num_item
EOS = num_item + 1

In [6]:
source, target = [], []
source_seqlen, target_seqlen = [], []
for _ in data:
    s, t = _
    source_seqlen.append(len(s) + 2)
    target_seqlen.append(len(t) + 2)
    s = torch.tensor([SOS] + s + [EOS])
    t = torch.tensor([SOS] + t + [EOS])
    source.append(s)
    target.append(t)
source = pad_sequence(source, batch_first=True, padding_value=0)
target = pad_sequence(target, batch_first=True, padding_value=0)
if target.shape[1] < 20:
    target = torch.cat([target, torch.zeros(target.shape[0], 20 - target.shape[1], dtype=torch.int)], dim=-1)
source_seqlen = torch.tensor(source_seqlen)
target_seqlen = torch.tensor(target_seqlen)

In [7]:
from torch.utils.data import Dataset, DataLoader
class MyDataset(Dataset):
    def __init__(
            self,
            source,
            target,
            source_seqlen,
            target_seqlen,
        ) -> None:
        super().__init__()
        self.source = source.to('cuda')
        self.target = target.to('cuda')
        self.source_seq_len = source_seqlen.to('cuda')
        self.target_seq_len = target_seqlen.to('cuda')

    def __len__(self):
        return len(self.source)

    def __getitem__(self, index):
        src = self.source[index]
        tgt = self.target[index]
        src_len = self.source_seq_len[index]
        tgt_len = self.target_seq_len[index]
        return src, tgt, src_len, tgt_len

# Model

In [8]:
from utils import normal_initialization
from module.layers import SeqPoolingLayer

class ConditionEncoder(nn.Module):
    def __init__(self, K) -> None:
        super().__init__()
        transformer_layer = nn.TransformerEncoderLayer(
            d_model=64,
            nhead=2,
            dim_feedforward=256,
            dropout=0.5,
            activation='gelu',
            layer_norm_eps=1e-12,
            batch_first=True,
        )
        self.encoder = nn.TransformerEncoder(
            encoder_layer=transformer_layer,
            num_layers=1,
        )
        self.condition_layer = nn.Sequential(
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, K),
        )
        self.condition_emb = nn.Embedding(K, 64)
        self.pooling_layer = SeqPoolingLayer('mean')
        self.tau = 5

    def forward(self, trm_input, src_mask, memory_key_padding_mask, src_seqlen):
        trm_out = self.encoder(
            src=trm_input,
            mask=src_mask,  # BxLxD
            src_key_padding_mask=memory_key_padding_mask,
        )
        trm_out = self.pooling_layer(trm_out, src_seqlen) # BD
        condition = self.condition_layer(trm_out) # BK
        condition = F.gumbel_softmax(condition, tau=self.tau, dim=-1) # BK
        self.tau = max(self.tau * 0.999, 0.2)
        rst = condition @ self.condition_emb.weight # BD
        return rst.unsqueeze(1)

class Generator(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.item_embedding = nn.Embedding(num_item + 2, 64, padding_idx=0)
        self.item_embedding_decoder = nn.Embedding(num_item + 2, 64, padding_idx=0)
        self.transformer = nn.Transformer(
            d_model=64,
            nhead=2,
            num_encoder_layers=2,
            num_decoder_layers=2,
            dim_feedforward=256,
            dropout=0.5,
            activation='gelu',
            layer_norm_eps=1e-12,
            batch_first=True,
        )
        self.dropout = nn.Dropout(0.5)
        self.position_embedding = torch.nn.Embedding(50, 64)
        self.condition_encoder = ConditionEncoder(5)
        self.device = 'cuda'
        self.apply(normal_initialization)
        # self.load_pretrained()

    def load_pretrained(self):
        # path = 'saved/SASRec8/amazon-toys-seq-noise-50/2024-01-24-16-37-41-603118.ckpt'
        path = 'saved/SASRec7/amazon-toys-seq-noise-50/2024-01-24-17-16-57-368371.ckpt'
        saved = torch.load(path, map_location='cpu')
        pretrained = saved['parameters']['item_embedding.weight']
        pretrained = torch.cat([
            pretrained,
            nn.init.normal_(torch.zeros(2, 64), std=0.02)
        ])
        self.item_embedding = nn.Embedding.from_pretrained(pretrained, padding_idx=0, freeze=False)

    def forward(self, src, tgt, src_mask, tgt_mask,
            src_padding_mask,
            tgt_padding_mask,
            memory_key_padding_mask,
            src_seqlen,
            tgt_seqlen,
        ):
        position_ids = torch.arange(src.size(1), dtype=torch.long, device=self.device)
        position_ids = position_ids.reshape(1, -1)
        src_position_embedding = self.position_embedding(position_ids)
        src_emb = self.dropout(self.item_embedding(src) + src_position_embedding)

        con_emb = self.condition_encoder(src_emb, src_mask, src_padding_mask, src_seqlen)

        position_ids = torch.arange(tgt.size(1), dtype=torch.long, device=self.device)
        position_ids = position_ids.reshape(1, -1)
        tgt_position_embedding = self.position_embedding(position_ids)
        tgt_emb = self.dropout(
            torch.cat([con_emb, self.item_embedding(tgt[:, 1:])], dim=1) + \
            tgt_position_embedding
        ) # replace [SOS] with condition embedding

        outs = self.transformer(
            src_emb, tgt_emb, src_mask, tgt_mask, None,
            src_padding_mask, tgt_padding_mask, memory_key_padding_mask
        )

        return outs @ self.item_embedding_decoder.weight.T
    
    def encode(self, src, src_mask):
        position_ids = torch.arange(src.size(1), dtype=torch.long, device=self.device)
        position_ids = position_ids.reshape(1, -1)
        src_position_embedding = self.position_embedding(position_ids)
        src_emb = self.dropout(self.item_embedding(src) + src_position_embedding)

        return self.transformer.encoder(src_emb, src_mask)

    def set_condition(self, condition):
        self.conditon = condition

    def decode(self, tgt, memory, tgt_mask):
        # position_ids = torch.arange(tgt.size(1), dtype=torch.long, device=self.device)
        # position_ids = position_ids.reshape(1, -1)
        # tgt_position_embedding = self.position_embedding(position_ids)
        # tgt_emb = self.dropout(self.item_embedding(tgt) + tgt_position_embedding)
        
        # return self.transformer.decoder(tgt_emb, memory, tgt_mask)
        con_emb = self.condition_encoder.condition_emb.weight[self.condition].unsqueeze(0).unsqueeze(0)
        position_ids = torch.arange(tgt.size(1), dtype=torch.long, device=self.device)
        position_ids = position_ids.reshape(1, -1)
        tgt_position_embedding = self.position_embedding(position_ids)
        if tgt.shape[1] == 1:
            # only SOS in
            tgt_emb = self.dropout(con_emb + tgt_position_embedding)
        else:
            # replace SOS with Condition embedding
            tgt_emb = self.dropout(
                torch.cat([con_emb, self.item_embedding(tgt[:, 1:])], dim=1) + \
                tgt_position_embedding
            )
        
        return self.transformer.decoder(tgt_emb, memory, tgt_mask)

def generate_square_subsequent_mask(sz):
    mask = (torch.triu(torch.ones((sz, sz), device='cuda')) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, -100000).masked_fill(mask == 1, float(0.0))
    return mask

def create_mask(src, tgt):
    src_seq_len = src.shape[1]
    tgt_seq_len = tgt.shape[1]

    tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
    src_mask = torch.zeros((src_seq_len, src_seq_len),device='cuda').type(torch.bool)

    src_padding_mask = (src == 0)
    tgt_padding_mask = (tgt == 0)
    return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask


# Train

In [9]:
model = Generator().to('cuda')
loss_fn = nn.CrossEntropyLoss(ignore_index=0)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.98), eps=1e-9, weight_decay=1e-3)
dataset = MyDataset(source, target, source_seqlen, target_seqlen)

In [10]:
def train_epoch(model, optimizer):
    model.train()
    losses = 0
    train_dataloader = DataLoader(dataset, batch_size=256, shuffle=True)
    for src, tgt, src_seqlen, tgt_seqlen in tqdm(train_dataloader):
        tgt_input = tgt[:, :-1]
        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)
        logits = model(
            src, tgt_input, src_mask, tgt_mask,
            src_padding_mask, tgt_padding_mask, src_padding_mask,
            src_seqlen, tgt_seqlen)
        optimizer.zero_grad()
        tgt_out = tgt[:, 1:]
        loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
        loss.backward()

        optimizer.step()
        losses += loss.item()
    return losses / len(list(train_dataloader))

In [26]:
from timeit import default_timer as timer
NUM_EPOCHS = 10

for epoch in range(1, NUM_EPOCHS+1):
    start_time = timer()
    train_loss = train_epoch(model, optimizer)
    end_time = timer()
    print((f"Epoch: {epoch}, Train loss: {train_loss:.3f}, "f"Epoch time = {(end_time - start_time):.3f}s"))

100%|██████████| 257/257 [00:52<00:00,  4.89it/s]


Epoch: 1, Train loss: 3.676, Epoch time = 53.654s


100%|██████████| 257/257 [00:52<00:00,  4.90it/s]


Epoch: 2, Train loss: 3.659, Epoch time = 53.341s


100%|██████████| 257/257 [00:52<00:00,  4.92it/s]


Epoch: 3, Train loss: 3.641, Epoch time = 53.042s


100%|██████████| 257/257 [00:52<00:00,  4.86it/s]


Epoch: 4, Train loss: 3.626, Epoch time = 53.636s


100%|██████████| 257/257 [00:52<00:00,  4.91it/s]


Epoch: 5, Train loss: 3.612, Epoch time = 53.175s


100%|██████████| 257/257 [00:52<00:00,  4.94it/s]


Epoch: 6, Train loss: 3.595, Epoch time = 52.923s


100%|██████████| 257/257 [00:51<00:00,  4.97it/s]


Epoch: 7, Train loss: 3.580, Epoch time = 52.550s


100%|██████████| 257/257 [00:52<00:00,  4.88it/s]


Epoch: 8, Train loss: 3.562, Epoch time = 53.764s


100%|██████████| 257/257 [00:52<00:00,  4.92it/s]


Epoch: 9, Train loss: 3.548, Epoch time = 53.076s


100%|██████████| 257/257 [00:52<00:00,  4.90it/s]


Epoch: 10, Train loss: 3.536, Epoch time = 53.251s


In [14]:
torch.save(model.state_dict(), f'translator-{dataset_name}-con.pth')

# Inference

In [15]:
def greedy_decode(model, src, src_mask, max_len, start_symbol):
    src = src.to('cuda')
    src_mask = src_mask.to('cuda')

    memory = model.encode(src, src_mask)
    ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to('cuda')
    for i in range(max_len-1):
        memory = memory.to('cuda')
        tgt_mask = (generate_square_subsequent_mask(ys.size(1))
                    .type(torch.bool)).to('cuda')
        out = model.decode(ys, memory, tgt_mask)
        prob = out[:, -1] @ model.item_embedding_decoder.weight.T
        _, next_word = torch.max(prob, dim=-1)
        next_word = next_word.item()

        ys = torch.cat([ys,
                        torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=1)
        if next_word == EOS:
            break
    return ys

def translate(model: torch.nn.Module, src):
    model.eval()
    src = src.reshape(1, -1)
    num_tokens = src.shape[1]
    src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)
    tgt_tokens = greedy_decode(
        model, src, src_mask, max_len=25, start_symbol=SOS).flatten()
    return tgt_tokens



In [16]:
def preprocess(seq):
    return torch.tensor([SOS] + seq + [EOS], device='cuda')
original_data = torch.load(f'./dataset/{full_dataset_name}-noise-50/{dataset_name}/train_ori.pth')
seqlist = [_[1][:_[3]] + [_[2][_[3] - 1]] for _ in original_data]
seqlist = [preprocess(_) for _ in seqlist]
# seqlist = pad_sequence(seqlist, batch_first=True).to('cuda')
# seqlen = torch.tensor([_[3] for _ in original_data], device='cuda')

In [24]:
idx = 0
print(source[idx])
translate(model, source[idx])
print(target[idx])

tensor([18358,  7216,  3088,  1543,  7170, 11503,  6678,  9199,  4282,     1,
        13469,  2323,  5045,  6555, 15279,  4124,  8803, 14213, 18359,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0])
tensor([18358,  4282,  8803, 18359,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0])


In [17]:
model.condition = 0
filtered_sequences = []
for seq in tqdm(seqlist[:100]):
    rst = translate(model, seq)
    filtered_sequences.append(rst)

100%|██████████| 100/100 [00:13<00:00,  7.17it/s]


In [22]:
seqlist[:5]

[tensor([18358,     1, 11982, 15853,  3328, 13373, 17788, 18359],
        device='cuda:0'),
 tensor([18358,  7216,  3088,  1543,  7170, 11503,  6678,  9199,  4282,     1,
         13469,  2323,  5045,  6555, 15279,  4124,  8803, 14213, 18359],
        device='cuda:0'),
 tensor([18358,  6299, 10219,  7951,     1,  3012, 14658,  8777, 14243,  1115,
           501,  2951, 10838, 11361, 14700,  9255,  4156,  6446,  3662,  9842,
         14213, 15076, 18359], device='cuda:0'),
 tensor([18358,  1691,  6245, 10593, 18359], device='cuda:0'),
 tensor([18358, 13465,   272,  2569,  4712,  9939,     1,  5822, 18359],
        device='cuda:0')]

In [5]:
interval = list(range(0, 20001, 5000))
rst_list = []
for i in range(len(interval) - 1):
    rst_list.append(torch.load(f'f-seq-con-{dataset}-0-{interval[i]}-{interval[i + 1]}.pth'))
rst = []
for _ in rst_list:
    rst += _

In [13]:
max_seq_len = 50
def truncate_or_pad(seq):
    cur_seq_len = len(seq)
    if cur_seq_len > max_seq_len:
        return seq[-max_seq_len:]
    else:
        return seq + [0] * (max_seq_len - cur_seq_len)


train_set = set()

for pattern in rst:
    seq = pattern.tolist()[1:-1]
    seq_len = len(seq)
    for _ in range(1, seq_len):
        train_set.add(tuple(
            truncate_or_pad(seq[:_]) + [seq[_]],
        ))

train_list = []
for _ in train_set:
    train_list.append([
        1,
        list(_)[:-1],
        list(_)[-1],
        sum([a != 0 for a in list(_)[:-1]]),
        1,
        0
    ])
print(len(train_list))

105671


In [14]:
ori_pattern = torch.load('./dataset/amazon-toys-noise-50/toy/train_new-pure.pth')

In [95]:
torch.save(train_list + ori_pattern, './dataset/amazon-toys-noise-50/toy/train_gene.pth')