In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from copy import deepcopy
from torch.nn.utils.rnn import pad_sequence
from utils.reparam_module import ReparamModule
import numpy as np
import random
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


# Enviroment

In [2]:
seed = 2024

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True

# Dataset

In [3]:
dataset_name = 'beauty'
full_dataset_name = 'amazon-beauty'
num_item_dict = {
    'toy': 11925,
    'sport': 18358,
    'beauty':12102,
}
path = f'./{dataset_name}-pair.pth'
data = torch.load(path)
num_item = num_item_dict[dataset_name]
SOS = num_item
EOS = num_item + 1

In [4]:
source, target = [], []
source_seqlen, target_seqlen = [], []
for _ in data:
    s, t = _
    source_seqlen.append(len(s) + 2)
    target_seqlen.append(len(t) + 2)
    s = torch.tensor([SOS] + s + [EOS])
    t = torch.tensor([SOS] + t + [EOS])
    source.append(s)
    target.append(t)
source = pad_sequence(source, batch_first=True, padding_value=0)
target = pad_sequence(target, batch_first=True, padding_value=0)
if target.shape[1] < 20:
    target = torch.cat([target, torch.zeros(target.shape[0], 20 - target.shape[1], dtype=torch.int)], dim=-1)
source_seqlen = torch.tensor(source_seqlen)
target_seqlen = torch.tensor(target_seqlen)

In [5]:
from torch.utils.data import Dataset, DataLoader
class MyDataset(Dataset):
    def __init__(
            self,
            source,
            target,
            source_seqlen,
            target_seqlen,
        ) -> None:
        super().__init__()
        self.source = source.to('cuda')
        self.target = target.to('cuda')
        self.source_seq_len = source_seqlen.to('cuda')
        self.target_seq_len = target_seqlen.to('cuda')

    def __len__(self):
        return len(self.source)

    def __getitem__(self, index):
        src = self.source[index]
        tgt = self.target[index]
        src_len = self.source_seq_len[index]
        tgt_len = self.target_seq_len[index]
        return src, tgt, src_len, tgt_len

# Model

In [6]:
from utils import normal_initialization
from module.layers import SeqPoolingLayer

class ConditionEncoder(nn.Module):
    def __init__(self, K) -> None:
        super().__init__()
        transformer_layer = nn.TransformerEncoderLayer(
            d_model=64,
            nhead=2,
            dim_feedforward=256,
            dropout=0.5,
            activation='gelu',
            layer_norm_eps=1e-12,
            batch_first=True,
        )
        self.encoder = nn.TransformerEncoder(
            encoder_layer=transformer_layer,
            num_layers=2,
        )
        self.condition_layer = nn.Sequential(
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, K),
        )
        self.condition_emb = nn.Embedding(K, 64)
        self.pooling_layer = SeqPoolingLayer('mean')
        self.tau = 1

    def forward(self, trm_input, src_mask, memory_key_padding_mask, src_seqlen):
        trm_out = self.encoder(
            src=trm_input,
            mask=src_mask,  # BxLxD
            src_key_padding_mask=memory_key_padding_mask,
        )
        trm_out = self.pooling_layer(trm_out, src_seqlen) # BD
        condition = self.condition_layer(trm_out) # BK
        condition = F.gumbel_softmax(condition, tau=self.tau, dim=-1) # BK
        self.condition4loss = condition
        self.tau = max(self.tau * 0.995, 0.1)
        rst = condition @ self.condition_emb.weight # BD
        return rst.unsqueeze(1)

class Generator(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        # self.item_embedding = nn.Embedding(num_item + 2, 64, padding_idx=0)
        # self.item_embedding_decoder = nn.Embedding(num_item + 2, 64, padding_idx=0)
        self.transformer = nn.Transformer(
            d_model=64,
            nhead=2,
            num_encoder_layers=2,
            num_decoder_layers=2,
            dim_feedforward=256,
            dropout=0.5,
            activation='gelu',
            layer_norm_eps=1e-12,
            batch_first=True,
        )
        self.dropout = nn.Dropout(0.5)
        self.position_embedding = torch.nn.Embedding(50, 64)
        self.condition_encoder = ConditionEncoder(5)
        self.device = 'cuda'
        self.apply(normal_initialization)
        self.load_pretrained()

    def load_pretrained(self):
        # path = 'saved/SASRec7/amazon-toys-seq-noise-50/2024-01-24-17-16-57-368371.ckpt'
        path = 'saved/SASRec/amazon-beauty-noise-50/2024-01-25-10-39-46-322830.ckpt'
        saved = torch.load(path, map_location='cpu')
        pretrained = saved['parameters']['item_embedding.weight']
        pretrained = torch.cat([
            pretrained,
            nn.init.normal_(torch.zeros(2, 64), std=0.02)
        ])
        self.item_embedding = nn.Embedding.from_pretrained(pretrained, padding_idx=0, freeze=False)
        self.item_embedding_decoder = self.item_embedding

    def condition_mask(self, logits, src):
        mask = torch.zeros_like(logits, device=logits.device, dtype=torch.bool)
        mask = mask.scatter(-1, src.unsqueeze(-2).repeat(1, mask.shape[1], 1), 1)
        logits = torch.masked_fill(logits, ~mask, -torch.inf)
        return logits

    def forward(self, src, tgt, src_mask, tgt_mask,
            src_padding_mask,
            tgt_padding_mask,
            memory_key_padding_mask,
            src_seqlen,
            tgt_seqlen,
        ):
        position_ids = torch.arange(src.size(1), dtype=torch.long, device=self.device)
        position_ids = position_ids.reshape(1, -1)
        src_position_embedding = self.position_embedding(position_ids)
        src_emb = self.dropout(self.item_embedding(src) + src_position_embedding)

        con_emb = self.condition_encoder(src_emb, src_mask, src_padding_mask, src_seqlen)

        position_ids = torch.arange(tgt.size(1), dtype=torch.long, device=self.device)
        position_ids = position_ids.reshape(1, -1)
        tgt_position_embedding = self.position_embedding(position_ids)
        tgt_emb = self.dropout(
            torch.cat([con_emb, self.item_embedding(tgt[:, 1:])], dim=1) + \
            tgt_position_embedding
        ) # replace [SOS] with condition embedding

        outs = self.transformer(
            src_emb, tgt_emb, src_mask, tgt_mask, None,
            src_padding_mask, tgt_padding_mask, memory_key_padding_mask
        )
        logits = outs @ self.item_embedding_decoder.weight.T
        logits = self.condition_mask(logits, src)

        return logits
    
    def encode(self, src, src_mask):
        position_ids = torch.arange(src.size(1), dtype=torch.long, device=self.device)
        position_ids = position_ids.reshape(1, -1)
        src_position_embedding = self.position_embedding(position_ids)
        src_emb = self.dropout(self.item_embedding(src) + src_position_embedding)

        return self.transformer.encoder(src_emb, src_mask)

    def set_condition(self, condition):
        self.condition = condition

    def decode(self, tgt, memory, tgt_mask):
        # position_ids = torch.arange(tgt.size(1), dtype=torch.long, device=self.device)
        # position_ids = position_ids.reshape(1, -1)
        # tgt_position_embedding = self.position_embedding(position_ids)
        # tgt_emb = self.dropout(self.item_embedding(tgt) + tgt_position_embedding)
        
        # return self.transformer.decoder(tgt_emb, memory, tgt_mask)
        con_emb = self.condition_encoder.condition_emb.weight[self.condition].unsqueeze(0).unsqueeze(0)
        position_ids = torch.arange(tgt.size(1), dtype=torch.long, device=self.device)
        position_ids = position_ids.reshape(1, -1)
        tgt_position_embedding = self.position_embedding(position_ids)
        if tgt.shape[1] == 1:
            # only SOS in
            tgt_emb = self.dropout(con_emb + tgt_position_embedding)
        else:
            # replace SOS with Condition embedding
            tgt_emb = self.dropout(
                torch.cat([con_emb, self.item_embedding(tgt[:, 1:])], dim=1) + \
                tgt_position_embedding
            )
        
        return self.transformer.decoder(tgt_emb, memory, tgt_mask)

def generate_square_subsequent_mask(sz):
    mask = (torch.triu(torch.ones((sz, sz), device='cuda')) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, -100000).masked_fill(mask == 1, float(0.0))
    return mask

def create_mask(src, tgt):
    src_seq_len = src.shape[1]
    tgt_seq_len = tgt.shape[1]

    tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
    # src_mask = torch.zeros((src_seq_len, src_seq_len),device='cuda').type(torch.bool)
    src_mask = generate_square_subsequent_mask(src_seq_len)

    src_padding_mask = (src == 0)
    tgt_padding_mask = (tgt == 0)
    return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask


# Train

In [7]:
from torch.optim.lr_scheduler import CosineAnnealingLR
NUM_EPOCHS = 80
model = Generator().to('cuda')
loss_fn = nn.CrossEntropyLoss(ignore_index=0)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.98), eps=1e-9)
dataset = MyDataset(source, target, source_seqlen, target_seqlen)
lr_scheduler = CosineAnnealingLR(
    optimizer,
    T_max=NUM_EPOCHS,
)

In [8]:
def train_epoch(model, optimizer):
    model.train()
    losses = 0
    train_dataloader = DataLoader(dataset, batch_size=256, shuffle=True)
    for src, tgt, src_seqlen, tgt_seqlen in tqdm(train_dataloader):
        tgt_input = tgt[:, :-1]
        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)
        logits = model(
            src, tgt_input, src_mask, tgt_mask,
            src_padding_mask, tgt_padding_mask, src_padding_mask,
            src_seqlen, tgt_seqlen)
        optimizer.zero_grad()
        tgt_out = tgt[:, 1:]
        loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
        condition_prob = model.condition_encoder.condition4loss
        reg_loss = - (condition_prob * torch.log(condition_prob + 1e-12)).sum(-1).mean()
        losses += loss.item()
        loss = loss + 1 * reg_loss
        loss.backward()

        optimizer.step()
        lr_scheduler.step()
    return losses / len(list(train_dataloader))

In [9]:
from timeit import default_timer as timer


for epoch in range(1, NUM_EPOCHS+1):
    start_time = timer()
    train_loss = train_epoch(model, optimizer)
    end_time = timer()
    print((f"Epoch: {epoch}, Train loss: {train_loss:.3f}, "f"Epoch time = {(end_time - start_time):.3f}s"))

100%|██████████| 133/133 [00:44<00:00,  2.97it/s]


Epoch: 1, Train loss: 2.474, Epoch time = 45.244s


100%|██████████| 133/133 [00:42<00:00,  3.15it/s]


Epoch: 2, Train loss: 1.733, Epoch time = 42.602s


100%|██████████| 133/133 [00:42<00:00,  3.14it/s]


Epoch: 3, Train loss: 1.620, Epoch time = 42.904s


100%|██████████| 133/133 [00:42<00:00,  3.12it/s]


Epoch: 4, Train loss: 1.563, Epoch time = 43.217s


100%|██████████| 133/133 [00:38<00:00,  3.48it/s]


Epoch: 5, Train loss: 1.523, Epoch time = 38.707s


100%|██████████| 133/133 [00:36<00:00,  3.64it/s]


Epoch: 6, Train loss: 1.494, Epoch time = 37.025s


100%|██████████| 133/133 [00:36<00:00,  3.69it/s]


Epoch: 7, Train loss: 1.474, Epoch time = 36.451s


100%|██████████| 133/133 [00:35<00:00,  3.72it/s]


Epoch: 8, Train loss: 1.461, Epoch time = 36.155s


100%|██████████| 133/133 [00:35<00:00,  3.77it/s]


Epoch: 9, Train loss: 1.447, Epoch time = 35.683s


100%|██████████| 133/133 [00:35<00:00,  3.70it/s]


Epoch: 10, Train loss: 1.434, Epoch time = 36.447s


100%|██████████| 133/133 [00:35<00:00,  3.77it/s]


Epoch: 11, Train loss: 1.418, Epoch time = 35.687s


100%|██████████| 133/133 [00:34<00:00,  3.82it/s]


Epoch: 12, Train loss: 1.411, Epoch time = 35.383s


100%|██████████| 133/133 [00:34<00:00,  3.81it/s]


Epoch: 13, Train loss: 1.408, Epoch time = 35.352s


100%|██████████| 133/133 [00:35<00:00,  3.72it/s]


Epoch: 14, Train loss: 1.402, Epoch time = 36.155s


100%|██████████| 133/133 [00:35<00:00,  3.74it/s]


Epoch: 15, Train loss: 1.399, Epoch time = 36.028s


100%|██████████| 133/133 [00:35<00:00,  3.78it/s]


Epoch: 16, Train loss: 1.390, Epoch time = 35.623s


100%|██████████| 133/133 [00:34<00:00,  3.80it/s]


Epoch: 17, Train loss: 1.378, Epoch time = 35.434s


100%|██████████| 133/133 [00:35<00:00,  3.77it/s]


Epoch: 18, Train loss: 1.376, Epoch time = 35.775s


100%|██████████| 133/133 [00:35<00:00,  3.76it/s]


Epoch: 19, Train loss: 1.374, Epoch time = 35.859s


100%|██████████| 133/133 [00:35<00:00,  3.78it/s]


Epoch: 20, Train loss: 1.370, Epoch time = 35.721s


100%|██████████| 133/133 [00:35<00:00,  3.76it/s]


Epoch: 21, Train loss: 1.366, Epoch time = 35.848s


100%|██████████| 133/133 [00:35<00:00,  3.77it/s]


Epoch: 22, Train loss: 1.362, Epoch time = 35.664s


100%|██████████| 133/133 [00:35<00:00,  3.71it/s]


Epoch: 23, Train loss: 1.355, Epoch time = 36.308s


100%|██████████| 133/133 [00:35<00:00,  3.73it/s]


Epoch: 24, Train loss: 1.351, Epoch time = 36.115s


100%|██████████| 133/133 [00:35<00:00,  3.73it/s]


Epoch: 25, Train loss: 1.351, Epoch time = 36.098s


100%|██████████| 133/133 [00:35<00:00,  3.72it/s]


Epoch: 26, Train loss: 1.350, Epoch time = 36.215s


100%|██████████| 133/133 [00:35<00:00,  3.73it/s]


Epoch: 27, Train loss: 1.346, Epoch time = 36.165s


100%|██████████| 133/133 [00:35<00:00,  3.72it/s]


Epoch: 28, Train loss: 1.341, Epoch time = 36.335s


100%|██████████| 133/133 [00:35<00:00,  3.74it/s]


Epoch: 29, Train loss: 1.334, Epoch time = 36.130s


100%|██████████| 133/133 [00:35<00:00,  3.73it/s]


Epoch: 30, Train loss: 1.333, Epoch time = 36.131s


100%|██████████| 133/133 [00:32<00:00,  4.05it/s]


Epoch: 31, Train loss: 1.334, Epoch time = 33.295s


100%|██████████| 133/133 [00:34<00:00,  3.88it/s]


Epoch: 32, Train loss: 1.335, Epoch time = 34.709s


100%|██████████| 133/133 [00:35<00:00,  3.71it/s]


Epoch: 33, Train loss: 1.330, Epoch time = 36.241s


100%|██████████| 133/133 [00:35<00:00,  3.72it/s]


Epoch: 34, Train loss: 1.323, Epoch time = 36.149s


100%|██████████| 133/133 [00:35<00:00,  3.72it/s]


Epoch: 35, Train loss: 1.319, Epoch time = 36.165s


100%|██████████| 133/133 [00:35<00:00,  3.76it/s]


Epoch: 36, Train loss: 1.318, Epoch time = 35.935s


100%|██████████| 133/133 [00:35<00:00,  3.75it/s]


Epoch: 37, Train loss: 1.320, Epoch time = 36.017s


100%|██████████| 133/133 [00:35<00:00,  3.73it/s]


Epoch: 38, Train loss: 1.318, Epoch time = 36.089s


100%|██████████| 133/133 [00:35<00:00,  3.73it/s]


Epoch: 39, Train loss: 1.313, Epoch time = 36.122s


100%|██████████| 133/133 [00:35<00:00,  3.74it/s]


Epoch: 40, Train loss: 1.308, Epoch time = 36.000s


100%|██████████| 133/133 [00:35<00:00,  3.78it/s]


Epoch: 41, Train loss: 1.304, Epoch time = 35.685s


100%|██████████| 133/133 [00:35<00:00,  3.76it/s]


Epoch: 42, Train loss: 1.304, Epoch time = 35.937s


100%|██████████| 133/133 [00:34<00:00,  3.80it/s]


Epoch: 43, Train loss: 1.304, Epoch time = 35.441s


100%|██████████| 133/133 [00:35<00:00,  3.76it/s]


Epoch: 44, Train loss: 1.307, Epoch time = 35.904s


100%|██████████| 133/133 [00:35<00:00,  3.78it/s]


Epoch: 45, Train loss: 1.302, Epoch time = 35.770s


100%|██████████| 133/133 [00:35<00:00,  3.77it/s]


Epoch: 46, Train loss: 1.294, Epoch time = 35.806s


100%|██████████| 133/133 [00:34<00:00,  3.81it/s]


Epoch: 47, Train loss: 1.291, Epoch time = 35.307s


100%|██████████| 133/133 [00:35<00:00,  3.70it/s]


Epoch: 48, Train loss: 1.292, Epoch time = 36.387s


100%|██████████| 133/133 [00:34<00:00,  3.81it/s]


Epoch: 49, Train loss: 1.295, Epoch time = 35.356s


100%|██████████| 133/133 [00:35<00:00,  3.80it/s]


Epoch: 50, Train loss: 1.293, Epoch time = 35.429s


100%|██████████| 133/133 [00:35<00:00,  3.76it/s]


Epoch: 51, Train loss: 1.289, Epoch time = 35.814s


100%|██████████| 133/133 [00:34<00:00,  3.90it/s]


Epoch: 52, Train loss: 1.285, Epoch time = 34.628s


100%|██████████| 133/133 [00:34<00:00,  3.91it/s]


Epoch: 53, Train loss: 1.280, Epoch time = 34.537s


100%|██████████| 133/133 [00:32<00:00,  4.06it/s]


Epoch: 54, Train loss: 1.281, Epoch time = 33.261s


100%|██████████| 133/133 [00:32<00:00,  4.06it/s]


Epoch: 55, Train loss: 1.283, Epoch time = 33.167s


100%|██████████| 133/133 [00:34<00:00,  3.80it/s]


Epoch: 56, Train loss: 1.284, Epoch time = 35.396s


100%|██████████| 133/133 [00:35<00:00,  3.73it/s]


Epoch: 57, Train loss: 1.280, Epoch time = 36.100s


100%|██████████| 133/133 [00:35<00:00,  3.76it/s]


Epoch: 58, Train loss: 1.274, Epoch time = 35.835s


100%|██████████| 133/133 [00:35<00:00,  3.72it/s]


Epoch: 59, Train loss: 1.270, Epoch time = 36.188s


100%|██████████| 133/133 [00:35<00:00,  3.71it/s]


Epoch: 60, Train loss: 1.274, Epoch time = 36.255s


100%|██████████| 133/133 [00:34<00:00,  3.83it/s]


Epoch: 61, Train loss: 1.273, Epoch time = 35.282s


100%|██████████| 133/133 [00:34<00:00,  3.85it/s]


Epoch: 62, Train loss: 1.274, Epoch time = 35.076s


100%|██████████| 133/133 [00:35<00:00,  3.76it/s]


Epoch: 63, Train loss: 1.271, Epoch time = 35.822s


100%|██████████| 133/133 [00:35<00:00,  3.72it/s]


Epoch: 64, Train loss: 1.265, Epoch time = 36.288s


100%|██████████| 133/133 [00:34<00:00,  3.84it/s]


Epoch: 65, Train loss: 1.260, Epoch time = 35.085s


100%|██████████| 133/133 [00:34<00:00,  3.89it/s]


Epoch: 66, Train loss: 1.264, Epoch time = 34.640s


100%|██████████| 133/133 [00:35<00:00,  3.74it/s]


Epoch: 67, Train loss: 1.264, Epoch time = 36.049s


100%|██████████| 133/133 [00:36<00:00,  3.64it/s]


Epoch: 68, Train loss: 1.264, Epoch time = 37.019s


100%|██████████| 133/133 [00:35<00:00,  3.79it/s]


Epoch: 69, Train loss: 1.260, Epoch time = 35.537s


100%|██████████| 133/133 [00:35<00:00,  3.76it/s]


Epoch: 70, Train loss: 1.258, Epoch time = 35.900s


100%|██████████| 133/133 [00:36<00:00,  3.67it/s]


Epoch: 71, Train loss: 1.253, Epoch time = 36.679s


100%|██████████| 133/133 [00:35<00:00,  3.70it/s]


Epoch: 72, Train loss: 1.258, Epoch time = 36.405s


100%|██████████| 133/133 [00:36<00:00,  3.69it/s]


Epoch: 73, Train loss: 1.255, Epoch time = 36.466s


100%|██████████| 133/133 [00:35<00:00,  3.77it/s]


Epoch: 74, Train loss: 1.255, Epoch time = 35.671s


100%|██████████| 133/133 [00:35<00:00,  3.76it/s]


Epoch: 75, Train loss: 1.254, Epoch time = 35.764s


100%|██████████| 133/133 [00:33<00:00,  3.92it/s]


Epoch: 76, Train loss: 1.249, Epoch time = 34.348s


100%|██████████| 133/133 [00:32<00:00,  4.07it/s]


Epoch: 77, Train loss: 1.247, Epoch time = 33.140s


100%|██████████| 133/133 [00:29<00:00,  4.57it/s]


Epoch: 78, Train loss: 1.249, Epoch time = 29.591s


100%|██████████| 133/133 [00:33<00:00,  3.93it/s]


Epoch: 79, Train loss: 1.251, Epoch time = 34.323s


100%|██████████| 133/133 [00:35<00:00,  3.78it/s]


Epoch: 80, Train loss: 1.248, Epoch time = 35.579s


In [10]:
torch.save(model.state_dict(), f'translator-{dataset_name}-con.pth')

# Inference

In [11]:
def inference_mask(logits, src, ys):
    mask = torch.zeros_like(logits, device=logits.device, dtype=torch.bool)
    mask = mask.scatter(-1, src, 1)
    mask = mask.scatter(-1, ys, 0)
    logits = torch.masked_fill(logits, ~mask, -torch.inf)
    return logits

def greedy_decode(model, src, src_mask, max_len, start_symbol):
    src = src.to('cuda')
    src_mask = src_mask.to('cuda')

    memory = model.encode(src, src_mask)
    ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to('cuda')
    for i in range(max_len-1):
        memory = memory.to('cuda')
        tgt_mask = (generate_square_subsequent_mask(ys.size(1))
                    .type(torch.bool)).to('cuda')
        out = model.decode(ys, memory, tgt_mask)
        prob = out[:, -1] @ model.item_embedding_decoder.weight.T
        prob = inference_mask(prob, src, ys)
        _, next_word = torch.max(prob, dim=-1)
        next_word = next_word.item()

        ys = torch.cat([ys,
                        torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=1)
        if next_word == EOS:
            break
    return ys

def translate(model: torch.nn.Module, src):
    model.eval()
    src = src.reshape(1, -1)
    num_tokens = src.shape[1]
    src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)
    tgt_tokens = greedy_decode(
        model, src, src_mask, max_len=25, start_symbol=SOS).flatten()
    return tgt_tokens



In [12]:
def preprocess(seq):
    return torch.tensor([SOS] + seq + [EOS], device='cuda')
original_data = torch.load(f'./dataset/{full_dataset_name}-noise-50/{dataset_name}/train_ori.pth')
seqlist = [_[1][:_[3]] + [_[2][_[3] - 1]] for _ in original_data]
seqlist = [preprocess(_) for _ in seqlist]

In [None]:
model.load_state_dict(torch.load(f'translator-{dataset_name}-con.pth'))

In [31]:
idx = 30
# model.condition = 0
model.set_condition(1)
print(source[idx])
print(translate(model, source[idx]))
print(target[idx])

tensor([12102, 10812, 10897,  1135, 10669, 10689,  1166, 10896, 11734, 11732,
         7114, 11772, 11773, 11768, 11853, 11878, 11879, 11774, 11775,  8976,
         8977, 11767,     5,   918,  1015,  1076,  1592,  1977,  2414,  2835,
         2837,  2841,  3564,  5580,  9939,  4505,  3387,   366, 11710, 11712,
        11805, 11807,  6440,  6456,   388,   417,   420,  1902,  2313, 12103])
tensor([12102, 11767, 11768, 12103], device='cuda:0')
tensor([12102, 11774, 11775, 12103,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0])


In [38]:
model.condition = 3
filtered_sequences = []
for seq in tqdm(seqlist[:100]):
    rst = translate(model, seq)
    filtered_sequences.append(rst)

100%|██████████| 100/100 [00:11<00:00,  8.54it/s]


In [33]:
debug = deepcopy(filtered_sequences)

In [39]:
cnt = 0
for a, b in zip(debug, filtered_sequences):
    if len(a) == len(b):
        if (a == b).all():
            cnt += 1
print(cnt)

95


In [41]:
interval = list(range(0, 25001, 5000))
rst_list = []
for i in range(len(interval) - 1):
    rst_list.append(torch.load(f'f-seq-con-{dataset_name}-{interval[i]}-{interval[i + 1]}.pth'))
rst = []
for _ in rst_list:
    rst += _
print(len(rst))

111815


In [42]:
ori_pattern = torch.load(f'./dataset/{full_dataset_name}-noise-50/{dataset_name}/train_new-pure.pth')
ori_seqlist = [list(_[1][:_[3]]) + [_[2][_[3] - 1]] for _ in list(ori_pattern)]
len(ori_pattern)

89210

In [43]:
max_seq_len = 50
def truncate_or_pad(seq):
    cur_seq_len = len(seq)
    if cur_seq_len > max_seq_len:
        return seq[-max_seq_len:]
    else:
        return seq + [0] * (max_seq_len - cur_seq_len)

train_set = set()
# for _ in ori_seqlist:
#     train_set.add(tuple(_))

for pattern in rst:
    seq = pattern.tolist()[1:-1]
    train_set.add(tuple(seq))
print(len(train_set))

22326


In [44]:
train_list = []
for _ in train_set:
    train_list.append([
        1,
        truncate_or_pad(list(_)[:-1]),
        truncate_or_pad(list(_)[1:]),
        sum([a != 0 for a in list(_)[:-1]]),
        [1] * max_seq_len,
        [0] * max_seq_len,
    ])
print(len(train_list))

22326


In [45]:
torch.save(train_list + ori_pattern, f'./dataset/{full_dataset_name}-noise-50/{dataset_name}/train_gene.pth')