In [1]:
LOAD_MODEL = 'models/attention2.01.pt'
MODEL_CHECKPOINT = 'models/attention2.01.pt'
DATASET_PATH = 'data/interim/preprocessed_paranmt.tsv'

In [2]:
import torch
import numpy as np
from torch.utils.data import DataLoader

import os
os.chdir("..") # go to the root dir

# Get the Dataset

In [3]:
MAX_SENT_SIZE = 32
MAX_TOKENS = 10_000

In [4]:
from src.data.make_dataset import ParanmtDataset

train_dataset = ParanmtDataset(
    path=DATASET_PATH,
    max_sent_size=MAX_SENT_SIZE,
    train=True,
    seed=42,
)

In [5]:
train_dataset.build_vocab(
    min_freq=2,
    specials=['<unk>', '<pad>', '<sos>', '<eos>'],
    max_tokens=MAX_TOKENS,
)

In [6]:
enc_vocab = train_dataset.toxic_vocab
dec_vocab = train_dataset.neutral_vocab

print("size of encoder vocab:", len(enc_vocab))
print("size of decoder vocab:", len(dec_vocab))

size of encoder vocab: 10000
size of decoder vocab: 10000


In [7]:
val_dataset = ParanmtDataset(
    path=DATASET_PATH,
    max_sent_size=MAX_SENT_SIZE,
    vocabs=(enc_vocab, dec_vocab), # avoid data leakage
    train=False,
    seed=42,
    take_first=10_000,
)

In [8]:
train_dataset.df

Unnamed: 0,similarity,lenght_diff,toxic_sent,neutral_sent,toxic_val,neutral_val
0,0.811567,0.179487,"[you, know, i, hate, that, health, food, shit, .]","[you, know, i, hate, a, healthy, diet, .]",0.999437,0.000569
1,0.883822,0.250000,"[what, the, hell, is, going, on, here, ?]","[what, is, going, on, there, ?]",0.877907,0.000041
2,0.769068,0.303030,"[she, tried, to, kill, her, own, father, with,...","[however, ,, mike, ,, she, tried, to, beat, hi...",0.966588,0.024886
3,0.823836,0.157895,"[have, a, shitty, day, .]","[have, a, bad, day, .]",0.996943,0.000633
4,0.670003,0.320513,"[you, ever, think, of, screaming, instead, of,...","[did, it, ever, occur, to, you, to, scream, yo...",0.999311,0.011481
...,...,...,...,...,...,...
470047,0.945723,0.173077,"[I, would, slap, you, even, if, mala, does, no...","[i, would, have, slapped, you, even, if, mala,...",0.987526,0.196128
470048,0.767978,0.272727,"[death, to, the, al, fayed, !, (, grunts, )]","[the, death, of, al, fayed, !]",0.997817,0.000219
470049,0.766673,0.068966,"[i, think, he, is, manure, ,, wolf, .]","[i, think, he, is, buggered, ,, wolf, .]",0.970698,0.000387
470050,0.776357,0.173913,"[can, not, even, take, care, of, your, own, go...","[can, not, you, even, take, care, of, your, so...",0.999640,0.000586


In [9]:
len(train_dataset), len(val_dataset)

(470052, 10000)

# Build the Dataloaders

In [10]:
batch_size = 128

In [11]:
train_dataloader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=4,
)

val_dataloader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=4,
)

In [12]:
# let's check if shape and everything is ok
for batch in train_dataloader:
    toxic_sent, neutral_sent = batch
    print("toxic_sent.shape:", toxic_sent.shape)
    print("neutral_sent.shape:", neutral_sent.shape)
    break

toxic_sent.shape: torch.Size([128, 32])
neutral_sent.shape: torch.Size([128, 32])


In [13]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

# Load the Model

- EncoderDecoder (Seq2Seq) with Attention mechanism

In [14]:
from src.models.attention.encoder import Encoder
from src.models.attention.decoder import Decoder
from src.models.attention.attention import Attention
from src.models.attention import Seq2SeqAttention

In [15]:
## Encoder
enc_input_dim = len(enc_vocab)
enc_embed_dim = 128
enc_hidden_dim = 256
enc_dropout = 0.5
enc_padding_idx = enc_vocab['<pad>']

## Decoder
dec_output_dim = len(dec_vocab)
dec_embed_dim = 128
dec_hidden_dim = 256
dec_dropout = 0.5
dec_padding_idx = dec_vocab['<pad>']

In [27]:
# load the encoder and decoder for our model
encoder = Encoder(
    input_dim=enc_input_dim,
    embed_dim=enc_embed_dim,
    hidden_dim=enc_hidden_dim,
    dec_hidden_dim=dec_hidden_dim,
    dropout=enc_dropout,
    vocab=enc_vocab,
    padding_idx=enc_padding_idx,
).to(device)

attention = Attention(
    enc_hidden_dim,
    dec_hidden_dim,
)

decoder = Decoder(
    output_dim=dec_output_dim,
    embed_dim=dec_embed_dim,
    hidden_dim=dec_hidden_dim,
    attention=attention,
    enc_hidden_dim=enc_hidden_dim,
    dropout=dec_dropout,
    vocab=dec_vocab,
    padding_idx=dec_padding_idx,
).to(device)

In [28]:
best_loss = float('inf')

model = Seq2SeqAttention(
    encoder=encoder,
    decoder=decoder,
    device=device,
    max_sent_size=MAX_SENT_SIZE,
).to(device)

In [29]:
from src.models.utils import count_parameters

print(f"number of parameters in model: {count_parameters(model)//1e6}M")

number of parameters in model: 13.0M


In [30]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
criterion = torch.nn.CrossEntropyLoss(ignore_index=dec_vocab['<pad>'])

In [31]:
from src.models.train_model import train

best_loss = train(
    model=model,
    loaders=(train_dataloader, val_dataloader),
    optimizer=optimizer,
    criterion=criterion,
    epochs=20,
    device=device,
    best_loss=best_loss,
    ckpt_path=MODEL_CHECKPOINT,
    clip_grad=1,
    teacher_force={
        'value': 0.95,
        'gamma': 1.0,
        'update_every_n_epoch': 50,
    } # first 10 epoch teacher force 1, after it will be turned off
)

Training 1: 100%|██████████| 3673/3673 [05:26<00:00, 11.25it/s, loss=3.29]
Evaluating 1: 100%|██████████| 79/79 [00:02<00:00, 34.17it/s, loss=5.2] 
Training 2: 100%|██████████| 3673/3673 [05:26<00:00, 11.25it/s, loss=2.6] 
Evaluating 2: 100%|██████████| 79/79 [00:02<00:00, 33.88it/s, loss=5.29]
Training 3: 100%|██████████| 3673/3673 [05:26<00:00, 11.24it/s, loss=2.42]
Evaluating 3: 100%|██████████| 79/79 [00:02<00:00, 33.99it/s, loss=5.31]
Training 4: 100%|██████████| 3673/3673 [05:27<00:00, 11.21it/s, loss=2.31]
Evaluating 4: 100%|██████████| 79/79 [00:02<00:00, 34.07it/s, loss=5.27]
Training 5: 100%|██████████| 3673/3673 [05:27<00:00, 11.21it/s, loss=2.22]
Evaluating 5: 100%|██████████| 79/79 [00:02<00:00, 33.92it/s, loss=5.31]
Training 6: 100%|██████████| 3673/3673 [05:26<00:00, 11.23it/s, loss=2.16]
Evaluating 6: 100%|██████████| 79/79 [00:02<00:00, 33.61it/s, loss=5.27]
Training 7: 100%|██████████| 3673/3673 [05:27<00:00, 11.20it/s, loss=2.11]
Evaluating 7: 100%|██████████| 79/79 

In [None]:
torch.save(model, 'models/attention2.02.pt')

In [32]:
from nltk.tokenize.treebank import TreebankWordDetokenizer
detokenizer = TreebankWordDetokenizer()

# let's see how our model works
num_examples = 10
num_sentence = 3
dataset = train_dataset
for _ in range(num_examples):
    idx = np.random.randint(0, len(dataset))
    toxic_sent = detokenizer.detokenize(dataset.df.loc[idx, 'toxic_sent'])
    neutral_sent = detokenizer.detokenize(dataset.df.loc[idx, 'neutral_sent'])
    
    print('toxic_sent:', toxic_sent)
    print('neutral_sent:', neutral_sent)
    preds = model.predict(
        toxic_sent,
        beam=True,
        beam_search_num_candidates=num_sentence,
        post_process_text=False,
    ) # let's use beam search
    print("predictions:")
    for i in range(num_sentence):
        print(f"\t{i+1})", preds[i])
    print("\n")

toxic_sent: I am going to start running a breast artery.
neutral_sent: I will start the harvest of the mammary artery.
predictions:
	1) ['I', 'am', 'going', 'to', 'start', 'running', 'a', 'breast', 'artery', '.', '<eos>']
	2) ['I', 'am', 'going', 'to', 'start', 'running', 'a', '<unk>', 'artery', '.', '<eos>']
	3) ['I', 'am', 'going', 'to', 'start', 'running', 'the', 'breast', 'artery', '.', '<eos>']


toxic_sent: stop playing dumb with us!
neutral_sent: stop playing mute with us!
predictions:
	1) ['stop', 'playing', 'games', 'with', 'us', '!', '<eos>']
	2) ['stop', 'messing', 'with', 'us', '!', '<eos>']
	3) ['stop', 'playing', 'with', 'us', '!', '<eos>']


toxic_sent: damn, i broke the door.
neutral_sent: i broke the frame.
predictions:
	1) ['hell', ',', 'i', 'broke', 'the', 'door', '.', '<eos>']
	2) ['i', 'broke', 'the', 'door', '.', '<eos>']
	3) ['hell', ',', 'i', 'broke', 'my', 'door', '.', '<eos>']


toxic_sent: everybody shut up!
neutral_sent: come on, silence all!
predictions:
	1

In [38]:
# let's load the model and predict
model = torch.load(MODEL_CHECKPOINT)
model.to(device)
model.eval()

Seq2SeqAttention(
  (encoder): Encoder(
    (vocab): Vocab()
    (embedding): Embedding(10000, 128, padding_idx=1)
    (rnn): LSTM(128, 256, batch_first=True, bidirectional=True)
    (fc_hidden): Linear(in_features=512, out_features=256, bias=True)
    (fc_cell): Linear(in_features=512, out_features=256, bias=True)
    (dropout): Dropout(p=0.5, inplace=False)
  )
  (decoder): Decoder(
    (vocab): Vocab()
    (embedding): Embedding(10000, 128, padding_idx=1)
    (rnn): LSTM(640, 256, batch_first=True)
    (attention): Attention(
      (attn): Linear(in_features=768, out_features=256, bias=True)
      (v): Linear(in_features=256, out_features=1, bias=False)
    )
    (dropout): Dropout(p=0.5, inplace=False)
    (fc_out): Linear(in_features=896, out_features=10000, bias=True)
  )
)

In [39]:
from nltk.tokenize.treebank import TreebankWordDetokenizer
detokenizer = TreebankWordDetokenizer()

# let's see how our model works
num_examples = 10
num_sentence = 3
dataset = train_dataset
for _ in range(num_examples):
    idx = np.random.randint(0, len(dataset))
    toxic_sent = detokenizer.detokenize(dataset.df.loc[idx, 'toxic_sent'])
    neutral_sent = detokenizer.detokenize(dataset.df.loc[idx, 'neutral_sent'])
    
    print('toxic_sent:', toxic_sent)
    print('neutral_sent:', neutral_sent)
    preds = model.predict(
        toxic_sent,
        beam=True,
        beam_search_num_candidates=num_sentence,
        post_process_text=False,
    ) # let's use beam search
    print("predictions:")
    for i in range(num_sentence):
        print(f"\t{i+1})", preds[i])
    print("\n")

toxic_sent: i will collect your heads.
neutral_sent: I will get your head.
predictions:
	1) ['I', 'will', 'take', 'your', 'heads', '.', '<eos>']
	2) ['I', 'will', 'be', 'your', 'heads', '.', '<eos>']
	3) ['I', 'am', 'going', 'to', 'take', 'your', 'heads', '.', '<eos>']


toxic_sent: leave one ship and leave, or you will die.
neutral_sent: i leave a vessel and get out, or muri i here.
predictions:
	1) ['leave', 'one', 'ship', 'and', 'leave', ',', 'or', 'you', 'will', 'die', '.', '<eos>']
	2) ['leave', 'the', 'ship', 'and', 'leave', ',', 'or', 'you', 'will', 'die', '.', '<eos>']
	3) ['leave', 'one', 'ship', ',', 'leave', ',', 'or', 'you', 'will', 'die', '.', '<eos>']


toxic_sent: well, it is going to be both our asses if you are wrong.
neutral_sent: if you are wrong, it is going to cost us both.
predictions:
	1) ['well', ',', 'it', 'is', 'going', 'to', 'be', 'both', 'if', 'you', 'are', 'wrong', '.', '<eos>']
	2) ['well', ',', 'it', 'will', 'be', 'both', 'if', 'you', 'are', 'wrong', '.',