In [1]:
import os
import pandas as pd
import numpy as np
import torch
import pytorch_lightning as pl

In [2]:
import gensim.downloader as dl

pretrained_weights_name = "word2vec-google-news-300"
model_dl_path = os.path.join(
    dl.BASE_DIR, pretrained_weights_name, f"{pretrained_weights_name}.gz")


if os.path.exists(model_dl_path):
    # load model
    print(f"Loading model from {model_dl_path}")
    gnews_embeddings = dl.load(pretrained_weights_name)
else:
    # download
    print(f"Model will be downloaded at {model_dl_path}")
    gnews_embeddings = dl.load("word2vec-google-news-300")

Loading model from /home/shawon/gensim-data/word2vec-google-news-300/word2vec-google-news-300.gz


In [3]:
# add PAD to embeddings

# 0 padding, 300 embedding dims
gnews_embeddings.add_vector("<PAD>", np.zeros(300))

# need it later for loading the embeddings in pytorch model
padding_idx = len(gnews_embeddings.index_to_key) - 1



In [4]:
# https://github.com/Oneplus/Tweebank

train_file = os.path.join(
    "/mnt/Others/experiments/datasets/Tweebank-dev/converted/"
    "en-ud-tweet-train.fixed.conllu")

# assert os.path.exists(train_file)

with open(train_file) as f:
    data = f.readlines()

In [5]:
# break line at every "\n"
tweets = list()
buffer = list()
for idx, tw in enumerate(data):
    if tw == "\n":
        # one partition here
        tweets.append(buffer)
        buffer = []
    else:
        # keep appending
        buffer.append(tw)
        
tweets[0]

['# tweet_id = feb_jul_16.1463316480\n',
 "# text = RT @USER991: Dear diary,       I've been rapping in 3 accents and no longer know which one is truly mine. I am a sadting - Drake URL217…\n",
 '1\tRT\trt\tX\t_\t_\t10\tdiscourse\t_\t_\n',
 '2\t@USER991\t@USER\tX\t_\t_\t1\tdiscourse\t_\tSpaceAfter=No\n',
 '3\t:\t:\tPUNCT\t_\t_\t1\tpunct\t_\t_\n',
 '4\tDear\tdear\tADJ\t_\t_\t5\tamod\t_\t_\n',
 '5\tdiary\tdiary\tNOUN\t_\t_\t10\tvocative\t_\tSpaceAfter=No\n',
 '6\t,\t,\tPUNCT\t_\t_\t10\tpunct\t_\t_\n',
 '7\tI\ti\tPRON\t_\t_\t10\tnsubj\t_\tSpaceAfter=No\n',
 "8\t've\t've\tAUX\t_\t_\t10\taux\t_\t_\n",
 '9\tbeen\tbe\tAUX\t_\t_\t10\taux\t_\t_\n',
 '10\trapping\trap\tVERB\t_\t_\t0\troot\t_\t_\n',
 '11\tin\tin\tADP\t_\t_\t13\tcase\t_\t_\n',
 '12\t3\tNUMBER\tNUM\t_\t_\t13\tnummod\t_\t_\n',
 '13\taccents\taccent\tNOUN\t_\t_\t10\tobl\t_\t_\n',
 '14\tand\tand\tCCONJ\t_\t_\t17\tcc\t_\t_\n',
 '15\tno\tno\tADV\t_\t_\t16\tadvmod\t_\t_\n',
 '16\tlonger\tlonger\tADV\t_\t_\t17\tadvmod\t_\t_\n',
 '17\tknow\

In [6]:
# format for tokens
# number - word - lemma - pos - _ - _ - id - role, -, - 

'4\tDear\tdear\tADJ\t_\t_\t5\tamod\t_\t_\n'.split("\t")


['4', 'Dear', 'dear', 'ADJ', '_', '_', '5', 'amod', '_', '_\n']

In [7]:
# need idx 1, 2,3 : word, lemma and pos

class ConlluRowInfo:
    word: str
    lemma: str
    pos: str
    
    def __init__(self, word: str, lemma: str, pos: str) -> None:
        self.word = word
        self.lemma = lemma
        self.pos = pos
        
    def __str__(self) -> str:
        rep = {
            "word": self.word,
            "lemma": self.lemma,
            "pos": self.pos
        }
        return str(rep)

In [8]:
from typing import List

class ConlluRow:
    info: List[ConlluRowInfo]
    # text: str
    
    def __init__(self, infos: List[ConlluRowInfo]) -> None:
        self.info = infos
        
    def __str__(self) -> str:
        return f"info : {self.info}"

In [9]:
structured_tweets = list()

for tweet in tweets:
    # text = tweet[1].replace("# text = ", "")
    info_in_tweet = list()
    # start from idx 2
    for infos in tweet[2:]:
        buffer = infos.split("\t")
        try:
            word = buffer[1]
            lemma = buffer[2]
            tag = buffer[3]
            info_in_tweet.append(ConlluRowInfo(word, lemma, tag))
        except IndexError:
            print(buffer)
        except AttributeError as e:
            print(e.name)
    structured_tweets.append(ConlluRow(info_in_tweet))  

In [10]:
# time to define the torch dataset

from torch.utils.data import Dataset
from tqdm.auto import trange, tqdm
from typing import Dict, List

class TweebankDataset(Dataset):
    def __init__(self, filename, w2v_weights=gnews_embeddings) -> None:
        self.filename = filename
        
        self.w2v = w2v_weights
        self.data = list()
        self.__read_data()
        
        self.MAX_SEQ_LEN = 50 # default value
        # self.__find_max_seq_len()
        
        self.UNIQUE_TAGS = ['PRON', 'NUM', 'NOUN', 'CCONJ', 'ADV', 'SCONJ', 
                               'ADP', 'AUX', 'PROPN', 'SYM', 'DET', 
                               'INTJ', 'PUNCT', 'X', 'ADJ', 'VERB', 'PART', "<PAD>"]
        self.tag_dict = dict()
        self.__encode_tags()
        
        self.number_tags = len(self.UNIQUE_TAGS)
        
        self.vocabulary = self.w2v.index_to_key  # type: ignore
            
    
    def __len__(self) ->  int:
        return len(self.data)
    
    def __getitem__(self, idx) -> Dict[str, torch.Tensor]:
        # ============== collect ===================
        words = [i.word for i in self.data[idx].info]
        # lemmas = [i.lemma for i in self.data[idx].info]
        tags = [i.pos for i in self.data[idx].info]
        
        
        # =================== convert using word2vec weights ==========
        for idx in range(len(words)):
            try:
                w2v_idx = self.w2v.key_to_index[words[idx]]  # type: ignore 
            except KeyError:
                w2v_idx = 0 # </s>
            words[idx] = w2v_idx
            tags[idx] = self.tag_dict[tags[idx]]
            
        
        # ============== pad words ===============
        # left pad
        padded_words = np.zeros(self.MAX_SEQ_LEN, dtype=np.int32)
        padded_words[-len(words):] = words
        
        # ============== pad tags =================
        padded_tags = np.ones(self.MAX_SEQ_LEN, dtype=np.int32) * self.tag_dict.get("<PAD>")  # type: ignore        
        padded_tags[-len(tags):] = tags
        
        return {
            "words": torch.tensor(padded_words),
            "tags": torch.tensor(padded_tags),
        }
        
    def __find_max_seq_len(self) -> None:
        seq_lens = []
        
        for idx in range(len(self.data)):
            words = [i.word for i in self.data[idx].info]
            seq_lens.append(len(words))
        
        
        self.MAX_SEQ_LEN = max(seq_lens)
        
    def __encode_tags(self) -> None:
        for idx, tag in enumerate(self.UNIQUE_TAGS):
            self.tag_dict[tag] = idx
        
    def __read_data(self) -> None:
        with open(self.filename, "r") as f:
            data = f.readlines()
            
            # ============ read the text file =============
            lines = list()
            buffer = list()
            for _, line in tqdm(enumerate(data)):
                if line == "\n":
                    lines.append(buffer)
                    buffer = []
                else:
                    buffer.append(line)
                    
            # ============== organize in objects ==============
            for idx, line in tqdm(enumerate(lines)):
                # from line index 2 and onwards
                line_info = list()
                for info in line[2:]:
                    buffer = info.split("\t")
                
                    try:
                        word = buffer[1]
                        lemma = buffer[2]
                        tag = buffer[3]
                        
                        line_info.append(ConlluRowInfo(word, lemma, tag))
                        
                    except IndexError:
                        print(buffer)
                        
                
                lines[idx] = ConlluRow(line_info)    

            self.data = lines


In [11]:
dataset = TweebankDataset(train_file)
sample = dataset[0]
sample

0it [00:00, ?it/s]

0it [00:00, ?it/s]

{'words': tensor([    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0, 31905,
             0,     0, 12654, 14263,     0,    20,   190,    42, 40105,     1,
           234, 22860,     0,    86,   951,   177,    48,    45,     4,  2604,
          2747,     0,    20,   248,     0,     0,     0, 10297,     0,     0],
        dtype=torch.int32),
 'tags': tensor([17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17,
         17, 13, 13, 12, 14,  2, 12,  0,  7,  7, 15,  6,  1,  2,  3,  4,  4, 15,
         10,  1,  7,  4,  0, 12,  0,  7, 10,  2, 12,  8, 13, 12],
        dtype=torch.int32)}

In [12]:
dataset.tag_dict

{'PRON': 0,
 'NUM': 1,
 'NOUN': 2,
 'CCONJ': 3,
 'ADV': 4,
 'SCONJ': 5,
 'ADP': 6,
 'AUX': 7,
 'PROPN': 8,
 'SYM': 9,
 'DET': 10,
 'INTJ': 11,
 'PUNCT': 12,
 'X': 13,
 'ADJ': 14,
 'VERB': 15,
 'PART': 16,
 '<PAD>': 17}

In [13]:
# https://stackabuse.com/python-how-to-flatten-list-of-lists/


# import itertools

# all_tags = [data["tags"] for data in dataset]
# all_tags = list(itertools.chain(*all_tags))
# unique_tags = set(all_tags)
# print(list(unique_tags))

In [88]:
# dataloaders
from torch.utils.data import DataLoader

bs = 128
dl_args = {
    "pin_memory": True,
    "batch_size": bs
}


training_set = dataset
validation_set = TweebankDataset("/mnt/Others/experiments/datasets/Tweebank-dev/converted/en-ud-tweet-dev.fixed.conllu")
test_set = TweebankDataset("/mnt/Others/experiments/datasets/Tweebank-dev/converted/en-ud-tweet-test.fixed.conllu")

train_loader = DataLoader(training_set, shuffle=True, **dl_args)
val_loader = DataLoader(validation_set, shuffle=False, **dl_args)
test_loader = DataLoader(test_set, shuffle=False, **dl_args)

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

In [89]:
assert training_set.tag_dict == validation_set.tag_dict == test_set.tag_dict

In [116]:
# model
# https://pytorch.org/tutorials/beginner/nlp/sequence_models_tutorial.html
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

class LSTMTagger(nn.Module):
    def __init__(self, 
                 embedding_dim: int, 
                 hidden_dim: int,  
                 tagset_size: int,
                 padding_idx=padding_idx, 
                 freeze_embeddings=True, 
                 w2v_weights=gnews_embeddings) -> None:
        
        super(LSTMTagger, self).__init__()
        
        self.hidden_dim = hidden_dim
        self.taget_size = tagset_size
        
        
        embedding_tensors = torch.from_numpy(w2v_weights.vectors) # type: ignore        
        self.word_embeddings = nn.Embedding.from_pretrained(
            embedding_tensors, freeze=freeze_embeddings, padding_idx=padding_idx)
        
        self.lstm = nn.LSTM(
            embedding_dim, 
            hidden_dim, 
            batch_first=True,
            bidirectional=True)
        
        self.attention =  nn.MultiheadAttention(hidden_dim * 2, num_heads=4, dropout=0.1, batch_first=True)
        self.relu = nn.ReLU()
        
        self.linear = nn.Linear(hidden_dim * 2, tagset_size)

        
    def forward(self, words):
        embeds = self.word_embeddings(words)
        
        lstm_out, _ = self.lstm(embeds)
        
        attn_out, _ = self.attention(lstm_out, lstm_out, lstm_out)
        relu_out = self.relu(attn_out)
        
        linear_out = self.linear(relu_out)

        logits = F.log_softmax(linear_out, dim=-1)
        return logits
        

In [121]:
tagset_size = len(dataset.UNIQUE_TAGS)
model = LSTMTagger(embedding_dim=300, hidden_dim=100,  tagset_size=tagset_size)

In [131]:
from einops import rearrange

# run a sample forward pass
sample = dataset[42]

with torch.no_grad():
    words = sample["words"].unsqueeze(0)
    tags = sample["tags"].unsqueeze(0).long()
    
    out = model(words)
    
    # apparently nllloss expects inputs in shape (bs, n_classes, feature_dims.......)
    # https://pytorch.org/docs/stable/generated/torch.nn.NLLLoss.html#torch.nn.NLLLoss
    out = rearrange(out, "1 words probas -> 1 probas words")
    
    print(f"out :: {out.size()}")
    print(f"tags :: {tags.size()}")
    
# sample_loss = F.nll_loss(out, tags, ignore_index=17)
sample_loss = nn.NLLLoss(ignore_index=17)
print(sample_loss(input=out, target=tags))

out :: torch.Size([1, 18, 50])
tags :: torch.Size([1, 50])
tensor(2.9064)


In [83]:
# https://discuss.pytorch.org/t/loss-function-for-multi-class-with-probabilities-as-output/60866

x = torch.tensor([[0, 0, 1], [1, 0, 0], [0, 0, 1], [0, 1, 0]], dtype=torch.int32)
_, i = torch.max(x, dim=-1)

print(i)

y = torch.randn((4, 3), dtype=torch.float32)
print(y)

print(F.nll_loss(F.log_softmax(y, dim=-1), i))

tensor([2, 0, 2, 1])
tensor([[ 1.0466,  0.2878,  0.9655],
        [ 0.3407,  0.7659, -1.7061],
        [-0.4508, -1.6277, -0.3213],
        [ 0.1605,  1.8744, -1.1999]])
tensor(0.7249)


In [132]:
device = torch.device("cuda")
model = model.to(device)

In [134]:
optimizer = optim.AdamW(params=model.parameters())
# ignore the index for PAD
criterion = nn.NLLLoss(ignore_index=training_set.tag_dict.get("<PAD>"))  # type: ignore        
run_validation_every_n_step = 10


# fp16
scaler = torch.cuda.amp.GradScaler()

epochs = 50
for e in trange(epochs):

    steps = 0
    for batch in train_loader:
        # switch to train mode
        model.train()
        
        words = batch["words"]
        tags = batch["tags"].long()
        
        # send data to device
        words = words.to(device)
        tags = tags.to(device)
        
        # zero out optimizer to accumulate new grads
        optimizer.zero_grad()
        
        with torch.autocast(device_type="cuda", dtype=torch.float16):
            logits = model(words)
            logits = rearrange(logits, "bs words probas -> bs probas words")
            
            # loss
            loss = criterion(logits, tags)
        
        
        # ======== validation ==============
        if steps % run_validation_every_n_step == 0:
            val_losses = []
            
            # switch context
            model.eval()
            with torch.no_grad():
                for val_batch in val_loader:
                    words = val_batch["words"]
                    tags = val_batch["tags"].long()
                    
                    words = words.to(device)
                    tags = tags.to(device)
                    
                    with torch.autocast(device_type="cuda", dtype=torch.float16):
                        logits = model(words)
                        logits = rearrange(logits, "bs words probas -> bs probas words")
                        val_loss = criterion(logits, tags)

                    val_losses.append(val_loss.item())
                    
                    # preds = torch.max(logits, dim=-1).indices

                # log
                print(f"Epoch:: [{e + 1}]/[{epochs}] Step:: {steps}")
                print(f"Train Loss:: {loss} __________ Val Loss:: {torch.mean(torch.tensor(val_losses))}")
        
        # switch context
        model.train()
        scaler.scale(loss).backward()  # type: ignore
        # loss.backward()
        scaler.step(optimizer)
        # optimizer.step()
        scaler.update()
        steps += 1
    


  0%|          | 0/300 [00:00<?, ?it/s]

torch.Size([128, 50])
Epoch:: [1]/[300] Step:: 0
Train Loss:: 2.901834487915039 __________ Val Loss:: 2.90346360206604
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
Epoch:: [1]/[300] Step:: 10
Train Loss:: 2.6507415771484375 __________ Val Loss:: 2.669147253036499
torch.Size([128, 50])
torch.Size([103, 50])
torch.Size([128, 50])
Epoch:: [2]/[300] Step:: 0
Train Loss:: 2.6422011852264404 __________ Val Loss:: 2.664486885070801
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
Epoch:: [2]/[300] Step:: 10
Train Loss:: 2.5923993587493896 __________ Val Loss:: 2.6027467250823975
torch.Size([128, 50])
torch.Size([103, 50])
torch.Size([128, 50])
Epoch:: [

torch.Size([103, 50])
torch.Size([128, 50])
Epoch:: [19]/[300] Step:: 0
Train Loss:: 0.9850803017616272 __________ Val Loss:: 1.047589659690857
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
Epoch:: [19]/[300] Step:: 10
Train Loss:: 0.9010156393051147 __________ Val Loss:: 1.0141125917434692
torch.Size([128, 50])
torch.Size([103, 50])
torch.Size([128, 50])
Epoch:: [20]/[300] Step:: 0
Train Loss:: 0.9633456468582153 __________ Val Loss:: 0.9993138909339905
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
Epoch:: [20]/[300] Step:: 10
Train Loss:: 0.8557214736938477 __________ Val Loss:: 0.9856648445129395
torch.Size([128, 50])
torch.Size([103, 50])
t

torch.Size([128, 50])
Epoch:: [36]/[300] Step:: 0
Train Loss:: 0.5193265676498413 __________ Val Loss:: 0.9140600562095642
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
Epoch:: [36]/[300] Step:: 10
Train Loss:: 0.4784262776374817 __________ Val Loss:: 0.9054408669471741
torch.Size([128, 50])
torch.Size([103, 50])
torch.Size([128, 50])
Epoch:: [37]/[300] Step:: 0
Train Loss:: 0.45328643918037415 __________ Val Loss:: 0.9397297501564026
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
Epoch:: [37]/[300] Step:: 10
Train Loss:: 0.46534526348114014 __________ Val Loss:: 0.9212509989738464
torch.Size([128, 50])
torch.Size([103, 50])
torch.Size([128, 50]

torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
Epoch:: [53]/[300] Step:: 10
Train Loss:: 0.351662814617157 __________ Val Loss:: 1.2099086046218872
torch.Size([128, 50])
torch.Size([103, 50])
torch.Size([128, 50])
Epoch:: [54]/[300] Step:: 0
Train Loss:: 0.2658579349517822 __________ Val Loss:: 1.2336333990097046
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
Epoch:: [54]/[300] Step:: 10
Train Loss:: 0.30260610580444336 __________ Val Loss:: 1.2447614669799805
torch.Size([128, 50])
torch.Size([103, 50])
torch.Size([128, 50])
Epoch:: [55]/[300] Step:: 0
Train Loss:: 0.24314193427562714 __________ Val Loss:: 1.2646712064743042
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])

torch.Size([128, 50])
Epoch:: [71]/[300] Step:: 0
Train Loss:: 0.17527984082698822 __________ Val Loss:: 1.6793403625488281
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
Epoch:: [71]/[300] Step:: 10
Train Loss:: 0.15432222187519073 __________ Val Loss:: 1.6522059440612793
torch.Size([128, 50])
torch.Size([103, 50])
torch.Size([128, 50])
Epoch:: [72]/[300] Step:: 0
Train Loss:: 0.194455087184906 __________ Val Loss:: 1.6715832948684692
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
Epoch:: [72]/[300] Step:: 10
Train Loss:: 0.1667330414056778 __________ Val Loss:: 1.7123708724975586
torch.Size([128, 50])
torch.Size([103, 50])
torch.Size([128, 50])

torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
Epoch:: [88]/[300] Step:: 10
Train Loss:: 0.10393884032964706 __________ Val Loss:: 2.1086788177490234
torch.Size([128, 50])
torch.Size([103, 50])
torch.Size([128, 50])
Epoch:: [89]/[300] Step:: 0
Train Loss:: 0.1175660714507103 __________ Val Loss:: 2.1647088527679443
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
Epoch:: [89]/[300] Step:: 10
Train Loss:: 0.14346961677074432 __________ Val Loss:: 2.1337037086486816
torch.Size([128, 50])
torch.Size([103, 50])
torch.Size([128, 50])
Epoch:: [90]/[300] Step:: 0
Train Loss:: 0.1042332872748375 __________ Val Loss:: 2.1503145694732666
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50]

torch.Size([103, 50])
torch.Size([128, 50])
Epoch:: [106]/[300] Step:: 0
Train Loss:: 0.0795929804444313 __________ Val Loss:: 2.465742826461792
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
Epoch:: [106]/[300] Step:: 10
Train Loss:: 0.06934485584497452 __________ Val Loss:: 2.6405935287475586
torch.Size([128, 50])
torch.Size([103, 50])
torch.Size([128, 50])
Epoch:: [107]/[300] Step:: 0
Train Loss:: 0.07043678313493729 __________ Val Loss:: 2.460073232650757
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
Epoch:: [107]/[300] Step:: 10
Train Loss:: 0.11226954311132431 __________ Val Loss:: 2.694276809692383
torch.Size([128, 50])
torch.Size([103, 5

torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
Epoch:: [123]/[300] Step:: 10
Train Loss:: 0.08279520273208618 __________ Val Loss:: 2.907627820968628
torch.Size([128, 50])
torch.Size([103, 50])
torch.Size([128, 50])
Epoch:: [124]/[300] Step:: 0
Train Loss:: 0.06714147329330444 __________ Val Loss:: 2.727725028991699
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
Epoch:: [124]/[300] Step:: 10
Train Loss:: 0.05699387192726135 __________ Val Loss:: 2.7917091846466064
torch.Size([128, 50])
torch.Size([103, 50])
torch.Size([128, 50])
Epoch:: [125]/[300] Step:: 0
Train Loss:: 0.07005968689918518 __________ Val Loss:: 2.771411657333374
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 

torch.Size([128, 50])
Epoch:: [141]/[300] Step:: 0
Train Loss:: 0.04048476740717888 __________ Val Loss:: 3.0465736389160156
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
Epoch:: [141]/[300] Step:: 10
Train Loss:: 0.05770348384976387 __________ Val Loss:: 3.184488296508789
torch.Size([128, 50])
torch.Size([103, 50])
torch.Size([128, 50])
Epoch:: [142]/[300] Step:: 0
Train Loss:: 0.04132966697216034 __________ Val Loss:: 3.169886589050293
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
Epoch:: [142]/[300] Step:: 10
Train Loss:: 0.043993785977363586 __________ Val Loss:: 3.125291109085083
torch.Size([128, 50])
torch.Size([103, 50])
torch.Size([128,

torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
Epoch:: [158]/[300] Step:: 10
Train Loss:: 0.019738174974918365 __________ Val Loss:: 3.411634683609009
torch.Size([128, 50])
torch.Size([103, 50])
torch.Size([128, 50])
Epoch:: [159]/[300] Step:: 0
Train Loss:: 0.024650853127241135 __________ Val Loss:: 3.322239875793457
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
Epoch:: [159]/[300] Step:: 10
Train Loss:: 0.02515656314790249 __________ Val Loss:: 3.3776466846466064
torch.Size([128, 50])
torch.Size([103, 50])
torch.Size([128, 50])
Epoch:: [160]/[300] Step:: 0
Train Loss:: 0.030780071392655373 __________ Val Loss:: 3.3716766834259033
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([1

torch.Size([103, 50])
torch.Size([128, 50])
Epoch:: [176]/[300] Step:: 0
Train Loss:: 0.016986314207315445 __________ Val Loss:: 3.686530351638794
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
Epoch:: [176]/[300] Step:: 10
Train Loss:: 0.027905762195587158 __________ Val Loss:: 3.8391971588134766
torch.Size([128, 50])
torch.Size([103, 50])
torch.Size([128, 50])
Epoch:: [177]/[300] Step:: 0
Train Loss:: 0.02261264994740486 __________ Val Loss:: 3.685045003890991
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
Epoch:: [177]/[300] Step:: 10
Train Loss:: 0.02637769840657711 __________ Val Loss:: 3.7206408977508545
torch.Size([128, 50])
torch.Size([10

torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
Epoch:: [193]/[300] Step:: 10
Train Loss:: 0.019829915836453438 __________ Val Loss:: 3.767313003540039
torch.Size([128, 50])
torch.Size([103, 50])
torch.Size([128, 50])
Epoch:: [194]/[300] Step:: 0
Train Loss:: 0.03110303357243538 __________ Val Loss:: 3.6566030979156494
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
Epoch:: [194]/[300] Step:: 10
Train Loss:: 0.0255808774381876 __________ Val Loss:: 3.721571683883667
torch.Size([128, 50])
torch.Size([103, 50])
torch.Size([128, 50])
Epoch:: [195]/[300] Step:: 0
Train Loss:: 0.023585090413689613 __________ Val Loss:: 3.716283082962036
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128,

torch.Size([128, 50])
Epoch:: [211]/[300] Step:: 0
Train Loss:: 0.019756663590669632 __________ Val Loss:: 4.065723896026611
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
Epoch:: [211]/[300] Step:: 10
Train Loss:: 0.011909584514796734 __________ Val Loss:: 3.999823570251465
torch.Size([128, 50])
torch.Size([103, 50])
torch.Size([128, 50])
Epoch:: [212]/[300] Step:: 0
Train Loss:: 0.01217583566904068 __________ Val Loss:: 4.019618511199951
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
Epoch:: [212]/[300] Step:: 10
Train Loss:: 0.020386217162013054 __________ Val Loss:: 4.050903797149658
torch.Size([128, 50])
torch.Size([103, 50])
torch.Size([128

torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
Epoch:: [228]/[300] Step:: 10
Train Loss:: 0.10241381824016571 __________ Val Loss:: 2.776186227798462
torch.Size([128, 50])
torch.Size([103, 50])
torch.Size([128, 50])
Epoch:: [229]/[300] Step:: 0
Train Loss:: 0.07889308035373688 __________ Val Loss:: 2.8038699626922607
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
Epoch:: [229]/[300] Step:: 10
Train Loss:: 0.06286415457725525 __________ Val Loss:: 2.6884586811065674
torch.Size([128, 50])
torch.Size([103, 50])
torch.Size([128, 50])
Epoch:: [230]/[300] Step:: 0
Train Loss:: 0.029106339439749718 __________ Val Loss:: 2.7992231845855713
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([12

Epoch:: [246]/[300] Step:: 0
Train Loss:: 0.007248890586197376 __________ Val Loss:: 3.9294540882110596
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
Epoch:: [246]/[300] Step:: 10
Train Loss:: 0.003712325356900692 __________ Val Loss:: 3.84932541847229
torch.Size([128, 50])
torch.Size([103, 50])
torch.Size([128, 50])
Epoch:: [247]/[300] Step:: 0
Train Loss:: 0.010500626638531685 __________ Val Loss:: 3.877917528152466
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
Epoch:: [247]/[300] Step:: 10
Train Loss:: 0.0034918312449008226 __________ Val Loss:: 3.855804681777954
torch.Size([128, 50])
torch.Size([103, 50])
torch.Size([128, 50])
Epoch:: [248]

Epoch:: [263]/[300] Step:: 10
Train Loss:: 0.0027230146806687117 __________ Val Loss:: 4.110350131988525
torch.Size([128, 50])
torch.Size([103, 50])
torch.Size([128, 50])
Epoch:: [264]/[300] Step:: 0
Train Loss:: 0.005441031884402037 __________ Val Loss:: 4.207213878631592
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
Epoch:: [264]/[300] Step:: 10
Train Loss:: 0.01432106364518404 __________ Val Loss:: 4.176031112670898
torch.Size([128, 50])
torch.Size([103, 50])
torch.Size([128, 50])
Epoch:: [265]/[300] Step:: 0
Train Loss:: 0.006078512407839298 __________ Val Loss:: 4.2125420570373535
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
Epoch:: [265]

torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
Epoch:: [281]/[300] Step:: 10
Train Loss:: 0.002012859797105193 __________ Val Loss:: 4.424371242523193
torch.Size([128, 50])
torch.Size([103, 50])
torch.Size([128, 50])
Epoch:: [282]/[300] Step:: 0
Train Loss:: 0.015118706971406937 __________ Val Loss:: 4.462736129760742
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
Epoch:: [282]/[300] Step:: 10
Train Loss:: 0.01703474670648575 __________ Val Loss:: 4.5326738357543945
torch.Size([128, 50])
torch.Size([103, 50])
torch.Size([128, 50])
Epoch:: [283]/[300] Step:: 0
Train Loss:: 0.010606308467686176 __________ Val Loss:: 4.353391170501709
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([12

torch.Size([128, 50])
Epoch:: [299]/[300] Step:: 0
Train Loss:: 0.006138088181614876 __________ Val Loss:: 4.295166015625
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
Epoch:: [299]/[300] Step:: 10
Train Loss:: 0.010249102488160133 __________ Val Loss:: 4.428346157073975
torch.Size([128, 50])
torch.Size([103, 50])
torch.Size([128, 50])
Epoch:: [300]/[300] Step:: 0
Train Loss:: 0.010336766950786114 __________ Val Loss:: 4.4639129638671875
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128, 50])
Epoch:: [300]/[300] Step:: 10
Train Loss:: 0.0034600638318806887 __________ Val Loss:: 4.433796405792236
torch.Size([128, 50])
torch.Size([103, 50])


In [145]:
def accuracy(pred:torch.Tensor, true: torch.Tensor) -> torch.Tensor:

    
    non_padded_pred = (pred != torch.tensor(17)).nonzero()
    non_padded_true = (true != torch.tensor(17)).nonzero()
    
    correct = pred[non_padded_pred].eq(true[non_padded_true])
    return correct.sum() / true.size()[1]

acc = []
model.eval()
with torch.no_grad():
    for batch in tqdm(train_loader):
        words = batch["words"].to(device)
        tags = batch["tags"].long().to(device)
        
        
        logits = model(words)
        
        preds = torch.max(logits, dim=-1).indices
        
        
        a = accuracy(preds, tags)
#         print(a)
        
        acc.append(a)

print(f"Mean acc : {torch.mean(torch.tensor(acc))}")

  0%|          | 0/13 [00:00<?, ?it/s]

torch.Size([128, 50])


RuntimeError: The size of tensor a (6400) must match the size of tensor b (1891) at non-singleton dimension 0