In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
! unzip /content/drive/MyDrive/nlp/lab2/lab2.zip -d /content

Archive:  /content/drive/MyDrive/nlp/lab2/lab2.zip
   creating: /content/lab2/
  inflating: /content/lab2/lab2.ipynb  
   creating: /content/lab2/tb_logs/
   creating: /content/lab2/tb_logs/ner_model/
   creating: /content/lab2/tb_logs/ner_model/version_0/
  inflating: /content/lab2/tb_logs/ner_model/version_0/events.out.tfevents.1699636469.ilyusha-Aspire-A515-43.27272.5  
  inflating: /content/lab2/tb_logs/ner_model/version_0/events.out.tfevents.1699636249.ilyusha-Aspire-A515-43.27272.2  
  inflating: /content/lab2/tb_logs/ner_model/version_0/events.out.tfevents.1699636402.ilyusha-Aspire-A515-43.27272.3  
  inflating: /content/lab2/tb_logs/ner_model/version_0/events.out.tfevents.1699636463.ilyusha-Aspire-A515-43.27272.4  
   creating: /content/lab2/tb_logs/ner_model/version_0/checkpoints/
  inflating: /content/lab2/tb_logs/ner_model/version_0/checkpoints/epoch=3-step=189372.ckpt  
 extracting: /content/lab2/tb_logs/ner_model/version_0/hparams.yaml  
  inflating: /content/lab2/tb_logs/

In [3]:
! pip install lightning torchmetrics

Collecting lightning
  Downloading lightning-2.1.2-py3-none-any.whl (2.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m36.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torchmetrics
  Downloading torchmetrics-1.2.1-py3-none-any.whl (806 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m806.1/806.1 kB[0m [31m62.0 MB/s[0m eta [36m0:00:00[0m
Collecting lightning-utilities<2.0,>=0.8.0 (from lightning)
  Downloading lightning_utilities-0.10.0-py3-none-any.whl (24 kB)
Collecting pytorch-lightning (from lightning)
  Downloading pytorch_lightning-2.1.2-py3-none-any.whl (776 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m776.9/776.9 kB[0m [31m60.0 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: lightning-utilities, torchmetrics, pytorch-lightning, lightning
Successfully installed lightning-2.1.2 lightning-utilities-0.10.0 pytorch-lightning-2.1.2 torchmetrics-1.2.1


In [4]:
import string
import numpy as np
import pandas as pd
from pprint import pprint

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from lightning.pytorch.loggers import TensorBoardLogger
import pytorch_lightning as pl

from torchmetrics import MetricCollection
from torchmetrics.classification import MulticlassAccuracy, MulticlassPrecision, MulticlassRecall

from gensim.models import KeyedVectors
from gensim.models import FastText

In [5]:
def parse_conllu(filename, size: int = None):
    lines = open(filename, 'r')
    parsed_data = []
    iter = 0
    for line in lines:
        iter += 1
        if size and size == iter:
            return parsed_data
        line = line.split("\n")[0]
        if line != "":
            splitted = line.split("\t")
            if len(splitted) != 3:
                continue
            parsed_data.append(splitted[1:])
    return parsed_data

In [7]:
# train_data = parse_conllu("dataset/train.conllu")
# sentences = train_data[:, 0]
# fasttext_model = FastText(sentences=sentences,
#                           epochs=5,
#                           vector_size=100,
#                           window=5,
#                           min_count=1,
#                           sg=1)
fasttext_wv = KeyedVectors.load("lab2/geowac_lemmas_none_fasttextskipgram_300_5_2020/model.model")

In [8]:
print(fasttext_wv.vector_size)

300


In [9]:
class StrangeCONLLDataset(Dataset):
    def __init__(self, raw_data):
        self.data = raw_data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample, label = self.data[idx]
        if label in ['B-LOC', 'I-LOC']:
            label = 0
        elif label in ['B-PER', 'I-PER']:
            label = 1
        elif label in ['B-ORG', 'I-ORG']:
            label = 2
        elif label in ['B-MISC', 'I-MISC']:
            label = 3
        else:
            label = 4
        return sample, label


In [49]:
params = {
    "num_classes": 5,
    "hidden_dim": 100,
    "learning_rate": 3e-4,
    "max_epochs": 200,
    "batch_size": 256,
    "num_workers": 2,
    "train_set_size": 10000,
    "valid_set_size": 2000,
    "test_set_size": 2000
}

In [45]:
train_dataset = StrangeCONLLDataset(parse_conllu("lab2/dataset/train.conllu", params["train_set_size"]))
val_dataset = StrangeCONLLDataset(parse_conllu("lab2/dataset/val.conllu", params["valid_set_size"]))
test_dataset = StrangeCONLLDataset(parse_conllu("lab2/dataset/test.conllu", params["test_set_size"]))

dataloader = {
    "train": DataLoader(train_dataset, batch_size=params["batch_size"], shuffle=True, num_workers=params["num_workers"]),
    "val": DataLoader(val_dataset, batch_size=params["batch_size"], shuffle=False, num_workers=params["num_workers"]),
    "test": DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=params["num_workers"])
}

In [50]:
class NERModel(pl.LightningModule):
    def __init__(self, word_vector, hidden_dim, output_dim):
        super(NERModel, self).__init__()
        self.word_vector = word_vector
        self.rnn = nn.RNN(word_vector.vector_size, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, output_dim)

        self.metrics = MetricCollection([
            MulticlassAccuracy(num_classes=params["num_classes"], average="micro"),
            MulticlassPrecision(num_classes=params["num_classes"], average="macro"),
            MulticlassRecall(num_classes=params["num_classes"], average="macro")
        ])

    def forward(self, word):
        embedded = torch.tensor(self.word_vector[word]).to(self.device)
        output, _ = self.rnn(embedded)
        predictions = self.fc(output)
        return predictions.softmax(dim=1)

    def shared_step(self, batch, batch_idx, stage):
        words, targets = batch
        predictions = self(words)
        loss = nn.CrossEntropyLoss()(predictions, targets)
        self.metrics(predictions, targets)
        self.log(f"{stage}_loss", loss)
        for metric_name in self.metrics:
          self.log(f"{stage}_{metric_name}", self.metrics[metric_name].compute())
        return loss

    def training_step(self, batch, batch_idx):
        return self.shared_step(batch, batch_idx, "train")

    def validation_step(self, batch, batch_idx):
        return self.shared_step(batch, batch_idx, "valid")

    def test_step(self, batch, batch_idx):
        return self.shared_step(batch, batch_idx, "test")

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=params["learning_rate"])
        return optimizer


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ner_model = NERModel(fasttext_wv, params["hidden_dim"], params["num_classes"])

In [13]:
tensorboard_writer = TensorBoardLogger("/content/drive/MyDrive/nlp/tb_logs/tb_logs", name="ner_model")

In [30]:
print(ner_model.device)

cpu


In [51]:
trainer = pl.Trainer(accelerator="gpu",
                     max_epochs=params["max_epochs"],
                     logger=tensorboard_writer)
trainer.fit(ner_model,
            train_dataloaders=dataloader["train"],
            val_dataloaders=dataloader["val"])

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name    | Type             | Params
---------------------------------------------
0 | rnn     | RNN              | 40.2 K
1 | fc      | Linear           | 505   
2 | metrics | MetricCollection | 0     
---------------------------------------------
40.7 K    Trainable params
0         Non-trainable params
40.7 K    Total params
0.163     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=200` reached.


In [48]:
test_metrics = trainer.test(ner_model, dataloaders=dataloader["test"], verbose=False)
pprint(test_metrics)

INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |          | 0/? [00:00<?, ?it/s]

[{'test_MulticlassAccuracy': 0.885495126247406,
  'test_MulticlassPrecision': 0.1989288032054901,
  'test_MulticlassRecall': 0.200042724609375,
  'test_loss': 1.043349027633667}]


In [None]:
actual_labels = ["Location", "Person", "Organization", "Miscellaneous", "Other"]
for i, (word, label) in enumerate(dataloader["test"]):
    predicted = ner_model(word)
    predicted_label = torch.argmax(predicted)
    print(f"Word: {word[0]}: predicted label: \"{actual_labels[predicted_label]}\" with probability: {predicted[0][predicted_label]}. Expected label: \"{actual_labels[label]}\"")
    if i == 100:
        break

In [None]:
tensorboard_writer.close()