In [1]:
import pandas as pd
import torch
import torch.nn as nn
from tqdm import tqdm
import numpy as np
import math

In [2]:
df = pd.read_csv('../data/raw/filtered.tsv', sep='\t')
sents = df[(df['similarity'] < 0.7) & (df['ref_tox'] > df['trn_tox'])]
sents = sents[['reference', 'translation']]

In [3]:
from torch.utils.data import Dataset, DataLoader
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = 0, 1, 2, 3
special_symbols = ['<unk>', '<pad>', '<bos>', '<eos>']

class DetoxDataset(Dataset):
    def __init__(self, sents, train=True, vocab=None):
        self.tokenizer = get_tokenizer('spacy', language='en_core_web_md')
        self.train = train
        self.sents = [sent[0] for sent in sents.values]
        if self.train:
            self.labels = [sent[1] for sent in sents.values]
        else:
            self.labels = []
        if vocab is None:
            self.vocab = self.build_vocab()

    def build_vocab(self):
        vocab = build_vocab_from_iterator(map(self.tokenizer, self.sents + self.labels),
                                          min_freq=1,
                                          specials=special_symbols,
                                          special_first=True)
        vocab.set_default_index(UNK_IDX)
        return vocab
    
    def __len__(self):
        return len(self.sents)

    def get_tokens(self, sentence):
        tokens = self.tokenizer(sentence)
        return self.vocab(tokens)

    def __getitem__(self, idx):
        if self.train:
            return self.get_tokens(self.sents[idx]), self.get_tokens(self.labels[idx])
        else:
            return self.get_tokens(self.sents[idx])

In [4]:
dataset = DetoxDataset(sents)

In [5]:
MAX_SIZE = 30

def collate_batch(batch):
    sentences_batch, labels_batch = [], []
    sentences_padding_mask, labels_padding_mask = [], []
    for _sent, _label in batch:
        _sent = _sent[:MAX_SIZE]
        _label = _label[:MAX_SIZE]

        sent_mask = [1] * (len(_sent) + 2)
        label_mask = [1] * (len(_label) + 2)

        _sent.append(EOS_IDX)
        while len(_sent) < MAX_SIZE + 1:
            _sent.append(PAD_IDX)
            sent_mask.append(0)
        sentences_batch.append(torch.cat((torch.tensor([BOS_IDX]),
                                          torch.tensor(_sent))))
        sentences_padding_mask.append(torch.tensor(sent_mask))

        _label.append(EOS_IDX)
        while len(_label) < MAX_SIZE + 1:
            _label.append(PAD_IDX)
            label_mask.append(0)
        labels_batch.append(torch.cat((torch.tensor([BOS_IDX]),
                                          torch.tensor(_label))))
        labels_padding_mask.append(torch.tensor(label_mask))

    sentences_batch = torch.stack(sentences_batch, dim=0)
    labels_batch = torch.stack(labels_batch, dim=0)

    senteces_mask = torch.zeros((len(batch) * 8, MAX_SIZE + 2, MAX_SIZE + 2)).type(torch.bool)
    labels_mask = torch.zeros((len(batch) * 8, MAX_SIZE + 2, MAX_SIZE + 2)).type(torch.bool)

    sentences_padding_mask = torch.stack(sentences_padding_mask, dim=0).type(torch.bool)
    labels_padding_mask = torch.stack(labels_padding_mask, dim=0).type(torch.bool)

    return sentences_batch, labels_batch, senteces_mask, labels_mask, sentences_padding_mask, labels_padding_mask

In [6]:
train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size

BATCH_SIZE = 8

train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_batch)

In [7]:
print(len(dataset.vocab))

53709


In [8]:
# for batch in train_loader:
#     sent, label, src_padding_mask, tgt_padding_mask = batch
#     print(sent[0], src_padding_mask[0], sep='\n')
#     break

In [9]:
class PositionalEncoding(nn.Module):
    def __init__(self,
                 emb_size: int,
                 dropout: float,
                 maxlen: int = 5000):
        super(PositionalEncoding, self).__init__()
        den = torch.exp(- torch.arange(0, emb_size, 2)* math.log(10000) / emb_size)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        pos_embedding = torch.zeros((maxlen, emb_size))
        pos_embedding[:, 0::2] = torch.sin(pos * den)
        pos_embedding[:, 1::2] = torch.cos(pos * den)
        pos_embedding = pos_embedding.unsqueeze(-2)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer('pos_embedding', pos_embedding)

    def forward(self, token_embedding: torch.Tensor):
        return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0), :])
    
class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size: int, emb_size):
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size)
        self.emb_size = emb_size

    def forward(self, tokens: torch.Tensor):
        return self.embedding(tokens.long()) * math.sqrt(self.emb_size)

In [10]:
class Seq2SeqTransformer(nn.Module):
    def __init__(self, 
                 vocab_size,
                 num_encoder_layers=6,
                 num_decoder_layers=6,
                 emb_size=512,
                 nhead=8,
                 dim_feedforward=2048,
                 dropout=0.1,
                 ):
        super().__init__()
        self.transformer = nn.Transformer(
            d_model=emb_size,
            nhead=nhead,
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True
        )
        self.generator = nn.Linear(emb_size, vocab_size)
        self.tok_embedding = TokenEmbedding(vocab_size, emb_size)
        self.pos_encoder = PositionalEncoding(emb_size, dropout=dropout)

    def forward(self,
                sentence,
                target,
                senteces_mask, 
                targets_mask,
                sentence_padding_mask,
                target_padding_mask,
                ):
        sentence_embedding = self.tok_embedding(sentence)
        target_embedding = self.tok_embedding(target)
        sentence_embedding = self.pos_encoder(sentence_embedding)
        target_embedding = self.pos_encoder(target_embedding)
        outs = self.transformer(sentence_embedding,
                                target_embedding, 
                                senteces_mask,
                                targets_mask,
                                src_key_padding_mask=sentence_padding_mask, 
                                tgt_key_padding_mask=target_padding_mask
                                )
        return self.generator(outs)


In [11]:
model = Seq2SeqTransformer(vocab_size=len(dataset.vocab), num_encoder_layers=2, num_decoder_layers=2)

In [12]:
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)
criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX)

In [13]:
def train(epoch, model, optimizer, criterion, device, train_loader):
    model = model.to(device)
    losses = []
    pbar = tqdm(train_loader)
    for batch in pbar:
        # torch.cuda.empty_cache()
        sent, label, senteces_mask, labels_mask, src_padding_mask, tgt_padding_mask = batch
        sent, label, senteces_mask, labels_mask, src_padding_mask, tgt_padding_mask = sent.to(device), label.to(device), senteces_mask.to(device), labels_mask.to(device), src_padding_mask.to(device), tgt_padding_mask.to(device)

        output = model(sent, label, senteces_mask, labels_mask, src_padding_mask, tgt_padding_mask)
        output = output.permute(0, 2, 1)
        # print(output.shape)
        loss = criterion(output, label)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if np.isnan(loss.item()):
        #     print(loss.item())
        #     print(output)
        #     print('---')
        #     print(label)
            print(losses)
            break
        losses.append(loss.item())

        pbar.set_description(f'Epoch {epoch}: Avg Loss: {np.mean(losses):.5f}')

def evaluate(epoch, model, criterion, device, val_loader):
    model.to(device)
    losses = []
    pbar = tqdm(val_loader)

    model.eval()
    with torch.no_grad():
        for batch in pbar:
            sent, label, senteces_mask, labels_mask, src_padding_mask, tgt_padding_mask = batch
            sent, label, senteces_mask, labels_mask, src_padding_mask, tgt_padding_mask = sent.to(device), label.to(device), senteces_mask.to(device), labels_mask.to(device), src_padding_mask.to(device), tgt_padding_mask.to(device)

            output = model(sent, label, senteces_mask, labels_mask, src_padding_mask, tgt_padding_mask)
            output = output.permute(0, 2, 1)
            # print(output.shape)
            loss = criterion(output, label)
            losses.append(loss.item())

            pbar.set_description(f'\tEpoch {epoch}: Avg Val Loss: {np.mean(losses):.5f}')
        
        for batch in val_loader:
            sent, target, src_padding_mask, tgt_padding_mask = batch
            sent, target = sent[0], target[0]
            print(' '.join(dataset.vocab.lookup_tokens(sent.tolist())))
            print(' '.join(dataset.vocab.lookup_tokens(target.tolist())))
            pred = predict(sent, model, device)
            print(' '.join(dataset.vocab.lookup_tokens(pred[0].tolist())))
            break
        
def predict(batch, model, device):

    model = model.to(device)

    sent, _, senteces_mask, labels_mask, src_padding_mask, tgt_padding_mask = batch

    sent = sent[0].to(device).reshape(1, -1)
    label = torch.tensor([PAD_IDX]).to(device).reshape(1, -1)
    senteces_mask = senteces_mask[0].to(device).reshape(1, -1)
    labels_mask = torch.zeros((8, len(label), len(label))).to(device).type(torch.bool)
    src_padding_mask = src_padding_mask[0].to(device).reshape(1, -1)
    tgt_padding_mask = torch.ones((1, len(label))).to(device).type(torch.bool)
    model.eval()
    print(senteces_mask.shape)
    print(labels_mask.shape)
    print(src_padding_mask.shape)
    print(tgt_padding_mask.shape)
    with torch.no_grad():
        for i in range(MAX_SIZE):
            output = model(sent, label, senteces_mask, labels_mask, src_padding_mask, tgt_padding_mask)
            output = output.permute(0, 2, 1)
            token = torch.argmax(output[0, :, i]).item()
            label = torch.tensor(label[0].tolist() + [token]).to(device).reshape(1, -1)
            labels_mask = torch.zeros((8, len(label), len(label))).to(device).type(torch.bool)
            tgt_padding_mask = torch.ones((1, len(label))).to(device).type(torch.bool)
            if token == EOS_IDX:
                break
    return label


In [14]:
for batch in val_loader:
    sent, target, senteces_mask, labels_mask, src_padding_mask, tgt_padding_mask = batch
    sent, target = sent[0], target[0]
    print(' '.join(dataset.vocab.lookup_tokens(sent.tolist())))
    print(' '.join(dataset.vocab.lookup_tokens(target.tolist())))
    pred = predict(batch, model, "cuda")
    print(' '.join(dataset.vocab.lookup_tokens(pred[0].tolist())))
    break

<bos> You ever had a pissed - off marine on your ass ? <eos> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
<bos> you ever been pissed off by a Marine ? <eos> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
torch.Size([1, 1024])
torch.Size([8, 1, 1])
torch.Size([1, 32])
torch.Size([1, 1])


  return torch._transformer_encoder_layer_fwd(
  return torch._native_multi_head_attention(


RuntimeError: shape '[1, 1, 1, 2]' is invalid for input of size 1

In [16]:
EPOCHS = 10

for epoch in range(10):
    train(epoch, model, optimizer, criterion, 'cuda', train_loader)
    evaluate(epoch, model, criterion, 'cuda', val_loader)

  0%|          | 0/11423 [00:00<?, ?it/s]


[]


	Epoch 0: Avg Val Loss: nan:   7%|â–‹         | 89/1270 [00:02<00:27, 42.70it/s]


KeyboardInterrupt: 