In [1]:
!pip install transformers



In [2]:
import os
import time
import random
import numpy as np
import pandas as pd

from sklearn.metrics import roc_curve, auc

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from transformers import DistilBertForSequenceClassification
from transformers import AdamW
from transformers import DistilBertTokenizerFast

In [3]:
RANDOM_SEED = 2020
torch.manual_seed(RANDOM_SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(RANDOM_SEED)
random.seed(RANDOM_SEED)

DATA_PATH = "/content/"

## Dataset

In [4]:
class CustomDataset(Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)

In [5]:
tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')

train_df = pd.read_csv(os.path.join(DATA_PATH, "sat_train.tsv"), sep="\t")
valid_df = pd.read_csv(os.path.join(DATA_PATH, "sat_valid.tsv"), sep="\t")
test_df = pd.read_csv(os.path.join(DATA_PATH, "sat_test.tsv"), sep="\t")

train_encodings = tokenizer(train_df["context"].values.tolist(), truncation=True, padding=True)
valid_encodings = tokenizer(valid_df["context"].values.tolist(), truncation=True, padding=True)
test_encodings = tokenizer(test_df["context"].values.tolist(), truncation=True, padding=True)

train_dataset = CustomDataset(train_encodings, train_df["label"].values)
valid_dataset = CustomDataset(valid_encodings, valid_df["label"].values)
test_dataset = CustomDataset(test_encodings, test_df["label"].values)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=8, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False)

## Test function

In [6]:
def train(model: nn.Module, loader: DataLoader, optimizer: torch.optim.Optimizer, device: str):
    model.train()

    epoch_loss = 0

    for _, batch in enumerate(loader):
        optimizer.zero_grad()

        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs[0]

        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()


    return epoch_loss / len(loader)


def evaluate(model: nn.Module, loader: DataLoader, device: str):
    model.eval()
    epoch_loss = 0
    with torch.no_grad():
        for _, batch in enumerate(loader):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs[0]
            epoch_loss += loss.item()

    return epoch_loss / len(loader)


def test(
    model: nn.Module,
    loader: DataLoader
):

    with torch.no_grad():
        y_real = []
        y_pred = []
        model.eval()

        for batch in loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            output = model(input_ids, attention_mask=attention_mask)[0]
            y_pred += [output]
            y_real += [batch["labels"]]
            
        y_real = torch.cat(y_real)
        y_pred = torch.cat(y_pred)[:,1]

    fpr, tpr, _ = roc_curve(y_real, y_pred)
    auroc = auc(fpr, tpr)

    return auroc


def epoch_time(start_time: int, end_time: int):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [7]:
def demo_bert(model):
    sat_test = [ 
        "Speculations about the meaning and purpose of prehistoric art [rely] heavily on analogies drawn with modern-day hunter-gatherer societies.",
        "Such primitive societies, [as] Steven Mithen emphasizes in The Prehistory of the Modern Mind, tend to view man and beast, animal and plant, organic and inorganic spheres, as participants in an integrated, animated totality.",
        "The dual expressions of this tendency are anthropomorphism (the practice of regarding animals as humans) and totemism (the practice of regarding humans as animals), both of [which] spread through the visual art and the mythology of primitive cultures.",
        "When considered in this light, the visual preoccupation of early humans with the nonhuman creatures [inhabited] their world becomes profoundly meaningful.",
        "In the practice of totemism, he has suggested, an unlettered humanity “broods upon [itself] and its place in nature.”",
    ]
    sat_label = [1, 1, 1, 0, 1]
    sat_encodings = [tokenizer(sentence) for sentence in sat_test]
    with torch.no_grad():
        outputs = []
        for sat_encoding in sat_encodings:
            input_ids = torch.LongTensor([sat_encoding["input_ids"]])
            attention_mask = torch.LongTensor(sat_encoding["attention_mask"])
            output = model(input_ids, attention_mask=attention_mask)
            outputs += [output[0]]
        output = torch.cat(outputs)[:,1]
    return output.tolist()

## Before fine-tuning

In [8]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased')
_ = model.to(device)

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_transform.weight', 'vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_layer_norm.bias', 'vocab_projector.weight', 'vocab_projector.bias']
- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.weight', 'pre_classifier.bias', 'classi

In [9]:
test_auroc = test(model, test_loader)
print(f'| SAT Dataset Test AUROC: {test_auroc:.5f}')

| SAT Dataset Test AUROC: 0.38889


In [10]:
demo_bert(model)

[0.14826509356498718,
 0.16729339957237244,
 0.13443192839622498,
 0.15767356753349304,
 0.12004546821117401]

## Fine tuning

In [11]:
N_EPOCHS = 5
optimizer = AdamW(model.parameters(), lr=5e-5)


for epoch in range(N_EPOCHS):

    start_time = time.time()

    train_loss = train(model, train_loader, optimizer, device)
    valid_loss = evaluate(model, valid_loader, device)

    end_time = time.time()

    epoch_mins, epoch_secs = epoch_time(start_time, end_time)

    print(f"Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s")
    print(f"\tTrain Loss: {train_loss:.5f}")
    print(f"\t Val. Loss: {valid_loss:.5f}")

test_auroc = test(model, test_loader)
print(f'| Test AUROC: {test_auroc:.5f}')    

Epoch: 01 | Time: 1m 7s
	Train Loss: 0.53139
	 Val. Loss: 0.45607
Epoch: 02 | Time: 1m 7s
	Train Loss: 0.50950
	 Val. Loss: 0.44636
Epoch: 03 | Time: 1m 7s
	Train Loss: 0.41529
	 Val. Loss: 0.51411
Epoch: 04 | Time: 1m 7s
	Train Loss: 0.31879
	 Val. Loss: 0.44216
Epoch: 05 | Time: 1m 7s
	Train Loss: 0.23448
	 Val. Loss: 0.46200
| Test AUROC: 0.80303


In [12]:
test_auroc = test(model, test_loader)
print(f'| SAT Dataset Test AUROC: {test_auroc:.5f}')

| SAT Dataset Test AUROC: 0.80303


In [13]:
demo_bert(model)

[2.0104339122772217,
 0.9977397322654724,
 2.1262519359588623,
 1.8164923191070557,
 1.252334475517273]