In [17]:
from torch.utils.data import Dataset
import torch
import pandas as pd

from transformers import BertTokenizer, BertForSequenceClassification
import numpy as np
from torch.utils.data import DataLoader
from torch import nn
from sklearn.metrics import accuracy_score
from transformers import AdamW
import time
from typing import Dict, Any
import tqdm
import os

In [4]:
class NorecDataset(Dataset):
    def __init__(self, path):
        df = pd.read_csv(path, sep=',')
        df.columns = ["Sentiment", "Document"]
        self.documents = list(df['Document'])
        self.labels = list(df['Sentiment'])

    def __getitem__(self, idx):
        return (self.documents[idx], torch.LongTensor([self.labels[idx]]))
    
    def __len__(self):
        return len(self.documents)

In [5]:
train_set = NorecDataset("./data/train.csv")
val_set = NorecDataset("./data/dev.csv")

In [16]:
def save_model(folder: str, name: str, epoch: int, model, train_time_sum, parameters) -> None:

    try:
        os.makedirs(folder)
    except OSError as e:
        if(epoch == 0):
            print("Folder already exist!")

    path = folder + name + ".pt"

    torch.save({
        "name": name,
        "epoch": epoch,
        "train_time_sum": train_time_sum,
        "parameters": parameters
    }, path)

    model_to_save = model.module if hasattr(model, 'module') else model
    model_to_save.save_pretrained(folder)

    print(f"Saved model to {folder}")

In [13]:
def train(train_set, val_set):
    tokenizer = BertTokenizer.from_pretrained("NbAiLab/nb-bert-base")
    model = BertForSequenceClassification.from_pretrained(
        "NbAiLab/nb-bert-base", num_labels=2)
    train_loader = DataLoader(
        train_set, batch_size=1, shuffle=True)
    val_loader = DataLoader(
        val_set, batch_size=1, shuffle=True)
    epochs = 2
    optimizer = AdamW(model.parameters(), lr=3e-5)
    
    if torch.cuda.is_available():
        dev = "cuda:0"
    else:
        dev = "cpu"
        
    folder = "./models/"
    name = "test"

    best_val_loss = 999999
    train_time_sum = 0
    model = model.to(dev)
    checkpointing_enabled = True
    for epoch in range(0, epochs):
        print("\n")
        print(f"----- Starting epoch: {epoch} -----")
        start_epoch_time = time.time()
        total_train_loss = 0
        total_train_acc = 0
        total_val_loss = 0
        total_val_acc = 0
        training_passes = 0
        validation_passes = 0
        
        model.train()
        for i, batch in enumerate(tqdm.tqdm(train_loader)):
                optimizer.zero_grad()
                doc, y = batch
                doc = list(doc)
                X = tokenizer(doc, add_special_tokens=True, max_length=256, padding=True, return_attention_mask=True,
                              return_tensors="pt")
                input_ids = X["input_ids"]
                token_type_ids = X["token_type_ids"]
                attention_mask = X["attention_mask"]

                loss, prediction = model(input_ids=input_ids, token_type_ids=token_type_ids,
                                         attention_mask=attention_mask, labels=y.squeeze()).values()

                total_train_loss += loss.item()
                loss.backward()
                optimizer.step()        
                training_passes += 1
                
        
        for i, batch in enumerate(tqdm.tqdm(val_loader)):
                optimizer.zero_grad()
                doc, y = batch
                doc = list(doc)
                y = y.to(dev)
                X = tokenizer(doc, add_special_tokens=True, max_length=256, padding=True, return_attention_mask=True,
                              return_tensors="pt")
                input_ids = X["input_ids"].to(dev)
                token_type_ids = X["token_type_ids"].to(dev)
                attention_mask = X["attention_mask"].to(dev)

                loss, prediction = model(input_ids=input_ids, token_type_ids=token_type_ids,
                                         attention_mask=attention_mask, labels=y.squeeze()).values()

                total_val_loss += loss.item()
                loss.backward()
                optimizer.step()        
                validation_passes += 1

        
        end_epoch_time = time.time()
        epoch_time = end_epoch_time - start_epoch_time
        train_loss = total_train_loss/training_passes
        val_loss = total_val_loss/validation_passes
        
        print(f"Epoch: {epoch}. Training loss: {train_loss}, val loss: {validation_loss}, epoch time: {epoch_time}")
        
        
        if val_loss < best_val_loss and checkpointing_enabled:
            parameters = sum(
                p.numel() for p in model.parameters() if p.requires_grad)
            print(
                f"Val loss is lower than previous best which was at {best_val_loss}. Saving model, with number of parameters: {parameters}")
            save_model(folder, name, epoch, model, train_time_sum, parameters)
            best_val_loss = val_loss
    

In [11]:
train(train_set,val_set)

Some weights of the model checkpoint at NbAiLab/nb-bert-base were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.decoder.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initi



----- Starting epoch: 0 -----


  0%|          | 6/2674 [00:11<1:22:54,  1.86s/it]


KeyboardInterrupt: 