# BERTweet-Large

## Imports

In [None]:
import os
import random

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from captum.attr import *
from tqdm.auto import tqdm
from transformers import AutoModelForSequenceClassification, AutoTokenizer, logging
from pydictobject import DictObject
from sklearn.metrics import classification_report

from BERTweet.TweetNormalizer import *

# Ensure deterministic behavior
seed = 12345678
# seed = 87654321

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

# Device configuration
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

logging.set_verbosity_error()

import wandb
wandb.login()

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [None]:
assert torch.cuda.is_available()

In [None]:
def log_memory():
    print(torch.cuda.memory_allocated() / 1e9)

## Data

In [None]:
class AlzheimersTweetsDataset(Dataset):
    def __init__(self, root, transform=None, target_transform=None, padding=True, max_length=200):
        self.tokenizer = AutoTokenizer.from_pretrained("vinai/bertweet-large", use_fast=False)
        self.tokenizer.model_max_length = 512
        self.transform = transform
        self.target_transform = target_transform

        if root[-4:] == ".csv":
            self.df = pd.read_csv(root)
        elif root[-5:] == ".xlsx":
            self.df = pd.read_excel(root)
        else:
            self.df = None

        self.length = len(self.df)

        self.tokens = self.tokenizer(normalizeTweet(self.df["tweet"].values), padding=padding, max_length=max_length, truncation=True, return_tensors='pt')
        self.tweets = self.tokens['input_ids']
        self.amasks = self.tokens['attention_mask']
        self.labels = torch.LongTensor(self.df["label"].values)

        if self.transform is not None:
            self.tweets = self.transform(self.tweets)

        if self.target_transform is not None:
            self.labels = self.target_transform(self.labels)
        
    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        return self.tweets[idx], self.amasks[idx], self.labels[idx]

In [None]:
def get_data(path=None, augment=None, **kwargs):
    if path:
        if path == "test":
            return AlzheimersTweetsDataset("data/test.csv", **kwargs)
        
        return AlzheimersTweetsDataset(path, **kwargs)
    
    if augment == 50/50:
        trainset = AlzheimersTweetsDataset("data/train_augment_5050.csv", **kwargs)
    elif augment == 75/25:
        trainset = AlzheimersTweetsDataset("data/train_augment_7525.csv", **kwargs)
    elif augment == "ta":
        trainset = AlzheimersTweetsDataset("data/train_ta.csv", **kwargs)
    elif augment == "sentiment140":
        trainset = AlzheimersTweetsDataset("data/train_sentiment140.csv", **kwargs)
    elif augment == "parental":
        trainset = AlzheimersTweetsDataset("data/train_parental.csv", **kwargs)
    else:
        trainset = AlzheimersTweetsDataset("data/train.csv", **kwargs)
        
    valset = AlzheimersTweetsDataset("data/val.csv", **kwargs)
    return trainset, valset

def make_loader(dataset, batch_size):
    loader = DataLoader(dataset=dataset, 
                        batch_size=batch_size, 
                        shuffle=True, 
                        pin_memory=True, num_workers=2)
    return loader

## Model

In [None]:
def get_model():
    return AutoModelForSequenceClassification.from_pretrained("vinai/bertweet-large")

# Training

In [None]:
def train(model, train_loader, val_loader, optimizer, scheduler, config):    
    # Tell wandb to watch what the model gets up to: gradients, weights, and more!
    # wandb.watch(model, log="all", log_freq=10)

    best_epoch = None
    best_val_accuracy = -1

    # Run training and track with wandb
    example_ct = 0  # number of examples seen
    batch_ct = 0

    for epoch in tqdm(range(config.epochs)):
        model.train()
        train_correct, val_correct = 0, 0
        for batch, masks, labels in train_loader:
            batch, masks, labels = batch.to(device), masks.to(device), labels.to(device)

            output = model(batch, attention_mask=masks, labels=labels)

            loss = output.loss

            partial_loss = loss / config.accum
            partial_loss.backward()
            
            predicted = output.logits.argmax(dim=-1)
            train_correct += (predicted == labels).sum().item()

            example_ct += len(batch)
            batch_ct += 1

            if (batch_ct % config.accum == 0) or (batch_ct == len(train_loader)):
                optimizer.step()
                optimizer.zero_grad()

            if (batch_ct % (config.log_interval * config.accum)) == 0:
                wandb.log({"epoch": epoch, "loss": loss.item()}, step=example_ct)
                print(f"Loss after {str(example_ct).zfill(5)} examples: {loss:.3f}")
        
        scheduler.step()

        with torch.no_grad():
            model.eval()
            for batch, masks, labels in val_loader:
                batch, masks, labels = batch.to(device), masks.to(device), labels.to(device)
                output = model(batch, attention_mask=masks, labels=labels)

                predicted = output.logits.argmax(dim=-1)
                val_correct += (predicted == labels).sum().item()
        
        train_accuracy = train_correct / len(train_loader.dataset)
        val_accuracy = val_correct / len(val_loader.dataset)

        if val_accuracy > best_val_accuracy:
            best_epoch = epoch
            best_val_accuracy = val_accuracy

        wandb.log({"train_accuracy": train_accuracy, "val_accuracy": val_accuracy}, step=example_ct)
        print(f"Epoch {str(epoch).zfill(2)} Summary: (Train %: {train_accuracy:%}, Val%: {val_accuracy:%})")

        # model.save_pretrained(os.path.join(wandb.run.dir, f"model_{epoch}"))
        model.save_pretrained(os.path.join("results", f"model_{epoch}"))

    return best_epoch

# Final Test

In [None]:
def test(config, model, data_dir="test", use_wandb=True, print_str=True):
    if not use_wandb:
        config = DictObject(config)

    model.eval()

    test = get_data(data_dir)
    test_loader = make_loader(test, batch_size=config.batch_size//config.accum)

    y_true = []
    y_pred = []

    with torch.no_grad():
        for batch, masks, labels in test_loader:
            batch, masks, labels = batch.to(device), masks.to(device), labels.to(device)
            output = model(batch, labels=labels)

            predicted = output.logits.argmax(dim=-1)

            y_true.extend(labels.cpu().tolist())
            y_pred.extend(predicted.cpu().tolist())

    accuracy = sum([i == j for i, j in zip(y_true, y_pred)]) / len(y_true)

    if use_wandb:
        wandb.log({"test_accuracy": accuracy})

    if print_str:
        print(classification_report(y_true, y_pred))
    else:
        return accuracy, classification_report(y_true, y_pred, output_dict=True)


## Pipeline

In [None]:
def make(config):
    # Make the data
    train, val = get_data(augment=config.augment)
    train_loader = make_loader(train, batch_size=config.batch_size//config.accum)
    val_loader = make_loader(val, batch_size=config.batch_size//config.accum)

    # Make the model
    model = get_model().to(device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
    if config.scheduler:
        scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, 
                                                      start_factor=config.scheduler[0], 
                                                      end_factor=config.scheduler[1],
                                                      total_iters=config.scheduler[2])
    else:
        scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, 1, 1, 0)
    
    return model, train_loader, val_loader, optimizer, scheduler


In [None]:
def model_pipeline(hyperparameters):
    assert hyperparameters["batch_size"] % hyperparameters["accum"] == 0

    # tell wandb to get started
    with wandb.init(project="Alzheimers", config=hyperparameters):
        # access all HPs through wandb.config, so logging matches execution!
        config = wandb.config

        # make the model, data, and optimization problem
        model, train_loader, val_loader, optimizer, scheduler = make(config)
        print(model)

        # and use them to train the model
        best_epoch = train(model, train_loader, val_loader, optimizer, scheduler, config)
        print("Best Epoch:", best_epoch)

        # and run test its final performance
        # model.from_pretrained(os.path.join(wandb.run.dir, f"model_{best_epoch}"))
        model = model.from_pretrained(os.path.join("results", f"model_{best_epoch}")).to(device)
        test(config, model)
      
    return model

# Run

In [None]:
config = {
    "epochs": 40,
    "batch_size": 32,
    "accum": 1,
    "learning_rate": 1e-5,
    "log_interval": 4,
    "augment": "parental",
    # "scheduler": [1, 0.1, 25, "linear"]
    "scheduler": None
}

In [None]:
stop

In [None]:
model = model_pipeline(config)

In [None]:
print(torch.cuda.memory_allocated(0))

In [None]:
stop

## Test Generalization

In [None]:
model = get_model().from_pretrained("results_sentiment140/model_13/").to(device)
# model = get_model().from_pretrained("results_mask(1-10 are corrupted)/model_2").to(device)

In [None]:
test(config, model, use_wandb=False)

In [None]:
test(config, model, data_dir="test_generalization2.xlsx", use_wandb=False)

In [None]:
# tes = []
# gen = []

# for i in tqdm(range(40)):
#     model = get_model().from_pretrained(f"results_sentiment140/model_{i}/").to(device)
#     tes.append(test(config, model, use_wandb=False, print_str=False)[0])
#     gen.append(test(config, model, data_dir="test_generalization2.xlsx", use_wandb=False, print_str=False)[0])

# df = pd.DataFrame({"test": tes, "gen": gen})
# df.plot.line(subplots=True)


## Captum

In [None]:
tokenizer = AutoTokenizer.from_pretrained("vinai/bertweet-large", use_fast=False)
tokenizer.model_max_length = 512

PAD_IND = tokenizer.encode("<pad>")[1]
token_reference = TokenReferenceBase(reference_token_idx=PAD_IND)

In [None]:
lig = LayerIntegratedGradients(lambda x, attention_mask=None: model(x, attention_mask=attention_mask).logits, model.roberta.embeddings)

In [None]:
# accumalate couple samples in this array for visualization purposes

def interpret_sentences(model, sentences, masks, labels, tokenize=[]):
    log_memory()
    classes = ["negative", "positive"]

    for i in tokenize:
        tokenized = tokenizer(normalizeTweet(sentences[i]), padding="max_length", truncation=True, return_tensors='pt')
        sentences[i] = tokenized['input_ids']
        masks[i] = tokenized['attention_mask']
        
    if isinstance(sentences, list):
        sentences = torch.stack(sentences)
        masks = torch.stack(masks)

    sentences = sentences.to(device)
    masks = masks.to(device)

    text = [[tokenizer.decode(word) for word in sentence if tokenizer.decode(word) != "<pad>"] for sentence in sentences.cpu()]

    model.zero_grad()
    pred = model(sentences, attention_mask=masks).logits.detach().cpu()
    pred_ind = pred.argmax(dim=-1)

    reference_indices = token_reference.generate_reference(tokenizer.model_max_length, device=device).unsqueeze(0)

    # compute attributions and approximation delta using layer integrated gradients
    attributions, delta = lig.attribute(sentences, 
                                        reference_indices, 
                                        target=1, 
                                        additional_forward_args=(masks),
                                        n_steps=500, 
                                        return_convergence_delta=True, 
                                        internal_batch_size=24
    )

    vis_data_records = []
    for i in range(len(sentences)):
        print(f"pred: {classes[pred_ind[i]]} ({pred[i][1]:.2f}), delta: {abs(delta[i])}")

        attr = attributions[i]
        attr = attr.sum(dim=-1)
        attr = attr / torch.norm(attr)
        attr = attr.cpu().detach().numpy()

        # storing couple samples in an array for visualization purposes
        vis_data_records.append(visualization.VisualizationDataRecord(
                                attr,
                                pred[i][1],
                                classes[pred_ind[i]],
                                classes[labels[i]],
                                classes[1],
                                attr.sum(),
                                text[i],
                                delta[i]))

    return vis_data_records

In [None]:
test = get_data("test_generalization.csv", padding="max_length", max_length=512)

neg_indices = test.labels == 0
pos_indices = test.labels == 1

neg_tweets = test.tweets[neg_indices]
neg_amasks = test.amasks[neg_indices]
neg_labels = test.labels[neg_indices]

pos_tweets = test.tweets[pos_indices]
pos_amasks = test.amasks[pos_indices]
pos_labels = test.labels[pos_indices]

In [None]:
i = 0

sentences = pos_tweets[i*16:min((i+1)*16, len(pos_tweets))]
masks = pos_amasks[i*16:min((i+1)*16, len(pos_tweets))]
labels = pos_labels[i*16:min((i+1)*16, len(pos_tweets))]

pos_data_records = interpret_sentences(model, sentences, masks, labels)
_ = visualization.visualize_text(pos_data_records)

In [None]:
i = 0

sentences = neg_tweets[i*16:min((i+1)*16, len(neg_tweets))]
masks = neg_amasks[i*16:min((i+1)*16, len(neg_tweets))]
labels = neg_labels[i*16:min((i+1)*16, len(neg_tweets))]

neg_data_records = interpret_sentences(model, sentences, masks, labels)
visualization.visualize_text(neg_data_records)