In [None]:
!nvidia-smi

In [None]:
from IPython.display import display
from typing import Optional, Tuple, Union, Dict
import os
import functools
from pathlib import Path
import pandas as pd
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model
from transformers import (
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments,
    Trainer,
    AutoModelForSequenceClassification,
    PreTrainedModel,
    AutoConfig
)
import datasets

from metrics import compute_metrics

In [None]:
# model name
# model_name = 'Linq-AI-Research/Linq-Embed-Mistral'
# model_name = 'dunzhang/stella_en_1.5B_v5'
# model_name = str(models_dir / "stella_en_400M_v5") #'dunzhang/stella_en_400M_v5'
# model_name = 'HIT-TMG/KaLM-embedding-multilingual-mini-instruct-v1'
# model_name = 'textdetox/xlmr-large-toxicity-classifier'
# model_name = 'JungleLee/bert-toxic-comment-classification'
model_name = "deberta-v3-large"
# model_name = 'Alibaba-NLP/gte-base-en-v1.5'
# model_name = 'avsolatorio/GIST-small-Embedding-v0'

model_appendix = "snapshots/64a8c8eab3e352a784c658aef62be1662607476f"

In [None]:
version = model_name
device = "cuda:0" if torch.cuda.is_available() else "cpu"
models_dir = Path("../hub")
datasets_dir = Path("../datasets")
results_dir = Path(f"../results/multimodel_finetuning") / version
train_log_dir = Path(f"../training_logs/multimodel_finetuning") / version
save_dir = models_dir / "model_finetuning" / version
model_dirname = str(models_dir / model_name / model_appendix)

In [None]:
def tokenize_examples(examples, tokenizer, classes):
    text = f"Issue: {examples['issue'][:-1]}.\nAnswer: {examples['post_text']}"
    labels = [examples[label] for label in classes]
    tokenized_inputs = tokenizer(text, truncation=True, max_length=700, padding=True)
    tokenized_inputs['labels'] = labels
    return tokenized_inputs


# define custom batch preprocessor
def collate_fn(batch, pad_token_id):
    dict_keys = ['input_ids', 'attention_mask', 'labels']
    d = {k: [dic[k] for dic in batch] for k in dict_keys}
    d['input_ids'] = torch.nn.utils.rnn.pad_sequence(
        d['input_ids'], batch_first=True, padding_value=pad_token_id
    )
    d['attention_mask'] = torch.nn.utils.rnn.pad_sequence(
        d['attention_mask'], batch_first=True, padding_value=0
    )
    d['labels'] = torch.stack(d['labels']).type(torch.float)
    return d


# create custom trainer class to be able to pass label weights and calculate mutilabel loss
class CustomTrainer(Trainer):

    def __init__(self, label_weights, **kwargs):
        super().__init__(**kwargs)
        self.label_weights = label_weights
    
    def compute_loss(self, model, inputs, num_items_in_batch=1000, return_outputs=False):
        labels = inputs.pop("labels")
        
        # forward pass
        outputs = model(**inputs)
        logits = outputs.get("logits")
        
        # compute custom loss
        loss = F.binary_cross_entropy_with_logits(logits, labels.to(torch.float32), pos_weight=self.label_weights)
        return (loss, outputs) if return_outputs else loss


In [None]:
ds = datasets.load_dataset(str(datasets_dir / 'appropriateness-corpus'))

classes = [
    'Excessive Intensity',
    'Emotional Deception',
    'Missing Seriousness',
    'Missing Openness',
    'Unclear Meaning',
    'Missing Relevance',
    'Confusing Reasoning',
    'Detrimental Orthography',
    'Reason Unclassified'
]

all_classes = [ # class layer
    'Excessive Intensity',
    'Emotional Deception',
    'Missing Seriousness',
    'Missing Openness',
    'Unclear Meaning',
    'Missing Relevance',
    'Confusing Reasoning',
    'Detrimental Orthography',
    'Reason Unclassified',
    
    'Toxic Emotions', # 1
    'Missing Commitment', # 1
    'Missing Intelligibility', # 1
    'Other Reasons', # 1
    
    'Inappropriateness', # 0

    #'Excessive Intensity', # 2
    #'Emotional Deception', # 2

    #'Missing Seriousness', # 2
    #'Missing Openness', # 2

    #'Unclear Meaning', # 2
    #'Missing Relevance', # 2
    #'Confusing Reasoning', # 2

    #'Detrimental Orthography', # 2
    #'Reason Unclassified', # 2
]

id_mappings = {
    0: [0],
    1: [1],
    2: [2],
    3: [3],
    4: [4],
    5: [5],
    6: [6],
    7: [7],
    8: [8],
    9: [9],
    9: [0, 1],
    10: [2, 3],
    11: [4, 5, 6],
    12: [7, 8],
    13: [9, 10, 11, 12],
}

class2id = {class_:id for id, class_ in enumerate(all_classes)}
id2class = {id:class_ for class_, id in class2id.items()}

In [None]:
# tokenize (might not run on compute node)
tokenizer = AutoTokenizer.from_pretrained(model_dirname)
tokenizer.pad_token = tokenizer.eos_token = tokenizer.special_tokens_map['pad_token']
tokenized_ds = ds.map(functools.partial(tokenize_examples, tokenizer=tokenizer, classes=classes), batched=False)

In [None]:
# save tokenized dataset to disk
tokenized_ds.save_to_disk(str(datasets_dir / "tokenized-appropriateness-corpus" / model_name))
with open(datasets_dir / "tokenized-appropriateness-corpus" / model_name / "pad_token_id", "w") as f:
    f.write(str(tokenizer.pad_token_id))

In [None]:
# load tokenized dataset from disk
tokenized_ds = datasets.load_from_disk(str(datasets_dir / "tokenized-appropriateness-corpus" / model_name))
with open(datasets_dir / "tokenized-appropriateness-corpus" / model_name / "pad_token_id", "r") as f:
    pad_token_id = int(f.read())
print(tokenized_ds)
print(f"{pad_token_id=}")
tokenized_ds = tokenized_ds.with_format('torch')
labels = tokenized_ds['train']['labels']
label_weights = torch.ones(len(classes))
print(label_weights)

In [None]:
tokenized_ds = tokenized_ds.shuffle()

In [None]:
tokenized_ds['train']['labels'].shape

In [None]:
class MultiBinaryClassifier(nn.Module):
    def __init__(
        self,
        model_name: str,
        num_classes: int = 9,
        quantization_config=None,
        lora_config=None,
    ):
        super().__init__()
        self.num_classes = num_classes
        self.config = AutoConfig.from_pretrained(model_name, local_files_only=True)

        # Initialize binary classifiers
        self.models = [
            AutoModelForSequenceClassification.from_pretrained(
                model_name,
                num_labels=1,
                trust_remote_code=True,
                quantization_config=quantization_config,
                device_map=(
                    f"cuda:{i % torch.cuda.device_count()}"
                    if torch.cuda.is_available()
                    else "cpu"
                ),
                local_files_only=True,
            )
            for i in range(num_classes)
        ]
        self.models = [prepare_model_for_kbit_training(model) for model in self.models]
        self.models = [get_peft_model(model, lora_config) for model in self.models]
        self.binary_classifiers = nn.ModuleList(self.models)

    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        return_dict: bool = True,
    ) -> Union[Tuple, Dict]:

        # Get logits from all binary classifiers
        all_logits = []
        for classifier in self.binary_classifiers:
            outputs = classifier(
                input_ids=input_ids, attention_mask=attention_mask, return_dict=True
            )
            all_logits.append(outputs.logits)

        # Stack logits to create a tensor of shape (batch_size, num_classes)
        logits = torch.cat(all_logits, dim=1)

        if labels is not None:
            # Convert multi-class labels to binary labels for each classifier
            loss = 0
            for i in range(self.num_classes):
                binary_labels = (labels == i).float()
                binary_loss = F.binary_cross_entropy_with_logits(
                    logits[:, i : i + 1], binary_labels.unsqueeze(1)
                )
                loss += binary_loss
            loss = loss / self.num_classes
        else:
            loss = None

        if not return_dict:
            output = (logits,)
            return ((loss,) + output) if loss is not None else output

        return {
            "loss": loss,
            "logits": logits,
        }

    def get_predicted_class(self, logits: torch.Tensor) -> torch.Tensor:
        # Apply sigmoid and get the class with highest probability
        probabilities = torch.sigmoid(logits)
        return torch.argmax(probabilities, dim=1)

    def save_pretrained(self, save_directory: str):
        """Save the model to a directory"""
        os.makedirs(save_directory, exist_ok=True)
        # Save each binary classifier
        for i, classifier in enumerate(self.binary_classifiers):
            classifier_dir = os.path.join(save_directory, f"classifier_{i}")
            classifier.save_pretrained(classifier_dir)

        # Save config
        self.config.save_pretrained(save_directory)

    @classmethod
    def from_pretrained(cls, load_directory: str, **kwargs):
        """Load the model from a directory"""
        config = AutoConfig.from_pretrained(load_directory, local_files_only=True)
        num_classes = len(
            [d for d in os.listdir(load_directory) if d.startswith("classifier_")]
        )

        model = cls(model_name=load_directory, num_classes=num_classes)

        # Load each binary classifier
        for i in range(num_classes):
            classifier_dir = os.path.join(load_directory, f"classifier_{i}")
            model.binary_classifiers[i] = (
                AutoModelForSequenceClassification.from_pretrained(classifier_dir, local_files_only=True)
            )

        return model

In [None]:
# quantization config not available for cuda 12x and is installed cpu-only
"""
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,  # enable 4-bit quantization
    bnb_4bit_quant_type="nf4",  # information theoretically optimal dtype for normally distributed weights
    bnb_4bit_use_double_quant=True,  # quantize quantized weights //insert xzibit meme
    bnb_4bit_compute_dtype=torch.bfloat16,  # optimized fp format for ML
)
"""

# lora config
lora_config = LoraConfig(
    r=8,  # the dimension of the low-rank matrices
    lora_alpha=16,  # scaling factor for LoRA activations vs pre-trained weight activations
    # target_modules = ['q_proj', 'k_proj', 'v_proj', 'o_proj'],
    # target_modules=["query_proj", "value_proj"],
    lora_dropout=0.05,  # dropout probability of the LoRA layers
    bias="none",  # wether to train bias weights, set to 'none' for attention layers
    task_type="SEQ_CLS",
)

model = MultiBinaryClassifier(
    model_name=model_dirname,
    num_classes=len(classes),
    # quantization_config=quantization_config,
    lora_config=lora_config
)

In [None]:
# define training args
training_args = TrainingArguments(
    output_dir = train_log_dir,
    learning_rate = 1e-4,
    per_device_train_batch_size = 3, # tested with 16gb gpu ram
    per_device_eval_batch_size = 3,
    num_train_epochs = 10,
    # weight_decay = 0.01,
    eval_strategy = 'epoch',
    save_strategy = 'epoch',
    logging_strategy = "epoch",
    load_best_model_at_end = True,
)

trainer = CustomTrainer(
    model = model,
    args = training_args,
    train_dataset = tokenized_ds['train'],
    eval_dataset = tokenized_ds['validation'],
    data_collator = functools.partial(collate_fn, pad_token_id=pad_token_id),
    compute_metrics = functools.partial(compute_metrics, id2class=id2class, id_mappings=id_mappings), #TODO extend to all classes
    label_weights = label_weights.to(device)
)

In [None]:
trainer.train()

In [None]:
if not results_dir.exists():
    results_dir.mkdir(parents=True)
val_metrics = trainer.evaluate(tokenized_ds['validation'], metric_key_prefix="validation")
val_metrics_df = pd.DataFrame(val_metrics, index=[0])
display(val_metrics_df)
val_metrics_df.to_csv(results_dir / "validation.csv")

In [None]:
if not results_dir.exists():
    results_dir.mkdir(parents=True)
test_metrics = trainer.evaluate(tokenized_ds['test'], metric_key_prefix="test")
test_metrics_df = pd.DataFrame(test_metrics, index=[0])
display(test_metrics_df)
test_metrics_df.to_csv(results_dir / "test.csv")

In [None]:
# save model
trainer.model.save_pretrained(str(save_dir))
tokenizer.save_pretrained(str(save_dir))

In [None]:
# load model
model = MultiBinaryClassifier.from_pretrained(
    str(save_dir),
    device_map=device,
    #quantization_config=quantization_config,
    num_labels=len(classes),
    problem_type="multi_label_classification",
    trust_remote_code=True,
    local_files_only=True,
)
model.config.pad_token_id = tokenizer.pad_token_id