In [1]:
LOAD_MODEL = 'models/seq2seq_2.01.pt'
MODEL_CHECKPOINT = 'models/seq2seq_2.01.pt'
DATASET_PATH = 'data/interim/preprocessed_paranmt3.tsv'

In [2]:
import torch
import numpy as np
from torch.utils.data import DataLoader

import os
os.chdir("..") # go to the root dir

# Get the Dataset

In [3]:
MAX_SENT_SIZE = 10
MAX_TOKENS = 30_000

In [4]:
from src.data.make_dataset import ParanmtDataset

train_dataset = ParanmtDataset(
    path=DATASET_PATH,
    max_sent_size=MAX_SENT_SIZE,
    train=True,
    seed=42,
)

In [5]:
train_dataset.build_vocab(
    min_freq=2,
    specials=['<unk>', '<pad>', '<sos>', '<eos>'],
    max_tokens=MAX_TOKENS,
)

In [6]:
enc_vocab = train_dataset.toxic_vocab
dec_vocab = train_dataset.neutral_vocab

print("size of encoder vocab:", len(enc_vocab))
print("size of decoder vocab:", len(dec_vocab))

size of encoder vocab: 12137
size of decoder vocab: 14544


In [7]:
val_dataset = ParanmtDataset(
    path=DATASET_PATH,
    max_sent_size=MAX_SENT_SIZE,
    vocabs=(enc_vocab, dec_vocab), # avoid data leakage
    train=False,
    seed=42,
    take_first=10_000,
)

In [8]:
train_dataset.df

Unnamed: 0,similarity,lenght_diff,toxic_sent,neutral_sent,toxic_val,neutral_val
0,0.784351,0.190476,"[dressing, like, a, bum, ?]","[walk, like, a, bum, ?]",0.911522,0.040193
1,0.610479,0.000000,"[you, bum, !]","[slacker, !]",0.968336,0.006487
2,0.866624,0.060606,"[chloe, ,, stop, being, so, paranoid, .]","[chloe, ,, stop, being, so, paranoid, .]",0.611291,0.021283
3,0.949912,0.157895,"[it, stinks, in, here, .]","[stinks, in, here, .]",0.698503,0.013900
4,0.885716,0.347826,"[i, want, to, silence, you, .]","[i, silence, you, .]",0.921562,0.020940
...,...,...,...,...,...,...
162956,0.816062,0.076923,"[that, is, a, stupid, excuse, !]","[that, is, a, cowardly, excuse, !]",0.999643,0.042241
162957,0.756084,0.300000,"[what, the, hell, ?, huh, ?]","[what, is, wrong, ?]",0.825482,0.000042
162958,0.618064,0.333333,"[like, your, pecker, .]","[like, a, cue, .]",0.978997,0.000061
162959,0.687292,0.037037,"[my, friend, i, asexual, beast, .]","[my, girl, i, asexual, bestie, .]",0.997259,0.071518


In [9]:
len(train_dataset), len(val_dataset)

(162961, 10000)

## Build the Dataloaders

In [10]:
batch_size = 128

In [11]:
train_dataloader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
)

val_dataloader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
)

In [12]:
# let's check if shape and everything is ok
for batch in train_dataloader:
    toxic_sent, neutral_sent = batch
    print("toxic_sent.shape:", toxic_sent.shape)
    print("neutral_sent.shape:", neutral_sent.shape)
    break

toxic_sent.shape: torch.Size([128, 10])
neutral_sent.shape: torch.Size([128, 10])


In [13]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

# Load the Model

- Simple EncoderDecoder (Seq2Seq) architerture
- a little trick was used (every time the decoder carries with it context vector from the encoder)

In [14]:
from src.models.seq2seq.encoder import Encoder
from src.models.seq2seq.decoder2 import Decoder2 # NOTE: using different Decoder than first notebook
from src.models.seq2seq import Seq2Seq

In [15]:
INPUT_DIM = len(enc_vocab)
OUTPUT_DIM = len(dec_vocab)
EMBED_DIM = 128
NUM_HIDDEN = 256
N_LAYERS = 1
DROPOUT = 0.3
ENC_PADDING_IDX = enc_vocab['<pad>']
DEC_PADDING_IDX = dec_vocab['<pad>']

In [16]:
# load the encoder and decoder for our model
encoder = Encoder(
    input_dim=INPUT_DIM,
    embed_dim=EMBED_DIM,
    hidden_dim=NUM_HIDDEN,
    num_layers=N_LAYERS,
    dropout=DROPOUT,
    vocab=enc_vocab,
    padding_idx=ENC_PADDING_IDX
).to(device)

decoder = Decoder2(
    output_dim=OUTPUT_DIM,
    embed_dim=EMBED_DIM,
    hidden_dim=NUM_HIDDEN,
    num_layers=N_LAYERS,
    dropout=DROPOUT,
    vocab=dec_vocab,
    padding_idx=DEC_PADDING_IDX
).to(device)

In [17]:
best_loss = float('inf')

model = Seq2Seq(
    encoder=encoder,
    decoder=decoder,
    device=device,
    max_sent_size=MAX_SENT_SIZE,
).to(device)

In [18]:
from src.models.utils import count_parameters

print(f"number of parameters in model: {count_parameters(model)//1e6}M")

number of parameters in model: 19.0M


In [19]:
# model = torch.load(LOAD_MODEL)
# model.to(device)

In [20]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
criterion = torch.nn.CrossEntropyLoss(ignore_index=dec_vocab['<pad>'])

In [21]:
from src.models.train_model import train

best_loss = train(
    model=model,
    loaders=(train_dataloader, val_dataloader),
    optimizer=optimizer,
    criterion=criterion,
    epochs=20,
    device=device,
    best_loss=best_loss,
    ckpt_path=MODEL_CHECKPOINT,
    clip_grad=1,
    teacher_force={
        'value': 1,
        'gamma': 0,
        'update_every_n_epoch': 5,
    } # first 5 epoch teacher force 1, after it will be turned off
)

Training 1: 100%|██████████| 1274/1274 [00:40<00:00, 31.68it/s, loss=3.48]
Evaluating 1: 100%|██████████| 79/79 [00:00<00:00, 94.12it/s, loss=4.92]
Training 2: 100%|██████████| 1274/1274 [00:40<00:00, 31.22it/s, loss=2.81]
Evaluating 2: 100%|██████████| 79/79 [00:00<00:00, 91.25it/s, loss=4.72]
Training 3: 100%|██████████| 1274/1274 [00:40<00:00, 31.22it/s, loss=2.55]
Evaluating 3: 100%|██████████| 79/79 [00:00<00:00, 81.78it/s, loss=4.86]
Training 4: 100%|██████████| 1274/1274 [00:40<00:00, 31.12it/s, loss=2.36]
Evaluating 4: 100%|██████████| 79/79 [00:00<00:00, 95.30it/s, loss=5.07]
Training 5: 100%|██████████| 1274/1274 [00:40<00:00, 31.25it/s, loss=2.21]
Evaluating 5: 100%|██████████| 79/79 [00:00<00:00, 93.65it/s, loss=5.11]


Update teacher force to 0


Training 6: 100%|██████████| 1274/1274 [00:40<00:00, 31.36it/s, loss=3.31]
Evaluating 6: 100%|██████████| 79/79 [00:00<00:00, 94.50it/s, loss=3.37]
Training 7: 100%|██████████| 1274/1274 [00:40<00:00, 31.40it/s, loss=3.1]
Evaluating 7: 100%|██████████| 79/79 [00:00<00:00, 92.83it/s, loss=3.34]
Training 8: 100%|██████████| 1274/1274 [00:40<00:00, 31.22it/s, loss=2.97]
Evaluating 8: 100%|██████████| 79/79 [00:00<00:00, 99.72it/s, loss=3.36] 
Training 9: 100%|██████████| 1274/1274 [00:40<00:00, 31.24it/s, loss=2.87]
Evaluating 9: 100%|██████████| 79/79 [00:00<00:00, 97.11it/s, loss=3.36] 
Training 10: 100%|██████████| 1274/1274 [00:40<00:00, 31.38it/s, loss=2.78]
Evaluating 10: 100%|██████████| 79/79 [00:00<00:00, 95.35it/s, loss=3.41]
Training 11: 100%|██████████| 1274/1274 [00:40<00:00, 31.38it/s, loss=2.7]
Evaluating 11: 100%|██████████| 79/79 [00:00<00:00, 94.83it/s, loss=3.46]
Training 12: 100%|██████████| 1274/1274 [00:40<00:00, 31.48it/s, loss=2.63]
Evaluating 12: 100%|██████████| 

In [22]:
# # let's load the model and predict
# model = torch.load(MODEL_CHECKPOINT)
# model.to(device)
# model.eval()

In [26]:
from nltk.tokenize.treebank import TreebankWordDetokenizer
detokenizer = TreebankWordDetokenizer()

# let's see how our model works
num_examples = 10
num_sentence = 3
dataset = train_dataset
for _ in range(num_examples):
    idx = np.random.randint(0, len(dataset))
    toxic_sent = detokenizer.detokenize(dataset.df.loc[idx, 'toxic_sent'])
    neutral_sent = detokenizer.detokenize(dataset.df.loc[idx, 'neutral_sent'])
    
    print('toxic_sent:', toxic_sent)
    print('neutral_sent:', neutral_sent)
    # let's use beam search
    preds = model.predict(
        toxic_sent,
        beam=True,
        beam_search_num_candidates=num_sentence,
        post_process_text=False,
    )
    print("predictions:")
    for i in range(num_sentence):
        print(f"\t{i+1})", preds[i])
    print("\n")

toxic_sent: if you die
neutral_sent: even if you died
predictions:
	1) ['if', 'you', 'die', '<eos>']
	2) ['if', 'you', 'die', 'die', '<eos>']
	3) ['when', 'you', 'die', '<eos>']


toxic_sent: i am going to die
neutral_sent: will i die here will i
predictions:
	1) ['i', 'am', 'die', '<eos>']
	2) ['i', 'am', 'dying', '<eos>']
	3) ['i', 'am', 'die', 'to', '<eos>']


toxic_sent: paula i crazy come on.
neutral_sent: paulie, this is crazy.
predictions:
	1) ['paulie', 'i', 'crazy', 'crazy', '.', '.', '<eos>']
	2) ['paulie', 'i', 'crazy', 'on', '.', '.', '<eos>']
	3) ['paulie', 'i', 'crazy', 'crazy', '.', '<eos>']


toxic_sent: drunk my ass.
neutral_sent: you are drunk.
predictions:
	1) ['drunk', 'my', '.', '.', '<eos>']
	2) ['drunk', 'my', '.', '.', '.', '<eos>']
	3) ['drunk', '.', '.', '.', '.', '<eos>']


toxic_sent: are you crazy?
neutral_sent: merlin! are you mad?
predictions:
	1) ['are', 'you', 'mad', '?', '<eos>']
	2) ['are', 'you', 'mad', '?', '?', '<eos>']
	3) ['have', 'you', 'mad', '