In [11]:
from typing import List
from tqdm.auto import tqdm
import os

# dataset 
data_root = os.path.join(
    os.getenv("EXPSTORE"),
    "datasets",
    "Tweebank-dev",
    "converted"
)

assert os.path.exists(data_root) == True

data_splits = {
    "train": "en-ud-tweet-train.fixed.conllu",
    "val": "en-ud-tweet-dev.fixed.conllu",
    "test": "en-ud-tweet-test.fixed.conllu"
}


class ConlluRowInfo:
    word: str
    lemma: str
    pos: str

    def __init__(self, word: str, lemma: str, pos: str) -> None:
        self.word = word
        self.lemma = lemma
        self.pos = pos

    def __str__(self) -> str:
        rep = {
            "word": self.word,
            "lemma": self.lemma,
            "pos": self.pos
        }
        return str(rep)


class ConlluRow:
    info: List[ConlluRowInfo]
    # text: str

    def __init__(self, infos: List[ConlluRowInfo]) -> None:
        self.info = infos

    def __str__(self) -> str:
        return f"info : {self.info}"

In [12]:
def read_data(filename):
    # ============ read ==============
    with open(filename, "r") as f:
        raw_data = f.readlines()
        
    # =============== process =============
    lines = list()
    buffer = list()
    for _, line in tqdm(enumerate(raw_data), desc="reading lines from file"):
        if line == "\n":
            lines.append(buffer)
            buffer = list()
        else:
            buffer.append(line)
            
        
    # make sure that buffer is always empty after the loop ends
    assert len(buffer) == 0
    
    # ========== orga in objects ============
    processed_lines = list()
    for idx, l in tqdm(enumerate(lines), desc="organising in objects"):
        l_info = list()
        for info in l[2:]:
            temp = info.split("\t")
            
            # need idx 1, 2,3 : word, lemma and pos
            word = temp[1]
            lemma = temp[2]
            tag = temp[3]
            
            l_info.append(ConlluRowInfo(word, lemma, tag))
            
        processed_lines.append(ConlluRow(l_info))
        
    # ===========================================
    return processed_lines

In [13]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from einops import rearrange

In [14]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 128) -> None:
        super().__init__()
        
        self.dropout = nn.Dropout(dropout)
        
        positions = torch.arange(0, max_len)
        # vectorize
        positions = rearrange(positions, "n -> n 1")
        
        # but in log space
        # -2i * n / d
        # even steps , since 2i
        # n = 10e3
        denominator = -torch.arange(0, d_model, 2) * torch.log(torch.tensor(10.0e3) / d_model)
        # exp since we took log from the original equation, which was 1/n^(2i / d)
        denominator = torch.exp(denominator)
        
        # positional encoding tensor
        pe = torch.zeros(size=(max_len, 1, d_model))
        
        # encode the first dim
        pe[:, 0, 0::2] = torch.sin(positions * denominator)
        # second dim
        pe[:, 0, 1::2] = torch.cos(positions * denominator)
        
        # register as a buffer, variable but without gradient update
        self.register_buffer("positional_encoding", pe)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x has a shape of (seq_len, batch_size, embedding_dim)
        # so you pass the embedded vectors for a sequence
        
        # residual connection + dropout
        x = x + self.positional_encoding[:x.size(0)]  # type: ignore
        return self.dropout(x)      


In [15]:
class PosTaggerTransformer(nn.Module):
    def __init__(self, 
                 vocab_size: int, 
                 d_model: int, 
                 n_heads: int, 
                 n_encoder_layers: int, 
                 n_decoder_layers: int, 
                 dropout: float,
                 max_len: int, 
                 n_tags: int) -> None:
        super().__init__()
        
        self.d_model = d_model
        
        self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=d_model)
        self.positional_encoder = PositionalEncoding(d_model=self.d_model, max_len=max_len)
        self.dropout = nn.Dropout(dropout)
        
        self.transformer = nn.Transformer(
            d_model=self.d_model,
            nhead=n_heads,
            num_decoder_layers=n_encoder_layers,
            num_encoder_layers=n_decoder_layers, 
            dropout=dropout,
        )
        
        self.linear = nn.Linear(d_model, n_tags)
        self.log_softmax = nn.LogSoftmax(dim=-1)
        
    def forward(self, source, target):
        # ================ get embed ================
        src = self.embedding(source) * torch.sqrt(torch.tensor(self.d_model))
        tgt = self.embedding(target) * torch.sqrt(torch.tensor(self.d_model))
        
        # ================ rearrange the shapes =================
        # default shape : bs, seq, embed
        # transformer needs : seq, bs, embed
        # can also use permute but this approach is more intuitive
        src = rearrange(src, "bs seq embed -> seq bs embed")
        tgt = rearrange(tgt, "bs seq embed -> seq bs embed")
        
        # =================== pos enc ================
        src_pe = self.positional_encoder(src)
        tgt_pe = self.positional_encoder(tgt)
        
        # =========== pass through transformer =============
        out = self.transformer(src_pe, tgt_pe)
        
        # ================ final linear layer ================
        out = self.linear(out)
        out = self.log_softmax(out)
        
        return out        

In [16]:
from torch.utils.data import Dataset, DataLoader


class TweebankDataset(Dataset):
    def __init__(self, file_name: str, max_seq_len: int, file_reader_fn=read_data) -> None:
        super().__init__()
        
        self.MAX_SEQ_LEN = max_seq_len
        
        # ================ tags ===================
        
        self.UNIQUE_TAGS = ['PRON', 'NUM', 'NOUN', 'CCONJ', 'ADV', 'SCONJ', 
                               'ADP', 'AUX', 'PROPN', 'SYM', 'DET', 
                               'INTJ', 'PUNCT', 'X', 'ADJ', 'VERB', 'PART', '</PAD>']
        self.tag_dict = dict()
        self.__encode_tags()
        
        self.n_classes = len(self.UNIQUE_TAGS)
        
        # ================= data ===================
        self.data = file_reader_fn(file_name)
        
        # ============== vocab =====================
        self.vocab = list()
        self.__build_vocab()
        
        self.vocab_size = len(self.vocab)
        
        self.word_dict = dict()
        self.__encode_words()
        
    # ======================= tag encoding ===============
    def __encode_tags(self) -> None:
        for idx, tag in enumerate(self.UNIQUE_TAGS):
            self.tag_dict[tag] = idx
            
    # ======================= vocab building and encoding ===============
    def __build_vocab(self) -> None:
        vocabulary = set()
        for idx in range(len(self.data)):
            words = [i.word for i in self.data[idx].info]
            for w in words:
                vocabulary.add(w)
        
        # ============ add oov and pad ===============
        vocabulary.add("</OOV>")
        vocabulary.add("</PAD>")
        self.vocab = list(vocabulary)
        
    def __encode_words(self) -> None:
        for idx, word in enumerate(self.vocab):
            self.word_dict[word] = idx
        
        
        
        
    # ========================== dataset methods =================   
    def __len__(self) -> int:
        return len(self.data)
    
    def __getitem__(self, idx):
        words = [i.word for i in self.data[idx].info]
        tags = [i.pos for i in self.data[idx].info]
        
        # ============ convert to ids =================
        # using idx from vocab
        oov_idx = self.word_dict["</OOV>"]
        word_ids = list()
        
        for w in words:
            if w in self.word_dict.keys():
                word_ids.append(self.word_dict[w])
            else:
                word_ids.append(oov_idx)
                
        # ================ same for tags ============
        tag_ids = list()
        for t in tags:
            tag_ids.append(self.tag_dict[t])
        
        # ============== pad words ===============
        padding_idx = self.word_dict["</PAD>"]

        
        # left pad
        padded_words = torch.ones(self.MAX_SEQ_LEN, dtype=torch.long) * padding_idx
        padded_words[-len(word_ids):] = torch.tensor(word_ids)
        
        padded_tags = torch.ones(self.MAX_SEQ_LEN, dtype=torch.long) * self.tag_dict.get("</PAD>")  # type: ignore        
        padded_tags[-len(tags):] = torch.tensor(tag_ids)
        
        return {
            "source": padded_words, 
            "targets": padded_tags
        }   
        

# ds = TweebankDataset(os.path.join(data_root, data_splits["train"]), 128)
# for d in ds:
#     print(d)


In [17]:
# device
device = "cuda"

# hparams
max_len = 128
d_model = 512


trainset = TweebankDataset(os.path.join(data_root, data_splits["train"]), max_len)
valset = TweebankDataset(os.path.join(data_root, data_splits["val"]), max_len)
testset = TweebankDataset(os.path.join(data_root, data_splits["test"]), max_len)

train_vocab_size = trainset.vocab_size


data_loader_args = {
    "pin_memory": True,
    "batch_size": 64,
}

train_loader = DataLoader(trainset, shuffle=True, **data_loader_args)
val_loader = DataLoader(valset, shuffle=False, **data_loader_args)
test_loader = DataLoader(testset, shuffle=False, **data_loader_args)

reading lines from file: 0it [00:00, ?it/s]

organising in objects: 0it [00:00, ?it/s]

reading lines from file: 0it [00:00, ?it/s]

organising in objects: 0it [00:00, ?it/s]

reading lines from file: 0it [00:00, ?it/s]

organising in objects: 0it [00:00, ?it/s]

In [18]:
import torch.optim as optim


model = PosTaggerTransformer(
    vocab_size=train_vocab_size,
    d_model=d_model, 
    n_heads=4, 
    n_encoder_layers=4, 
    n_decoder_layers=4, 
    dropout=0.1, 
    max_len=max_len, 
    n_tags=trainset.n_classes)
model = model.cuda()

In [19]:
optimizer = optim.Adam(model.parameters())
criterion = nn.NLLLoss(ignore_index=trainset.tag_dict["</PAD>"])

# fp16
scaler = torch.cuda.amp.GradScaler()

epochs = 5
run_validation_every_n_step = 10


In [20]:
from tqdm.auto import trange

for e in trange(epochs):

    steps = 0
    for batch in train_loader:
        # switch to train mode
        model.train()

        source = batch["source"]
        targets = batch["targets"].long()

        # send data to device
        source = source.to(device)
        targets = targets.to(device)

        # zero out optimizer to accumulate new grads
        optimizer.zero_grad()

        with torch.autocast(device_type="cuda", dtype=torch.float16):
            logits = model(source, targets)
            logits = rearrange(logits, "seq bs probas -> bs probas seq")

            # loss
            loss = criterion(logits, targets)

        # ======== validation ==============
        if steps % run_validation_every_n_step == 0:
            val_losses = []

            # switch context
            model.eval()
            with torch.no_grad():
                for val_batch in val_loader:
                    source = val_batch["source"]
                    targets = val_batch["targets"].long()

                    source = source.to(device)
                    targets = targets.to(device)

                    with torch.autocast(device_type="cuda", dtype=torch.float16):
                        logits = model(source, targets)
                        logits = rearrange(
                            logits, "seq bs probas -> bs probas seq")
                        val_loss = criterion(logits, targets)

                    val_losses.append(val_loss.item())

                    # preds = torch.max(logits, dim=-1).indices

                # log
                print(f"Epoch:: [{e + 1}]/[{epochs}] Step:: {steps}")
                print(
                    f"Train Loss:: {loss} __________ Val Loss:: {torch.mean(torch.tensor(val_losses))}")

        # switch context
        model.train()
        scaler.scale(loss).backward()  # type: ignore
        # loss.backward()
        scaler.step(optimizer)
        # optimizer.step()
        scaler.update()
        steps += 1


  0%|          | 0/5 [00:00<?, ?it/s]

Epoch:: [1]/[5] Step:: 0
Train Loss:: 2.969268560409546 __________ Val Loss:: 2.973590850830078
Epoch:: [1]/[5] Step:: 10
Train Loss:: 2.579211711883545 __________ Val Loss:: 2.5029423236846924
Epoch:: [1]/[5] Step:: 20
Train Loss:: 1.1059503555297852 __________ Val Loss:: 0.9716982245445251
Epoch:: [2]/[5] Step:: 0
Train Loss:: 0.9828959703445435 __________ Val Loss:: 0.9093950390815735
Epoch:: [2]/[5] Step:: 10
Train Loss:: 0.13950476050376892 __________ Val Loss:: 0.10858168452978134
Epoch:: [2]/[5] Step:: 20
Train Loss:: 0.02538725547492504 __________ Val Loss:: 0.014244730584323406
Epoch:: [3]/[5] Step:: 0
Train Loss:: 0.007921725511550903 __________ Val Loss:: 0.0035086784046143293
Epoch:: [3]/[5] Step:: 10
Train Loss:: 0.002956085605546832 __________ Val Loss:: 0.001426293863914907
Epoch:: [3]/[5] Step:: 20
Train Loss:: 0.0016054840525612235 __________ Val Loss:: 0.0007603598642162979
Epoch:: [4]/[5] Step:: 0
Train Loss:: 0.0020909642335027456 __________ Val Loss:: 0.00059688952

In [21]:
torch.save(model.state_dict(), "saved.pt")