In [None]:
import pandas as pd
import torch
import transformers
from torch.utils.data import Dataset, DataLoader
from transformers import DistilBertModel, DistilBertTokenizer
from torch import cuda

device = 'cuda' if cuda.is_available() else 'cpu'
print(f'Runs on {device}')

In [None]:
MAX_LEN = 512
TRAIN_BATCH_SIZE = 8
VALID_BATCH_SIZE = 20
EPOCHS = 5
LEARNING_RATE = 1e-05
THRESHOLD = 0.5
TRAIN_BACKBONE = True

In [None]:
!pip install wandb
import wandb

wandb.login()
wandb.init(
    # set the wandb project where this run will be logged
    project="wat-distilbert",
    
    # track hyperparameters and run metadata
    config={
        "learning_rate": LEARNING_RATE,
        "architecture": "distilbert",
        "epochs": EPOCHS,
        "train_distilbert": TRAIN_BACKBONE,
    }
)

In [None]:
id2label = {0: "Irrelevant", 1: "Relevant"}
label2id = {"Irrelevant": 0, "Relevant": 1}

df = pd.read_csv('../data/WaTA_dataset.csv', encoding = "ISO-8859-1")
df['Class'] = df['Class'].apply(label2id.get)
df.head(10)

In [None]:
count = df['Class'].value_counts()

print(f'Number of irrelevant sentences: {count[0]}')
print(f'Number of relevant sentences: {count[1]}')
print(f'Percentage of irrelevant: {count[0] / (count[0] + count[1])}')
print(f'Percentage of relevant: {count[1] / (count[0] + count[1])}')

In [None]:
def tokenize(sentence, tokenizer):
    inputs = tokenizer.encode_plus(
        sentence,
        None,
        add_special_tokens=True,
        max_length=MAX_LEN,
        padding='max_length',
        truncation=True,
        return_token_type_ids=True
    )
    ids = inputs['input_ids']
    mask = inputs['attention_mask']
    return torch.tensor(ids, dtype=torch.long), torch.tensor(mask, dtype=torch.long)

class RelevanceDataset(Dataset):
    def __init__(self, data, tokenizer):
        self.len = len(data)
        self.data = data
        self.tokenizer = tokenizer
        
    def __getitem__(self, index):
        sentence = self.data.Sentence[index]
        
        ids, mask = tokenize(sentence, self.tokenizer)

        return {
            'ids': ids,
            'mask': mask,
            'sentence': sentence,
            'targets': torch.tensor(self.data.Class[index], dtype=torch.float)
        } 
    
    def __len__(self):
        return self.len

In [None]:
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased')

train_size = 0.9
train_dataset = df.sample(frac=train_size,random_state=200)
test_dataset = df.drop(train_dataset.index).reset_index(drop=True)
train_dataset = train_dataset.reset_index(drop=True)


print("FULL Dataset: {}".format(df.shape))
print("TRAIN Dataset: {}".format(train_dataset.shape))
print("TEST Dataset: {}".format(test_dataset.shape))

training_set = RelevanceDataset(train_dataset, tokenizer)
testing_set = RelevanceDataset(test_dataset, tokenizer)

In [None]:
train_params = {
    'batch_size': TRAIN_BATCH_SIZE,
    'shuffle': True,
    'num_workers': 0
}

test_params = {
    'batch_size': VALID_BATCH_SIZE,
    'shuffle': True,
    'num_workers': 0
}

training_loader = DataLoader(training_set, **train_params)
testing_loader = DataLoader(testing_set, **test_params)

In [None]:
class BinaryClassifier(torch.nn.Module):
    def __init__(self):
        super(BinaryClassifier, self).__init__()
        self.backbone = DistilBertModel.from_pretrained("distilbert-base-uncased")
        self.head = torch.nn.Sequential(
            torch.nn.Linear(768, 768),
            torch.nn.ReLU(),
            torch.nn.Linear(768, 1),
            torch.nn.Sigmoid()
        )

    def forward(self, input_ids, attention_mask):
        backbone_out = self.backbone(input_ids=input_ids, attention_mask=attention_mask)
        hidden_state = backbone_out[0]
        return self.head(hidden_state[:, 0])
    
    def unfreeze_backbone(self):
        for param in self.backbone.parameters():
            param.requires_grad = True
            
    def freeze_backbone(self):
        for param in self.backbone.parameters():
            param.requires_grad = False

In [None]:
model = BinaryClassifier()
model.to(device)

In [None]:
# Run this cell if you want to use pretrained classification model
# saved_model = wandb.restore('classifier.bin', run_path="wat-distilbert/ru3wl9xi")8p404q8d
saved_model = wandb.restore('classifier.bin', run_path="wat-distilbert/8p404q8d")
#saved_tokenizer = wandb.restore('tokenizer.bin', run_path="wat-distilbert/ru3wl9xi")
model = torch.load(saved_model.name)
model.to(device)
#tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased')
#loaded_tokenizer = loaded_tokenizer.load_vocabulary(saved_tokenizer.name)

In [None]:
loss_function = torch.nn.BCELoss()
optimizer = torch.optim.Adam(params=model.parameters(), lr=LEARNING_RATE)

In [None]:
def compute_accuracy(tp, tn, fp, fn):
    return (tp + tn) / (tp + tn + fp + fn)

def compute_precision(tp, fp):
    return tp / (tp + fp)

def compute_recall(tp, fn):
    return tp / (tp + fn)

def compute_f1(tp, fn, fp):
    return tp / (tp + (fn + fp) / 2)

def pred_to_class(pred, threshold=0.5):
    return (pred >= THRESHOLD).float() 

In [None]:
def test_example(model, tokenizer, nb_relevant=20, nb_irrelevant=20):
    for i in range(len(testing_set)):
        test_data = testing_set[i]
        ids, mask, sentence, target = test_data['ids'], test_data['mask'], test_data['sentence'], test_data["targets"]
        if target.item() == 1:
            if nb_relevant > 0:
                nb_relevant -= 1
            else:
                continue
        elif target.item() == 0:
            if nb_irrelevant > 0:
                nb_irrelevant -= 1
            else:
                continue
        ids = ids.to(device, dtype=torch.long)
        mask = mask.to(device, dtype=torch.long)
        outputs = model(ids, mask)
        pred_class = pred_to_class(outputs, THRESHOLD)
        print(sentence)
        print(f"pred: {id2label[pred_class.item()]}, target: {id2label[target.item()]}")
    
test_example(model, tokenizer)

In [None]:
# TRAINING

if TRAIN_BACKBONE:
    model.unfreeze_backbone()

for i in range(EPOCHS):
    total_loss = 0
    nb_steps = 0
    tp, fp, fn, tn = 0, 0, 0, 0 
    model.train()
    for j,data in enumerate(training_loader):
        ids = data['ids'].to(device, dtype = torch.long)
        mask = data['mask'].to(device, dtype = torch.long)
        targets = data['targets'].to(device, dtype = torch.float).reshape(-1, 1)

        outputs = model(ids, mask)
        loss = loss_function(outputs, targets)
        total_loss += loss.item()
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        pred_class = pred_to_class(outputs, THRESHOLD)
        t_preds, t_targets = pred_class == 1, targets == 1
        f_preds, f_targets = pred_class == 0, targets == 0
        tp += (t_preds & t_targets).sum().item()
        fp += (t_preds & f_targets).sum().item()
        fn += (f_preds & t_targets).sum().item()
        tn += (f_preds & f_targets).sum().item()
        
        nb_steps += 1
        
        if j > 0 and j % 500 == 0:
            print(f"Training Loss per 500 steps: {loss.item()}")
            print(f"Training Accuracy per 500 steps: {compute_accuracy(tp, tn, fp, fn)}")
            print(f"Training Precision per 500 steps: {compute_precision(tp, fp)}")
            print(f"Training Recall per 500 steps: {compute_recall(tp, fn)}")
            print(f"Training f1 per 500 steps: {compute_f1(tp, fn, fp)}")
    
    avg_loss = total_loss / nb_steps
    accuracy = compute_accuracy(tp, tn, fp, fn)
    precision = compute_precision(tp, fp)
    recall = compute_recall(tp, fn)
    f1 = compute_f1(tp, fn, fp)
    wandb.log({"loss": avg_loss, "accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1})
    print(f"Epoch {i}, avg loss: {avg_loss}")
    print(f"Epoch {i}, accuracy: {accuracy}")
    print(f"Epoch {i}, precision: {precision}")
    print(f"Epoch {i}, recall: {recall}")
    print(f"Epoch {i}, f1: {f1}")
    
    test_example(model, tokenizer)

In [None]:
# EVALUATION

model.eval()
tp, fp, fn, tn = 0, 0, 0, 0
total_loss = 0
nb_steps = 0
with torch.no_grad():
    for i, data in enumerate(testing_loader):
        ids = data['ids'].to(device, dtype = torch.long)
        mask = data['mask'].to(device, dtype = torch.long)
        targets = data['targets'].to(device, dtype = torch.float).reshape(-1, 1)
        
        outputs = model(ids, mask)
        loss = loss_function(outputs, targets)
        total_loss += loss.item()
        
        pred_class = pred_to_class(outputs, THRESHOLD)
        t_preds, t_targets = pred_class == 1, targets == 1
        f_preds, f_targets = pred_class == 0, targets == 0
        tp += (t_preds & t_targets).sum().item()
        fp += (t_preds & f_targets).sum().item()
        fn += (f_preds & t_targets).sum().item()
        tn += (f_preds & f_targets).sum().item()

        nb_steps += 1

        if i % 100 == 0:
            print(f"Validation Loss per 100 steps: {loss.item()}")
            print(f"Validation Accuracy per 100 steps: {compute_accuracy(tp, tn, fp, fn)}")
            print(f"Validation Precision per 100 steps: {compute_precision(tp, fp)}")
            print(f"Validation Recall per 100 steps: {compute_recall(tp, fn)}")
            print(f"Validation f1 per 500 steps: {compute_f1(tp, fn, fp)}")

avg_loss = total_loss / nb_steps
accuracy = compute_accuracy(tp, tn, fp, fn)
precision = compute_precision(tp, fp)
recall = compute_recall(tp, fn)
f1 = compute_f1(tp, fn, fp)
print(f"Avg loss: {avg_loss}")
print(f"Accuracy: {accuracy}")
print(f"Precision: {precision}")
print(f"Recall: {recall}")
print(f"f1: {f1}")

In [None]:
import os

out_model = os.path.join(wandb.run.dir, "classifier.bin")
out_tokenizer = os.path.join(wandb.run.dir, "tokenizer.bin")

torch.save(model, out_model)
tokenizer.save_vocabulary(out_tokenizer)

wandb.save('classifier.bin')
wandb.save('tokenizer.bin')

out_model_pt = os.path.join(wandb.run.dir, "classifier.pt")
torch.save(model.state_dict(), out_model_pt)

wandb.save('classifier.pt')

In [None]:
test_example(model, tokenizer, 10, 10)

In [None]:
wandb.finish()

THRESHOLD TESTS

THRESHOLD: 0.5

Avg loss: 0.7038330654244102 \
Accuracy: 0.8570290720828355 \
Precision: 0.8728854519091348 \
Recall: 0.9495268138801262 \
f1: 0.9095945605640896

(icy-silence-9) \
Avg loss: 0.831273452559438 \
Accuracy: 0.8410991636798089 \
Precision: 0.9219539584503088 \
Recall: 0.8633017875920084 \
f1: 0.8916644040184633

(different-dream-10) \
Avg loss: 0.5872668912708168 \
Accuracy: 0.8614097968936678 \
Precision: 0.8940162271805274 \
Recall: 0.926919032597266 \
f1: 0.9101703665462054

(worthy-breeze-12) \
Avg loss: 0.9387986349253287 \
Accuracy: 0.8566308243727598 \
Precision: 0.8933673469387755 \
Recall: 0.9206098843322819 \
f1: 0.9067840497151735

(frosty-star-13) \
Avg loss: 0.3630746433008758 \
Accuracy: 0.8546395858223815 \
Precision: 0.881011403073872 \
Recall: 0.9342797055730809 \
f1: 0.9068639959173258

THRESHOLD: 0.6

Avg loss: 0.7061429985850636 \
Accuracy: 0.8554360812425329 \
Precision: 0.8759159745969711 \
Recall: 0.9426919032597266 \
f1: 0.9080779944289693

THRESHOLD: 0.7

Avg loss: 0.7037493148331239 \
Accuracy: 0.8566308243727598 \
Precision: 0.8801775147928994 \
Recall: 0.9384858044164038 \
f1: 0.9083969465648855

THRESHOLD: 0.8

Avg loss: 0.7040119320611536 \
Accuracy: 0.8566308243727598 \
Precision: 0.8847305389221557 \
Recall: 0.9321766561514195 \
f1: 0.9078341013824884

**THRESHOLD: 0.9**

Avg loss: 0.7037559155104978 \
Accuracy: 0.8610115491835922 \
Precision: 0.8955680081507896 \
Recall: 0.9242902208201893 \
f1: 0.9097024579560156

THRESHOLD: 0.95

Avg loss: 0.7070154592197108 \
Accuracy: 0.8582238152130626 \
Precision: 0.9017671517671517 \
Recall: 0.9121976866456362 \
f1: 0.9069524307370622