In [1]:
import torch
from torch.utils.data import DataLoader

import os
os.chdir("..") # go to the root dir

# Get the Dataset

In [4]:
MAX_SENT_SIZE = 12
MAX_TOKENS = 30_000

In [5]:
from src.data.make_dataset import ParanmtDataset

train_dataset = ParanmtDataset(
    path='data/interim/preprocessed_paranmt.tsv',
    max_sent_size=MAX_SENT_SIZE,
    train=True,
    seed=42,
)

In [6]:
train_dataset.build_vocab(
    min_freq=2,
    specials=['<unk>', '<pad>', '<sos>', '<eos>'],
    max_tokens=MAX_TOKENS,
)

In [17]:
vocab = train_dataset.vocab

In [18]:
len(vocab)

29878

In [8]:
val_dataset = ParanmtDataset(
    path='data/interim/preprocessed_paranmt.tsv',
    max_sent_size=MAX_SENT_SIZE,
    vocab=vocab, # avoid data leakage
    train=False,
    seed=42,
)

In [9]:
train_dataset.df

Unnamed: 0,similarity,lenght_diff,toxic_sent,neutral_sent,toxic_val,neutral_val
0,0.699613,0.151515,"[they, simply, hit, the, ground, dead, .]","[they, just, died, on, the, spot, .]",0.985227,0.000696
1,0.736382,0.166667,"[why, not, ,, what, the, hell, .]","[after, all, ,, why, not, .]",0.886357,0.000042
2,0.716637,0.032258,"[cigarettes, and, beer, kick, ass, .]","[cigarettes, and, beer, are, great, !]",0.997185,0.000066
3,0.623832,0.111111,"[half, of, dodds, ', breasts, disappeared, .]","[half, of, dodd, 's, chest, dissolved, .]",0.981391,0.002114
4,0.834294,0.227273,"[flew, out, of, nigeria, ,, crashed, here, .]","[he, 's, taken, out, of, nigeria, and, crashed...",0.643518,0.001495
...,...,...,...,...,...,...
234587,0.943562,0.307692,"[you, have, got, to, be, fucking, kidding, me, .]","[you, have, to, be, kidding, me, .]",0.991443,0.000125
234588,0.931757,0.043478,"[it, 's, kiki, ,, the, witch, .]","[that, 's, kiki, the, witch, .]",0.976002,0.015735
234589,0.840842,0.263158,"[take, the, dog, and, hit, it, with, a, brick, .]","[grab, the, dog, to, hit, a, brick]",0.993978,0.053225
234590,0.891694,0.055556,"[we, all, got, ta, die, sometime, ,, right, ?]","[we, 're, all, gon, na, die, someday, ,, right...",0.963766,0.094380


## Build the Dataloaders

In [10]:
batch_size = 16

In [11]:
from src.data.make_dataset import collate_batch

train_dataloader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate_batch,
)

val_dataloader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=collate_batch,
)

In [12]:
# let's check if shape and everything is ok
for batch in train_dataloader:
    toxic_sent, neutral_sent = batch
    print("toxic_sent.shape:", toxic_sent.shape)
    print("neutral_sent.shape:", neutral_sent.shape)
    break

toxic_sent.shape: torch.Size([12, 16])
neutral_sent.shape: torch.Size([12, 16])


In [13]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

# Load the Model

- Simple EncoderDecoder (Seq2Seq) architerture

In [14]:
from src.models.seq2seq.encoder import Encoder
from src.models.seq2seq.decoder import Decoder
from src.models.seq2seq import Seq2Seq

In [19]:
INPUT_DIM = len(vocab)
OUTPUT_DIM = len(vocab)
EMBED_DIM = 128
NUM_HIDDEN = 256
N_LAYERS = 2
DROPOUT = 0.5
PADDING_IDX = vocab['<pad>']

In [20]:
# load the encoder and decoder for our model
encoder = Encoder(
    input_dim=INPUT_DIM,
    embed_dim=EMBED_DIM,
    hidden_dim=NUM_HIDDEN,
    num_layers=N_LAYERS,
    dropout=DROPOUT,
    padding_idx=PADDING_IDX
).to(device)

decoder = Decoder(
    output_dim=OUTPUT_DIM,
    embed_dim=EMBED_DIM,
    hidden_dim=NUM_HIDDEN,
    num_layers=N_LAYERS,
    dropout=DROPOUT,
    padding_idx=PADDING_IDX
).to(device)

In [21]:
best_loss = float('inf')

model = Seq2Seq(
    encoder=encoder,
    decoder=decoder,
    device=device,
    max_sent_size=MAX_SENT_SIZE,
    vocab=vocab,
).to(device)

In [23]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = torch.nn.CrossEntropyLoss(ignore_index=vocab['<pad>'])

In [22]:
from src.models.train_model import train

best_loss = train(
    model=model,
    loaders=(train_dataloader, val_dataloader),
    optimizer=optimizer,
    criterion=criterion,
    epochs=20,
    device=device,
    best_loss=best_loss,
    ckpt_path='models/seq2seq.pt',
    clip_grad=1,
)

Training 1: 100%|██████████| 13033/13033 [04:20<00:00, 49.94it/s, loss=4.07]
Evaluating 1: 100%|██████████| 3259/3259 [00:16<00:00, 195.14it/s, loss=3.94]
Training 2: 100%|██████████| 13033/13033 [04:19<00:00, 50.29it/s, loss=3.57]
Evaluating 2: 100%|██████████| 3259/3259 [00:16<00:00, 192.94it/s, loss=3.81]
Training 3: 100%|██████████| 13033/13033 [04:20<00:00, 50.09it/s, loss=3.4]
Evaluating 3: 100%|██████████| 3259/3259 [00:16<00:00, 191.94it/s, loss=3.69]
Training 4: 100%|██████████| 13033/13033 [04:22<00:00, 49.73it/s, loss=3.29]
Evaluating 4: 100%|██████████| 3259/3259 [00:16<00:00, 200.23it/s, loss=3.64]
Training 5: 100%|██████████| 13033/13033 [04:22<00:00, 49.65it/s, loss=3.21]
Evaluating 5: 100%|██████████| 3259/3259 [00:17<00:00, 191.55it/s, loss=3.61]
Training 6: 100%|██████████| 13033/13033 [04:21<00:00, 49.92it/s, loss=3.14]
Evaluating 6: 100%|██████████| 3259/3259 [00:16<00:00, 195.08it/s, loss=3.58]
Training 7: 100%|██████████| 13033/13033 [04:21<00:00, 49.80it/s, loss=

In [21]:
# let's load the model and predict
model = torch.load('models/seq2seq.pt')

In [40]:
from nltk.tokenize.treebank import TreebankWordDetokenizer
detokenizer = TreebankWordDetokenizer()

# let's see how our model works
num_examples = 10
for _ in range(num_examples):
    idx = val_idx[np.random.randint(0, len(val_idx))]
    toxic_sent = detokenizer.detokenize(df.loc[val_idx, 'toxic_sent'][idx])
    neutral_sent = detokenizer.detokenize(df.loc[val_idx, 'neutral_sent'][idx])
    
    print('toxic_sent:', toxic_sent)
    print('neutral_sent:', neutral_sent)
    print('prediction:', model.predict(toxic_sent))
    print("\n")

toxic_sent: well, i mean, honesly mostly schmucks.
neutral_sent: well...mostly dummies.
prediction: well, i mean, well.


toxic_sent: they're tragic, not ridiculous.
neutral_sent: tragic, never comic.
prediction: they're not funny.


toxic_sent: so nelson was crazy, like you said.
neutral_sent: so nelson was a control freak like you said.
prediction: so crazy was crazy, you crazy.


toxic_sent: i mean, this is retarded.
neutral_sent: i mean, this is crazy.
prediction: i mean, this is mean.


toxic_sent: it's useless!
neutral_sent: this is futile!
prediction: it's no use!


toxic_sent: cunning fox, this ernie allen.
neutral_sent: a clever old fox, ernie allen.
prediction: fox fox, fox fox.


toxic_sent: fuck! let me out of here!
neutral_sent: get me out of here!
prediction: let me out of here!


toxic_sent: i'm going to arrest you.
neutral_sent: here to arrest you.
prediction: i'll arrest you.


toxic_sent: grotesque as i promised.
neutral_sent: grotesque, as promised . - okay.
predicti