In [1]:
import sys
sys.path.append("nlp_project")

import numpy as np
import pandas as pd
import torch

from transformers import BertTokenizer, BertModel

from scripts.read_write_data import read_processed_data

KernelInterrupted: Execution interrupted by the Jupyter kernel.

In [2]:
TRAIN_PATH = "nlp_project/data/processed/train.conll"
DEV_PATH = "nlp_project/data/processed/dev.conll"
TEST_PATH = "nlp_project/data/processed/test.conll"

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
bert_tokenizer = BertTokenizer.from_pretrained("bert-base-cased", do_lower_case=False)
bert_model = BertModel.from_pretrained("bert-base-cased")

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [4]:
sentences = []
targets = []
for words, labels, _, _ in read_processed_data(TRAIN_PATH):
    sentences.append(words)
    targets.append(list(map(int, labels)))

In [5]:
print(sentences[0])
print(targets[0])

['My', 'dad', 'just', 'does', "n't", 'understand', '?']
[0, 0, 0, 0, 0, 0, 0]


In [6]:
from typing import List

def add_bert_tags(documents: List[List[str]], doc_tags: List[List[str]]):
    updated_docs = [['[CLS]', *doc, '[SEP]'] for doc in documents]
    updated_tags = [[0, *tags, 
] for tags in doc_tags]  # normally this would be <pad>, *tags, <pad>

    return updated_docs, updated_tags

bert_sentences, bert_targets = add_bert_tags(sentences, targets)

print(bert_sentences[0])
print(bert_targets[0])

['[CLS]', 'My', 'dad', 'just', 'does', "n't", 'understand', '?', '[SEP]']
[0, 0, 0, 0, 0, 0, 0, 0]


### Data Loader

In [7]:
from torch.utils import data

class BertDataset(data.Dataset):

    def __init__(self, documents: List[List[str]], doc_tags: List[List[str]]):
        self.docs = documents
        self.doc_tags = doc_tags

    def __len__(self):
        return len(self.docs)

    def __getitem__(self, idx: int):
        sentence, tags = self.docs[idx], self.doc_tags[idx]

        enc_sentence = []
        enc_tags = []
        is_heads = []
        for word, tag in zip(sentence, tags):
            tokens = bert_tokenizer.tokenize(word)
            ids = bert_tokenizer.convert_tokens_to_ids(tokens)

            tag = [tag] + [0] * (len(tokens) - 1)  # target label should only be assigned to head
            is_head = [1] + [0] * (len(tokens) - 1)  # pay attention only to the first part of the word

            enc_sentence += ids
            enc_tags += tag
            is_heads += is_head
        
        sentence = " ".join(sentence)
        return sentence, enc_sentence, enc_tags

dataset_test = BertDataset([bert_sentences[0]], [bert_targets[0]])
dataset_test[0]

("[CLS] My dad just does n't understand ? [SEP]",
 [101, 1422, 4153, 1198, 1674, 183, 112, 189, 2437, 136],
 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

In [8]:
def pad_batch(batch):
    batch_t = list(map(list, zip(*batch)))  # transpose
    sentences, enc_sents, enc_targets = batch_t
    max_len = max((map(len, enc_sents)))
    
    pad_sents = []
    pad_targets = []
    for enc_sent, enc_target in zip(enc_sents, enc_targets):
        pad_sents.append(enc_sent + [0] * (max_len - len(enc_sent)))
        pad_targets.append(enc_target + [0] * (max_len - len(enc_target)))
    
    pad_sents = torch.LongTensor(pad_sents)
    pad_targets = torch.LongTensor(pad_targets)

    return sentences, pad_sents, pad_targets

### Model

In [9]:
# N - number of samples, T - number of tokens, EMB - embedding dimension
bert_model(torch.LongTensor([dataset_test[0][1]]))[0].shape  # (N, T, EMB)

torch.Size([1, 10, 768])

In [10]:
import torch.nn.functional as F

torch.manual_seed(1)

# TODO: come up with an original name for the model lol
class SomeModel(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(768, out_features=2)
        
    def forward(self, inputs):
        """
        inputs: (N, T)
        """
        embeds = bert_model(inputs)[0]   # (N, T, 768)
        preds = self.linear(embeds)  # (N, T, 2)
        log_probs = F.softmax(preds, dim=2)
        return log_probs

In [11]:
def train(model, iterator, optimizer, criterion):          
    model.train()

    for i, batch in enumerate(iterator):
        sent, enc_sent, enc_targets = batch

        # predictions
        pred_tags = model(inputs=enc_sent) # (N, T, 2)
        pred_tags = pred_tags.view(-1, 2) # (N x T, 2)
        
        # true label for each word
        targets = enc_targets.flatten()  # (N x T)

        # loss
        batch_loss = criterion(pred_tags, targets)
        
        # optimization
        batch_loss.backward()
        optimizer.step()
        optimizer.zero_grad()
                        
        if i % 10 == 0:
            print(f"Batch {i} loss: {batch_loss:.4f}")

### Training

In [12]:
import torch.nn as nn
from torch.optim import Adam


model = SomeModel()
model.to(device)

train_dataset = BertDataset(
    documents=bert_sentences,
    doc_tags=bert_targets    
)

train_iter = data.DataLoader(
    dataset=train_dataset,
    batch_size=32,
    shuffle=True,
    collate_fn=pad_batch
)

optimizer = Adam(model.parameters(), lr=0.001)

criterion = nn.CrossEntropyLoss(ignore_index=0)

In [13]:
train(model, train_iter, optimizer, criterion)

NameError: name 'F' is not defined

<a style='text-decoration:none;line-height:16px;display:flex;color:#5B5B62;padding:10px;justify-content:end;' href='https://deepnote.com?utm_source=created-in-deepnote-cell&projectId=b2f14aee-af04-4db5-af55-57a3a58b9f40' target="_blank">
 </img>
Created in <span style='font-weight:600;margin-left:4px;'>Deepnote</span></a>