In [1]:
LOAD_MODEL = 'models/transformer2.01.pt'
MODEL_CHECKPOINT = 'models/transformer2.01.pt'
DATASET_PATH = 'data/interim/preprocessed_paranmt3.tsv'

In [2]:
import numpy as np

from torch.utils.data import DataLoader
import torch
import torch.nn as nn
from torch.optim import Adam

import os
os.chdir("..") # go to the root dir

## Get the Dataset

In [3]:
MAX_SENT_SIZE = 32
MAX_TOKENS = 25_000

In [4]:
from src.data.make_dataset import ParanmtDataset

train_dataset = ParanmtDataset(
    path=DATASET_PATH,
    max_sent_size=MAX_SENT_SIZE,
    train=True,
    seed=42,
)

In [5]:
train_dataset.build_vocab(
    min_freq=2,
    specials=['<unk>', '<pad>', '<sos>', '<eos>'],
    max_tokens=MAX_TOKENS,
)

In [6]:
enc_vocab = train_dataset.toxic_vocab
dec_vocab = train_dataset.neutral_vocab

In [7]:
print("size of encoder vocab:", len(enc_vocab))
print("size of decoder vocab:", len(dec_vocab))

size of encoder vocab: 25000
size of decoder vocab: 25000


In [8]:
val_dataset = ParanmtDataset(
    path=DATASET_PATH,
    max_sent_size=MAX_SENT_SIZE,
    vocabs=(enc_vocab, dec_vocab), # avoid data leakage
    train=False,
    seed=42,
    take_first=10_000,
)

In [9]:
len(train_dataset), len(val_dataset)

(499273, 10000)

## Let's create Dataloader

In [10]:
batch_size = 256

In [11]:
train_dataloader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
)

val_dataloader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
)

In [12]:
# let's check if shape and everything is ok
for batch in train_dataloader:
    toxic_sent, neutral_sent = batch
    print("toxic_sent.shape:", toxic_sent.shape)
    print("neutral_sent.shape:", neutral_sent.shape)
    break

toxic_sent.shape: torch.Size([256, 32])
neutral_sent.shape: torch.Size([256, 32])


In [13]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

# Load the Model

- Transformer architerture

In [14]:
from src.models.transformer.encoder import Encoder
from src.models.transformer.decoder import Decoder
from src.models.transformer import Transformer

In [15]:
# configure some parameters for the model
heads = 4
hidden_dim = 256
ff_expantion = 4
max_size = MAX_SENT_SIZE

## Encoder
enc_input_dim = len(enc_vocab)
enc_dropout = 0.1
enc_num_layers = 3
enc_padding_idx = enc_vocab['<pad>']

## Decoder
dec_output_dim = len(dec_vocab)
dec_dropout = 0.1
dec_num_layers = 3
dec_padding_idx = dec_vocab['<pad>']

In [16]:
# load the encoder and decoder for our model
encoder = Encoder(
    input_dim=enc_input_dim,
    hidden_dim=hidden_dim,
    num_layers=enc_num_layers,
    heads=heads,
    ff_expantion=ff_expantion,
    dropout=enc_dropout,
    device=device,
    max_size=max_size,
    vocab=enc_vocab,
).to(device)

decoder = Decoder(
    output_dim=dec_output_dim,
    hidden_dim=hidden_dim,
    num_layers=dec_num_layers,
    heads=heads,
    ff_expantion=ff_expantion,
    dropout=dec_dropout,
    device=device,
    max_size=max_size,
    vocab=dec_vocab,
).to(device)



In [17]:
best_loss = float('inf')

model = Transformer(
    encoder=encoder,
    decoder=decoder,
    device=device,
    max_sent_size=MAX_SENT_SIZE,
).to(device)

In [18]:
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
criterion = torch.nn.CrossEntropyLoss(ignore_index=decoder.padding_idx)

In [19]:
from src.models.train_model import train

best_loss = train(
    model=model,
    loaders=(train_dataloader, val_dataloader),
    optimizer=optimizer,
    criterion=criterion,
    epochs=10,
    device=device,
    best_loss=best_loss,
    ckpt_path=MODEL_CHECKPOINT,
)

Training 1: 100%|██████████| 1951/1951 [04:33<00:00,  7.14it/s, loss=4.9] 
Evaluating 1: 100%|██████████| 40/40 [00:02<00:00, 18.86it/s, loss=4.03]
Training 2: 100%|██████████| 1951/1951 [04:30<00:00,  7.21it/s, loss=3.96]
Evaluating 2: 100%|██████████| 40/40 [00:02<00:00, 19.01it/s, loss=3.68]
Training 3: 100%|██████████| 1951/1951 [04:31<00:00,  7.19it/s, loss=3.69]
Evaluating 3: 100%|██████████| 40/40 [00:02<00:00, 19.30it/s, loss=3.46]
Training 4: 100%|██████████| 1951/1951 [04:33<00:00,  7.14it/s, loss=3.51]
Evaluating 4: 100%|██████████| 40/40 [00:02<00:00, 18.74it/s, loss=3.29]
Training 5: 100%|██████████| 1951/1951 [04:30<00:00,  7.20it/s, loss=3.37]
Evaluating 5: 100%|██████████| 40/40 [00:02<00:00, 18.40it/s, loss=3.18]
Training 6: 100%|██████████| 1951/1951 [04:31<00:00,  7.18it/s, loss=3.26]
Evaluating 6: 100%|██████████| 40/40 [00:02<00:00, 18.56it/s, loss=3.07]
Training 7: 100%|██████████| 1951/1951 [04:31<00:00,  7.19it/s, loss=3.16]
Evaluating 7: 100%|██████████| 40/40 

In [20]:
# let's load the model and predict
model = torch.load(MODEL_CHECKPOINT)
model.to(device)
model.eval()
None

In [25]:
from nltk.tokenize.treebank import TreebankWordDetokenizer
detokenizer = TreebankWordDetokenizer()

# let's see how our model works
num_examples = 10
df = val_dataset
for _ in range(num_examples):
    idx = np.random.randint(0, len(df))
    toxic_sent = detokenizer.detokenize(df.df.loc[idx, 'toxic_sent'])
    neutral_sent = detokenizer.detokenize(df.df.loc[idx, 'neutral_sent'])

    print('toxic_sent:', toxic_sent)
    print('neutral_sent:', neutral_sent)
    print('prediction:', model.predict(toxic_sent, post_process_text=True, use_beam_search=True)[0])
    print("\n")

toxic_sent: see her die a thousand ways.
neutral_sent: and seeing her die in thousands of ways.
prediction: see her death.


toxic_sent: with explosives around my ankles, ready to explode.
neutral_sent: explosives still around their ankles, still ready to explode.
prediction: with my, ready to explode.


toxic_sent: fucked up my life.
neutral_sent: he messed up his life.
prediction: i screwed up my life.


toxic_sent: and he is the son of death.
neutral_sent: he is a dead man.
prediction: and he is the son of death.


toxic_sent: they said if i told anyone, they would kill zak.
neutral_sent: they said that if we told anybody that they said they would hurt zach.
prediction: they said if i told anyone, they would have killed.


toxic_sent: but now you are completely crazy.
neutral_sent: now i know you are completely mad.
prediction: but now you are completely mad.


toxic_sent: his alibi is bullshit.
neutral_sent: his alibi is bogus.
prediction: his is nonsense.


toxic_sent: damn i i sh

In [84]:
from torchtext.data.metrics import bleu_score
from tqdm import tqdm

def calculate_bleu(dataset, model):
    preds = []
    trgs = []
    with torch.no_grad():
        for i in tqdm(range(len(dataset))):
            toxic_sent, neutral_sent = dataset[i]
            toxic_sent = toxic_sent.to(model.device).unsqueeze(0)
            pred = model.predict(toxic_sent, post_process_text=False)
            
            pred = pred[1:-1] # remove <sos> and <eos>
            
            neutral_sent = model.decoder.vocab.lookup_tokens(neutral_sent.numpy())
            neutral_sent = neutral_sent[1:] # remove <sos>
            neutral_sent = neutral_sent[:neutral_sent.index('<eos>')]
            
            preds.append(pred)
            trgs.append([neutral_sent])
        
    return bleu_score(preds, trgs)

In [85]:
calculate_bleu(val_dataset, model)

100%|██████████| 10000/10000 [04:58<00:00, 33.46it/s]


0.19219723264416075