<a href="https://colab.research.google.com/github/Tuan-Lee-23/Deep-learning-with-Pytorch/blob/main/Transformer_chatbot_v2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Install libs

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
%cd drive/MyDrive/Project/Chatbot transformer/

/content/drive/MyDrive/Project/Chatbot transformer


In [None]:
!pip install -q wandb 
!pip install -q PyYAML

In [None]:
from collections import Counter
import json
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import math
import torch.nn.functional as F
import numpy as np
import pandas as pd 
from torchtext.legacy.data import Field, Iterator, BucketIterator, TabularDataset
from sklearn.model_selection import train_test_split
import spacy
from torchtext.data.metrics import bleu_score
import sys
from tqdm.notebook import tqdm
import random
import wandb 
import yaml

In [None]:
def save_yaml(data, filename = 'config.yml',):
    with open(filename, 'w') as file:
        yaml.dump(data, file)

def read_yaml(filename = 'config.yml'):
    with open(filename, 'r') as file:
        documents = yaml.load(file, Loader=yaml.SafeLoader)
        return documents

# HyperParameters

In [None]:
opt = read_yaml()
opt

{'CHECKPOINT_PATH': 'checkpoint/my_checkpoint.pth.tar',
 'DATASET_PATH': 'data/dataset.csv',
 'FFN_dimension': 1024,
 'TEST_PATH': 'data/test.csv',
 'TRAIN_PATH': 'data/train.csv',
 'VAL_PATH': 'data/val.csv',
 'batch_size': 32,
 'd_model': 512,
 'dropout': 0.1,
 'heads': 8,
 'load_model': False,
 'lr': 0.001,
 'max_seq_len': 100,
 'num_epochs': 100,
 'num_layers': 6,
 'ratio_print_per_epoch': 0.1,
 'ratio_validate_per_epoch': 0.3,
 'save_model': True,
 'train_size': 0.7,
 'val_size': 0.1,
 'vocab_min_frequency': 2}

# Prepare dataset

In [None]:
df = pd.read_csv(opt['DATASET_PATH'])
df

Unnamed: 0,source,target
0,can we make this quick roxanne korrine and and...,well i thought wed start with pronunciation if...
1,well i thought wed start with pronunciation if...,not the hacking and gagging and spitting part ...
2,not the hacking and gagging and spitting part ...,okay then how bout we try out some french cuis...
3,youre asking me out thats so cute whats your n...,forget it
4,no no its my fault we didnt have a proper intr...,cameron
...,...,...
221611,your orders mr vereker,im to take the sikali with the main column to ...
221612,im to take the sikali with the main column to ...,lord chelmsford seems to want me to stay back ...
221613,lord chelmsford seems to want me to stay back ...,i think chelmsford wants a good man on the bor...
221614,well i assure you sir i have no desire to crea...,and i assure you you do not in fact id be obli...


In [None]:
df = df.sample(frac=1, random_state = 10).reset_index(drop=True)
df['source']

0                                       what a dork huh huh
1                                the safety circuits failed
2                          thats okay ive already had lunch
3         the batterys in thats done with besides youd d...
4                                           are you kidding
                                ...                        
221611                                           what is it
221612    sorry mr sheldrake im full up youll have to ta...
221613    look whos talking the great white father and w...
221614     what are you talking about you got real problems
221615                                          any friends
Name: source, Length: 221616, dtype: object

In [None]:
df = df.iloc[:150000]
total_count = len(df)
train_count = int(opt['train_size'] * total_count)
val_count = int(opt['val_size'] * total_count)
test_count = total_count - train_count - val_count

train_df = df.iloc[:train_count]
val_df = df.iloc[train_count: train_count + val_count]
test_df = df.iloc[train_count + val_count:]

In [None]:
print(train_df.shape, val_df.shape, test_df.shape)
# assert len(train_df) + len(val_df) + len(test_df) == len(df)

(105000, 2) (15000, 2) (30000, 2)


In [None]:
train_df.to_csv(opt['TRAIN_PATH'], index = False)
val_df.to_csv(opt['VAL_PATH'], index = False)
test_df.to_csv(opt['TEST_PATH'], index = False)

In [None]:
# spacy_to = spacy.load("en_core_web_sm")
def tokenizer(text):
    # return [tok.text for tok in spacy_target.tokenizer(text)]
    return text.split(' ')


TEXT = Field(sequential = True, tokenize=tokenizer, lower=True, init_token="<sos>", eos_token="<eos>")

fields = {'source': ('src', TEXT), 'target': ('trg', TEXT)}


train_data, val_data, test_data = TabularDataset.splits(
   path = '',
   train = opt['TRAIN_PATH'],
   validation = opt['VAL_PATH'],
   test = opt['TEST_PATH'],
   format = 'csv',
   fields = fields,
)

In [None]:
TEXT.build_vocab(train_data, max_size = 50000, min_freq = opt['vocab_min_frequency'])

In [None]:
print(TEXT.vocab.stoi["hello"])
print(len(TEXT.vocab))

349
27229


# Model

In [None]:
class Transformer(nn.Module):
    def __init__(
        self,
        embedding_size,
        src_vocab_size,
        trg_vocab_size,
        src_pad_idx,
        num_heads,
        num_encoder_layers,
        num_decoder_layers,
        forward_expansion,
        dropout,
        max_len,
        device,
    ):
        super(Transformer, self).__init__()
        self.src_word_embedding = nn.Embedding(src_vocab_size, embedding_size)
        self.src_position_embedding = nn.Embedding(max_len, embedding_size)
        self.trg_word_embedding = nn.Embedding(trg_vocab_size, embedding_size)
        self.trg_position_embedding = nn.Embedding(max_len, embedding_size)

        self.device = device
        self.transformer = nn.Transformer(
            embedding_size,
            num_heads,
            num_encoder_layers,
            num_decoder_layers,
            forward_expansion,
            dropout,
        )
        self.fc_out = nn.Linear(embedding_size, trg_vocab_size)
        self.dropout = nn.Dropout(dropout)
        self.src_pad_idx = src_pad_idx

    def make_src_mask(self, src):
        src_mask = src.transpose(0, 1) == self.src_pad_idx

        # (N, src_len)
        return src_mask.to(self.device)

    def forward(self, src, trg):
        src_seq_length, N = src.shape
        trg_seq_length, N = trg.shape

        src_positions = (
            torch.arange(0, src_seq_length)
            .unsqueeze(1)
            .expand(src_seq_length, N)
            .to(self.device)
        )

        trg_positions = (
            torch.arange(0, trg_seq_length)
            .unsqueeze(1)
            .expand(trg_seq_length, N)
            .to(self.device)
        )

        embed_src = self.dropout(
            (self.src_word_embedding(src) + self.src_position_embedding(src_positions))
        )
        embed_trg = self.dropout(
            (self.trg_word_embedding(trg) + self.trg_position_embedding(trg_positions))
        )

        src_padding_mask = self.make_src_mask(src)
        trg_mask = self.transformer.generate_square_subsequent_mask(trg_seq_length).to(
            self.device
        )

        out = self.transformer(
            embed_src,
            embed_trg,
            src_key_padding_mask=src_padding_mask,
            tgt_mask=trg_mask,
        )
        out = self.fc_out(out)
        return out

# Utilities

In [None]:
for temp in test_df.iloc[:3].values:
    print(temp[0])

dieter
if you get lonely go down and see mrs romari i told her you were staying with me
it will how do you figure


In [None]:
def translate_sentence(model, sentence, german, english, device, max_length=50, min_length = 5):
    # Load german tokenizer

    # Create tokens using spacy and everything in lower case (which is what our vocab is)
    if type(sentence) == str:
        tokens = [token.lower() for token in tokenizer(sentence)]
    else:
        tokens = [token.lower() for token in sentence]

    # Add <SOS> and <EOS> in beginning and end respectively
    tokens.insert(0, german.init_token)
    tokens.append(german.eos_token)

    # Go through each german token and convert to an index
    text_to_indices = [german.vocab.stoi[token] for token in tokens]

    # Convert to Tensor
    sentence_tensor = torch.LongTensor(text_to_indices).unsqueeze(1).to(device)
    outputs = [english.vocab.stoi["<sos>"]]
    for i in range(max_length):
        trg_tensor = torch.LongTensor(outputs).unsqueeze(1).to(device)

        with torch.no_grad():
            output = model(sentence_tensor, trg_tensor)

        best_guess = output.argmax(2)[-1, :].item()

        if best_guess == english.vocab.stoi["<eos>"]:
            if len(output) > min_length:
                break
            else:
                temp = output.topk(2, 2).indices[:, :, 1][-1, :].item()
                best_guess = temp
        
        outputs.append(best_guess)

    translated_sentence = [english.vocab.itos[idx] for idx in outputs]
    # remove start token
    return " ".join(translated_sentence[1:])



def bleu(data, model, german, english, device):
    targets = []
    outputs = []

    for example in data.values:
        src = example[0]
        trg = example[1]

        prediction = translate_sentence(model, src, german, english, device).split(" ")
        prediction = prediction[:-1]  # remove <eos> token

        targets.append([trg])
        outputs.append(prediction)

    return bleu_score(outputs, targets)


def save_checkpoint(model, optimizer, lr_scheduler, epoch, table_df, filename= opt['CHECKPOINT_PATH']):
    print("=> Saving checkpoint")
    checkpoint = {
            "model_state_dict": model.state_dict(),
            "optimizer": optimizer.state_dict(),
            "lr_scheduler": lr_scheduler.state_dict(),
            "epoch": epoch, 
            "table_df": table_df,
    }

    torch.save(checkpoint, filename)

def load_checkpoint(filename = opt['CHECKPOINT_PATH']):
    print("=> Loading checkpoint")
    checkpoint = torch.load(filename)

    return checkpoint

# Dataloader

In [None]:
# We're ready to define everything we need for training our Seq2Seq model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_iterator, valid_iterator, test_iterator = BucketIterator.splits(
    (train_data, val_data, test_data),
    batch_size=opt['batch_size'],
    sort_within_batch=True,
    sort_key=lambda x: len(x.src),
    device=device,
)

# Setup training

## (Hyper)parameters set

In [None]:
# Training hyperparameters
num_epochs = opt['num_epochs']
learning_rate = opt['lr']

# Model hyperparameters
src_vocab_size = len(TEXT.vocab)
trg_vocab_size = len(TEXT.vocab)
embedding_size = opt['d_model']
num_heads = opt['heads']
num_encoder_layers = opt['num_layers']
num_decoder_layers = opt['num_layers']
dropout = opt['dropout']
max_len = opt['max_seq_len']
forward_expansion = opt['FFN_dimension']

src_pad_idx = TEXT.vocab.stoi["<pad>"]

## Model instance

In [None]:
# Model
model = Transformer(
    embedding_size,
    src_vocab_size,
    trg_vocab_size,
    src_pad_idx,
    num_heads,
    num_encoder_layers,
    num_decoder_layers,
    forward_expansion,
    dropout,
    max_len,
    device,
).to(device)

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize


wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit: ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


## Learning rate Scheduler, Optimizer

In [None]:
# Adam Optimizer 
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Lr scheduler
lr_scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr= 0.1, 
                        steps_per_epoch= len(train_iterator), epochs= opt['num_epochs'])



# Loss function (ignore pad token when calculating loss) (including softmax inside)
pad_idx = TEXT.vocab.stoi["<pad>"]
criterion = nn.CrossEntropyLoss(ignore_index=pad_idx)  

## Others

In [None]:
sentences = ["what are you talking about", 
             "who the fuck are you", 
             "what are you doing bro",
             "get the fuck out of here", 
             "stay out of my sight dude"]

# example response table dataframe
table_df = pd.DataFrame({'epoch': [], 'step': [], 'query': [], 'response': []})


def validate(model, iterator):
    print("-------Validating-----")
    model.eval()
    for batch_idx, batch in enumerate(iterator):
    # Get input and targets and get to cuda
        inp_data = batch.src.to(device)
        target = batch.trg.to(device)

        # Forward prop
        output = model(inp_data, target[:-1, :])

        # Output is of shape (trg_len, batch_size, output_dim) but Cross Entropy Loss
        # doesn't take input in that form. For example if we have MNIST we want to have
        # output to be: (N, 10) and targets just (N). Here we can view it in a similar
        # way that we have output_words * batch_size that we want to send in into
        # our cost function, so we need to do some reshapin.
        # Let's also remove the start token while we're at it
        output = output.reshape(-1, output.shape[2])
        
        target = target[1:].reshape(-1)
        loss = criterion(output, target)  # including softmax
        losses.append(loss.item())

    mean_loss = np.mean(losses)
    print(f"Val loss: {mean_loss}\n")
    wandb.log({"loss/val": mean_loss})

## Wandb + checkpoint setup

In [60]:
# Wandb setup
project = "Transformer chatbot"
display_name = "model 1"


# Load model checkpoint
if opt['load_model']:
    run = wandb.init(project=project, resume=True)
    checkpoint = torch.load(wandb.restore(opt['CHECKPOINT_PATH']))

    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])

    start_epoch = checkpoint['epoch']
    table_df = checkpoint['table_df']
    
    model.train()
else:
    wandb.init(project=project)
    start_epoch = 0


# Watch wandb
wandb.watch(model, log_freq=100)

Problem at: <ipython-input-60-b5d42242e4f4> 20 <module>


Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/wandb/sdk/wandb_run.py", line 1788, in _atexit_cleanup
    self._on_finish()
  File "/usr/local/lib/python3.7/dist-packages/wandb/sdk/wandb_run.py", line 1971, in _on_finish
    self._poll_exit_response = self._wait_for_finish()
  File "/usr/local/lib/python3.7/dist-packages/wandb/sdk/wandb_run.py", line 1911, in _wait_for_finish
    poll_exit_resp = self._backend.interface.communicate_poll_exit()
  File "/usr/local/lib/python3.7/dist-packages/wandb/sdk/interface/interface.py", line 586, in communicate_poll_exit
    resp = self._communicate_poll_exit(poll_exit)
  File "/usr/local/lib/python3.7/dist-packages/wandb/sdk/interface/interface_shared.py", line 402, in _communicate_poll_exit
    result = self._communicate(rec)
  File "/usr/local/lib/python3.7/dist-packages/wandb/sdk/interface/interface_shared.py", line 213, in _communicate
    return self._communicate_async(rec, local=local).get(timeout=timeout)


Exception: ignored

# Train

In [None]:
for epoch in range(start_epoch, num_epochs):
    print(f"[Epoch {epoch} / {num_epochs}]")


    if opt['save_model']:
        save_checkpoint(model, optimizer, lr_scheduler, epoch, table_df)
        wandb.save(opt['CHECKPOINT_PATH']) #wandb save checkopint

    model.train()
    losses = []



    for batch_idx, batch in enumerate(tqdm(train_iterator)):
        # Get input and targets and get to cuda
        inp_data = batch.src.to(device)
        target = batch.trg.to(device)

        # Forward prop
        output = model(inp_data, target[:-1, :])
        # Output is of shape (trg_len, batch_size, output_dim) but Cross Entropy Loss
        # doesn't take input in that form. For example if we have MNIST we want to have
        # output to be: (N, 10) and targets just (N). Here we can view it in a similar
        # way that we have output_words * batch_size that we want to send in into
        # our cost function, so we need to do some reshapin.
        # Let's also remove the start token while we're at it
        output = output.reshape(-1, output.shape[2])  
        target = target[1:].reshape(-1)

        optimizer.zero_grad()

        loss = criterion(output, target)  #including softmax
        losses.append(loss.item())

        # Back prop
        loss.backward()
        # Clip to avoid exploding gradient issues, makes sure grads are
        # within a healthy range
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)

        # Gradient descent step
        optimizer.step()

        # lr Scheduler step
        lr_scheduler.step()


        steps_to_print = int(len(train_iterator) * opt['ratio_print_per_epoch'])
        steps_to_validate = int((len(train_iterator) * opt['ratio_validate_per_epoch']))
        
        mean_loss = 0

        # Log metrics step
        if batch_idx %  steps_to_print == 0:
            mean_loss = np.mean(losses)
            print(f"Step [{batch_idx} / {len(train_iterator)}] - train loss: {mean_loss}")
            wandb.log({"loss/train": mean_loss})
            wandb.log({'learning rate': optimizer.param_groups[0]['lr']})
            
        # Validate step
        if batch_idx % steps_to_validate == 0:
            train_loss = np.mean(losses)

            # validate
            model.eval() 
            validate(model, valid_iterator) # log wandb metrics
            
            # calculate bleu score
            train_bleu_score = bleu(train_df[:100], model, TEXT, TEXT, device)
            val_bleu_score = bleu(val_df[:100], model, TEXT, TEXT, device)

            # Example response
            import random
            print("----------Example---------")
            sentence = random.choice(sentences)
            translated_sentence = translate_sentence(
                model, sentence, TEXT, TEXT, device, max_length=50
            )
            print(f"Query: {sentence}")
            print(f"Response: {translated_sentence}")

            # Add data to wandb table ------------- 

            # (create a fk new table because wandb doesn't allow us to reuse the old table)----
            # create a wandb table
            columns= ['epoch', 'step', 'query', 'response']
            my_table = wandb.Table(columns=columns)
            # append data to existing table dataframe
            table_df = table_df.append({'epoch': epoch, 'step': batch_idx, 'query': sentence, 'response': translated_sentence}, ignore_index = True)

            # Log wandb metrics
            wandb.log({'loss/train': train_loss})
            wandb.log({'bleu/train': train_bleu_score})
            wandb.log({'bleu/val': val_bleu_score})

            # log wandb table 
            wandb.log({"Example responses": table_df})
            

            model.train()


[Epoch 0 / 100]
=> Saving checkpoint


  0%|          | 0/3282 [00:00<?, ?it/s]

Step [0 / 3282] - train loss: 10.448338508605957
-------Validating-----
Val loss: 9.180075819948886

----------Example---------
Query: who the fuck are you
Response: i i i i i
Step [328 / 3282] - train loss: 8.105618335250625
Step [656 / 3282] - train loss: 7.590418360373055
Step [984 / 3282] - train loss: 7.30699887006778
-------Validating-----
Val loss: 7.043287188946054

----------Example---------
Query: get the fuck out of here
Response: i i i i i
Step [1312 / 3282] - train loss: 6.934880930495124
Step [1640 / 3282] - train loss: 6.8556437444298615
Step [1968 / 3282] - train loss: 6.796923307524459
-------Validating-----
Val loss: 6.715085722823844

----------Example---------
Query: what are you talking about
Response: you you you you you
Step [2296 / 3282] - train loss: 6.679230627951818
Step [2624 / 3282] - train loss: 6.648866723335924
Step [2952 / 3282] - train loss: 6.622996855766401
-------Validating-----
Val loss: 6.583044104754839

----------Example---------
Query: what are

  0%|          | 0/3282 [00:00<?, ?it/s]

Step [0 / 3282] - train loss: 6.094046592712402
-------Validating-----
Val loss: 6.196909468224708

----------Example---------
Query: what are you talking about
Response: you you you you you
Step [328 / 3282] - train loss: 6.218741061394675
Step [656 / 3282] - train loss: 6.234160269555995
Step [984 / 3282] - train loss: 6.244289671702431
-------Validating-----
Val loss: 6.2334068710458075

----------Example---------
Query: stay out of my sight dude
Response: you you you you you
Step [1312 / 3282] - train loss: 6.238437060301487
Step [1640 / 3282] - train loss: 6.243967764411799
Step [1968 / 3282] - train loss: 6.247065832166275
-------Validating-----
Val loss: 6.242382952246055

----------Example---------
Query: what are you talking about
Response: you you you you you
Step [2296 / 3282] - train loss: 6.244661581825231
Step [2624 / 3282] - train loss: 6.248016711974901
Step [2952 / 3282] - train loss: 6.249933579208654
-------Validating-----
Val loss: 6.244933332758035

----------Examp

  0%|          | 0/3282 [00:00<?, ?it/s]

Step [0 / 3282] - train loss: 6.198040962219238
-------Validating-----
Val loss: 6.210998601101815

----------Example---------
Query: what are you talking about
Response: i i i i i
Step [328 / 3282] - train loss: 6.227940443464389
Step [656 / 3282] - train loss: 6.237745932109826
Step [984 / 3282] - train loss: 6.2470250201848385
-------Validating-----
Val loss: 6.237543858157677

----------Example---------
Query: what are you doing bro
Response: you you you you you
Step [1312 / 3282] - train loss: 6.240927178718842
Step [1640 / 3282] - train loss: 6.248106922461757
Step [1968 / 3282] - train loss: 6.250597005535083
-------Validating-----
Val loss: 6.243752605118457

----------Example---------
Query: what are you talking about
Response: you you you you you
Step [2296 / 3282] - train loss: 6.246878939107479
Step [2624 / 3282] - train loss: 6.248969038564061
Step [2952 / 3282] - train loss: 6.250872271542155
-------Validating-----
Val loss: 6.246046688769465

----------Example---------
Q

  0%|          | 0/3282 [00:00<?, ?it/s]

Step [0 / 3282] - train loss: 6.1497650146484375
-------Validating-----
Val loss: 6.205982644507226

----------Example---------
Query: stay out of my sight dude
Response: you you you you you
Step [328 / 3282] - train loss: 6.222385984913149
Step [656 / 3282] - train loss: 6.232478668592328
Step [984 / 3282] - train loss: 6.238109947398735
-------Validating-----
Val loss: 6.232757844096723

----------Example---------
Query: what are you talking about
Response: you you you you you
Step [1312 / 3282] - train loss: 6.240052436521879
Step [1640 / 3282] - train loss: 6.2446244965694175
Step [1968 / 3282] - train loss: 6.247758353252693
-------Validating-----
Val loss: 6.250245707859925

----------Example---------
Query: get the fuck out of here
Response: you you you you you
Step [2296 / 3282] - train loss: 6.252955799612577
Step [2624 / 3282] - train loss: 6.256247539368887
Step [2952 / 3282] - train loss: 6.257940014130479
-------Validating-----
Val loss: 6.253298037638203

----------Exampl

  0%|          | 0/3282 [00:00<?, ?it/s]

Step [0 / 3282] - train loss: 6.2679548263549805
-------Validating-----
Val loss: 6.213644524838062

----------Example---------
Query: what are you doing bro
Response: you you you you you
Step [328 / 3282] - train loss: 6.230376312308443
Step [656 / 3282] - train loss: 6.242475262755399
Step [984 / 3282] - train loss: 6.248504406812266
-------Validating-----
Val loss: 6.241038236900723

----------Example---------
Query: get the fuck out of here
Response: you you you you you
Step [1312 / 3282] - train loss: 6.246955294759472
Step [1640 / 3282] - train loss: 6.251347257444775
Step [1968 / 3282] - train loss: 6.256047591664433
-------Validating-----
Val loss: 6.252494132914249

----------Example---------
Query: what are you talking about
Response: you you you you you
Step [2296 / 3282] - train loss: 6.253759259410347
Step [2624 / 3282] - train loss: 6.256671219354584
Step [2952 / 3282] - train loss: 6.2588367310139015
-------Validating-----
Val loss: 6.254578402877273

----------Example--

  0%|          | 0/3282 [00:00<?, ?it/s]

Step [0 / 3282] - train loss: 6.330493450164795
-------Validating-----
Val loss: 6.21469274175928

----------Example---------
Query: get the fuck out of here
Response: you you you you you
Step [328 / 3282] - train loss: 6.23691445424742
Step [656 / 3282] - train loss: 6.246656993354405
Step [984 / 3282] - train loss: 6.254579670655514
-------Validating-----
Val loss: 6.248237455132971

----------Example---------
Query: get the fuck out of here
Response: you you you you you
Step [1312 / 3282] - train loss: 6.254769780486491
Step [1640 / 3282] - train loss: 6.258679904508794
Step [1968 / 3282] - train loss: 6.2641370241528
-------Validating-----
Val loss: 6.259613464511402

----------Example---------
Query: what are you doing bro
Response: you you you you you
Step [2296 / 3282] - train loss: 6.262695874560214
Step [2624 / 3282] - train loss: 6.2658618206069585
Step [2952 / 3282] - train loss: 6.270234826179819
-------Validating-----
Val loss: 6.2670152037749505

----------Example--------

  0%|          | 0/3282 [00:00<?, ?it/s]

Step [0 / 3282] - train loss: 6.242919921875
-------Validating-----
Val loss: 6.234362223807802

----------Example---------
Query: stay out of my sight dude
Response: you you you you you
Step [328 / 3282] - train loss: 6.257914106349897
Step [656 / 3282] - train loss: 6.271811863880598
Step [984 / 3282] - train loss: 6.280684116618014
-------Validating-----
Val loss: 6.270277491969735

----------Example---------
Query: what are you talking about
Response: you you you you you
Step [1312 / 3282] - train loss: 6.275907115055052
Step [1640 / 3282] - train loss: 6.281600202410292
Step [1968 / 3282] - train loss: 6.287641704882138
-------Validating-----
Val loss: 6.28287184464423

----------Example---------
Query: what are you doing bro
Response: you you you you you
Step [2296 / 3282] - train loss: 6.285543699243929
Step [2624 / 3282] - train loss: 6.288917815046651
Step [2952 / 3282] - train loss: 6.290365369823001
-------Validating-----
Val loss: 6.285585764057273

----------Example-------

  0%|          | 0/3282 [00:00<?, ?it/s]

Step [0 / 3282] - train loss: 6.238842964172363
-------Validating-----
Val loss: 6.27701582705721

----------Example---------
Query: who the fuck are you
Response: you you you you you
Step [328 / 3282] - train loss: 6.296127141866469
Step [656 / 3282] - train loss: 6.297452542963926
Step [984 / 3282] - train loss: 6.306617428216843
-------Validating-----
Val loss: 6.299619925375476

----------Example---------
Query: who the fuck are you
Response: the the the the the
Step [1312 / 3282] - train loss: 6.30360693022814
Step [1640 / 3282] - train loss: 6.305594715511859
Step [1968 / 3282] - train loss: 6.310490174682271
-------Validating-----
Val loss: 6.3098599265254505

----------Example---------
Query: what are you doing bro
Response: you you you you you
Step [2296 / 3282] - train loss: 6.311853291792694
Step [2624 / 3282] - train loss: 6.316396075345221
Step [2952 / 3282] - train loss: 6.319315744977478
-------Validating-----
Val loss: 6.315373665305442

----------Example---------
Query

  0%|          | 0/3282 [00:00<?, ?it/s]

Step [0 / 3282] - train loss: 6.4080610275268555
-------Validating-----
Val loss: 6.284838090044387

----------Example---------
Query: get the fuck out of here
Response: i i i i i
Step [328 / 3282] - train loss: 6.309128311941199
Step [656 / 3282] - train loss: 6.323582303248756
Step [984 / 3282] - train loss: 6.328920286656082
-------Validating-----
Val loss: 6.322627698885421

----------Example---------
Query: stay out of my sight dude
Response: you you you you you
Step [1312 / 3282] - train loss: 6.328747008970609
Step [1640 / 3282] - train loss: 6.333922154869207
Step [1968 / 3282] - train loss: 6.340944170090682
-------Validating-----
Val loss: 6.341030792178701

----------Example---------
Query: get the fuck out of here
Response: i i i i i
Step [2296 / 3282] - train loss: 6.344485900829472
Step [2624 / 3282] - train loss: 6.348370600077841
Step [2952 / 3282] - train loss: 6.351497475925935
-------Validating-----
Val loss: 6.350623514223208

----------Example---------
Query: what 

  0%|          | 0/3282 [00:00<?, ?it/s]

Step [0 / 3282] - train loss: 6.222428798675537
-------Validating-----
Val loss: 6.314770900442245

----------Example---------
Query: who the fuck are you
Response: the the the the the
Step [328 / 3282] - train loss: 6.342935418724117
Step [656 / 3282] - train loss: 6.358959018019545
Step [984 / 3282] - train loss: 6.368565704996353
-------Validating-----
Val loss: 6.362666405803266

----------Example---------
Query: stay out of my sight dude
Response: a a a a a
Step [1312 / 3282] - train loss: 6.367996538654532
Step [1640 / 3282] - train loss: 6.375077982196609
Step [1968 / 3282] - train loss: 6.379794800859733
-------Validating-----
Val loss: 6.377867410391993

----------Example---------
Query: what are you talking about
Response: you you you you you
Step [2296 / 3282] - train loss: 6.382527916014324
Step [2624 / 3282] - train loss: 6.3866484218410084
Step [2952 / 3282] - train loss: 6.390597526742778
-------Validating-----
Val loss: 6.385941932600807

----------Example---------
Quer

  0%|          | 0/3282 [00:00<?, ?it/s]

Step [0 / 3282] - train loss: 6.456384658813477
-------Validating-----
Val loss: 6.394723308847306

----------Example---------
Query: get the fuck out of here
Response: you you you you you
Step [328 / 3282] - train loss: 6.413432423034706
Step [656 / 3282] - train loss: 6.421963880582982
Step [984 / 3282] - train loss: 6.429314000734437
-------Validating-----
Val loss: 6.420072124330439

----------Example---------
Query: stay out of my sight dude
Response: you you you you you
Step [1312 / 3282] - train loss: 6.425709304572317
Step [1640 / 3282] - train loss: 6.431160574044401
Step [1968 / 3282] - train loss: 6.435431427425808
-------Validating-----
Val loss: 6.429701209492028

----------Example---------
Query: what are you talking about
Response: you you you you you
Step [2296 / 3282] - train loss: 6.4335989707489505
Step [2624 / 3282] - train loss: 6.438139787269017
Step [2952 / 3282] - train loss: 6.4431902571555675
-------Validating-----
Val loss: 6.439771945887803

----------Exampl

  0%|          | 0/3282 [00:00<?, ?it/s]

Step [0 / 3282] - train loss: 6.4398722648620605
-------Validating-----
Val loss: 6.432102008576089

----------Example---------
Query: what are you talking about
Response: you you you you you
Step [328 / 3282] - train loss: 6.456245191712726
Step [656 / 3282] - train loss: 6.471627959137911
Step [984 / 3282] - train loss: 6.476988689115648
-------Validating-----
Val loss: 6.467382609317282

----------Example---------
Query: who the fuck are you
Response: i i i i i
Step [1312 / 3282] - train loss: 6.478061203636735
Step [1640 / 3282] - train loss: 6.482613282501351
Step [1968 / 3282] - train loss: 6.486925681514103
-------Validating-----
Val loss: 6.4818243418259645

----------Example---------
Query: stay out of my sight dude
Response: i i i i i
Step [2296 / 3282] - train loss: 6.4865885284240505
Step [2624 / 3282] - train loss: 6.49086616271072
Step [2952 / 3282] - train loss: 6.495359775123246
-------Validating-----
Val loss: 6.495715376999493

----------Example---------
Query: what a

  0%|          | 0/3282 [00:00<?, ?it/s]

Step [0 / 3282] - train loss: 6.765456676483154
-------Validating-----
Val loss: 6.45883971376622

----------Example---------
Query: who the fuck are you
Response: you you you you you
Step [328 / 3282] - train loss: 6.490668609925081
Step [656 / 3282] - train loss: 6.515771142119413
Step [984 / 3282] - train loss: 6.528564307187906
-------Validating-----
Val loss: 6.526234814231741

----------Example---------
Query: get the fuck out of here
Response: you you you you you
Step [1312 / 3282] - train loss: 6.5316348243215145
Step [1640 / 3282] - train loss: 6.536621585933217
Step [1968 / 3282] - train loss: 6.543223140245456
-------Validating-----
Val loss: 6.543124484224907

----------Example---------
Query: what are you talking about
Response: you you you you you
Step [2296 / 3282] - train loss: 6.547295613242534
Step [2624 / 3282] - train loss: 6.551659243920493
Step [2952 / 3282] - train loss: 6.556188808340545
-------Validating-----
Val loss: 6.555866141624396

----------Example------

  0%|          | 0/3282 [00:00<?, ?it/s]

Step [0 / 3282] - train loss: 6.724778175354004
-------Validating-----
Val loss: 6.519866196652676

----------Example---------
Query: get the fuck out of here
Response: you you you you you
Step [328 / 3282] - train loss: 6.56001751643972
Step [656 / 3282] - train loss: 6.579818554077013
Step [984 / 3282] - train loss: 6.595970321062342
-------Validating-----
Val loss: 6.59065245763746

----------Example---------
Query: who the fuck are you
Response: you you you you you
Step [1312 / 3282] - train loss: 6.598235229024459
Step [1640 / 3282] - train loss: 6.603787534409782
Step [1968 / 3282] - train loss: 6.610922996309725
-------Validating-----
Val loss: 6.606196867628685

----------Example---------
Query: get the fuck out of here
Response: you you you you you
Step [2296 / 3282] - train loss: 6.610541448479852


KeyboardInterrupt: ignored

In [None]:
print("----------Example---------")
sentence = "what a dork huh huh"
translated_sentence = translate_sentence(
    model, sentence, TEXT, TEXT, device, max_length=50
)
print(f"Query: {sentence}")
print(f"Response: {translated_sentence}")

# Evaluate on test set

In [None]:
# running on entire test data takes a while
score = bleu(test_df.iloc[:10], model, TEXT, TEXT, device)
print(f"Bleu score {score * 100:.2f}")