In [1]:
import datasets
from datasets import load_dataset
from transformers import AutoTokenizer
from torch.utils.data import DataLoader
from transformers import BertModel, BertConfig, DistilBertModel, DistilBertTokenizer
from torch.optim import AdamW
import torch.nn as nn
from transformers import get_scheduler
import torch
from tqdm.notebook import tqdm
import evaluate
import random
import argparse
import os
from nltk.corpus import wordnet
from nltk.tokenize import word_tokenize

In [53]:
import nlpaug.augmenter.word as naw

def augment(text):
    aug = naw.SynonymAug(aug_src='wordnet', aug_p=0.3)
    augmented_text = aug.augment(text)
    return augmented_text[0]

In [3]:
global device
global tokenizer
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

def tokenize_function(example):
    return tokenizer(example, padding="max_length", truncation=True,return_attention_mask=True, return_token_type_ids=True, return_tensors='pt')

In [55]:
import csv
from torch.utils.data import Dataset, ConcatDataset
from torch.utils.data.sampler import BatchSampler, WeightedRandomSampler

def custom_transform(item):
    text1 = augment(item["text1"])
    text2 = augment(item["text2"])
    reverse_combined_text = text2 + ' [SEP] ' + text1
    tokenize_input = tokenize_function(reverse_combined_text)
    tokenize_input["label"] = item["label"]
    tokenize_input["text1"] = text2
    tokenize_input["text2"] = text1
    return tokenize_input

class myDataset(Dataset):
    def __init__(self, csv_file):
        self.data = []
        with open(csv_file, 'r') as csvfile:
            csvreader = csv.reader(csvfile)
            ignore_header = 1
            for row in csvreader:
                if ignore_header:
                    ignore_header = 0
                    continue
                row = {"text1": row[0], 
                       "text2": row[1], 
                       "label": int(row[2])}
                self.data.append(row)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        item = self.data[index]
        combined_text = item["text1"]+' [SEP] '+item["text2"]
        tokenize_input = tokenize_function(combined_text)
        tokenize_input["label"] = item["label"]
        tokenize_input["text1"] = item["text1"]
        tokenize_input["text2"] = item["text2"]
        return tokenize_input
    
dataset = myDataset('assignment_A.csv')
print("Dataset size: ", len(dataset))
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [3200, 800])

def get_train_loader():
    global train_dataset
    more_positives = []
    for x in train_dataset:
        if x['label']==1 and random.random()<0.4:
            more_positives.append(custom_transform(x))
    train_dataset = ConcatDataset([train_dataset1,more_positives])
    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
    return train_loader

test_loader = DataLoader(test_dataset, batch_size=16)
train_loader = get_train_loader()
print(len(train_loader))

Dataset size:  4000
209


In [59]:
class Model(torch.nn.Module):
    def __init__(self, bert, num_classes):
        super(Model, self).__init__()
        self.bert = bert
        self.num_classes = num_classes
        self.linear1 = torch.nn.Linear(768, 768)
        self.dropout1 = torch.nn.Dropout(0.2)
        self.layer_norm1 = nn.LayerNorm(768)
        self.linear2 = torch.nn.Linear(768, self.num_classes)
        

    def forward(self, input_ids, attention_mask, token_type_ids):
        outputs = self.bert(input_ids, attention_mask=attention_mask)
        x = outputs[0][:,0,:]
        x = self.linear1(x)
        x = nn.Tanh()(x)
        x = self.layer_norm1(x)
        x = self.dropout1(x)
        x = self.linear2(x)
        return x

In [63]:
distilbert = DistilBertModel.from_pretrained('distilbert-base-uncased',dropout=0.2)
model = Model(distilbert,2)
model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = AdamW(model.parameters(), lr=2e-5, weight_decay = 0.001)
n_epochs = 15

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_layer_norm.bias', 'vocab_transform.bias', 'vocab_transform.weight', 'vocab_projector.weight', 'vocab_projector.bias', 'vocab_layer_norm.weight']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [43]:
def do_eval(eval_dataloader, print_count = 0):
    model.eval()
    metric = evaluate.load("accuracy")
    counter,true_positives,false_positives,false_negatives = 0,0,0,0
    for batch in tqdm(eval_dataloader):
        batch_input_ids = batch['input_ids'].squeeze(1).to(device)
        batch_token_type_ids = batch['token_type_ids'].squeeze(1).to(device)
        batch_attention_mask = batch['attention_mask'].squeeze(1).to(device)
        batch_labels = batch['label'].float().to(device)
        outputs = model(batch_input_ids,batch_attention_mask,batch_token_type_ids)
        _, predictions = torch.max(outputs,1)
        metric.add_batch(predictions=predictions, references=batch_labels)
        wrong_indices = torch.nonzero(predictions!=batch_labels)
        for i in range(wrong_indices.shape[0]):
            if counter == print_count:
                break
            counter+=1
            index = wrong_indices[i,0].item()
            print(f"{counter}, Text1: {batch['text1'][index]}, Text2: {batch['text2'][index]}, Actual prediction: {batch_labels[index].item()}")
        true_positives += torch.sum(predictions * batch_labels).clone().detach().cpu().item()
        false_positives += torch.sum((predictions - batch_labels).clone().detach() == 1).cpu().item()
        false_negatives += torch.sum((predictions - batch_labels).clone().detach() == -1).cpu().item()
    score = metric.compute()
    precision, recall = 0,0
    if (true_positives+false_positives)!=0:
        precision = true_positives/(true_positives+false_positives)
    if (true_positives+false_negatives)!=0:
        recall = true_positives/(true_positives+false_negatives)
    return score,precision,recall

In [64]:
from torch.optim.lr_scheduler import CosineAnnealingLR
num_training_steps = n_epochs * len(train_loader)
progress_bar = tqdm(range(num_training_steps))
lr_scheduler = get_scheduler(
        name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps
)

for epoch in range(n_epochs):
    model.train()
    metric = evaluate.load("accuracy")
    for batch in train_loader:
        batch_input_ids = batch['input_ids'].squeeze(1).to(device)
        batch_token_type_ids = batch['token_type_ids'].squeeze(1).to(device)
        batch_attention_mask = batch['attention_mask'].squeeze(1).to(device)
        batch_labels = batch['label'].to(device)
        optimizer.zero_grad()
        outputs = model(batch_input_ids,batch_attention_mask,batch_token_type_ids)
        loss = criterion(outputs, batch_labels)
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        _, predictions = torch.max(outputs,1)
        metric.add_batch(predictions=predictions, references=batch_labels)
    print("Training Accuracy: ",metric.compute())
    print("Validation Metrics: ",do_eval(test_loader,print_count=0))
        

  0%|          | 0/3135 [00:00<?, ?it/s]

Training Accuracy:  {'accuracy': 0.6458770614692654}


  0%|          | 0/50 [00:00<?, ?it/s]

Validation Metrics:  ({'accuracy': 0.7625}, 0.5645161290322581, 0.17676767676767677)
Training Accuracy:  {'accuracy': 0.7145427286356821}


  0%|          | 0/50 [00:00<?, ?it/s]

Validation Metrics:  ({'accuracy': 0.7975}, 0.7142857142857143, 0.30303030303030304)
Training Accuracy:  {'accuracy': 0.7712143928035982}


  0%|          | 0/50 [00:00<?, ?it/s]

Validation Metrics:  ({'accuracy': 0.8275}, 0.6704545454545454, 0.5959595959595959)
Training Accuracy:  {'accuracy': 0.824287856071964}


  0%|          | 0/50 [00:00<?, ?it/s]

Validation Metrics:  ({'accuracy': 0.84375}, 0.6553191489361702, 0.7777777777777778)
Training Accuracy:  {'accuracy': 0.8623688155922039}


  0%|          | 0/50 [00:00<?, ?it/s]

Validation Metrics:  ({'accuracy': 0.87}, 0.7117117117117117, 0.797979797979798)
Training Accuracy:  {'accuracy': 0.9013493253373314}


  0%|          | 0/50 [00:00<?, ?it/s]

Validation Metrics:  ({'accuracy': 0.8925}, 0.8783783783783784, 0.6565656565656566)
Training Accuracy:  {'accuracy': 0.9187406296851575}


  0%|          | 0/50 [00:00<?, ?it/s]

Validation Metrics:  ({'accuracy': 0.9075}, 0.8297872340425532, 0.7878787878787878)
Training Accuracy:  {'accuracy': 0.9421289355322339}


  0%|          | 0/50 [00:00<?, ?it/s]

Validation Metrics:  ({'accuracy': 0.92375}, 0.8870056497175142, 0.7929292929292929)
Training Accuracy:  {'accuracy': 0.9583208395802099}


  0%|          | 0/50 [00:00<?, ?it/s]

Validation Metrics:  ({'accuracy': 0.925}, 0.8254716981132075, 0.8838383838383839)
Training Accuracy:  {'accuracy': 0.9667166416791604}


  0%|          | 0/50 [00:00<?, ?it/s]

Validation Metrics:  ({'accuracy': 0.9225}, 0.8148148148148148, 0.8888888888888888)
Training Accuracy:  {'accuracy': 0.9724137931034482}


  0%|          | 0/50 [00:00<?, ?it/s]

Validation Metrics:  ({'accuracy': 0.93}, 0.8585858585858586, 0.8585858585858586)
Training Accuracy:  {'accuracy': 0.9781109445277362}


  0%|          | 0/50 [00:00<?, ?it/s]

Validation Metrics:  ({'accuracy': 0.93}, 0.865979381443299, 0.8484848484848485)
Training Accuracy:  {'accuracy': 0.9817091454272864}


  0%|          | 0/50 [00:00<?, ?it/s]

Validation Metrics:  ({'accuracy': 0.9325}, 0.8564356435643564, 0.8737373737373737)
Training Accuracy:  {'accuracy': 0.9835082458770614}


  0%|          | 0/50 [00:00<?, ?it/s]

Validation Metrics:  ({'accuracy': 0.93375}, 0.8756476683937824, 0.8535353535353535)
Training Accuracy:  {'accuracy': 0.9850074962518741}


  0%|          | 0/50 [00:00<?, ?it/s]

Validation Metrics:  ({'accuracy': 0.9325}, 0.8673469387755102, 0.8585858585858586)


In [66]:
do_eval(test_loader,10)

  0%|          | 0/50 [00:00<?, ?it/s]

1, Text1: depression, Text2: motivation, Actual prediction: 1.0
2, Text1: Whenever he walks into the room, I get horny... why?, Text2: Having a bipolar Mom means I never know which Mom I get that day., Actual prediction: 0.0
3, Text1: money, Text2: feeling alone and no one to talk to feeling sad about an ex that came back i to my life and just ghosted me tired of this feeling, i'm a widow and just feeling overwhelmed with nobody to share my fears about all this with, Actual prediction: 0.0
4, Text1: grieving, Text2: sad, Actual prediction: 1.0
5, Text1: When I was a child, I was bullied a bunch and didn't have supportive friends. Although it was only verbal bullying, I wonder if that's what caused my social awkwardness., Text2: I have no problem beating up a none year old if they don't stop bullying my little girl., Actual prediction: 1.0
6, Text1: i feel lonely and wat to talk with someone, Text2: lost feelings, feeling alone, Actual prediction: 1.0
7, Text1: i feel helpless, Text2: i

({'accuracy': 0.9325}, 0.8673469387755102, 0.8585858585858586)