In [None]:
import pandas as pd
import torch
import torch.nn as nn
import numpy as np
from transformers import DistilBertTokenizerFast, DistilBertModel, DistilBertForSequenceClassification 

dataset = pd.read_csv("clean_COVIDSenti.csv")
tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')

def tokenize(tweet):
    tokenized = tokenizer(tweet, return_tensors='pt', padding="max_length", max_length = 47) #Max tweet token length is 47
    return tokenized

tweets, labels = dataset['tweet'], dataset['label'] + 1 #Labels need to be 0-indexed
tokenized_tweets = tweets.map(tokenize)
tokenized_tweets, labels = tokenized_tweets.to_list(), labels.to_list()

#Determining correct backend
if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("Training on Apple GPU")
elif torch.cuda.is_available():
    device = torch.device("cuda")
    print("Training on CUDA")
else:
    print ("MPS device not found.")

tensor([[  101, 21887, 23350,  2529, 21887, 23350,  2828, 26629,   102,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0]])
Training on Apple GPU


In [203]:
print(np.min(list(map(lambda x: len(x['input_ids'][0]), tokenized_tweets))))
print(tokenized_tweets[36044])

47
{'input_ids': tensor([[  101,  4501,  4484,  2034,  2048,  2553,  3117, 21887, 23350, 16311,
          2506,  2233,  3458,   102,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])}


In [205]:
from torch.utils.data import DataLoader, WeightedRandomSampler, Dataset, random_split

class TweetDataset(Dataset):
    def __init__(self, tweets, labels):
        self.x = tweets
        self.y = labels
        
    def __getitem__(self, index):
        # Check that x is really a dictionary before processing
        x = self.x[index]
        x = dict(x)
        x = {key: torch.squeeze(val, dim = 0) for key, val in x.items()}
        y = self.y[index]
        return (x, y)
    
    def __len__(self):
        return len(self.x)
    
folds = 5
early_stopping = 5 #Stop if 5 epochs without improvement on val
train_frac = 0.1
test_frac = 0.8
val_frac = 0.1
batch_size = 64
test_accuracies = []
print(isinstance(labels,pd.Series))
data = TweetDataset(tokenized_tweets, labels)


for fold in range(folds):
    print(f"FOLD {fold}")
    gen = torch.Generator().manual_seed(fold)
    train, val, test = random_split(data, lengths=[train_frac, val_frac, test_frac], generator=gen)
    
    #Dealing with imbalanced class weights for train dataset
    labels_for_counts = list(map(lambda x: x[-1], train))
    frequency = 1 / np.bincount(labels_for_counts)
    class_weights = torch.tensor(frequency, dtype=torch.float32)
    obs_weights = list(map(lambda x: class_weights[x[-1]], train))
        
    #train_sampler = WeightedRandomSampler(weights = obs_weights, num_samples = len(obs_weights))
    train_loader = DataLoader(train, batch_size=batch_size, shuffle = True) #Test with shuffle instead of sampler, maybe?
    val_loader = DataLoader(val, shuffle=False, batch_size=batch_size)
    test_loader = DataLoader(test, shuffle=False, batch_size=batch_size)
    
    #---- TRAINING ACTUAL MODEL FROM HERE ON OUT ----#
    model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased', num_labels=3)
    model = model.to(device)

    #Just train these layers to save CONSIDERABLE time -- distilbert has 6 transformer layers
    layers = [model.classifier, model.pre_classifier, model.distilbert.transformer.layer[4], model.distilbert.transformer.layer[5]]
    updated_params = nn.ParameterList([])
    for layer in layers:
        updated_params.extend(layer.parameters())

    for param in model.parameters():
        param.requires_grad = True #Usually, this would be false, but we're updating the entirety of the model now
        
    for param in updated_params:
        param.requires_grad = True

    model.train()

    lr = 0.0001
    epoch = 0
    no_improvement = 0
    curr_acc = 0
    criterion = nn.CrossEntropyLoss() #Without softmax we use CEL
    optimizer = torch.optim.AdamW(model.parameters(), lr = lr)

    while no_improvement < early_stopping:
        epoch += 1
        print(f"Epoch {epoch}")
        
        #Training model layers
        for train_inputs, train_labels in train_loader:
            train_inputs['input_ids'], train_inputs['attention_mask'] = train_inputs['input_ids'].to(device), train_inputs['attention_mask'].to(device)
            train_labels = train_labels.to(device)
            
            model.zero_grad()
            with torch.autocast("mps"):
                output = model(**train_inputs)['logits']   
            loss = criterion(output, train_labels)
            loss.backward()
            optimizer.step()
        
        #Early stopping
        model.eval()
        correct = torch.tensor(0, device = device)
        incorrect = torch.tensor(0, device = device)
        
        for val_inputs, val_labels in val_loader:
            val_inputs['input_ids'], val_inputs['attention_mask'] = val_inputs['input_ids'].to(device), val_inputs['attention_mask'].to(device)
            val_labels = val_labels.to(device)
            probs = model(**val_inputs)['logits']
            preds = torch.argmax(probs, axis = 1)
            preds = preds.to(device)
            correct += (preds == val_labels).sum()
            incorrect += (preds != val_labels).sum()  
        
        accuracy = correct / (correct + incorrect)
        if accuracy > curr_acc:
            print(f"New accuracy has been reached: {accuracy}")
            curr_acc = accuracy
            no_improvement = 0
        else:
            no_improvement += 1
        
        model.train()
        
    model.eval()
    correct = torch.tensor(0, device = device)
    incorrect = torch.tensor(0, device = device)
    
    #Getting test accuracy for CV purposes
    for test_inputs, test_labels in test_loader:
        test_inputs['input_ids'], test_inputs['attention_mask'] = test_inputs['input_ids'].to(device), test_inputs['attention_mask'].to(device)
        test_labels = test_labels.to(device)
        probs = model(**test_inputs)['logits']
        preds = torch.argmax(probs, axis = 1)
        preds = preds.to(device)
        correct += (preds == test_labels).sum()
        incorrect += (preds != test_labels).sum()  
    
    test_accuracy = correct / (correct + incorrect)
    test_accuracies.append(test_accuracy)
    print(f"FOR FOLD {fold}, THE TEST ACCURACY WAS {test_accuracy}")
    print("---------------------------------------")
        



False
FOLD 0


Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch 1
New accuracy has been reached: 0.8541111350059509
Epoch 2
New accuracy has been reached: 0.8673333525657654
Epoch 3
New accuracy has been reached: 0.8733333349227905
Epoch 4
Epoch 5
New accuracy has been reached: 0.8752222061157227
Epoch 6
Epoch 7
Epoch 8
Epoch 9
Epoch 10
FOR FOLD 0, THE TEST ACCURACY WAS 0.8671666383743286
---------------------------------------
FOLD 1


Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch 1
New accuracy has been reached: 0.835444450378418
Epoch 2


KeyboardInterrupt: 