In [1]:
LOAD_MODEL = 'models/seq2seq.01.pt'
MODEL_CHECKPOINT = 'models/seq2seq.01.pt'
DATASET_PATH = 'data/interim/preprocessed_paranmt3.tsv'

In [2]:
import torch
import numpy as np
from torch.utils.data import DataLoader

import os
os.chdir("..") # go to the root dir

# Get the Dataset

In [3]:
MAX_SENT_SIZE = 12
MAX_TOKENS = 30_000

In [4]:
from src.data.make_dataset import ParanmtDataset

train_dataset = ParanmtDataset(
    path=DATASET_PATH,
    max_sent_size=MAX_SENT_SIZE,
    train=True,
    seed=42,
)

In [5]:
train_dataset.build_vocab(
    min_freq=2,
    specials=['<unk>', '<pad>', '<sos>', '<eos>'],
    max_tokens=MAX_TOKENS,
)

In [6]:
enc_vocab = train_dataset.toxic_vocab
dec_vocab = train_dataset.neutral_vocab

In [7]:
print("size of encoder vocab:", len(enc_vocab))
print("size of decoder vocab:", len(dec_vocab))

size of encoder vocab: 16608
size of decoder vocab: 19520


In [8]:
val_dataset = ParanmtDataset(
    path=DATASET_PATH,
    max_sent_size=MAX_SENT_SIZE,
    vocabs=(enc_vocab, dec_vocab), # avoid data leakage
    train=False,
    seed=42,
    take_first=10_000,
)

In [9]:
train_dataset.df

Unnamed: 0,similarity,lenght_diff,toxic_sent,neutral_sent,toxic_val,neutral_val
0,0.607065,0.034483,"[get, yourself, a, good, hooker, .]","[go, get, yourself, a, nice, bird, .]",0.999393,0.001473
1,0.918913,0.200000,"[long, hair, ,, pretty, little, mouth, ,, perf...","[long, hair, ,, nice, lips, ,, perfect, butt, ?]",0.997611,0.106524
2,0.634488,0.185185,"[you, should, have, finished, it, and, ordered...","[you, should, have, ordered, that, pilot, to, ...",0.953872,0.014506
3,0.865425,0.357143,"[what, the, hell, is, he, saying, ?]","[what, does, he, say, ?]",0.967230,0.000041
4,0.845061,0.040000,"[then, why, are, you, crazy, ?]","[then, why, with, the, crazy, ?]",0.994152,0.007104
...,...,...,...,...,...,...
238619,0.809637,0.000000,"[parasites, kill, puppies, .]","[germs, can, kill, puppies, .]",0.997600,0.001762
238620,0.660608,0.111111,"[damn, it, ,, where, is, everybody, ?, bruno, !]","[where, the, hell, did, everyone, go, ?]",0.997657,0.176493
238621,0.697438,0.222222,"[i, am, tired, of, hearing, you, two, gallinas...","[i, am, tired, of, hearing, your, whining, .]",0.955657,0.273435
238622,0.711497,0.105263,"[it, is, a, blow, job, .]","[she, is, a, whack, job, .]",0.994690,0.192220


In [10]:
len(train_dataset), len(val_dataset)

(238624, 10000)

# Build the Dataloaders

In [11]:
batch_size = 32

In [12]:
train_dataloader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
)

val_dataloader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
)

In [13]:
# let's check if shape and everything is ok
for batch in train_dataloader:
    toxic_sent, neutral_sent = batch
    print("toxic_sent.shape:", toxic_sent.shape)
    print("neutral_sent.shape:", neutral_sent.shape)
    break

toxic_sent.shape: torch.Size([32, 12])
neutral_sent.shape: torch.Size([32, 12])


In [14]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

# Load the Model

- Simple EncoderDecoder (Seq2Seq) architerture

In [15]:
from src.models.seq2seq.encoder import Encoder
from src.models.seq2seq.decoder import Decoder
from src.models.seq2seq import Seq2Seq

In [16]:
INPUT_DIM = len(enc_vocab)
OUTPUT_DIM = len(dec_vocab)
EMBED_DIM = 128
NUM_HIDDEN = 256
N_LAYERS = 2
DROPOUT = 0.3
ENC_PADDING_IDX = enc_vocab['<pad>']
DEC_PADDING_IDX = dec_vocab['<pad>']

In [17]:
# load the encoder and decoder for our model
encoder = Encoder(
    input_dim=INPUT_DIM,
    embed_dim=EMBED_DIM,
    hidden_dim=NUM_HIDDEN,
    num_layers=N_LAYERS,
    dropout=DROPOUT,
    vocab=enc_vocab,
    padding_idx=ENC_PADDING_IDX
).to(device)

decoder = Decoder(
    output_dim=OUTPUT_DIM,
    embed_dim=EMBED_DIM,
    hidden_dim=NUM_HIDDEN,
    num_layers=N_LAYERS,
    dropout=DROPOUT,
    vocab=dec_vocab,
    padding_idx=DEC_PADDING_IDX
).to(device)

In [18]:
best_loss = float('inf')

model = Seq2Seq(
    encoder=encoder,
    decoder=decoder,
    device=device,
    max_sent_size=MAX_SENT_SIZE,
).to(device)

In [19]:
from src.models.utils import count_parameters

print(f"number of parameters in model: {count_parameters(model)//1e6}M")

number of parameters in model: 26.0M


In [20]:
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
criterion = torch.nn.CrossEntropyLoss(ignore_index=dec_vocab['<pad>'])

In [21]:
from src.models.train_model import train

best_loss = train(
    model=model,
    loaders=(train_dataloader, val_dataloader),
    optimizer=optimizer,
    criterion=criterion,
    epochs=30,
    device=device,
    best_loss=best_loss,
    ckpt_path=MODEL_CHECKPOINT,
    clip_grad=1,
    teacher_force={
        'value': 1,
        'gamma': 0,
        'update_every_n_epoch': 10,
    } # first 10 epoch teacher force 1, after it will be turned off
)

Training 1: 100%|██████████| 7457/7457 [04:00<00:00, 31.05it/s, loss=3.76]
Evaluating 1: 100%|██████████| 313/313 [00:02<00:00, 143.49it/s, loss=5.42]
Training 2: 100%|██████████| 7457/7457 [03:59<00:00, 31.13it/s, loss=3.13]
Evaluating 2: 100%|██████████| 313/313 [00:02<00:00, 154.96it/s, loss=5.49]
Training 3: 100%|██████████| 7457/7457 [04:00<00:00, 30.98it/s, loss=2.89]
Evaluating 3: 100%|██████████| 313/313 [00:02<00:00, 145.94it/s, loss=5.65]
Training 4: 100%|██████████| 7457/7457 [04:02<00:00, 30.73it/s, loss=2.75]
Evaluating 4: 100%|██████████| 313/313 [00:02<00:00, 148.30it/s, loss=5.67]
Training 5: 100%|██████████| 7457/7457 [03:59<00:00, 31.10it/s, loss=2.64]
Evaluating 5: 100%|██████████| 313/313 [00:02<00:00, 146.70it/s, loss=5.76]
Training 6: 100%|██████████| 7457/7457 [04:00<00:00, 31.06it/s, loss=2.57]
Evaluating 6: 100%|██████████| 313/313 [00:02<00:00, 146.59it/s, loss=5.86]
Training 7: 100%|██████████| 7457/7457 [04:00<00:00, 30.98it/s, loss=2.5] 
Evaluating 7: 100%|

Update teacher force to 0


Training 11: 100%|██████████| 7457/7457 [03:58<00:00, 31.31it/s, loss=3.67]
Evaluating 11: 100%|██████████| 313/313 [00:02<00:00, 150.89it/s, loss=3.69]
Training 12: 100%|██████████| 7457/7457 [04:01<00:00, 30.94it/s, loss=3.5] 
Evaluating 12: 100%|██████████| 313/313 [00:02<00:00, 144.02it/s, loss=3.62]
Training 13: 100%|██████████| 7457/7457 [03:58<00:00, 31.23it/s, loss=3.41]
Evaluating 13: 100%|██████████| 313/313 [00:02<00:00, 150.61it/s, loss=3.63]
Training 14: 100%|██████████| 7457/7457 [03:58<00:00, 31.32it/s, loss=3.35]
Evaluating 14: 100%|██████████| 313/313 [00:02<00:00, 143.69it/s, loss=3.59]
Training 15: 100%|██████████| 7457/7457 [03:59<00:00, 31.19it/s, loss=3.3] 
Evaluating 15: 100%|██████████| 313/313 [00:02<00:00, 143.53it/s, loss=3.58]
Training 16: 100%|██████████| 7457/7457 [04:00<00:00, 30.99it/s, loss=3.26]
Evaluating 16: 100%|██████████| 313/313 [00:02<00:00, 145.06it/s, loss=3.56]
Training 17: 100%|██████████| 7457/7457 [03:58<00:00, 31.32it/s, loss=3.22]
Evalua

Update teacher force to 0


Training 21: 100%|██████████| 7457/7457 [03:59<00:00, 31.13it/s, loss=3.1] 
Evaluating 21: 100%|██████████| 313/313 [00:02<00:00, 145.29it/s, loss=3.52]
Training 22: 100%|██████████| 7457/7457 [03:58<00:00, 31.25it/s, loss=3.08]
Evaluating 22: 100%|██████████| 313/313 [00:02<00:00, 144.49it/s, loss=3.5] 
Training 23: 100%|██████████| 7457/7457 [03:57<00:00, 31.35it/s, loss=3.05]
Evaluating 23: 100%|██████████| 313/313 [00:02<00:00, 145.59it/s, loss=3.51]
Training 24: 100%|██████████| 7457/7457 [04:01<00:00, 30.92it/s, loss=3.03]
Evaluating 24: 100%|██████████| 313/313 [00:02<00:00, 144.18it/s, loss=3.49]
Training 25: 100%|██████████| 7457/7457 [03:58<00:00, 31.30it/s, loss=3.01]
Evaluating 25: 100%|██████████| 313/313 [00:02<00:00, 141.95it/s, loss=3.5] 
Training 26: 100%|██████████| 7457/7457 [03:58<00:00, 31.24it/s, loss=2.99]
Evaluating 26: 100%|██████████| 313/313 [00:02<00:00, 141.40it/s, loss=3.52]
Training 27: 100%|██████████| 7457/7457 [03:58<00:00, 31.23it/s, loss=2.97]
Evalua

Update teacher force to 0


In [22]:
best_loss

3.466131195854455

In [23]:
# let's load the model and predict
model = torch.load(MODEL_CHECKPOINT)
model.to(device)
model.eval()

Seq2Seq(
  (encoder): Encoder(
    (vocab): Vocab()
    (embedding): Embedding(16608, 128, padding_idx=1)
    (rnn): GRU(128, 256, num_layers=2, batch_first=True, dropout=0.3)
    (dropout): Dropout(p=0.3, inplace=False)
  )
  (decoder): Decoder(
    (vocab): Vocab()
    (embedding): Embedding(19520, 128, padding_idx=1)
    (rnn): GRU(128, 256, num_layers=2, batch_first=True, dropout=0.3)
    (fc_out): Sequential(
      (0): Linear(in_features=256, out_features=1024, bias=True)
      (1): ReLU()
      (2): Dropout(p=0.3, inplace=False)
      (3): Linear(in_features=1024, out_features=19520, bias=True)
    )
    (dropout): Dropout(p=0.3, inplace=False)
  )
)

In [27]:
from nltk.tokenize.treebank import TreebankWordDetokenizer
detokenizer = TreebankWordDetokenizer()

# let's see how our model works
num_examples = 10
num_sentence = 3
dataset = train_dataset
for _ in range(num_examples):
    idx = np.random.randint(0, len(dataset))
    toxic_sent = detokenizer.detokenize(dataset.df.loc[idx, 'toxic_sent'])
    neutral_sent = detokenizer.detokenize(dataset.df.loc[idx, 'neutral_sent'])
    
    print('toxic_sent:', toxic_sent)
    print('neutral_sent:', neutral_sent)
    preds = model.predict(toxic_sent, beam=True, beam_search_num_sentence=num_sentence) # let's use beam search
    print("predictions:")
    for i in range(num_sentence):
        print(f"\t{i+1})", preds[i])
    print("\n")

toxic_sent: what the fuck he is talking about?!
neutral_sent: what is he talking about?
predictions:
	1) what is he talking about?
	2) what is he talking about about
	3) what is he talking talking?


toxic_sent: it is fucking painful.
neutral_sent: it hurts a lot.
predictions:
	1) it is bloody.
	2) it is fucking . .
	3) it is bloody . .


toxic_sent: we can fuck tomorrow.
neutral_sent: we can love each other tomorrow.
predictions:
	1) we can die tomorrow tomorrow tomorrow.
	2) we can die tomorrow tomorrow . .
	3) we can die tomorrow tomorrow tomorrow . .


toxic_sent: i fucking told you.
neutral_sent: i already told you.
predictions:
	1) i told you.
	2) i told you!
	3) i told you . .


toxic_sent: studying torah . asshole,
neutral_sent: i am studying the torah, piping.
predictions:
	1) happy, . . . . . . .
	2) happy, . . . . . .
	3) happy, . . . . .


toxic_sent: you get one shot.
neutral_sent: you have one chance.
predictions:
	1) you shot one chance.
	2) you shot one shot.
	3) you ha