In [1]:
LOAD_MODEL = 'models/seq2seq.01.pt'
MODEL_CHECKPOINT = 'models/seq2seq.01.pt'
DATASET_PATH = 'data/interim/preprocessed_paranmt.tsv'

In [2]:
import torch
import numpy as np
from torch.utils.data import DataLoader

import os
os.chdir("..") # go to the root dir

# Get the Dataset

In [3]:
MAX_SENT_SIZE = 10
MAX_TOKENS = 8_000

In [4]:
from src.data.make_dataset import ParanmtDataset

train_dataset = ParanmtDataset(
    path=DATASET_PATH,
    max_sent_size=MAX_SENT_SIZE,
    train=True,
    seed=42,
)

In [5]:
train_dataset.build_vocab(
    min_freq=2,
    specials=['<unk>', '<pad>', '<sos>', '<eos>'],
    max_tokens=MAX_TOKENS,
)

In [6]:
enc_vocab = train_dataset.toxic_vocab
dec_vocab = train_dataset.neutral_vocab

In [7]:
print("size of encoder vocab:", len(enc_vocab))
print("size of decoder vocab:", len(dec_vocab))

size of encoder vocab: 8000
size of decoder vocab: 8000


In [8]:
val_dataset = ParanmtDataset(
    path=DATASET_PATH,
    max_sent_size=MAX_SENT_SIZE,
    vocabs=(enc_vocab, dec_vocab), # avoid data leakage
    train=False,
    seed=42,
    take_first=10_000,
)

In [9]:
train_dataset.df

Unnamed: 0,similarity,lenght_diff,toxic_sent,neutral_sent,toxic_val,neutral_val
0,0.708038,0.171429,"[what, the, hell, i, danger, looking, at, ?]","[what, the, hell, is, safe, watch, ?]",0.888703,0.130954
1,0.606822,0.238095,"[lisa, ,, hit, him, again, .]","[lisa, ,, one, more, .]",0.957538,0.000053
2,0.719271,0.051282,"[what, are, you, doing, with, that, hooker, ?]","[what, are, you, doing, with, that, outsider, ?]",0.998877,0.000056
3,0.821008,0.047619,"[we, are, going, to, hit, him, !]","[it, is, going, to, hit, !]",0.997299,0.014387
4,0.725030,0.096774,"[i, do, not, fucking, believe, it, !]","[i, do, not, freaking, believe, it]",0.957814,0.056393
...,...,...,...,...,...,...
157735,0.827812,0.200000,"[I, will, make, you, fall, !]","[I, am, going, to, fall, !]",0.590488,0.006672
157736,0.625040,0.333333,"[i, fucking, my, girlfriend, .]","[satisfying, my, girl, behind, my, back, .]",0.999578,0.029578
157737,0.815115,0.041667,"[he, is, going, to, shoot, again, .]","[he, is, going, to, fire, again, .]",0.989201,0.008294
157738,0.866068,0.037037,"[oh, ,, mars, solid, ,, you, stink, .]","[oh, ,, mars, solid, ,, you, smell, .]",0.999077,0.072257


In [10]:
len(train_dataset), len(val_dataset)

(157740, 10000)

# Build the Dataloaders

In [11]:
batch_size = 128

In [12]:
train_dataloader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=4,
)

val_dataloader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=4,
)

In [13]:
# let's check if shape and everything is ok
for batch in train_dataloader:
    toxic_sent, neutral_sent = batch
    print("toxic_sent.shape:", toxic_sent.shape)
    print("neutral_sent.shape:", neutral_sent.shape)
    break

toxic_sent.shape: torch.Size([128, 10])
neutral_sent.shape: torch.Size([128, 10])


In [14]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

# Load the Model

- Simple EncoderDecoder (Seq2Seq) architerture

In [15]:
from src.models.seq2seq.encoder import Encoder
from src.models.seq2seq.decoder import Decoder
from src.models.seq2seq import Seq2Seq

In [16]:
INPUT_DIM = len(enc_vocab)
OUTPUT_DIM = len(dec_vocab)
EMBED_DIM = 256
NUM_HIDDEN = 512
N_LAYERS = 6
DROPOUT = 0.5
ENC_PADDING_IDX = enc_vocab['<pad>']
DEC_PADDING_IDX = dec_vocab['<pad>']

In [17]:
# load the encoder and decoder for our model
encoder = Encoder(
    input_dim=INPUT_DIM,
    embed_dim=EMBED_DIM,
    hidden_dim=NUM_HIDDEN,
    num_layers=N_LAYERS,
    dropout=DROPOUT,
    vocab=enc_vocab,
    padding_idx=ENC_PADDING_IDX
).to(device)

decoder = Decoder(
    output_dim=OUTPUT_DIM,
    embed_dim=EMBED_DIM,
    hidden_dim=NUM_HIDDEN,
    num_layers=N_LAYERS,
    dropout=DROPOUT,
    vocab=dec_vocab,
    padding_idx=DEC_PADDING_IDX
).to(device)

In [18]:
best_loss = float('inf')

model = Seq2Seq(
    encoder=encoder,
    decoder=decoder,
    device=device,
    max_sent_size=MAX_SENT_SIZE,
).to(device)

In [19]:
from src.models.utils import count_parameters

print(f"number of parameters in model: {count_parameters(model)//1e6}M")

number of parameters in model: 32.0M


In [20]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = torch.nn.CrossEntropyLoss(ignore_index=dec_vocab['<pad>'])

In [21]:
from src.models.train_model import train

best_loss = train(
    model=model,
    loaders=(train_dataloader, val_dataloader),
    optimizer=optimizer,
    criterion=criterion,
    epochs=20,
    device=device,
    best_loss=best_loss,
    ckpt_path=MODEL_CHECKPOINT,
    clip_grad=1,
    teacher_force={
        'value': 0.85,
        'gamma': 1.0,
        'update_every_n_epoch': 10,
    } # first 10 epoch teacher force 1, after it will be turned off
)

Training 1: 100%|██████████| 1233/1233 [01:09<00:00, 17.63it/s, loss=4.57]
Evaluating 1: 100%|██████████| 79/79 [00:01<00:00, 58.37it/s, loss=4.34]
Training 2: 100%|██████████| 1233/1233 [01:09<00:00, 17.69it/s, loss=3.68]
Evaluating 2: 100%|██████████| 79/79 [00:01<00:00, 58.08it/s, loss=4.04]
Training 3: 100%|██████████| 1233/1233 [01:09<00:00, 17.65it/s, loss=3.28]
Evaluating 3: 100%|██████████| 79/79 [00:01<00:00, 57.33it/s, loss=3.95]
Training 4: 100%|██████████| 1233/1233 [01:09<00:00, 17.64it/s, loss=3.06]
Evaluating 4: 100%|██████████| 79/79 [00:01<00:00, 58.74it/s, loss=3.91]
Training 5: 100%|██████████| 1233/1233 [01:09<00:00, 17.66it/s, loss=2.9] 
Evaluating 5: 100%|██████████| 79/79 [00:01<00:00, 56.57it/s, loss=3.84]
Training 6: 100%|██████████| 1233/1233 [01:10<00:00, 17.60it/s, loss=2.77]
Evaluating 6: 100%|██████████| 79/79 [00:01<00:00, 58.56it/s, loss=3.81]
Training 7: 100%|██████████| 1233/1233 [01:09<00:00, 17.65it/s, loss=2.66]
Evaluating 7: 100%|██████████| 79/79 

Update teacher force to 0.85



Training 11: 100%|██████████| 1233/1233 [01:10<00:00, 17.57it/s, loss=2.34]
Evaluating 11: 100%|██████████| 79/79 [00:01<00:00, 57.33it/s, loss=3.77]
Training 12: 100%|██████████| 1233/1233 [01:09<00:00, 17.64it/s, loss=2.27]
Evaluating 12: 100%|██████████| 79/79 [00:01<00:00, 57.56it/s, loss=3.81]
Training 13: 100%|██████████| 1233/1233 [01:09<00:00, 17.64it/s, loss=2.22]
Evaluating 13: 100%|██████████| 79/79 [00:01<00:00, 56.15it/s, loss=3.84]
Training 14: 100%|██████████| 1233/1233 [01:09<00:00, 17.69it/s, loss=2.17]
Evaluating 14: 100%|██████████| 79/79 [00:01<00:00, 57.70it/s, loss=3.82]
Training 15: 100%|██████████| 1233/1233 [01:10<00:00, 17.60it/s, loss=2.13]
Evaluating 15: 100%|██████████| 79/79 [00:01<00:00, 57.62it/s, loss=3.86]
Training 16: 100%|██████████| 1233/1233 [01:10<00:00, 17.60it/s, loss=2.08]
Evaluating 16: 100%|██████████| 79/79 [00:01<00:00, 56.44it/s, loss=3.87]
Training 17: 100%|██████████| 1233/1233 [01:10<00:00, 17.60it/s, loss=2.04]
Evaluating 17: 100%|███

Update teacher force to 0.85





In [22]:
# let's load the model and predict
model = torch.load(MODEL_CHECKPOINT)
model.to(device)
model.eval()

Seq2Seq(
  (encoder): Encoder(
    (vocab): Vocab()
    (embedding): Embedding(8000, 256, padding_idx=1)
    (rnn): LSTM(256, 512, num_layers=6, batch_first=True, dropout=0.5)
    (dropout): Dropout(p=0.5, inplace=False)
  )
  (decoder): Decoder(
    (vocab): Vocab()
    (embedding): Embedding(8000, 256, padding_idx=1)
    (rnn): LSTM(256, 512, num_layers=6, batch_first=True, dropout=0.5)
    (fc_out): Linear(in_features=512, out_features=8000, bias=True)
    (dropout): Dropout(p=0.5, inplace=False)
  )
)

In [23]:
from nltk.tokenize.treebank import TreebankWordDetokenizer
detokenizer = TreebankWordDetokenizer()

# let's see how our model works
num_examples = 10
num_sentence = 3
dataset = val_dataset
for idx in range(num_examples):
    idx = np.random.randint(0, len(dataset))
    toxic_sent = detokenizer.detokenize(dataset.df.loc[idx, 'toxic_sent'])
    neutral_sent = detokenizer.detokenize(dataset.df.loc[idx, 'neutral_sent'])
    
    print('toxic_sent:', toxic_sent)
    print('neutral_sent:', neutral_sent)
    
    # let's use beam search
    # i turned off postprocess_text on purpose 
    # to see everything (postprocess_text removes some tokens and detokenize the sentence)
    preds = model.predict(
        toxic_sent,
        beam=True,
        beam_search_num_candidates=num_sentence,
        post_process_text=False
    )
    print("predictions:")
    for i in range(num_sentence):
        print(f"\t{i+1})", preds[i])
    print("\n")

toxic_sent: let me tell the damn story now!
neutral_sent: I am telling this story!
predictions:
	1) ['give', 'me', 'the', 'keys', '!', '<eos>']
	2) ['give', 'me', 'the', 'money', '!', '<eos>']
	3) ['give', 'me', 'the', 'gun', '!', '<eos>']


toxic_sent: you fucking believe this?
neutral_sent: do you believe that?
predictions:
	1) ['do', 'you', 'believe', 'it', '?', '<eos>']
	2) ['can', 'you', 'believe', 'it', '?', '<eos>']
	3) ['do', 'you', 'believe', 'this', '?', '<eos>']


toxic_sent: shut up, bean paste!
neutral_sent: stop talking, bean paste!
predictions:
	1) ['quiet', ',', '<unk>', '<unk>', '!', '<eos>']
	2) ['quiet', ',', '<unk>', '!', '<eos>']
	3) ['hush', ',', '<unk>', '<unk>', '!', '<eos>']


toxic_sent: you swear to god? you crazy?
neutral_sent: you swear by god?
predictions:
	1) ['are', 'you', 'kidding', 'me', '?', '<eos>']
	2) ['you', 'are', 'kidding', 'me', '?', '<eos>']
	3) ['are', 'you', 'kidding', 'me', 'too', '?', '<eos>']


toxic_sent: get the hell out of here!
neutra

## Beam Search vs Greedy Search

In [24]:
from nltk.tokenize.treebank import TreebankWordDetokenizer
detokenizer = TreebankWordDetokenizer()

# let's see how our model works
num_examples = 10
num_sentence = 3
dataset = val_dataset
for idx in range(num_examples):
    idx = np.random.randint(0, len(dataset))
    toxic_sent = detokenizer.detokenize(dataset.df.loc[idx, 'toxic_sent'])
    neutral_sent = detokenizer.detokenize(dataset.df.loc[idx, 'neutral_sent'])
    
    print('toxic_sent:', toxic_sent)
    print('neutral_sent:', neutral_sent)
    
    preds = model.predict(
        toxic_sent,
        beam=True,
        beam_search_num_candidates=num_sentence,
        post_process_text=False
    )
    print("Beam Search predictions:")
    for i in range(num_sentence):
        print(f"\t{i+1})", preds[i])
    print("\n")
    
    
    preds = model.predict(
        toxic_sent,
        beam=False,
        post_process_text=False
    )
    print("Greedy Search prediction:")
    print(preds)
    print("\n")

toxic_sent: you scream like a hen.
neutral_sent: you sound like you are howling.
Beam Search predictions:
	1) ['you', 'smell', 'like', 'a', '<unk>', '.', '<eos>']
	2) ['you', 'smell', 'like', 'cattle', '.', '<eos>']
	3) ['you', 'smell', 'like', 'a', 'girl', '.', '<eos>']


Greedy Search prediction:
['you', 'smell', 'like', 'a', '<unk>', '.']


toxic_sent: my mother is on die.
neutral_sent: my mom i dying.
Beam Search predictions:
	1) ['my', 'father', 'is', 'dying', '.', '<eos>']
	2) ['my', 'wife', 'is', 'dying', '.', '<eos>']
	3) ['my', 'mother', 'is', 'dying', '.', '<eos>']


Greedy Search prediction:
['my', 'father', 'is', 'dying', '.']


toxic_sent: shut your mouth!
neutral_sent: close your mouth.
Beam Search predictions:
	1) ['close', 'your', 'mouth', '!', '<eos>']
	2) ['keep', 'your', 'mouth', '!', '<eos>']
	3) ['open', 'your', 'mouth', '!', '<eos>']


Greedy Search prediction:
['close', 'your', 'mouth', '!']


toxic_sent: she makin' you crazy?
neutral_sent: is he driving you craz

### Actually Greedy search doing great job