In [1]:
! pip install transformers

In [1]:
! rm -rf space-model
! git clone https://github.com/StepanTita/space-model.git
! git clone https://github.com/StepanTita/nano-BERT.git

In [2]:
import sys

sys.path.append('space-model')

In [27]:
import math
import json
from collections import Counter
import random
import os

import pandas as pd
import numpy as np

import torch
import torch.nn.functional as F

from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
from sklearn.model_selection import train_test_split

from tqdm import tqdm

import matplotlib.pyplot as plt

from transformers import AutoModel, AutoTokenizer, AutoModelForSequenceClassification

from space_model.model import SpaceModelForClassification, SpaceModel
from space_model.loss import *

In [4]:
SEED = 42

In [5]:
def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

seed_everything(seed=SEED)

In [6]:
def on_gpu(f):
    def wrapper():
        if torch.cuda.is_available():
            return f()
        else:
            print('cuda unavailable')
    return wrapper

In [7]:
if torch.cuda.is_available():
    ! pip install pynvml
    from pynvml import *
    from numba import cuda

@on_gpu
def print_gpu_utilization():
    try:
        nvmlInit()
        handle = nvmlDeviceGetHandleByIndex(0)
        info = nvmlDeviceGetMemoryInfo(handle)
        print(f"GPU memory occupied: {info.used//1024**2} MB.")
    except Exception as e:
        print(e)

@on_gpu
def free_gpu_cache():
    print("Initial GPU Usage")
    print_gpu_utilization()

    torch.cuda.empty_cache()

    print("GPU Usage after emptying the cache")
    print_gpu_utilization()

def print_summary(result):
    print(f"Time: {result.metrics['train_runtime']:.2f}")
    print(f"Samples/second: {result.metrics['train_samples_per_second']:.2f}")
    print_gpu_utilization()

In [8]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

In [9]:
data = None
with open('nano-BERT/data/imdb_train.json') as f:
    data = pd.DataFrame([json.loads(l) for l in f.readlines()])
data

In [10]:
test_data = None
with open('nano-BERT/data/imdb_test.json') as f:
   test_data = pd.DataFrame([json.loads(l) for l in f.readlines()])

In [11]:
def encode_label(label):
    if label == 'pos':
        return 1
    elif label == 'neg':
        return 0
    raise Exception(f'Unknown Label: {label}!')

class IMDBDataset(torch.utils.data.Dataset):
    def __init__(self, dataframe, tokenizer, label_encoder, max_seq_len, inference=False):
        self.len = len(dataframe)
        self.data = dataframe
        self.tokenizer = tokenizer
        self.label_encoder = label_encoder
        self.max_seq_len = max_seq_len
        self.inference = inference

    def __getitem__(self, index):
        sentence = ' '.join(self.data.text[index])

        if not self.inference:
            labels = self.label_encoder(self.data.label[index])

        inputs = tokenizer(sentence, max_length=self.max_seq_len, padding='max_length', truncation=True, return_tensors='pt')

        return {
            **{k: v.squeeze(0) for k, v in inputs.items()},
            'label_ids': torch.tensor(labels, dtype=torch.long) if not self.inference else []
        }

    def __len__(self):
        return self.len

In [41]:
MODEL_NAME = 'bert-base-uncased'

NUM_EPOCHS = 5
BATCH_SIZE = 256
MAX_SEQ_LEN = 256
LEARNING_RATE = 2e-4
MAX_GRAD_NORM = 1000

In [42]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer

In [43]:
train_dataset = IMDBDataset(data, tokenizer, encode_label, MAX_SEQ_LEN)
test_dataset = IMDBDataset(test_data, tokenizer, encode_label, MAX_SEQ_LEN)

In [44]:
cls_bert = AutoModelForSequenceClassification.from_pretrained(
    MODEL_NAME,
).to(device)

In [45]:
for param in cls_bert.bert.parameters():
    param.requires_grad_(False)

In [46]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [47]:
count_parameters(cls_bert)

In [48]:
ids = train_dataset[0]['input_ids'].unsqueeze(0)
mask = train_dataset[0]['attention_mask'].unsqueeze(0)
targets = train_dataset[0]['label_ids'].unsqueeze(0)

ids = ids.to(device)
mask = mask.to(device)
targets = targets.to(device)

outputs = cls_bert(input_ids=ids, attention_mask=mask, labels=targets)
initial_loss = outputs[0]
initial_loss

In [49]:
def eval(f):
    def wrapper(model, *args, **kwargs):
        model.eval()
        return f(model, *args, **kwargs)
    return wrapper

def train(f):
    def wrapper(model, *args, **kwargs):
        model.train()
        return f(model, *args, **kwargs)
    return wrapper

In [50]:
@train
def train_epoch(model, train_dataloader, optimizer):
    train_loss = 0.0
    train_preds = []
    train_labels = []

    for step, batch in enumerate(tqdm(train_dataloader, total=len(train_dataloader))):
        ids = batch['input_ids'].to(device, dtype = torch.long)
        mask = batch['attention_mask'].to(device, dtype = torch.long)
        targets = batch['label_ids'].to(device, dtype = torch.long)

        outputs = model(input_ids=ids, attention_mask=mask, labels=targets) # (B, Seq_Len, 2)

        loss, logits = outputs.loss, outputs.logits

        probs = F.softmax(logits, dim=-1).cpu()
        pred = torch.argmax(probs, dim=-1) # (B)
        train_preds += pred.detach().tolist()
        train_labels += [l.item() for l in targets.cpu()]

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
    return train_loss, train_preds, train_labels

@eval
def eval_epoch(model, val_dataloader):
    val_loss = 0.0
    val_preds = []
    val_labels = []

    with torch.no_grad():

        for step, batch in enumerate(tqdm(val_dataloader, total=len(val_dataloader))):
            ids = batch['input_ids'].to(device, dtype = torch.long)
            mask = batch['attention_mask'].to(device, dtype = torch.long)
            targets = batch['label_ids'].to(device, dtype = torch.long)

            outputs = model(input_ids=ids, attention_mask=mask, labels=targets)

            loss, logits = outputs.loss, outputs.logits

            probs = F.softmax(logits, dim=-1).cpu()
            pred = torch.argmax(probs, dim=-1) # (B)
            val_preds += pred.detach().tolist()
            val_labels += [l.item() for l in targets.cpu()]

            val_loss += loss.item()
    return val_loss, val_preds, val_labels

In [51]:
def training(model, train_data, val_data, config):
    model = model.to(device)

    optimizer = torch.optim.Adam(
        params=model.parameters(),
        lr=config['lr'],
        weight_decay=config['weight_decay']
    )

    num_train_steps = int(len(train_data) / config['batch_size'] * config['epochs'])

    print(f'Train steps: {num_train_steps}')

    train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=config['batch_size'], shuffle=True)
    val_dataloader = torch.utils.data.DataLoader(val_data, batch_size=2 * config['batch_size'])

    history = {
        'train_losses': [],
        'val_losses': [],
        'train_acc': [],
        'val_acc': [],
        'train_f1': [],
        'val_f1': [],
        'train_precision': [],
        'val_precision': [],
        'train_recall': [],
        'val_recall': [],
    }

    for epoch_num in range(config['epochs']):
        # train stage
        train_loss, train_preds, train_labels = train_epoch(model, train_dataloader, optimizer)

        # eval stage
        val_loss, val_preds, val_labels = eval_epoch(model, val_dataloader)
        
        # metrics
        train_acc = accuracy_score(train_labels, train_preds)
        val_acc = accuracy_score(val_labels, val_preds)
        train_f1 = f1_score(train_labels, train_preds, average='macro')
        val_f1 = f1_score(val_labels, val_preds, average='macro')
        train_precision = precision_score(train_labels, train_preds)
        val_precision = precision_score(val_labels, val_preds)
        train_recall = recall_score(train_labels, train_preds)
        val_recall = recall_score(val_labels, val_preds)

        history['train_losses'].append(train_loss)
        history['val_losses'].append(val_loss)
        history['train_acc'].append(train_acc)
        history['val_acc'].append(val_acc)
        history['train_f1'].append(train_f1)
        history['val_f1'].append(val_f1)
        history['train_precision'].append(train_precision)
        history['val_precision'].append(val_precision)
        history['train_recall'].append(train_recall)
        history['val_recall'].append(val_recall)

        print()
        print(f'Train loss: {train_loss / len(train_dataloader)} | Val loss: {val_loss / len(val_dataloader)}')
        print(f'Train acc: {train_acc} | Val acc: {val_acc}')
        print(f'Train f1: {train_f1} | Val f1: {val_f1}')
        print(f'Train precision: {train_precision} | Val precision: {val_precision}')
        print(f'Train recall: {train_recall} | Val recall: {val_recall}')
    return history

In [52]:
config = {
    'epochs': NUM_EPOCHS,
    'batch_size': BATCH_SIZE,
    'gradient_accumulation_steps': 0,
    'fp16': False,
    'lr': LEARNING_RATE,
    'max_grad_norm': MAX_GRAD_NORM,
    'weight_decay': 0.01,
}

In [53]:
history = training(cls_bert, train_dataset, test_dataset, config)

In [54]:
free_gpu_cache()

In [55]:
def plot_results(history, do_val=True):
    fig, ax = plt.subplots(figsize=(8, 8))

    x = list(range(0, len(history['train_losses'])))

    # loss

    ax.plot(x, history['train_losses'], label='train_loss')

    if do_val:
        ax.plot(x, history['val_losses'], label='val_loss')

    plt.title('Train / Validation Loss')
    plt.legend(loc='upper right')

    # accuracy
    
    if 'train_acc' in history:
        fig, ax = plt.subplots(figsize=(8, 8))

        ax.plot(x, history['train_acc'], label='train_acc')

        if do_val:
            ax.plot(x, history['val_acc'], label='val_acc')

    plt.title('Train / Validation Accuracy')
    plt.legend(loc='upper right')

    # f1-score
    
    if 'train_f1' in history:
        fig, ax = plt.subplots(figsize=(8, 8))

        ax.plot(x, history['train_f1'], label='train_f1')

        if do_val:
            ax.plot(x, history['val_f1'], label='val_f1')

        plt.title('Train / Validation F1')
        plt.legend(loc='upper right')
    
    # precision
    
    if 'train_precision' in history:
        fig, ax = plt.subplots(figsize=(8, 8))

        ax.plot(x, history['train_precision'], label='train_precision')

        if do_val:
            ax.plot(x, history['val_precision'], label='val_precision')

        plt.title('Train / Validation Precision')
        plt.legend(loc='upper right')
    
    # recall
    
    if 'train_recall' in history:
        fig, ax = plt.subplots(figsize=(8, 8))

        ax.plot(x, history['train_recall'], label='train_recall')

        if do_val:
            ax.plot(x, history['val_recall'], label='val_recall')

        plt.title('Train / Validation Recall')
        plt.legend(loc='upper right')

    fig.show()

In [56]:
plot_results(history)

In [57]:
class SpaceModelForSequenceClassificationOutput:
    def __init__(self, loss=None, logits=None, concept_spaces=None):
        self.loss = loss
        self.logits = logits
        self.concept_spaces = concept_spaces

In [58]:
class SpaceBertForSequenceClassification(torch.nn.Module):
    def __init__(self, base_model, n_embed=3, n_latent=3, n_concept_spaces=2, dropout=0.1, l1=1e-3, l2=1e-4, fine_tune=True):
        super().__init__()

        if fine_tune:
            for p in base_model.parameters():
                p.requires_grad_(False)

        self.bert = base_model

        self.space_model = SpaceModel(n_embed, n_latent, n_concept_spaces, output_concept_spaces=True)

        # self.ffwd = torch.nn.Sequential(
        #     torch.nn.Linear(n_concept_spaces * n_latent, 2 * n_concept_spaces * n_latent),
        #     torch.nn.GELU(),
        #     torch.nn.Linear(2 * n_concept_spaces * n_latent, n_concept_spaces * n_latent),
        #     torch.nn.Dropout(dropout),
        # )

        self.classifier = torch.nn.Linear(n_concept_spaces * n_latent, n_concept_spaces)

        self.l1 = l1
        self.l2 = l2

    def forward(self, input_ids, attention_mask, labels=None):
        embed = self.bert(input_ids, attention_mask).last_hidden_state

        out = self.space_model(embed)

        concept_hidden = out.logits

        logits = self.classifier(concept_hidden)

        loss = 0.0
        if labels is not None:
            loss = F.cross_entropy(logits, labels)
        loss += self.l1 * inter_space_loss(out.concept_spaces) + self.l2 * intra_space_loss(out.concept_spaces)

        return SpaceModelForSequenceClassificationOutput(loss, logits, out.concept_spaces)

In [59]:
bert = AutoModel.from_pretrained(MODEL_NAME)
bert

In [60]:
space_bert = SpaceBertForSequenceClassification(bert, n_embed=768, n_latent=3, n_concept_spaces=2)
space_bert

In [61]:
count_parameters(space_bert)

In [62]:
space_history = training(space_bert, train_dataset, test_dataset, config)

In [63]:
free_gpu_cache()

In [64]:
plot_results(space_history)

In [65]:
@train
def train_unsup_epoch(model, train_dataloader, optimizer):
    train_loss = 0.0

    for step, batch in enumerate(tqdm(train_dataloader, total=len(train_dataloader))):
        ids = batch['input_ids'].to(device, dtype = torch.long)
        mask = batch['attention_mask'].to(device, dtype = torch.long)

        outputs = model(input_ids=ids, attention_mask=mask) # (B, Seq_Len, 2)

        loss, logits = outputs.loss, outputs.logits

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
    return train_loss

@eval
def eval_unsup_epoch(model, val_dataloader):
    val_loss = 0.0

    with torch.no_grad():

        for step, batch in enumerate(tqdm(val_dataloader, total=len(val_dataloader))):
            ids = batch['input_ids'].to(device, dtype = torch.long)
            mask = batch['attention_mask'].to(device, dtype = torch.long)

            outputs = model(input_ids=ids, attention_mask=mask)

            loss, logits = outputs.loss, outputs.logits

            val_loss += loss.item()
    return val_loss

In [66]:
def unsup_training(model, train_data, val_data, config):
    model = model.to(device)

    optimizer = torch.optim.Adam(
        params=model.parameters(),
        lr=config['lr'],
        weight_decay=config['weight_decay']
    )

    num_train_steps = int(len(train_data) / config['batch_size'] * config['epochs'])

    print(f'Train steps: {num_train_steps}')

    train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=config['batch_size'], shuffle=True)
    val_dataloader = torch.utils.data.DataLoader(val_data, batch_size=2 * config['batch_size'])

    history = {
        'train_losses': [],
        'val_losses': [],
    }

    for epoch_num in range(config['epochs']):
        print(f'Epoch: {epoch_num + 1}')
        # train stage
        train_loss = train_unsup_epoch(model, train_dataloader, optimizer)

        # eval stage
        val_loss = eval_unsup_epoch(model, val_dataloader)

        history['train_losses'].append(train_loss)
        history['val_losses'].append(val_loss)

        print()
        print(f'Train loss: {train_loss / len(train_dataloader)} | Val loss: {val_loss / len(val_dataloader)}')
    return history

In [67]:
unsup_train, sup_train = train_test_split(data, random_state=SEED, test_size=0.2)

In [68]:
unsup_train_dataset = IMDBDataset(unsup_train.reset_index(drop=True), tokenizer, encode_label, MAX_SEQ_LEN)
sup_train_dataset = IMDBDataset(sup_train.reset_index(drop=True), tokenizer, encode_label, MAX_SEQ_LEN)

In [69]:
config['epochs'] = 2

In [70]:
unsup_history = unsup_training(space_bert, unsup_train_dataset, test_dataset, config)

In [71]:
free_gpu_cache()

In [72]:
plot_results(unsup_history)

In [73]:
config['epochs'] = 1

In [74]:
sup_history = training(space_bert, sup_train_dataset, test_dataset, config)

In [75]:
free_gpu_cache()

In [76]:
plot_results(sup_history)