In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import nltk
import string
from tqdm import tqdm

In [7]:
nltk.download("punkt")

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


True

In [56]:
# Dataset class

class Dataset_Translation(Dataset):
    def __init__(self, lines, word2idx_X=None, word2idx_Y=None):
        punctuation = set(string.punctuation) | set(["'", '”', '“', "’", "‘"])
        self.len = len(lines)
        self.X = []
        self.Y = []
        self.vocab_X = set()
        self.vocab_Y = set()
        for line in tqdm(lines):
            line_splited = line.split('\t')
            # Tokenize the input and target strings
            tokenized_texts = [nltk.word_tokenize(sent) for sent in line_splited[:2]]
            # Remove punctuation
            tokenized_texts = [[word for word in sent if word not in punctuation] for sent in tokenized_texts]
            self.X.append(tokenized_texts[0])
            self.Y.append(tokenized_texts[1])

            # Update the vocabulary
            self.vocab_X.update(tokenized_texts[0])
            self.vocab_Y.update(tokenized_texts[1])

        self.vocab_X.add('<s>')
        self.vocab_X.add('</s>')
        self.vocab_X.add('<PAD>')
        # self.vocab_X.add('<UNK>') # DESCOMENTAR-HO

        self.vocab_Y.add('<s>')
        self.vocab_Y.add('</s>')
        self.vocab_Y.add('<PAD>')
        # self.vocab_Y.add('<UNK>')

        if word2idx_X is None:
            self.word2idx_X = {word: idx for idx, word in enumerate(sorted(list(self.vocab_X)))}
        else:
            self.word2idx_X = word2idx_X

        if word2idx_Y is None:
            self.word2idx_Y = {word: idx for idx, word in enumerate(sorted(list(self.vocab_Y)))}
        else:
            self.word2idx_Y = word2idx_Y

        self.MAX_LENGTH_out = max([len(y) for y in self.Y])
        self.MAX_LENGTH_in = max([len(x) for x in self.X])

        for i in range(len(self.X)):
            self.Y[i] = ['<s>'] + self.Y[i] + ['</s>'] + ['<PAD>'] * (self.MAX_LENGTH_out - len(self.Y[i]))
            self.X[i] = ['<PAD>'] * (self.MAX_LENGTH_in - len(self.X[i])) + ['<s>'] + self.X[i] + ['</s>']

        for i in range(len(self.X)):
            self.X[i] = [self.word2idx_X[word] if word in self.word2idx_X else 0 for word in self.X[i]] # Canviar a "UNK" # Que no em donava temps a tornar a entrenar el model
            self.Y[i] = [self.word2idx_Y[word] if word in self.word2idx_Y else 0 for word in self.Y[i]] # Canviar a "UNK"

        self.X = torch.tensor(self.X)
        self.Y = torch.tensor(self.Y)

    def __len__(self):
        return self.len

    def __getitem__(self, idx):
        return self.X[idx], self.Y[idx]

In [8]:
with open('spa.txt', 'r') as f:
    lines = f.readlines()

from sklearn.model_selection import train_test_split

train_lines, test_lines = train_test_split(lines, test_size=0.2, random_state=42)
test_lines, val_lines = train_test_split(test_lines, test_size=0.5, random_state=42)

train_dataset = Dataset_Translation(train_lines)

100%|██████████| 102467/102467 [00:20<00:00, 4880.71it/s]
100%|██████████| 12808/12808 [00:02<00:00, 5361.88it/s]
100%|██████████| 12809/12809 [00:02<00:00, 5086.89it/s]


In [57]:
test_dataset = Dataset_Translation(test_lines, word2idx_X=train_dataset.word2idx_X, word2idx_Y=train_dataset.word2idx_Y)
val_dataset = Dataset_Translation(val_lines, word2idx_X=train_dataset.word2idx_X, word2idx_Y=train_dataset.word2idx_Y)

100%|██████████| 12808/12808 [00:03<00:00, 3211.45it/s]
100%|██████████| 12809/12809 [00:02<00:00, 4833.41it/s]


In [59]:
# Dataloader class
train_dataloader = DataLoader(train_dataset, batch_size=150, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=150, shuffle=False)
val_dataloader = DataLoader(val_dataset, batch_size=150, shuffle=False)

In [9]:
print(train_dataset.vocab_X.__len__(), train_dataset.vocab_Y.__len__())

14166 27914


In [39]:
class Translator(nn.Module):
    def __init__(self, vocab_size_input, vocab_size_target, target_len, S_token_id, hidden_size=300, n_layers=1, dropout=0.1):
        super(Translator, self).__init__()
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.hidden_size = hidden_size
        self.n_layers = n_layers
        self.target_len = target_len
        self.S_token_id = S_token_id

        self.embedding = nn.Embedding(vocab_size_input, hidden_size)
        self.lstm_enc = nn.LSTM(hidden_size, hidden_size, n_layers, batch_first=True)
        self.lstm_dec = nn.LSTM(hidden_size, hidden_size, n_layers, batch_first=True)

        self.projection = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(hidden_size, vocab_size_target)
        )

    def forward(self, input):
        batch_size = input.size(0)
        hidden = torch.zeros(self.n_layers, batch_size, self.hidden_size).to(self.device)
        cell = torch.zeros(self.n_layers, batch_size, self.hidden_size).to(self.device)


        embedded = self.embedding(input)
        _, (hidden, cell) = self.lstm_enc(embedded, (hidden, cell))


        inp = self.embedding(torch.tensor([self.S_token_id]*batch_size).to(self.device)).unsqueeze(1)
        h0 = hidden
        c0 = torch.zeros_like(h0).to(self.device)

        pred = self.projection(inp)
        for i in range(self.target_len - 1): # rm <SOS>
            out, (h0, c0) = self.lstm_dec(inp, (h0, c0)) # 1, batch, 512

            pred = torch.cat((pred, self.projection(out)), dim=1)
            inp = out

        return pred


In [40]:
model = Translator(len(train_dataset.vocab_X), len(train_dataset.vocab_Y), train_dataset.MAX_LENGTH_out+2, train_dataset.word2idx_Y['<s>'])

model.to(model.device)


Translator(
  (embedding): Embedding(14166, 300)
  (lstm_enc): LSTM(300, 300, batch_first=True)
  (lstm_dec): LSTM(300, 300, batch_first=True)
  (projection): Sequential(
    (0): Dropout(p=0.1, inplace=False)
    (1): Linear(in_features=300, out_features=27914, bias=True)
  )
)

In [41]:
model(next(iter(train_dataloader))[0].to(model.device)).size()

torch.Size([150, 51, 27914])

In [44]:
import torch.optim as optim

# Define the optimizer and loss function
optimizer = optim.AdamW(model.parameters())
criterion = nn.CrossEntropyLoss()

# Set the number of epochs
num_epochs = 2

# Training loop
for epoch in range(num_epochs):
    running_loss = 0.0
    for inputs, targets in tqdm(train_dataloader):
        inputs = inputs.to(model.device)
        targets = targets.to(model.device)

        # Zero the gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(inputs)

        outputs = outputs.permute(0, 2, 1)

        # Compute the loss
        loss = criterion(outputs, targets)

        # Backward pass
        loss.backward()

        # Update the weights
        optimizer.step()

        # Update the running loss
        running_loss += loss.item()

    # Print the average loss for the epoch
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(train_dataloader)}")


100%|██████████| 684/684 [08:50<00:00,  1.29it/s]


Epoch 1/2, Loss: 0.8770264389744976


100%|██████████| 684/684 [08:56<00:00,  1.27it/s]

Epoch 2/2, Loss: 0.8042848033514637





In [71]:
def compute_metrics(model, dataloader):
    model.eval()
    model.target_len = dataloader.dataset.MAX_LENGTH_out+2
    special_tokens = ['<s>', '</s>', '<PAD>', '<UNK>']
    metrics = {"bleu1": 0, "bleu2": 0, "bleu3": 0, "bleu4": 0, "meteor": 0}
    word2idx_Y_inv = {idx: word for word, idx in dataloader.dataset.word2idx_Y.items()}
    counter = 0
    for inputs, targets in tqdm(dataloader):
        with torch.no_grad():
            batch_size = inputs.size(0)
            inputs = inputs.to(model.device)
            outputs = model(inputs)
            outputs = outputs.max(2)[1]

            outputs = outputs.cpu().numpy()
            targets = targets.cpu().numpy()

            for i in range(batch_size):
                outputs_i = outputs[i]
                targets_i = targets[i]

                outputs_i = [word2idx_Y_inv[idx] for idx in outputs_i]
                targets_i = [word2idx_Y_inv[idx] for idx in targets_i]

                outputs_i = [word for word in outputs_i if word not in special_tokens]
                targets_i = [word for word in targets_i if word not in special_tokens]

                targets_i = ' '.join(targets_i)
                outputs_i = ' '.join(outputs_i)


                metrics["bleu1"] += nltk.translate.bleu_score.sentence_bleu([targets_i], outputs_i, weights=(1, 0, 0, 0))
                metrics["bleu2"] += nltk.translate.bleu_score.sentence_bleu([targets_i], outputs_i, weights=(0, 1, 0, 0))
                metrics["bleu3"] += nltk.translate.bleu_score.sentence_bleu([targets_i], outputs_i, weights=(0, 0, 1, 0))
                metrics["bleu4"] += nltk.translate.bleu_score.sentence_bleu([targets_i], outputs_i, weights=(0, 0, 0, 1))
                # metrics["meteor"] += nltk.translate.meteor_score.meteor_score(nltk.word_tokenize(targets_i), nltk.word_tokenize(outputs_i)) # No se perque esta fallant

                counter += 1

    for key in metrics:
        metrics[key] /= counter
    return metrics

In [72]:
compute_metrics(model, test_dataloader)

The hypothesis contains 0 counts of 2-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 3-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
100%|██████████| 86/86 [00:31<00:00,  2.72it/s]


{'bleu1': 0.2154662556879726,
 'bleu2': 0.0994629776335736,
 'bleu3': 0.05914867591293268,
 'bleu4': 0.034924967506193304,
 'meteor': 0.0}