In [1]:
from jflegDataset import JflegDataset
from torch.utils.data import DataLoader
from _utils import tokenizerSetup, SpecialToken
from model import S2S
import torch

In [2]:

TRAIN_PATH = "dataset/train.csv"
VAL_PATH = "dataset/eval.csv"


In [3]:
tokenizer = tokenizerSetup()

sentences = ["It will rain in the",
            "I want to eat a big bowl of",
            "My dog is"]

a = tokenizer(sentences, return_tensors="pt", padding=True)

a

{'input_ids': tensor([[50259, 50259, 50259,  1026,   481,  6290,   287,   262],
        [   40,   765,   284,  4483,   257,  1263,  9396,   286],
        [50259, 50259, 50259, 50259, 50259,  3666,  3290,   318]]), 'attention_mask': tensor([[0, 0, 0, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 1, 1, 1]])}

In [4]:
from torch.utils.data import Dataset
import torch
import pandas as pd
import numpy as np
import random


class JflegDataset(Dataset):
    def __init__(self, path, tokenizer, max_len=128) -> None:
        super().__init__()
        self.data = pd.read_csv(path)
        self.tokenizer = tokenizer
        self._preprocess()
        self.max_len = max_len

    def _preprocess(self):
        self.data = self.data.groupby(
            'input')['target'].agg(np.array).reset_index()
        self.data["input"] = self.data["input"].str.replace(
            r'^grammar: ', '', regex=True)

    def _process_sequence(self, sequence):
        sequence = f"{self.tokenizer.bos_token} {sequence} {self.tokenizer.eos_token}"
        result = self.tokenizer(sequence, return_tensors="pt",
                                padding="max_length", truncation=True, max_length=self.max_len)
        result = {key: value.squeeze() for key, value in result.items()}
        return result

    def _right_shift(self, original_tensor: torch.Tensor, shift, filling_value) -> torch.Tensor:
        head = torch.full((shift,), filling_value)
        return torch.cat((head, original_tensor[:-shift]))

    def __len__(self):
        return self.data.size//2

    def __getitem__(self, index):
        input = self.data.iloc[index]["input"]
        input = self._process_sequence(input)

        target_text_list = self.data.iloc[index]["target"]
        target_out = random.choice(target_text_list)
        target_out = self._process_sequence(target_out)

        bos_token_index = torch.where(
            target_out["input_ids"] == self.tokenizer.bos_token_id)[0]

        target_in = {
            "input_ids": target_out["input_ids"].clone(),
            "attention_mask": target_out["attention_mask"].clone()
        }

        target_in["input_ids"][bos_token_index] = self.tokenizer.pad_token_id
        target_in["attention_mask"][bos_token_index] = 0.

        return input, target_in, target_out

    def decode(self, embedding):
        return self.tokenizer.decode(embedding, skip_special_tokens=False)

In [5]:
ds_train = JflegDataset(TRAIN_PATH, tokenizer)
ds_eval = JflegDataset(VAL_PATH, tokenizer)

dl_train = DataLoader(ds_train, batch_size=1)
dl_eval = DataLoader(ds_eval, batch_size=1)

In [6]:
it = iter(dl_train)
next(it)
input, target_in, target_out = next(it)

target_in, target_out

({'input_ids': tensor([[50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259,
           50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259,
           50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259,
           50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259,
           50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259,
           50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259,
           50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259,
           50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259,
           50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259,
           50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259,
           50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259,
             352,   767,    12,    20,   657,   661,  5526,   284,  2421,   477,
            139

In [9]:
import torch
from torch import nn, Tensor
import math


class PositionalEncoding(nn.Module):
    # https://pytorch.org/tutorials/beginner/transformer_tutorial.html

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2)
                             * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: Tensor) -> Tensor:
        """
        Arguments:
            x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``
        """
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)


class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size: int, emb_size):
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size)
        self.emb_size = emb_size

    def forward(self, tokens: torch.Tensor):
        return self.embedding(tokens.long()) * math.sqrt(self.emb_size)


class S2S(nn.Module):

    def __init__(self, ntoken: int, d_model: int):
        super().__init__()
        self.model_type = 'Transformer'
        self.pos_encoder = PositionalEncoding(d_model, 0.1)
        self.embedding = TokenEmbedding(ntoken, d_model)

        self.transformer = nn.Transformer(
            d_model=d_model,
            nhead=4,
            num_encoder_layers=6,
            num_decoder_layers=6,
            dim_feedforward=1024,
            dropout=0.1,
            batch_first=True
        )

        self.head = nn.Linear(d_model, ntoken)
        self.sm = nn.Softmax(dim=0)

    def forward(self, input: Tensor, target: Tensor, input_mask: Tensor, target_mask: Tensor) -> Tensor:
        input_mask = input_mask.bool()
        target_mask = target_mask.bool()

        src = self.embedding(input)
        src = self.pos_encoder(src)

        target = self.embedding(target)
        target = self.pos_encoder(target)

        out_sequence_len = target.size(1)

        trmask = torch.triu(torch.ones(
            out_sequence_len, out_sequence_len) * float("-inf"), diagonal=1)


        encoded = self.transformer.encoder(
            src, src_key_padding_mask=input_mask)

        decoded = self.transformer.decoder(
            target, memory=encoded, tgt_mask=trmask,tgt_key_padding_mask=target_mask)

        out = self.head(decoded)

        return self.sm(out)

In [10]:
model = S2S(tokenizer.vocab_size + len(SpecialToken), 768)

for input, target_in, target_out in dl_train:
    print(target_in["input_ids"].shape)
    out = model(input["input_ids"], target_in["input_ids"], input["attention_mask"], target_in["attention_mask"])
    i = out.argmax(2)

    print(f"{out.shape=}")
    break

torch.Size([1, 128])
out.shape=torch.Size([1, 128, 50260])
