# Otterly Obsessed with Semantics!

In [None]:
import random
import json
import shutil
import pandas as pd
import numpy as np
import os
import tqdm
import torch
import torch.nn as nn
from transformers import set_seed
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer, DataCollatorWithPadding, BitsAndBytesConfig
from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model
from datasets import DatasetDict, Dataset
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.metrics import f1_score, accuracy_score
from src.custom_bert_model import TheOtterBertModel
from src.path_utils import get_project_root
from src.classes.label_hierarchy import LabelHierarchy
from src.classes.semeval_dataset.semeval_otter_set import SemevalOtterSet
from src.classes.run_config import RunConfig
from src.classes.semeval_dataset.semeval_dataset import SemevalDataset

# Setup

In [None]:
def load_run_config_from_env() -> RunConfig:
  limit = os.getenv('limit', None)
  if limit == 'None':
    limit = None
  if limit is not None:
    limit = int(limit)

  return RunConfig(
      dataset_style=os.getenv('dataset_style', 'all_lower'),
      model_name=os.getenv('model_name', 'bert-base-cased'),
      use_custom_head=os.getenv('use_custom_head', 'True') == 'True',
      freeze_base_model=os.getenv('freeze_base_model', 'False') == 'True',
      use_hierarchy=os.getenv('use_hierarchy', 'True') == 'True',
      extra_layers=os.getenv('extra_layers', 'False') == 'True',
      weight_loss=os.getenv('weight_loss', 'False') == 'True',
      epochs=int(os.getenv('epochs', 10)),
      lr=float(os.getenv('lr', 5e-5)),
      batch_size=int(os.getenv('batch_size', 32)),
      acc_steps=int(os.getenv('acc_steps', 4)),
      seed=int(os.getenv('seed', 42)), 
      limit=limit,
  )

cfg = load_run_config_from_env()
print(f'Config: {cfg}')

In [None]:
MODEL_MAX_LENGTH = 512
IS_LLAMA = 'llama' in cfg.model_name
print(f'Is llama model: {IS_LLAMA}')

In [None]:
random.seed(cfg.seed)
np.random.seed(cfg.seed)
torch.manual_seed(cfg.seed)
torch.cuda.manual_seed(cfg.seed)
torch.cuda.manual_seed_all(cfg.seed)
set_seed(cfg.seed)

# Loading Data

In [None]:
if cfg.use_hierarchy:
  semeval_train = SemevalOtterSet('train', cfg.dataset_style)
  semeval_val = SemevalOtterSet('validation', cfg.dataset_style)
  semeval_dev = SemevalOtterSet('dev', cfg.dataset_style)
else:
  semeval_train = SemevalDataset('train', cfg.dataset_style)
  semeval_val = SemevalDataset('validation', cfg.dataset_style)
  semeval_dev = SemevalDataset('dev', cfg.dataset_style)
assert semeval_train.alphabet.labels() == semeval_val.alphabet.labels()
assert semeval_dev.alphabet.labels() == semeval_train.alphabet.labels()

labels = semeval_train.alphabet.labels()
print(f'Labels: {", ".join(labels)}')

In [None]:
tokenizer = AutoTokenizer.from_pretrained(cfg.model_name, trust_remote_code=True)

if IS_LLAMA:
    # LLAMA Tokenizer
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "right"

In [None]:
# Convert data to list of tuples, labels binarized
mlb = MultiLabelBinarizer(classes=labels)
tuples_train = [(ds.text, mlb.fit_transform([ds.labels]), ds.meme_id) for ds in semeval_train.samples]
tuples_val = [(ds.text, mlb.fit_transform([ds.labels]), ds.meme_id) for ds in semeval_val.samples]
tuples_dev = [(ds.text, mlb.fit_transform([ds.labels]), ds.meme_id) for ds in semeval_dev.samples]

random.shuffle(tuples_train)
random.shuffle(tuples_val)
random.shuffle(tuples_dev)

tuples_train = tuples_train[:cfg.limit]
tuples_val = tuples_val[:cfg.limit]
tuples_dev = tuples_dev[:cfg.limit]

print(f'Train tuples: {len(tuples_train)}')
print(f'Valid tuples: {len(tuples_val)}')
print(f'Dev tuples: {len(tuples_dev)}')

In [None]:
# Create dataframe
df_train = pd.DataFrame([(tup[0], tup[2]) for tup in tuples_train], columns=['text', 'id'])
df_val = pd.DataFrame([(tup[0], tup[2]) for tup in tuples_val], columns=['text', 'id'])
df_dev = pd.DataFrame([(tup[0], tup[2]) for tup in tuples_dev], columns=['text', 'id'])

# Add labels one by one
for i in range(len(labels)):
  df_train[semeval_train.alphabet.id2lbl[i]] = [tup[1][0][i] for tup in tuples_train]
  df_val[semeval_train.alphabet.id2lbl[i]] = [tup[1][0][i] for tup in tuples_val]
  df_dev[semeval_train.alphabet.id2lbl[i]] = [tup[1][0][i] for tup in tuples_dev]

In [None]:
dataset_dict = DatasetDict()
dataset_dict['train'] = Dataset.from_pandas(df_train).map(lambda x: {"labels": [float(x[c]) for c in labels]})
dataset_dict['valid'] = Dataset.from_pandas(df_val).map(lambda x: {"labels": [float(x[c]) for c in labels]})
dataset_dict['dev'] = Dataset.from_pandas(df_dev).map(lambda x: {"labels": [float(x[c]) for c in labels]})

In [None]:
# Ensure we don't loose data
max_len = max([max([len(tokenizer(sample['text'])['input_ids']) for sample in dataset_dict[split]]) for split in ['train', 'valid', 'dev']])

In [None]:
# Tokenize Dataset
def preprocess_samples(samples):
  if IS_LLAMA:
    return tokenizer(samples['text'], truncation=True, return_token_type_ids=False, padding='max_length', max_length=MODEL_MAX_LENGTH)
  return tokenizer(samples['text'], truncation=True)

In [None]:
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
tokenized_dataset = dataset_dict.map(preprocess_samples, batched=True)
print(tokenized_dataset['train'])
print(tokenized_dataset['train'][:5]['text'])

In [None]:
lh = LabelHierarchy()
lbl_parents = labels[:8]
print(f'Parents: {lbl_parents}')

# Loading Model

In [None]:
if IS_LLAMA:
    # Create LORA Config
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16
    )

    peft_config = LoraConfig(
            lora_alpha=16,
            lora_dropout=0.1,
            r=64,
            bias="none",
            task_type="SEQ_CLS",
    )
    print('Created LORA Config')

In [None]:
if IS_LLAMA:
    text_classifier = AutoModelForSequenceClassification.from_pretrained(
        cfg.model_name,
        use_cache=False,
        quantization_config=bnb_config,
        problem_type='multi_label_classification',
        device_map="auto",
        num_labels=len(labels),
        id2label=semeval_train.alphabet.id2lbl,
        label2id=semeval_train.alphabet.lbl2id
    )
    text_classifier = prepare_model_for_kbit_training(text_classifier)
    text_classifier = get_peft_model(text_classifier, peft_config)

else:
    text_classifier = AutoModelForSequenceClassification.from_pretrained(
        cfg.model_name,
        problem_type='multi_label_classification',
        num_labels=len(labels),
        id2label=semeval_train.alphabet.id2lbl,
        label2id=semeval_train.alphabet.lbl2id,
    )

In [None]:
if cfg.use_custom_head:
  assert cfg.use_hierarchy, f'Can only use the custom classification head with the label hierarchy!'
  print(f'Using custom classification head!')
  classification_head = TheOtterBertModel(text_classifier.config.hidden_size, [3, 3, 2, 20], extra_layers=cfg.extra_layers)
  text_classifier.classifier = classification_head

# Training!

In [None]:
def compute_metrics(valid_predictions, thresholds: np.ndarray = np.array([0.2] * len(labels)), from_logits: bool = True):
  predictions, gt_labels = valid_predictions

  assert thresholds.shape == (len(labels), )

  if from_logits:
    # Apply softmax
    pred_sig = 1 / (1 + np.exp(-predictions))
    # Apply threshold
    predictions_binary = (pred_sig > thresholds).astype(float)
    
  else:
    predictions_binary = predictions

  tp = tn = fp = fn = 0

  # Iterate over all pairs, get parents and calculate tp, tn, fp and fn
  for pred_bin, gold_bin in zip(predictions_binary, gt_labels):

    # Convert labels to string
    gold = [semeval_train.alphabet.id2lbl[idx] for idx in range(len(gold_bin)) if gold_bin[idx]]
    pred = [semeval_train.alphabet.id2lbl[idx] for idx in range(len(pred_bin)) if pred_bin[idx]]

    # Get Parents of labels
    pred_parents = list(set(sum([lh.get_parent_labels_flat(lh.get_node_by_label(pred)) for pred in pred], [])))
    gt_parents = list(set(sum([lh.get_parent_labels_flat(lh.get_node_by_label(gold)) for gold in gold], [])))

    tp += len([lbl for lbl in pred_parents if lbl in gt_parents])
    tn += len([lbl for lbl in labels if lbl not in pred_parents and lbl not in gt_parents])
    fp += len([lbl for lbl in pred_parents if lbl not in gt_parents])
    fn += len([lbl for lbl in gt_parents if lbl not in pred_parents])

  hp = (tp / (tp + fp)) if (tp + fp) > 0 else 0
  hr = (tp / (tp + fn)) if (tp + fp) > 0 else 0
  hf = (2 * (hp * hr) / (hp + hr)) if (hp + hr) > 0 else 0

  return {'hp': hp, 'hr': hr, 'hf': hf}

In [None]:
if cfg.weight_loss:
    # Calculate which labels occur how often in order to weight the loss
    label_counts = {lbl: sum(tokenized_dataset['train'][lbl]) for lbl in labels}
    print(f'Label counts: {label_counts}')
    total_label_count = sum(label_counts.values())
    print(f'Total positive examples: {total_label_count}')
    label_weights = {lbl: (total_label_count / (len(labels) * label_counts[lbl])) for lbl in labels}
    print(f'Label weights: {label_weights}')
    weight_tensor = torch.Tensor(list(label_weights.values()))

    # Setup custom Trainer
    loss_fct = nn.BCEWithLogitsLoss(weight=weight_tensor).to('cuda')

    class CustomTrainer(Trainer):

        def compute_loss(self, model, inputs, return_outputs=False):
            labels = inputs.get("labels")
            # forward pass
            outputs = model(**inputs)

            logits = outputs.get('logits')

            if IS_LLAMA:
                # Move the loss fct to the same device that the logits are on
                loss_fct.to(logits.get_device())

            # compute custom loss
            loss = loss_fct(logits, labels)
            return (loss, outputs) if return_outputs else loss

In [None]:
result_dir = os.path.join(get_project_root(), 'data', 'model_results', cfg.identifier())

training_args = TrainingArguments(
  output_dir=result_dir,
  learning_rate=cfg.lr,
  per_device_train_batch_size=cfg.batch_size,
  per_device_eval_batch_size=cfg.batch_size,
  num_train_epochs=cfg.epochs,
  weight_decay=0.01,
  evaluation_strategy="epoch",
  save_strategy="epoch",
  load_best_model_at_end=True,
  metric_for_best_model='hf',
  save_total_limit=2,
  gradient_accumulation_steps=cfg.acc_steps,
  report_to="none"
)

trainer_class = CustomTrainer if cfg.weight_loss else Trainer

trainer = trainer_class(
  model=text_classifier,
  args=training_args,
  train_dataset=tokenized_dataset["train"],
  eval_dataset=tokenized_dataset["valid"],
  tokenizer=tokenizer,
  data_collator=data_collator,
  compute_metrics=compute_metrics,
)

In [None]:
trainer.train()

# Evaluation

In [None]:
# Obtain the results of the *best* model on the valid set
trainer_valid_results = trainer.evaluate(tokenized_dataset['valid'], metric_key_prefix='valid')
print(f'Valid: {trainer_valid_results}')

In [None]:
def logits_from_text(text: str):
    if IS_LLAMA:
      model_input =  tokenizer(text, truncation=True, return_token_type_ids=False, padding='max_length', max_length=MODEL_MAX_LENGTH, return_tensors='pt')
    else:
       model_input = tokenizer(text, truncation=True, return_tensors='pt')
    with torch.no_grad():
        
      # USE mps on Apple silicon
      model_input.to('cuda:0' if IS_LLAMA and torch.cuda.is_available() else 'cuda')
      logits = text_classifier(**model_input).logits[0].to('cpu')
    return logits

In [None]:
# Dump results of the model to a file, get logits for train / valid data
if os.path.isdir(result_dir):
  shutil.rmtree(result_dir)
os.makedirs(result_dir)
print(f'Saving model results to path: {result_dir}')

# Dump results on valid / test set to file
with open(os.path.join(result_dir, 'valid_results_default_th.json'), 'w') as f:
  json.dump(trainer_valid_results, f, indent=4)

split_logits = {}
split_labels = {}

# Classify each example, dump result to file
for split in ['dev', 'valid']:

  curr_logits = []
  curr_labels = []

  # Get logits for each sample, store labels
  for sample in tqdm.tqdm(tokenized_dataset[split], f'Getting logits in split {split}'):
    logits = logits_from_text(sample['text'])
    curr_logits.append(logits)
    curr_labels.append(sample['labels'])

  split_logits[split] = np.array(curr_logits)
  split_labels[split] = np.array(curr_labels)

  np.save(os.path.join(result_dir, f'{split}_logits'), split_logits[split])
  np.save(os.path.join(result_dir, f'{split}_labels'), split_labels[split])

print(f'Finished dumping results to file!')

## Thresholds

In [None]:
# Make logits easier to access
valid_logits = split_logits['valid']
valid_labels = split_labels['valid']

# Using only the logits of the leaves, the logits of the parents will be 0 -> Parents predicted according to hierarchy
valid_logits_leaves = np.copy(split_logits['valid'])
valid_logits_leaves[:, :8] = 0

dev_logits = split_logits['dev']
dev_labels = split_labels['dev']
dev_logits_leaves = np.copy(split_logits['dev'])
dev_logits_leaves[:, :8] = 0

In [None]:
# Determine the best threshold
possible_thresholds = [0.001, 0.01, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.5, 0.7]

th_scores = {th: compute_metrics((valid_logits, valid_labels), np.array([th] * len(labels))) for th in possible_thresholds}
best_th_sorted = sorted(th_scores, key=lambda th: th_scores[th]['hf'], reverse=True)
best_th = best_th_sorted[0]
best_th_arr = np.array([best_th] * len(labels))
print(f'Thresholds:\n\t' + '\n\t'.join([f'{th:.3f} - {th_scores[th]["hf"]:.4f} (hf)' for th in best_th_sorted]))

print(f'Best threshold: ({best_th}) {th_scores[best_th]}')

# Ensure the best threshold is actually better than the one we guessed
assert th_scores[best_th_sorted[0]]['hf'] >= trainer_valid_results['valid_hf']

In [None]:
# Find the best th for each class individually, one for accuracy metric and one for f1 metric
individual_ths_acc = {}
individual_ths_f1 = {}

for idx, label in enumerate(labels):

    lbl_logits = valid_logits[:, idx]
    lbl_logits = 1 / (1 + np.exp(-lbl_logits))
    lbl_gt = valid_labels[:, idx]

    best_th_acc = max(possible_thresholds, key=lambda th: accuracy_score((lbl_logits > th).astype(float), lbl_gt))
    best_th_f1 = max(possible_thresholds, key=lambda th: f1_score((lbl_logits > th).astype(float), lbl_gt))

    individual_ths_acc[label] = best_th_acc
    individual_ths_f1[label] = best_th_f1

print(f'Best ths acc: {individual_ths_acc}')
print(f'Best ths f1: {individual_ths_f1}')

individual_th_arr_acc = np.array(list(individual_ths_acc.values()))
individual_th_arr_f1 = np.array(list(individual_ths_f1.values()))

In [None]:
# Sanity check: We should get the same results as the trainer
res_default_th = compute_metrics((valid_logits, valid_labels))
assert abs(res_default_th['hf'] - trainer_valid_results['valid_hf']) < 0.01

res_single_th = compute_metrics((valid_logits, valid_labels), thresholds=best_th_arr)
res_multi_th_acc = compute_metrics((valid_logits_leaves, valid_labels), thresholds=individual_th_arr_acc)
res_multi_th_f1 = compute_metrics((valid_logits_leaves, valid_labels), thresholds=individual_th_arr_f1)

print(f'Original th: {trainer_valid_results["valid_hf"]:.4f} (hf)')
print(f'Best single th: {res_single_th["hf"]:.4f} (hf)')
print(f'Multi th (acc): {res_multi_th_acc["hf"]:.4f} (hf)')
print(f'Multi th (f1): {res_multi_th_f1["hf"]:.4f} (hf)')

best_valid_result = max([res_single_th, res_multi_th_acc, res_multi_th_f1], key=lambda e: e['hf'])


## Dev

In [None]:
# Pick the best th / logits for DEV data
dev_logits_final = dev_logits

one_th_for_all = False

if res_single_th['hf'] == best_valid_result['hf']:
    dev_th = best_th_arr
    one_th_for_all = True
    print(f'Using single th for dev set')
elif res_multi_th_acc['hf'] == best_valid_result['hf']:
    dev_th = individual_th_arr_acc
    dev_logits_final = dev_logits_leaves
    print(f'Using acc th for dev set')
else:
    dev_th = individual_th_arr_f1
    dev_logits_final = dev_logits_leaves
    print(f'Using f1 th for dev set')

# Compute score on valid data with best th
dev_res = compute_metrics((dev_logits_final, dev_labels), thresholds=dev_th)
print(f'Dev results (with best th): {dev_res}')
