In [15]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from copy import deepcopy
from torch.nn.utils.rnn import pad_sequence
from utils.reparam_module import ReparamModule
import numpy as np
import random
from tqdm import tqdm

# Enviroment

In [16]:
seed = 2024

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True

# Dataset

In [17]:
dataset_name = 'yelp-small'
full_dataset_name = 'yelp-small'
num_item_dict = {
    'toy': 11925,
    'sport': 18358,
    'beauty':12102,
    'yelp-small': 20034,
}
path = f'./{dataset_name}-pair.pth'
data = torch.load(path)
num_item = num_item_dict[dataset_name]
SOS = num_item
EOS = num_item + 1

In [18]:
source, target = [], []
source_seqlen, target_seqlen = [], []
for _ in data:
    s, t = _
    source_seqlen.append(len(s) + 2)
    target_seqlen.append(len(t) + 2)
    s = torch.tensor([SOS] + s + [EOS])
    t = torch.tensor([SOS] + t + [EOS])
    source.append(s)
    target.append(t)
source = pad_sequence(source, batch_first=True, padding_value=0)
target = pad_sequence(target, batch_first=True, padding_value=0)
if target.shape[1] < 20:
    target = torch.cat([target, torch.zeros(target.shape[0], 20 - target.shape[1], dtype=torch.int)], dim=-1)
source_seqlen = torch.tensor(source_seqlen)
target_seqlen = torch.tensor(target_seqlen)

In [19]:
from torch.utils.data import Dataset, DataLoader
class MyDataset(Dataset):
    def __init__(
            self,
            source,
            target,
            source_seqlen,
            target_seqlen,
        ) -> None:
        super().__init__()
        self.source = source.to('cuda')
        self.target = target.to('cuda')
        self.source_seq_len = source_seqlen.to('cuda')
        self.target_seq_len = target_seqlen.to('cuda')

    def __len__(self):
        return len(self.source)

    def __getitem__(self, index):
        src = self.source[index]
        tgt = self.target[index]
        src_len = self.source_seq_len[index]
        tgt_len = self.target_seq_len[index]
        return src, tgt, src_len, tgt_len

# Model

In [20]:
from utils import normal_initialization
from module.layers import SeqPoolingLayer

class ConditionEncoder(nn.Module):
    def __init__(self, K) -> None:
        super().__init__()
        transformer_layer = nn.TransformerEncoderLayer(
            d_model=64,
            nhead=2,
            dim_feedforward=256,
            dropout=0.5,
            activation='gelu',
            layer_norm_eps=1e-12,
            batch_first=True,
        )
        self.encoder = nn.TransformerEncoder(
            encoder_layer=transformer_layer,
            num_layers=2,
        )
        self.condition_layer = nn.Sequential(
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, K),
        )
        self.condition_emb = nn.Embedding(K, 64)
        self.pooling_layer = SeqPoolingLayer('mean')
        self.tau = 1

    def forward(self, trm_input, src_mask, memory_key_padding_mask, src_seqlen):
        trm_out = self.encoder(
            src=trm_input,
            mask=src_mask,  # BxLxD
            src_key_padding_mask=memory_key_padding_mask,
        )
        trm_out = self.pooling_layer(trm_out, src_seqlen) # BD
        condition = self.condition_layer(trm_out) # BK
        condition = F.gumbel_softmax(condition, tau=self.tau, dim=-1) # BK
        self.condition4loss = condition
        self.tau = max(self.tau * 0.995, 0.1)
        rst = condition @ self.condition_emb.weight # BD
        return rst.unsqueeze(1)

class Generator(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        # self.item_embedding = nn.Embedding(num_item + 2, 64, padding_idx=0)
        # self.item_embedding_decoder = nn.Embedding(num_item + 2, 64, padding_idx=0)
        self.transformer = nn.Transformer(
            d_model=64,
            nhead=2,
            num_encoder_layers=2,
            num_decoder_layers=2,
            dim_feedforward=256,
            dropout=0.5,
            activation='gelu',
            layer_norm_eps=1e-12,
            batch_first=True,
        )
        self.dropout = nn.Dropout(0.5)
        self.position_embedding = torch.nn.Embedding(50, 64)
        self.condition_encoder = ConditionEncoder(5)
        self.device = 'cuda'
        self.apply(normal_initialization)
        self.load_pretrained()

    def load_pretrained(self):
        path_dict = {
            'toy': 'saved/SASRec/amazon-toys-noise-50/2024-01-22-19-24-37-334415.ckpt',
            'beauty': 'saved/SASRec/amazon-beauty-noise-50/2024-01-25-10-39-46-322830.ckpt',
            'sport': 'saved/SASRec/amazon-sport-noise-50/2024-01-25-09-36-23-316839.ckpt',
            'yelp-small': 'saved/SASRec/yelp-small-noise-50/2024-01-25-20-34-58-431582.ckpt',
        }
        path = path_dict[dataset_name]
        saved = torch.load(path, map_location='cpu')
        pretrained = saved['parameters']['item_embedding.weight']
        pretrained = torch.cat([
            pretrained,
            nn.init.normal_(torch.zeros(2, 64), std=0.02)
        ])
        self.item_embedding = nn.Embedding.from_pretrained(pretrained, padding_idx=0, freeze=False)
        self.item_embedding_decoder = self.item_embedding

    def condition_mask(self, logits, src):
        mask = torch.zeros_like(logits, device=logits.device, dtype=torch.bool)
        mask = mask.scatter(-1, src.unsqueeze(-2).repeat(1, mask.shape[1], 1), 1)
        logits = torch.masked_fill(logits, ~mask, -torch.inf)
        return logits

    def forward(self, src, tgt, src_mask, tgt_mask,
            src_padding_mask,
            tgt_padding_mask,
            memory_key_padding_mask,
            src_seqlen,
            tgt_seqlen,
        ):
        position_ids = torch.arange(src.size(1), dtype=torch.long, device=self.device)
        position_ids = position_ids.reshape(1, -1)
        src_position_embedding = self.position_embedding(position_ids)
        src_emb = self.dropout(self.item_embedding(src) + src_position_embedding)

        con_emb = self.condition_encoder(src_emb, src_mask, src_padding_mask, src_seqlen)

        position_ids = torch.arange(tgt.size(1), dtype=torch.long, device=self.device)
        position_ids = position_ids.reshape(1, -1)
        tgt_position_embedding = self.position_embedding(position_ids)
        tgt_emb = self.dropout(
            torch.cat([con_emb, self.item_embedding(tgt[:, 1:])], dim=1) + \
            tgt_position_embedding
        ) # replace [SOS] with condition embedding

        outs = self.transformer(
            src_emb, tgt_emb, src_mask, tgt_mask, None,
            src_padding_mask, tgt_padding_mask, memory_key_padding_mask
        )
        logits = outs @ self.item_embedding_decoder.weight.T
        logits = self.condition_mask(logits, src)

        return logits
    
    def encode(self, src, src_mask):
        position_ids = torch.arange(src.size(1), dtype=torch.long, device=self.device)
        position_ids = position_ids.reshape(1, -1)
        src_position_embedding = self.position_embedding(position_ids)
        src_emb = self.dropout(self.item_embedding(src) + src_position_embedding)

        return self.transformer.encoder(src_emb, src_mask)

    def set_condition(self, condition):
        self.condition = condition

    def decode(self, tgt, memory, tgt_mask):
        # position_ids = torch.arange(tgt.size(1), dtype=torch.long, device=self.device)
        # position_ids = position_ids.reshape(1, -1)
        # tgt_position_embedding = self.position_embedding(position_ids)
        # tgt_emb = self.dropout(self.item_embedding(tgt) + tgt_position_embedding)
        
        # return self.transformer.decoder(tgt_emb, memory, tgt_mask)
        con_emb = self.condition_encoder.condition_emb.weight[self.condition].unsqueeze(0).unsqueeze(0)
        position_ids = torch.arange(tgt.size(1), dtype=torch.long, device=self.device)
        position_ids = position_ids.reshape(1, -1)
        tgt_position_embedding = self.position_embedding(position_ids)
        if tgt.shape[1] == 1:
            # only SOS in
            tgt_emb = self.dropout(con_emb + tgt_position_embedding)
        else:
            # replace SOS with Condition embedding
            tgt_emb = self.dropout(
                torch.cat([con_emb, self.item_embedding(tgt[:, 1:])], dim=1) + \
                tgt_position_embedding
            )
        
        return self.transformer.decoder(tgt_emb, memory, tgt_mask)

def generate_square_subsequent_mask(sz):
    mask = (torch.triu(torch.ones((sz, sz), device='cuda')) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, -100000).masked_fill(mask == 1, float(0.0))
    return mask

def create_mask(src, tgt):
    src_seq_len = src.shape[1]
    tgt_seq_len = tgt.shape[1]

    tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
    # src_mask = torch.zeros((src_seq_len, src_seq_len),device='cuda').type(torch.bool)
    src_mask = generate_square_subsequent_mask(src_seq_len)

    src_padding_mask = (src == 0)
    tgt_padding_mask = (tgt == 0)
    return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask


# Train

In [21]:
from torch.optim.lr_scheduler import CosineAnnealingLR
NUM_EPOCHS = 80
model = Generator().to('cuda')
loss_fn = nn.CrossEntropyLoss(ignore_index=0)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.98), eps=1e-9)
dataset = MyDataset(source, target, source_seqlen, target_seqlen)
lr_scheduler = CosineAnnealingLR(
    optimizer,
    T_max=NUM_EPOCHS,
)

In [22]:
def train_epoch(model, optimizer):
    model.train()
    losses = 0
    train_dataloader = DataLoader(dataset, batch_size=256, shuffle=True)
    for src, tgt, src_seqlen, tgt_seqlen in tqdm(train_dataloader):
        tgt_input = tgt[:, :-1]
        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)
        logits = model(
            src, tgt_input, src_mask, tgt_mask,
            src_padding_mask, tgt_padding_mask, src_padding_mask,
            src_seqlen, tgt_seqlen)
        optimizer.zero_grad()
        tgt_out = tgt[:, 1:]
        loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
        condition_prob = model.condition_encoder.condition4loss
        reg_loss = - (condition_prob * torch.log(condition_prob + 1e-12)).sum(-1).mean()
        losses += loss.item()
        loss = loss + 1 * reg_loss
        loss.backward()

        optimizer.step()
        lr_scheduler.step()
    return losses / len(list(train_dataloader))

In [23]:
from timeit import default_timer as timer


for epoch in range(1, NUM_EPOCHS+1):
    start_time = timer()
    train_loss = train_epoch(model, optimizer)
    end_time = timer()
    print((f"Epoch: {epoch}, Train loss: {train_loss:.3f}, "f"Epoch time = {(end_time - start_time):.3f}s"))

100%|██████████| 590/590 [02:30<00:00,  3.93it/s]


Epoch: 1, Train loss: 2.251, Epoch time = 151.965s


100%|██████████| 590/590 [02:30<00:00,  3.92it/s]


Epoch: 2, Train loss: 1.888, Epoch time = 151.839s


100%|██████████| 590/590 [02:30<00:00,  3.92it/s]


Epoch: 3, Train loss: 1.830, Epoch time = 152.146s


100%|██████████| 590/590 [02:21<00:00,  4.16it/s]


Epoch: 4, Train loss: 1.799, Epoch time = 143.190s


100%|██████████| 590/590 [01:48<00:00,  5.41it/s]


Epoch: 5, Train loss: 1.778, Epoch time = 110.641s


100%|██████████| 590/590 [01:46<00:00,  5.52it/s]


Epoch: 6, Train loss: 1.761, Epoch time = 108.344s


100%|██████████| 590/590 [01:38<00:00,  6.01it/s]


Epoch: 7, Train loss: 1.747, Epoch time = 99.802s


100%|██████████| 590/590 [01:36<00:00,  6.10it/s]


Epoch: 8, Train loss: 1.738, Epoch time = 98.212s


100%|██████████| 590/590 [01:36<00:00,  6.10it/s]


Epoch: 9, Train loss: 1.729, Epoch time = 98.348s


100%|██████████| 590/590 [01:37<00:00,  6.05it/s]


Epoch: 10, Train loss: 1.719, Epoch time = 99.014s


100%|██████████| 590/590 [01:36<00:00,  6.12it/s]


Epoch: 11, Train loss: 1.713, Epoch time = 98.077s


100%|██████████| 590/590 [01:36<00:00,  6.10it/s]


Epoch: 12, Train loss: 1.707, Epoch time = 98.172s


100%|██████████| 590/590 [01:37<00:00,  6.06it/s]


Epoch: 13, Train loss: 1.700, Epoch time = 98.945s


100%|██████████| 590/590 [01:36<00:00,  6.10it/s]


Epoch: 14, Train loss: 1.694, Epoch time = 98.151s


100%|██████████| 590/590 [01:36<00:00,  6.10it/s]


Epoch: 15, Train loss: 1.690, Epoch time = 98.378s


100%|██████████| 590/590 [01:37<00:00,  6.05it/s]


Epoch: 16, Train loss: 1.683, Epoch time = 98.953s


100%|██████████| 590/590 [01:38<00:00,  5.97it/s]


Epoch: 17, Train loss: 1.679, Epoch time = 100.457s


100%|██████████| 590/590 [01:37<00:00,  6.03it/s]


Epoch: 18, Train loss: 1.675, Epoch time = 99.288s


100%|██████████| 590/590 [01:37<00:00,  6.03it/s]


Epoch: 19, Train loss: 1.670, Epoch time = 99.528s


100%|██████████| 590/590 [01:36<00:00,  6.08it/s]


Epoch: 20, Train loss: 1.666, Epoch time = 98.418s


100%|██████████| 590/590 [01:37<00:00,  6.03it/s]


Epoch: 21, Train loss: 1.662, Epoch time = 99.457s


100%|██████████| 590/590 [01:39<00:00,  5.94it/s]


Epoch: 22, Train loss: 1.659, Epoch time = 100.747s


100%|██████████| 590/590 [01:37<00:00,  6.07it/s]


Epoch: 23, Train loss: 1.654, Epoch time = 98.872s


100%|██████████| 590/590 [01:39<00:00,  5.95it/s]


Epoch: 24, Train loss: 1.651, Epoch time = 100.606s


100%|██████████| 590/590 [01:40<00:00,  5.88it/s]


Epoch: 25, Train loss: 1.647, Epoch time = 102.023s


100%|██████████| 590/590 [01:38<00:00,  6.02it/s]


Epoch: 26, Train loss: 1.643, Epoch time = 99.501s


100%|██████████| 590/590 [01:37<00:00,  6.07it/s]


Epoch: 27, Train loss: 1.641, Epoch time = 98.802s


100%|██████████| 590/590 [01:37<00:00,  6.06it/s]


Epoch: 28, Train loss: 1.637, Epoch time = 98.862s


100%|██████████| 590/590 [01:37<00:00,  6.06it/s]


Epoch: 29, Train loss: 1.632, Epoch time = 99.007s


100%|██████████| 590/590 [01:36<00:00,  6.10it/s]


Epoch: 30, Train loss: 1.631, Epoch time = 98.147s


100%|██████████| 590/590 [01:37<00:00,  6.05it/s]


Epoch: 31, Train loss: 1.629, Epoch time = 99.199s


100%|██████████| 590/590 [01:36<00:00,  6.09it/s]


Epoch: 32, Train loss: 1.624, Epoch time = 98.444s


100%|██████████| 590/590 [01:36<00:00,  6.10it/s]


Epoch: 33, Train loss: 1.622, Epoch time = 98.410s


100%|██████████| 590/590 [01:37<00:00,  6.08it/s]


Epoch: 34, Train loss: 1.620, Epoch time = 98.554s


100%|██████████| 590/590 [01:36<00:00,  6.09it/s]


Epoch: 35, Train loss: 1.616, Epoch time = 98.477s


100%|██████████| 590/590 [01:37<00:00,  6.05it/s]


Epoch: 36, Train loss: 1.613, Epoch time = 98.921s


100%|██████████| 590/590 [01:37<00:00,  6.05it/s]


Epoch: 37, Train loss: 1.611, Epoch time = 99.122s


100%|██████████| 590/590 [01:36<00:00,  6.09it/s]


Epoch: 38, Train loss: 1.608, Epoch time = 98.305s


100%|██████████| 590/590 [01:37<00:00,  6.07it/s]


Epoch: 39, Train loss: 1.605, Epoch time = 98.926s


100%|██████████| 590/590 [01:36<00:00,  6.09it/s]


Epoch: 40, Train loss: 1.603, Epoch time = 98.385s


100%|██████████| 590/590 [01:37<00:00,  6.06it/s]


Epoch: 41, Train loss: 1.601, Epoch time = 98.969s


100%|██████████| 590/590 [01:36<00:00,  6.10it/s]


Epoch: 42, Train loss: 1.597, Epoch time = 98.216s


100%|██████████| 590/590 [01:37<00:00,  6.07it/s]


Epoch: 43, Train loss: 1.596, Epoch time = 98.752s


100%|██████████| 590/590 [01:37<00:00,  6.07it/s]


Epoch: 44, Train loss: 1.593, Epoch time = 98.575s


100%|██████████| 590/590 [01:37<00:00,  6.08it/s]


Epoch: 45, Train loss: 1.591, Epoch time = 98.694s


100%|██████████| 590/590 [01:36<00:00,  6.11it/s]


Epoch: 46, Train loss: 1.589, Epoch time = 98.077s


100%|██████████| 590/590 [01:36<00:00,  6.10it/s]


Epoch: 47, Train loss: 1.588, Epoch time = 98.317s


100%|██████████| 590/590 [01:36<00:00,  6.12it/s]


Epoch: 48, Train loss: 1.584, Epoch time = 97.930s


100%|██████████| 590/590 [01:37<00:00,  6.03it/s]


Epoch: 49, Train loss: 1.583, Epoch time = 99.481s


100%|██████████| 590/590 [01:36<00:00,  6.09it/s]


Epoch: 50, Train loss: 1.581, Epoch time = 98.386s


100%|██████████| 590/590 [01:37<00:00,  6.06it/s]


Epoch: 51, Train loss: 1.577, Epoch time = 98.927s


100%|██████████| 590/590 [01:37<00:00,  6.06it/s]


Epoch: 52, Train loss: 1.575, Epoch time = 98.773s


100%|██████████| 590/590 [01:38<00:00,  6.01it/s]


Epoch: 53, Train loss: 1.576, Epoch time = 99.738s


100%|██████████| 590/590 [01:37<00:00,  6.04it/s]


Epoch: 54, Train loss: 1.572, Epoch time = 99.142s


100%|██████████| 590/590 [01:37<00:00,  6.05it/s]


Epoch: 55, Train loss: 1.570, Epoch time = 99.223s


100%|██████████| 590/590 [01:39<00:00,  5.90it/s]


Epoch: 56, Train loss: 1.569, Epoch time = 101.370s


100%|██████████| 590/590 [01:37<00:00,  6.05it/s]


Epoch: 57, Train loss: 1.567, Epoch time = 99.083s


100%|██████████| 590/590 [01:37<00:00,  6.06it/s]


Epoch: 58, Train loss: 1.564, Epoch time = 98.817s


100%|██████████| 590/590 [01:36<00:00,  6.11it/s]


Epoch: 59, Train loss: 1.564, Epoch time = 98.226s


100%|██████████| 590/590 [01:37<00:00,  6.08it/s]


Epoch: 60, Train loss: 1.562, Epoch time = 98.478s


100%|██████████| 590/590 [01:37<00:00,  6.08it/s]


Epoch: 61, Train loss: 1.558, Epoch time = 98.714s


100%|██████████| 590/590 [01:37<00:00,  6.07it/s]


Epoch: 62, Train loss: 1.559, Epoch time = 98.584s


100%|██████████| 590/590 [01:36<00:00,  6.10it/s]


Epoch: 63, Train loss: 1.557, Epoch time = 98.349s


100%|██████████| 590/590 [01:36<00:00,  6.09it/s]


Epoch: 64, Train loss: 1.554, Epoch time = 98.243s


100%|██████████| 590/590 [01:36<00:00,  6.10it/s]


Epoch: 65, Train loss: 1.552, Epoch time = 98.343s


100%|██████████| 590/590 [01:36<00:00,  6.12it/s]


Epoch: 66, Train loss: 1.552, Epoch time = 97.774s


100%|██████████| 590/590 [01:36<00:00,  6.10it/s]


Epoch: 67, Train loss: 1.550, Epoch time = 98.408s


100%|██████████| 590/590 [01:37<00:00,  6.08it/s]


Epoch: 68, Train loss: 1.548, Epoch time = 98.551s


100%|██████████| 590/590 [01:36<00:00,  6.10it/s]


Epoch: 69, Train loss: 1.549, Epoch time = 98.403s


100%|██████████| 590/590 [01:37<00:00,  6.08it/s]


Epoch: 70, Train loss: 1.545, Epoch time = 98.544s


100%|██████████| 590/590 [01:36<00:00,  6.10it/s]


Epoch: 71, Train loss: 1.543, Epoch time = 98.354s


100%|██████████| 590/590 [01:36<00:00,  6.09it/s]


Epoch: 72, Train loss: 1.543, Epoch time = 98.313s


100%|██████████| 590/590 [01:36<00:00,  6.09it/s]


Epoch: 73, Train loss: 1.542, Epoch time = 98.585s


100%|██████████| 590/590 [01:36<00:00,  6.09it/s]


Epoch: 74, Train loss: 1.539, Epoch time = 98.348s


100%|██████████| 590/590 [01:37<00:00,  6.07it/s]


Epoch: 75, Train loss: 1.540, Epoch time = 98.842s


100%|██████████| 590/590 [01:37<00:00,  6.04it/s]


Epoch: 76, Train loss: 1.538, Epoch time = 99.065s


100%|██████████| 590/590 [01:37<00:00,  6.03it/s]


Epoch: 77, Train loss: 1.535, Epoch time = 100.004s


100%|██████████| 590/590 [01:36<00:00,  6.10it/s]


Epoch: 78, Train loss: 1.535, Epoch time = 98.115s


100%|██████████| 590/590 [01:37<00:00,  6.03it/s]


Epoch: 79, Train loss: 1.535, Epoch time = 99.974s


100%|██████████| 590/590 [01:38<00:00,  6.02it/s]


Epoch: 80, Train loss: 1.531, Epoch time = 99.516s


In [24]:
torch.save(model.state_dict(), f'translator-{dataset_name}-con.pth')

# Inference

In [25]:
def inference_mask(logits, src, ys):
    mask = torch.zeros_like(logits, device=logits.device, dtype=torch.bool)
    mask = mask.scatter(-1, src, 1)
    mask = mask.scatter(-1, ys, 0)
    logits = torch.masked_fill(logits, ~mask, -torch.inf)
    return logits

def greedy_decode(model, src, src_mask, max_len, start_symbol):
    src = src.to('cuda')
    src_mask = src_mask.to('cuda')

    memory = model.encode(src, src_mask)
    ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to('cuda')
    for i in range(max_len-1):
        memory = memory.to('cuda')
        tgt_mask = (generate_square_subsequent_mask(ys.size(1))
                    .type(torch.bool)).to('cuda')
        out = model.decode(ys, memory, tgt_mask)
        prob = out[:, -1] @ model.item_embedding_decoder.weight.T
        prob = inference_mask(prob, src, ys)
        _, next_word = torch.max(prob, dim=-1)
        next_word = next_word.item()

        ys = torch.cat([ys,
                        torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=1)
        if next_word == EOS:
            break
    return ys

def translate(model: torch.nn.Module, src):
    model.eval()
    src = src.reshape(1, -1)
    num_tokens = src.shape[1]
    src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)
    tgt_tokens = greedy_decode(
        model, src, src_mask, max_len=25, start_symbol=SOS).flatten()
    return tgt_tokens



In [26]:
def preprocess(seq):
    return torch.tensor([SOS] + seq + [EOS], device='cuda')
original_data = torch.load(f'./dataset/{full_dataset_name}-noise-50/{dataset_name}/train_ori.pth')
seqlist = [_[1][:_[3]] + [_[2][_[3] - 1]] for _ in original_data]
seqlist = [preprocess(_) for _ in seqlist]

In [None]:
model.load_state_dict(torch.load(f'translator-{dataset_name}-con.pth'))

In [None]:
idx = 30
# model.condition = 0
model.set_condition(1)
print(source[idx])
print(translate(model, source[idx]))
print(target[idx])

In [33]:
model.condition = 2
filtered_sequences = []
for seq in tqdm(seqlist[:100]):
    rst = translate(model, seq)
    filtered_sequences.append(rst)

100%|██████████| 100/100 [00:16<00:00,  6.07it/s]


In [32]:
debug = deepcopy(filtered_sequences)

In [34]:
cnt = 0
for a, b in zip(debug, filtered_sequences):
    if len(a) == len(b):
        if (a == b).all():
            cnt += 1
print(cnt)

83


In [None]:
interval = list(range(0, 25001, 5000))
rst_list = []
for i in range(len(interval) - 1):
    rst_list.append(torch.load(f'f-seq-con-{dataset_name}-{interval[i]}-{interval[i + 1]}.pth'))
rst = []
for _ in rst_list:
    rst += _
print(len(rst))

In [None]:
ori_pattern = torch.load(f'./dataset/{full_dataset_name}-noise-50/{dataset_name}/train_new-pure.pth')
ori_seqlist = [list(_[1][:_[3]]) + [_[2][_[3] - 1]] for _ in list(ori_pattern)]
len(ori_pattern)

In [None]:
max_seq_len = 50
def truncate_or_pad(seq):
    cur_seq_len = len(seq)
    if cur_seq_len > max_seq_len:
        return seq[-max_seq_len:]
    else:
        return seq + [0] * (max_seq_len - cur_seq_len)

train_set = set()
# for _ in ori_seqlist:
#     train_set.add(tuple(_))

for pattern in rst:
    seq = pattern.tolist()[1:-1]
    train_set.add(tuple(seq))
print(len(train_set))

In [None]:
train_list = []
for _ in train_set:
    train_list.append([
        1,
        truncate_or_pad(list(_)[:-1]),
        truncate_or_pad(list(_)[1:]),
        sum([a != 0 for a in list(_)[:-1]]),
        [1] * max_seq_len,
        [0] * max_seq_len,
    ])
print(len(train_list))

In [None]:
torch.save(train_list + ori_pattern, f'./dataset/{full_dataset_name}-noise-50/{dataset_name}/train_gene.pth')