In [1]:
import sys

sys.path.append('space-model')

In [2]:
import os
import random

from collections import Counter

import numpy as np

import torch
import torch.nn.functional as F

from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, jaccard_score
from sklearn.model_selection import train_test_split

from tqdm import tqdm

import matplotlib.pyplot as plt

from transformers import AutoModel, AutoTokenizer, AutoModelForSequenceClassification
from transformers import get_linear_schedule_with_warmup

from datasets import load_dataset, Dataset

from space_model.model import *
from space_model.loss import *

from logger import get_logger, log_continue
from train import train, eval, plot_results, concept_space_to_preds
from utils import free_resources_deco

In [3]:
SEED = 42

In [4]:
def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True


seed_everything(seed=SEED)

In [5]:
def on_gpu(f):
    def wrapper(*args):
        if torch.cuda.is_available():
            return f(*args)
        else:
            print('cuda unavailable')

    return wrapper


if torch.cuda.is_available():
    ! pip install pynvml
    from pynvml import *


@on_gpu
def print_gpu_utilization(dev_id):
    try:
        nvmlInit()
        handle = nvmlDeviceGetHandleByIndex(dev_id)
        info = nvmlDeviceGetMemoryInfo(handle)
        print(f"GPU memory occupied: {info.used // 1024 ** 2} MB.")
    except Exception as e:
        print(e)


@on_gpu
def free_gpu_cache(dev_id=0):
    print("Initial GPU Usage")
    print_gpu_utilization(dev_id)

    torch.cuda.empty_cache()

    print("GPU Usage after emptying the cache")
    print_gpu_utilization(dev_id)


def print_summary(result):
    print(f"Time: {result.metrics['train_runtime']:.2f}")
    print(f"Samples/second: {result.metrics['train_samples_per_second']:.2f}")
    print_gpu_utilization()



In [6]:
device_id = 0
device = torch.device(f'cuda:{device_id}' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda', index=0)

In [7]:
if torch.cuda.is_available():
    torch.cuda.set_device(device)

In [8]:
MODEL_NAME = 'bert-base-cased'

DATASET_NAME = 'tdavidson/hate_speech_offensive'

NUM_LABELS = 3
N_LATENT = 128

NUM_EPOCHS = 2
BATCH_SIZE = 16
MAX_SEQ_LEN = 512
LEARNING_RATE = 2e-5
MAX_GRAD_NORM = 1000

In [9]:
dataset = load_dataset(DATASET_NAME)
dataset

DatasetDict({
    train: Dataset({
        features: ['count', 'hate_speech_count', 'offensive_language_count', 'neither_count', 'class', 'tweet'],
        num_rows: 24783
    })
})

In [10]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer

BertTokenizerFast(name_or_path='bert-base-cased', vocab_size=28996, model_max_length=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	0: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	100: AddedToken("[UNK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	101: AddedToken("[CLS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	102: AddedToken("[SEP]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	103: AddedToken("[MASK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}

In [11]:
def get_preds_from_logits(outputs):
    probs = F.softmax(outputs.logits, dim=-1).cpu()
    pred = torch.argmax(probs, dim=-1)  # (B)
    return outputs.loss, pred.long(), outputs.logits


def bert_outputs_callback(outputs):
    return F.sigmoid(outputs.hidden_states[-1].mean(dim=-1))

In [12]:
@eval
def eval_epoch(model, val_dataloader, config):
    val_loss = 0.0
    val_preds = []
    cs_val_preds = []
    val_labels = []

    with torch.no_grad():

        for step, batch in enumerate(tqdm(val_dataloader, total=len(val_dataloader))):
            ids = batch['input_ids'].to(model.device, dtype=torch.long)
            mask = batch['attention_mask'].to(model.device, dtype=torch.long)
            targets = batch['label'].to(model.device, dtype=torch.long)

            outputs = model(input_ids=ids, attention_mask=mask, labels=targets)  # (B, Seq_Len, 2)

            loss, pred, logits = config['preds_from_logits_func'](outputs)

            val_preds += pred.detach().tolist()
            val_labels += targets.cpu().tolist()

            ### Distance Based Classification
            # out.concept_spaces (n, B, seq_len, n_latent)
            if hasattr(outputs, 'concept_spaces'):
                preds = concept_space_to_preds(outputs.concept_spaces)

                # multi-label classification
                if len(targets.shape) > 1:
                    # turn preds into one-hot
                    preds = F.one_hot(torch.tensor(preds), len(outputs.concept_spaces))
                cs_val_preds += preds
            ### END

            val_loss += loss.item()
    return {
        'loss': val_loss,
        'preds': val_preds,
        'labels': val_labels,
        'cs_preds': cs_val_preds,
    }


In [13]:
config = {
    'experiment_name': 'no_rationales',
    'log_terminal': True,

    'dataset_name': DATASET_NAME,
    'model_name': MODEL_NAME,

    'num_labels': NUM_LABELS,
    'num_epochs': NUM_EPOCHS,
    'iterations': 1,

    'max_seq_len': MAX_SEQ_LEN,
    'batch_size': BATCH_SIZE,
    'lr': LEARNING_RATE,
    'fp16': False,
    'max_grad_norm': MAX_GRAD_NORM,
    'weight_decay': 0.01,
    'num_warmup_steps': 0,
    'gradient_accumulation_steps': 1,

    'threshold': 0.1,

    # funcs:
    'preds_from_logits_func': get_preds_from_logits,
    'model_outputs_callback': bert_outputs_callback,
}

In [14]:
base_name = f'hatexplain-{MODEL_NAME.replace("/", "_")}-{NUM_EPOCHS}'
base_name

'hatexplain-bert-base-cased-2'

In [15]:
class BERTOutputs:
    def __init__(self, loss, logits, hidden_states):
        self.loss = loss
        self.logits = logits
        self.hidden_states = hidden_states


class BERTwithRationalLossForSequenceClassification(torch.nn.Module):
    def __init__(self, model, rational_weight=0.5, ce_weight=1.0):
        super(BERTwithRationalLossForSequenceClassification, self).__init__()
        self.model = model

        self.ce_weight = ce_weight
        self.rational_weight = rational_weight

        self.classifier = torch.nn.Linear(model.config.hidden_size, model.config.num_labels)

    def to(self, device):
        super().to(device)
        self.device = device
        self.model.to(device)
        return self

    def forward(self, input_ids, attention_mask, labels):
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)

        hidden_states = outputs.hidden_states
        last_hidden_state = hidden_states[-1]  # (B, seq_len, n_embed)

        logits = self.classifier(last_hidden_state)[:, 0, :]

        ce_loss = F.cross_entropy(logits, labels)

        loss = self.ce_weight * ce_loss

        return BERTOutputs(loss, logits, hidden_states)

In [16]:
state_dict = torch.load(f'models/{config["experiment_name"]}/{base_name}.bin')
body_bert = AutoModel.from_pretrained(MODEL_NAME, num_labels=NUM_LABELS, output_hidden_states=True)
raw_model = BERTwithRationalLossForSequenceClassification(body_bert)
raw_model.load_state_dict(state_dict)
raw_model

BERTwithRationalLossForSequenceClassification(
  (model): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((

In [17]:
raw_model.to(device)

BERTwithRationalLossForSequenceClassification(
  (model): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((

In [16]:
log = get_logger(f'logs/{config["experiment_name"]}', base_name)

In [17]:
tokenized_dataset = dataset.map(
    lambda x: tokenizer(x['tweet'], truncation=True, padding='max_length', max_length=config['max_seq_len'], return_tensors='pt'),
    batched=True
).rename_columns({'class': 'label'})
tokenized_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])
tokenized_dataset

DatasetDict({
    train: Dataset({
        features: ['count', 'hate_speech_count', 'offensive_language_count', 'neither_count', 'label', 'tweet', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 24783
    })
})

In [18]:
test_dataloader = torch.utils.data.DataLoader(tokenized_dataset['train'], batch_size=config['batch_size'])

In [19]:
@log_continue
def eval_results(log, results_path, model, val_dataloader, config):
    # base_model.load_state_dict(torch.load(f'models/{config["dataset_name"]}_{config["model_name"]}_{config["num_epochs"] * config["iterations"]}.bin'))
    # base_model.to(device)

    eval_results = eval_epoch(model, val_dataloader, config)

    val_loss = eval_results['loss']
    val_preds = eval_results['preds']
    cs_val_preds = eval_results['cs_preds']
    val_labels = eval_results['labels']

    if len(cs_val_preds) != 0:
        cs_val_acc = accuracy_score(val_labels, cs_val_preds)
        cs_val_f1 = f1_score(val_labels, cs_val_preds, average='macro')

    val_acc = accuracy_score(val_labels, val_preds)

    val_f1 = f1_score(val_labels, val_preds, average='macro')

    val_precision = precision_score(val_labels, val_preds, average='macro')

    val_recall = recall_score(val_labels, val_preds, average='macro')

    log.info(f'Val loss: {val_loss / len(val_dataloader)}')
    log.info(f'Val acc: {val_acc}')

    if len(cs_val_preds) != 0:
        log.info(f'CS Val acc: {cs_val_acc}')
    log.info(f'Val f1: {val_f1}')

    if len(cs_val_preds) != 0:
        log.info(f'CS Val f1: {cs_val_f1}')

    log.info(f'Val precision: {val_precision}')
    log.info(f'Val recall: {val_recall}')

    if not os.path.exists(f'results/{config["experiment_name"]}'):
        os.makedirs(f'results/{config["experiment_name"]}', exist_ok=True)

    with open(f'results/{config["experiment_name"]}/{results_path}_eval.txt', 'w') as f:
        f.writelines(
            [
                f'Val loss: {val_loss / len(val_dataloader)}\n',
                f'Val acc: {val_acc}\n',
                f'CS Val acc: {cs_val_acc}\n' if len(cs_val_preds) != 0 else 'CS Val acc: N/A\n',
                f'Val f1: {val_f1}\n',
                f'CS Val acc: {cs_val_f1}\n' if len(cs_val_preds) != 0 else 'CS Val f1: N/A\n',
                f'Val precision: {val_precision}\n',
                f'Val recall: {val_recall}\n'
            ]
        )

In [22]:
eval_results(log, base_name, raw_model, test_dataloader, config)

100%|██████████| 1549/1549 [03:47<00:00,  6.81it/s]
  _warn_prf(average, modifier, msg_start, len(result))
[36m2024-04-12 00:13:13,212 - default.terminal - INFO - Val loss: 6.449129791967595[0m[0m
[36m2024-04-12 00:13:13,213 - default.terminal - INFO - Val acc: 0.17822701045071218[0m[0m
[36m2024-04-12 00:13:13,214 - default.terminal - INFO - Val f1: 0.19654676442690336[0m[0m
[36m2024-04-12 00:13:13,214 - default.terminal - INFO - Val precision: 0.17540648957803487[0m[0m
[36m2024-04-12 00:13:13,215 - default.terminal - INFO - Val recall: 0.4102922459876579[0m[0m


# Space Model

In [20]:
class SpaceModelWithRationalLossForSequenceClassification(torch.nn.Module):
    def __init__(
            self,
            base_model,
            n_embed=3,
            n_latent=3,
            n_concept_spaces=2,
            l1=1e-3,
            l2=1e-4,
            ce_w=1.0,
            rational_weight=0.5,
            fine_tune=True
    ):
        super().__init__()

        if fine_tune:
            for p in base_model.parameters():
                p.requires_grad_(False)

        self.device = base_model.device

        self.base_model = base_model

        self.space_model = SpaceModel(n_embed, n_latent, n_concept_spaces, output_concept_spaces=True)

        self.classifier = torch.nn.Linear(n_concept_spaces * n_latent, n_concept_spaces)

        self.l1 = l1
        self.l2 = l2
        self.ce_w = ce_w
        self.rational_weight = rational_weight

    def to(self, device):
        self.device = device
        super().to(device)
        return self

    def forward(self, input_ids, attention_mask, labels=None, hard_rationales=None):
        embed = self.base_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
        ).last_hidden_state  # (B, max_seq_len, 768)

        out = self.space_model(embed)  # SpaceModelOutput(logits=(B, n_concept_spaces * n_latent), ...)

        concept_hidden = out.logits

        logits = self.classifier(concept_hidden)

        loss = 0.0
        if labels is not None:
            loss = self.ce_w * F.cross_entropy(logits, labels)
            loss += self.l1 * losses.inter_space_loss(out.concept_spaces, labels) + \
                    self.l2 * losses.intra_space_loss(out.concept_spaces)

        return SpaceModelForSequenceClassificationOutput(loss, logits, out.concept_spaces, out.raw_concept_spaces)

In [21]:
def space_outputs_callback(outputs):
    loss, preds, logits = get_preds_from_logits(outputs)
    raw_concept_spaces = torch.stack(
        [x.mean(2) for x in outputs.raw_concept_spaces]
    ).permute(1, 0, 2)  # (B, n_concept_spaces, max_seq_len)
    return F.sigmoid(raw_concept_spaces[torch.arange(raw_concept_spaces.size(0)), preds])

In [22]:
config = {
    'experiment_name': 'no_rationales',
    'log_terminal': True,

    'dataset_name': DATASET_NAME,
    'model_name': MODEL_NAME,

    'num_labels': NUM_LABELS,
    'num_epochs': NUM_EPOCHS,
    'iterations': 1,

    'max_seq_len': MAX_SEQ_LEN,
    'batch_size': BATCH_SIZE,
    'lr': LEARNING_RATE,
    'fp16': False,
    'max_grad_norm': MAX_GRAD_NORM,
    'weight_decay': 0.01,
    'num_warmup_steps': 0,
    'gradient_accumulation_steps': 1,

    'l1': 0.1,
    'l2': 1e-5,
    'ce_w': 1.0,

    'threshold': 0.1,

    # funcs:
    'preds_from_logits_func': get_preds_from_logits,
    'model_outputs_callback': space_outputs_callback,
}

In [23]:
space_name = f'hatexplain_space-{MODEL_NAME.replace("/", "_")}-({N_LATENT})_{NUM_EPOCHS}'

In [24]:
base_model = AutoModel.from_pretrained(MODEL_NAME)
base_model

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(28996, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
  

In [25]:
state_dict = torch.load(f'models/{config["experiment_name"]}/{space_name}.bin')
space_model = SpaceModelWithRationalLossForSequenceClassification(
    base_model,
    n_embed=768,
    n_latent=N_LATENT,
    n_concept_spaces=NUM_LABELS,
    l1=config['l1'],
    l2=config['l2'],
    ce_w=config['ce_w'],
    rational_weight=0.5,
    fine_tune=False
)
space_model.load_state_dict(state_dict)
space_model

SpaceModelWithRationalLossForSequenceClassification(
  (base_model): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): 

In [26]:
space_model.to(device)

SpaceModelWithRationalLossForSequenceClassification(
  (base_model): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): 

In [27]:
eval_results(log, space_name, space_model, test_dataloader, config)

100%|██████████| 1549/1549 [03:49<00:00,  6.75it/s]
  _warn_prf(average, modifier, msg_start, len(result))
[36m2024-04-12 00:34:03,316 - default.terminal - INFO - Val loss: 4.442464914175293[0m[0m
[36m2024-04-12 00:34:03,317 - default.terminal - INFO - Val acc: 0.17620949844651576[0m[0m
[36m2024-04-12 00:34:03,318 - default.terminal - INFO - CS Val acc: 0.1764112496469354[0m[0m
[36m2024-04-12 00:34:03,318 - default.terminal - INFO - Val f1: 0.1951944485029856[0m[0m
[36m2024-04-12 00:34:03,319 - default.terminal - INFO - CS Val f1: 0.19002573598072306[0m[0m
[36m2024-04-12 00:34:03,320 - default.terminal - INFO - Val precision: 0.15714653516595725[0m[0m
[36m2024-04-12 00:34:03,321 - default.terminal - INFO - Val recall: 0.414399300755294[0m[0m
