# About the Notebook

Knowledge distillation trials for multi-label classification. Using deberta-v3-base as the student model and deberta-v3-large as the teacher model.

# Setup

In [1]:
DEBUG = False
WANDB = False
ENVIRON = "jarvislabs"
NOTES = "knowledge distillation on deberta-v3-base with trained deberta-v3-large"

# Setup Environment

In [2]:
import pkgutil
from pathlib import Path

PROJECT = "DataSolve-2022"

if ENVIRON == "jarvislabs":
    ROOT_DIR = Path(f"/home/{PROJECT}")
    ARTIFACTS_DIR = Path("/home/artifacts")
    SETUP_SCRIPT_PATH = Path("/home/setup.sh")
elif ENVIRON == "lambdalabs":
    ROOT_DIR = Path(f"/ubuntu/home/{PROJECT}")
    ARTIFACTS_DIR = Path("/ubuntu/home/artifacts")
    SETUP_SCRIPT_PATH = Path("/home/setup.sh")
elif ENVIRON == "kaggle":
    ROOT_DIR = Path(f"/kaggle/working/{PROJECT}")
    ARTIFACTS_DIR = Path("/kaggle/working/artifacts")
    SETUP_SCRIPT_PATH = Path("/kaggle/input/datasolve-setup-script/setup.sh")
    
if not pkgutil.find_loader("omegaconf") and ENVIRON == "kaggle":
    !bash {SETUP_SCRIPT_PATH} {ENVIRON} "true"

In [3]:
# load secret keys
%load_ext dotenv
%dotenv {ROOT_DIR}/.env

In [4]:
# DOWNLOAD TEACHER MODEL
import os, wandb
experiment = "crimson-elevator-29"
if not os.path.exists(f"./{experiment}"):
    api = wandb.Api()
    artifact = api.artifact(f"gladiator/DataSolve-2022/{experiment}:v0", type="model")
    artifact_dir = artifact.download(f"./{experiment}")

# Configuration

In [5]:
import os, gc
gc.enable()
from omegaconf import OmegaConf

class Config:
    # GENERAL
    debug = DEBUG
    wandb = WANDB
    seed = 42
    train_csv = "train_processed.csv"
    
    # MODEL
    model = dict(
        model_name_or_path = "microsoft/deberta-v3-large",
        gradient_checkpointing = True,
        output_hidden_states = False,
        output_last_hidden_state = False,
        output_pooled_embeds = False
    
    )
    teacher_model_path = "crimson-elevator-29/pytorch_model.bin"

    # TRACKING
    tags = ["clspool", f"{model['model_name_or_path']}", "512", "tts_split"]
    notes = NOTES
    upload_artifacts_to_wandb = True
    
    # DATA
    data = dict(
        max_length = 512,
        truncation = True,
        pad_to_multiple_of = 8,
    )
    
    # TRAINING ARGUMENTS
    training_args = dict(
        # general
        seed = seed,
        evaluation_strategy = "epoch",
        save_strategy = "epoch",
        save_total_limit = 1,

        # train settings
        num_train_epochs = 8,
        lr_scheduler_type = "cosine",
        warmup_ratio = 0.2,
        per_device_train_batch_size = 16,
        per_device_eval_batch_size = 16,
        gradient_accumulation_steps = 1,
        learning_rate = 5e-5,
        weight_decay = 0.01,
        max_grad_norm = 1.0,
        
        # misc
        # eval_accumulation_steps=10,
        adam_epsilon = 1e-6,
        fp16 = True,
        dataloader_num_workers = min(6, os.cpu_count()),
        load_best_model_at_end = True,
        metric_for_best_model = "eval_f1",
        greater_is_better = True,
        group_by_length = True,
        length_column_name = "length",
        report_to = "wandb" if WANDB else "none",
        dataloader_pin_memory = True,
    )


# CONFIG SETTINGS
config_dict = {x:dict(Config.__dict__)[x] for x in dict(Config.__dict__) if not x.startswith('_')}
cfg = OmegaConf.create(config_dict)
if cfg.debug: cfg.tags += ["debug"]
if cfg.debug: cfg.training_args.num_train_epochs = 2
print(OmegaConf.to_yaml(cfg, resolve=True))

debug: false
wandb: false
seed: 42
train_csv: train_processed.csv
model:
  model_name_or_path: microsoft/deberta-v3-large
  gradient_checkpointing: true
  output_hidden_states: false
  output_last_hidden_state: false
  output_pooled_embeds: false
teacher_model_path: crimson-elevator-29/pytorch_model.bin
tags:
- clspool
- microsoft/deberta-v3-large
- '512'
- tts_split
notes: knowledge distillation on deberta-v3-base with trained deberta-v3-large
upload_artifacts_to_wandb: true
data:
  max_length: 512
  truncation: true
  pad_to_multiple_of: 8
training_args:
  seed: 42
  evaluation_strategy: epoch
  save_strategy: epoch
  save_total_limit: 1
  num_train_epochs: 8
  lr_scheduler_type: cosine
  warmup_ratio: 0.2
  per_device_train_batch_size: 16
  per_device_eval_batch_size: 16
  gradient_accumulation_steps: 1
  learning_rate: 5.0e-05
  weight_decay: 0.01
  max_grad_norm: 1.0
  adam_epsilon: 1.0e-06
  fp16: true
  dataloader_num_workers: 6
  load_best_model_at_end: true
  metric_for_best_m

# Imports

In [6]:
import copy
import glob
import shutil
import pickle
import warnings
import logging
import numpy as np
import pandas as pd
from pprint import pprint
from dataclasses import dataclass
from typing import Dict, List, Tuple, Callable, Optional

import wandb
from wandb import AlertLevel

from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit
from sklearn.metrics import f1_score, accuracy_score, roc_auc_score
import torch
import torch.nn as nn
import torch.nn.functional as F

import datasets, transformers
from datasets import Dataset
from transformers import (
    AutoTokenizer,
    AutoConfig,
    AutoModel,
    AutoModelForSequenceClassification,
    EvalPrediction,
    PreTrainedTokenizer,
    PretrainedConfig,
    PreTrainedModel,
    DataCollatorWithPadding,
    Trainer,
    TrainingArguments,
    set_seed,
)
from transformers.modeling_outputs import ModelOutput

# SYSTEM SETTINGS
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
os.environ["WANDB_SILENT"] = "false"
set_seed(cfg.seed)
if not cfg.debug:
    warnings.simplefilter("ignore")
    logging.disable(logging.WARNING)

# Helper Functions

In [7]:
def delete_checkpoints(dir):
    for file in glob.glob(f"{dir}/checkpoint-*"):
        shutil.rmtree(file, ignore_errors=True)


def clear_memory():
    gc.collect()
    torch.cuda.empty_cache()

def delete_file(path: str):
    if os.exists(path):
        os.remove(path)

def save_pickle(obj, filepath):
    with open(filepath, 'wb') as handle:
        pickle.dump(obj, handle, protocol=pickle.HIGHEST_PROTOCOL)
    
def process_hydra_config_for_wandb(cfg: OmegaConf):
    """
    Only keep relevant part of config for logging
    """
    tmp_cfg = copy.deepcopy(cfg)
    cfg_dict = OmegaConf.to_container(tmp_cfg, resolve=True, throw_on_missing=True)
    del cfg_dict["training_args"]
    return cfg_dict

# Init W&B run

In [8]:
if cfg.wandb:
    wandb.init(
        project="DataSolve-2022",
        tags=cfg.tags,
        notes=cfg.notes,
        config=process_hydra_config_for_wandb(cfg),
        save_code=True,
    )
    wandb.alert(
        title=f"Experiment {wandb.run.name}",
        text=f"ðŸš€ Starting experiment {wandb.run.name}, Description: {cfg.notes}",
        level=AlertLevel.INFO,
        wait_duration=0,
    )

EXP_NAME = wandb.run.name if cfg.wandb else "debug"

# Read and process data

In [9]:
# READ DATA
df = pd.read_csv(ROOT_DIR/'input'/cfg.train_csv)
if cfg.debug:
    df = df.sample(100, random_state=42).reset_index(drop=True)
LABEL_COLS = [col for col in df.columns if col not in ["id", "name", "document_text"]]
print(len(LABEL_COLS))
df.head()

50


Unnamed: 0,id,name,document_text,Accounting and Finance,Antitrust,Banking,Broker Dealer,Commodities Trading,Compliance Management,Consumer protection,...,Required Disclosures,Research,Risk Management,Securities Clearing,Securities Issuing,Securities Management,Securities Sales,Securities Settlement,Trade Pricing,Trade Settlement
0,4772,Consent Order in the Matter of Solium Financia...,"Solium Financial Services LLC (""SFS"") is a bro...",0,0,0,1,0,1,0,...,0,0,0,0,0,0,0,0,0,0
1,4774,Alberta Securities Commission Warns Investors ...,A new year brings new investment opportunities...,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
2,4775,Exempt Market Dealer Agrees to Settlement,The Alberta Securities Commission (ASC) has co...,0,0,0,1,0,1,0,...,0,0,0,0,0,0,1,1,0,1
3,4776,Canadian Securities Regulators Announces Consu...,The Canadian Securities Administrators (CSA) p...,0,0,0,0,0,0,0,...,1,0,0,0,0,0,0,1,0,0
4,4778,CSA Consultation Paper 51-405 Consideration of...,"On April 6, 2017, the Canadian Securities Admi...",0,0,0,0,0,0,1,...,1,0,0,0,0,0,0,1,0,0


In [10]:
X = df[[col for col in df.columns if col not in LABEL_COLS]]
y = df[LABEL_COLS]

msss = MultilabelStratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
for train_index, val_index in msss.split(X, y):
    print(len(train_index))
    print(len(val_index))
    val_df = df.loc[val_index]
    train_df = df.loc[train_index]

7885
1974


In [11]:
train_df.shape, val_df.shape

((7885, 53), (1974, 53))

In [12]:
tokenizer = AutoTokenizer.from_pretrained(cfg.model.model_name_or_path)

def tokenize_func(example: pd.Series, tokenizer: PreTrainedTokenizer, max_length: int = 512, truncation: bool = True, mode: str="train"):
    tokenized = tokenizer(
    example["text"],
    truncation=truncation,
    max_length=max_length,
    add_special_tokens=True,
)
    if mode == "train":
        tokenized["labels"] = [example[i] for i in LABEL_COLS]
    tokenized["length"] = len(tokenized["input_ids"])
    return tokenized

def preprocess_data(df_: pd.DataFrame, mode:str="train"):
    df_["text"] = tokenizer.cls_token + df_["name"] + tokenizer.sep_token + df_["document_text"] + tokenizer.sep_token
    tok_ds = Dataset.from_pandas(df_)
    tok_ds = tok_ds.map(lambda x: tokenize_func(x, tokenizer, max_length=cfg.data.max_length, truncation=cfg.data.truncation, mode=mode), num_proc=2)
    return tok_ds

train_ds = preprocess_data(train_df, mode="train")
val_ds = preprocess_data(val_df, mode="train")
print(len(train_ds), len(val_ds))

   

#0:   0%|          | 0/3943 [00:00<?, ?ex/s]

 

#1:   0%|          | 0/3942 [00:00<?, ?ex/s]

    

#0:   0%|          | 0/987 [00:00<?, ?ex/s]

#1:   0%|          | 0/987 [00:00<?, ?ex/s]

7885 1974


In [13]:
train_ds[0]

{'id': 4772,
 'name': 'Consent Order in the Matter of Solium Financial Services LLC',
 'document_text': 'Solium Financial Services LLC ("SFS") is a broker-dealer with a principal place of business at 50 Tice Boulevard, Suite A-18 Woodcliff Lake, New Jersey 07677, and is registered as a broker-dealer with the Alabama Securities Commission ("Commission"). During the period from at least January 2009 to June 6, 2019, SFS acted as broker-dealer in Alabama as the term broker-dealer is defined by Title 8, Chapter 6, 8-6-2 of the Act. Code of Alabama, 8-6-3(a) states that it is unlawful for a person to transact business in Alabama as a broker-dealer or agent unless such person is registered under the Act. By engaging in the conduct set forth above, SFS acted as an unregistered broker-dealer in Alabama in violation of 8-6-3(a) of the Act. This Order concludes the investigation by the Commission and any other action that the Commission could commence under applicable Alabama law as it relates t

# Metrics

In [14]:
def post_process_logits(logits: np.ndarray, threshold=0.5):
    # first, apply sigmoid on logits which are of shape (batch_size, num_labels)
    sigmoid = torch.nn.Sigmoid()
    probs = sigmoid(torch.Tensor(logits))
    # next, use threshold to turn them into integer predictions
    preds = np.zeros(probs.shape)
    preds[np.where(probs >= threshold)] = 1
    preds = preds.flatten().astype(int)
    return preds

def compute_metrics(p: EvalPrediction):
    # `predictions` might return last_hidden_state or pooled_embeds
    # In that case, take the first element (array) of the tuple for logits
    logits = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
    preds = post_process_logits(logits)
    labels = p.label_ids.flatten()
    f1_macro_average = f1_score(labels, preds, average='macro')
    roc_auc = roc_auc_score(labels, preds, average = 'macro')
    accuracy = accuracy_score(labels, preds)
    # return as dictionary
    return {
        'f1': f1_macro_average,
        'roc_auc': roc_auc,
        'accuracy': accuracy
    }

# Custom Trainer

In [15]:
class DistillationTrainingArguments(TrainingArguments):
    def __init__(self, *args, alpha=0.5, temperature=2.0, **kwargs):
        super().__init__(*args, **kwargs)
        self.alpha = alpha
        self.temperature = temperature

class DistillationTrainer(Trainer):
    def __init__(self, *args, teacher_model=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.teacher_model = teacher_model
    
    def compute_loss(self, model, inputs, return_outputs=False):
        outputs_stu = model(**inputs)
        # extract bce loss and logits from student
        loss_stu = outputs_stu.loss
        logits_stu = outputs_stu.logits
        # extract logits from teacher
        with torch.no_grad():
            outputs_tea = self.teacher_model(**inputs)
            loss_tea = outputs_tea.loss
            logits_tea = outputs_tea.logits
        # soften probabilites and compute distillation loss
        # kdl_loss = nn.KLDivLoss(reduction="batchmean")
        # temperature
        # logits_stu = torch.sigmoid(logits_stu) / self.args.temperature
        # logits_tea = torch.sigmoid(logits_tea) / self.args.temperature
        
        # loss_kd = self.args.temperature ** 2 \
                    # * sum(kdl_loss(logits_stu[idx], logits_tea[idx]) for idx in range(50)) 
        # return weighted student loss
        # loss = self.args.alpha * loss_stu + (1. - self.args.alpha) * loss_kd
        # Compute losses
        distillation_loss = nn.BCEWithLogitsLoss()(logits_tea.view(-1, 50),logits_stu.view(-1, 50))
        loss = self.args.alpha * loss_stu + (1. - self.args.alpha) * distillation_loss
        return (loss, outputs_stu) if return_outputs else loss

# Model

In [16]:
@dataclass
class CustomModelOutput(ModelOutput):
    loss: Optional[torch.FloatTensor] = None
    logits: torch.FloatTensor = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    last_hidden_state: Optional[torch.FloatTensor] = None
    pooled_embeds: Optional[torch.FloatTensor] = None


@dataclass
class ModelConfig:
    model_name_or_path: str
    gradient_checkpointing: Optional[bool] = False
    output_hidden_states: Optional[bool] = False
    output_last_hidden_state: Optional[bool] = False
    output_pooled_embeds: Optional[bool] = False

class DataSolveModel(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.model_config = config
        self.hf_config = AutoConfig.from_pretrained(self.model_config.model_name_or_path)
        self.backbone = AutoModel.from_pretrained(self.model_config.model_name_or_path, config=self.hf_config)
        if self.model_config.gradient_checkpointing:
            self.backbone.gradient_checkpointing_enable()
        self.output = nn.Linear(self.hf_config.hidden_size, 50)

    def forward(self, input_ids, attention_mask, labels=None):
        trans_out = self.backbone(input_ids, attention_mask=attention_mask)
        last_hidden_state = trans_out.last_hidden_state
        pooled_embeds = last_hidden_state[:, 0]  # CLS token
        logits = self.output(pooled_embeds)

        loss = None
        if labels is not None:
            loss_fct = torch.nn.BCEWithLogitsLoss()
            loss = loss_fct(logits.view(-1, 50), 
                            labels.float().view(-1, 50))

        return CustomModelOutput(
            loss=loss,
            logits=logits,
            hidden_states=trans_out.hidden_states if self.model_config.output_hidden_states else None,
            last_hidden_state=last_hidden_state if self.model_config.output_last_hidden_state else None,
            pooled_embeds=pooled_embeds if self.model_config.output_pooled_embeds else None,
        )

# Train

In [None]:
set_seed(cfg.seed)

OUT_DIR = Path(ARTIFACTS_DIR/f'{EXP_NAME}')
print(f"Saving outputs to {OUT_DIR}")
os.makedirs(OUT_DIR, exist_ok=True)
print(f"EXPERIMENT: {EXP_NAME}, DESC: {cfg.notes}\n")
os.makedirs(OUT_DIR, exist_ok=True)

# sort by length to have similar length samples in each batch for speeding up evaluation
val_ds = val_ds.sort("length")

# remove unwanted columns
keep_cols = {"input_ids", "attention_mask", "labels", "token_type_ids"}
remove_cols = [c for c in train_ds.column_names if c not in keep_cols]
train_ds = train_ds.remove_columns(remove_cols)
val_ds = val_ds.remove_columns(remove_cols)
train_ds.set_format("torch")
val_ds.set_format("torch")

# init model
teacher_model_config = ModelConfig("microsoft/deberta-v3-large")
teacher_model = DataSolveModel(teacher_model_config)
teacher_model.load_state_dict(torch.load(cfg.teacher_model_path))
teacher_model.to(torch.device("cuda"))
teacher_model.eval()

model_config = ModelConfig(**cfg.model)
model = DataSolveModel(model_config)

# init trainer
training_args = DistillationTrainingArguments(output_dir=OUT_DIR, **cfg.training_args)
trainer = DistillationTrainer(
            model,
            teacher_model=teacher_model,
            args=training_args,
            data_collator=DataCollatorWithPadding(tokenizer, pad_to_multiple_of=cfg.data.pad_to_multiple_of),
            train_dataset=train_ds,
            eval_dataset=val_ds,
            tokenizer=tokenizer,
            compute_metrics=compute_metrics,
)
# train
trainer.train()

# ---------------------------------- Save, log, cleanup and upload ---------------------

# Save model
delete_checkpoints(OUT_DIR)
trainer.save_model()
clear_memory()

# Save oof predictions
logits, _, metrics = trainer.predict(val_ds)
if isinstance(logits, tuple):
    logits = logits[0]

oof_dict = {"id": val_df["id"], "logits": logits}
save_pickle(oof_dict, OUT_DIR/f"{EXP_NAME}_oof.pkl")

fin_f1_score = np.round(metrics["test_f1"], 6)

print("=" * 30)
print(f"  EXP {EXP_NAME}, F1-SCORE: {fin_f1_score}")
print("=" * 30)
if cfg.wandb:
    wandb.log({"cv": fin_f1_score})

# save experiment config file
config_file_save_path = OUT_DIR / f"{EXP_NAME}_config.yaml"
with open(config_file_save_path, "w") as fp:
    OmegaConf.save(config=cfg, f=fp.name)
shutil.copyfile(config_file_save_path, os.path.join(wandb.run.dir, f"{EXP_NAME}_config.yaml"))
clear_memory()

Saving outputs to /home/artifacts/debug
EXPERIMENT: debug, DESC: knowledge distillation on deberta-v3-base with trained deberta-v3-large



Epoch,Training Loss,Validation Loss


# Inference

In [None]:
test_df = pd.read_csv(ROOT_DIR/'input'/'test.csv')
sub_df = pd.read_csv(ROOT_DIR/'input'/'sample_submission.csv')
test_df.head()

In [None]:
test_ds = preprocess_data(test_df, mode="inference")
# sort test dataset to have similar length samples in a batch to speed up inference
test_ds = test_ds.sort("length")
test_ds

In [None]:
model_config = ModelConfig(model_name_or_path=cfg.model.model_name_or_path)
model = DataSolveModel(model_config)
model.load_state_dict(torch.load(OUT_DIR/'pytorch_model.bin'))
trainer_args = TrainingArguments("./tmp", per_device_eval_batch_size = cfg.training_args.per_device_eval_batch_size)
trainer = Trainer(model, trainer_args, data_collator=DataCollatorWithPadding(tokenizer))
out = trainer.predict(test_ds)
logits = out.predictions[0] if isinstance(out.predictions, tuple) else out.predictions

# Create submission

In [None]:
ids = []
for id_ in test_ds['id']:
    for col in LABEL_COLS:
        ids.append(f"{id_}_{col}")

In [None]:
predictions = post_process_logits(logits, threshold=0.5)
predictions.shape

In [None]:
sub_df["id"] = ids
sub_df['predictions'] = predictions

In [None]:
sub_df

In [None]:
sub_df.to_csv(OUT_DIR/f"{EXP_NAME}_sub.csv", index=False)
test_logits_dict = {"id":  test_ds['id'], "logits": logits}
save_pickle(test_logits_dict, OUT_DIR/f"{EXP_NAME}_test_logits.pkl")
if cfg.wandb:
    # log artifacts to wandb
    if cfg.upload_artifacts_to_wandb:
        model_artifact = wandb.Artifact(name=EXP_NAME, type="model")
        model_artifact.add_dir(OUT_DIR)
        wandb.log_artifact(model_artifact)

    wandb.alert(
        title=f"Experiment {EXP_NAME}",
        text=f"ðŸŽ‰ Finished experiment {EXP_NAME}, Score: {fin_f1_score}",
        level=AlertLevel.INFO,
        wait_duration=0,
    )

    wandb.save(f"{EXP_NAME}_sub.csv")
    wandb.finish()

# Clean-up

In [None]:
if ENVIRON == "kaggle":
    shutil.rmtree("./tmp", ignore_errors=True)
    shutil.rmtree(ROOT_DIR, ignore_errors=True)

In [None]:
with open("crimson-elevator-29/crimson-elevator-29_test_logits.pkl", "rb") as handler:

    logits = pickle.load(handler)['logits']
def sigmoid(z):
    return 1/(1 + np.exp(-z))

In [None]:
logits.shape

In [None]:
sigmoid(logits)

In [None]:
logits_stu = []
for l in logits:
    probs = [[p, 1-p] for p in l]
    logits_stu.append(probs)
logits_stu
# logits_stu = torch.Tensor(logits_stu)
# logits_tea = torch.Tensor(logits_stu)
# sum(torch.nn.KLDivLoss()(logits_stu[:, i], logits_tea[:, i]) for i in range(50))

In [None]:
outputs_stu = model(**inputs)
# extract bce loss and logits from student
loss_stu = outputs_stu.loss
logits_stu = outputs_stu.logits
# extract logits from teacher
with torch.no_grad():
    outputs_tea = self.teacher_model(**inputs)
    logits_tea = outputs_tea.logits
# soften probabilites and compute distillation loss
kdl_loss = nn.KLDivLoss(reduction="batchmean")
# temperature
logits_stu = torch.sigmoid(logits_stu) / self.args.temperature
logits_tea = torch.sigmoid(logits_tea) / self.args.temperature

loss_kd = self.args.temperature ** 2 \
            * sum(kdl_loss(logits_stu[idx], logits_tea[idx]) for idx in range(50)) 
# return weighted student loss
loss = self.args.alpha * loss_stu + (1. - self.args.alpha) * loss_kd
return (loss, outputs_stu) if return_outputs else loss