# About the Notebook

All the transformer models used for the final ensemble were based out of this notebook. This pipeline runs a 5-fold training. The folds are stratified using the iterstrat package which helps to stratify multi-label data. Specifically, `MultilabelStratifiedKFold` was used to create the folds. You can also check the complete data preprocessing and preparation stage in [this notebook](https://www.kaggle.com/code/atharvaingle/datasolve-eda-cv-setup).

# Setup

In [1]:
DEBUG = False
WANDB = True
ENVIRON = "lambdalabs"
EXP_NAME = "dbv3l-15ep"
NOTES = "5 fold experiment, with the best settings till now, deberta-v3-large, 15 epochs, 4e-5, 16 bs, previous run failed because of catastrophic forgetting"

# Setup Environment

In [2]:
import pkgutil
from pathlib import Path

PROJECT = "DataSolve-2022"

if ENVIRON == "jarvislabs":
    ROOT_DIR = Path(f"/home/{PROJECT}")
    ARTIFACTS_DIR = Path("/home/artifacts")
    SETUP_SCRIPT_PATH = Path("/home/setup.sh")
elif ENVIRON == "lambdalabs":
    ROOT_DIR = Path(f"/home/ubuntu/{PROJECT}")
    ARTIFACTS_DIR = Path("/home/ubuntu/artifacts")
    SETUP_SCRIPT_PATH = Path("/home/setup.sh")
elif ENVIRON == "kaggle":
    ROOT_DIR = Path(f"/kaggle/working/{PROJECT}")
    ARTIFACTS_DIR = Path("/kaggle/working/artifacts")
    SETUP_SCRIPT_PATH = Path("/kaggle/input/datasolve-setup-script/setup.sh")
    
if not pkgutil.find_loader("omegaconf") and ENVIRON == "kaggle":
    !bash {SETUP_SCRIPT_PATH} {ENVIRON} "true"

In [3]:
# load secret keys
%load_ext dotenv
%dotenv {ROOT_DIR}/.env

# Configuration

In [4]:
import os, gc
gc.enable()
from omegaconf import OmegaConf

class Config:
    # GENERAL
    debug = DEBUG
    wandb = WANDB
    seed = 42
    train_csv = "train_folds_5.csv"
    fold = 0 # will be overriden later
    
    # MODEL
    model = dict(
        model_name_or_path = "microsoft/deberta-v3-large",
        gradient_checkpointing = False,
        reinit_last_layers = 0,
        output_hidden_states = False,
        output_last_hidden_state = False,
        output_pooled_embeds = False,
    )


    # TRACKING
    exp_name = EXP_NAME
    tags = ["clspool", f"{model['model_name_or_path']}", "512", "5_fold_split"]
    notes = NOTES
    upload_artifacts_to_wandb = True
    
    # DATA
    data = dict(
        max_length = 512,
        truncation = True,
        pad_to_multiple_of = 8,
    )
    
    # TRAINING ARGUMENTS
    training_args = dict(
        # general
        seed = seed,
        evaluation_strategy = "epoch",
        save_strategy = "epoch",
        save_total_limit = 1,

        # train settings
        num_train_epochs = 15,
        lr_scheduler_type = "cosine",
        warmup_ratio = 0.2,
        per_device_train_batch_size = 16,
        per_device_eval_batch_size = 16,
        gradient_accumulation_steps = 1,
        learning_rate = 4e-5,
        weight_decay = 0.01,
        max_grad_norm = 1.0,
        
        # misc
        adam_epsilon = 1e-6,
        fp16 = True,
        dataloader_num_workers = min(6, os.cpu_count()),
        load_best_model_at_end = True,
        metric_for_best_model = "eval_f1",
        greater_is_better = True,
        group_by_length = True,
        length_column_name = "length",
        report_to = "wandb" if WANDB else "none",
        dataloader_pin_memory = True,
    )


# CONFIG SETTINGS
config_dict = {x:dict(Config.__dict__)[x] for x in dict(Config.__dict__) if not x.startswith('_')}
cfg = OmegaConf.create(config_dict)
if cfg.debug: 
    cfg.tags += ["debug"]
    cfg.training_args.num_train_epochs = 2
    cfg.model.model_name_or_path = "microsoft/deberta-v3-base"

OUTPUT_DIR = Path(ARTIFACTS_DIR/f'{cfg.exp_name}')
print(f"Saving outputs to {OUTPUT_DIR}")
os.makedirs(OUTPUT_DIR, exist_ok=True)
print(f"EXPERIMENT: {cfg.exp_name}, DESC: {cfg.notes}\n")
print(OmegaConf.to_yaml(cfg, resolve=True))

Saving outputs to /home/ubuntu/artifacts/dbv3l-15ep
EXPERIMENT: dbv3l-15ep, DESC: 5 fold experiment, with the best settings till now, deberta-v3-large, 15 epochs, 4e-5, 16 bs, previous run failed because of catastrophic forgetting

debug: false
wandb: true
seed: 42
train_csv: train_folds_5.csv
fold: 0
model:
  model_name_or_path: microsoft/deberta-v3-large
  gradient_checkpointing: false
  reinit_last_layers: 0
  output_hidden_states: false
  output_last_hidden_state: false
  output_pooled_embeds: false
exp_name: dbv3l-15ep
tags:
- clspool
- microsoft/deberta-v3-large
- '512'
- 5_fold_split
notes: 5 fold experiment, with the best settings till now, deberta-v3-large, 15 epochs,
  4e-5, 16 bs, previous run failed because of catastrophic forgetting
upload_artifacts_to_wandb: true
data:
  max_length: 512
  truncation: true
  pad_to_multiple_of: 8
training_args:
  seed: 42
  evaluation_strategy: epoch
  save_strategy: epoch
  save_total_limit: 1
  num_train_epochs: 15
  lr_scheduler_type: c

# Imports

In [5]:
import copy
import glob
import shutil
import pickle
import warnings
import logging
import numpy as np
import pandas as pd
from pprint import pprint
from dataclasses import dataclass
from typing import Dict, List, Tuple, Callable, Optional, Union

import wandb
from wandb import AlertLevel

from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit
from sklearn.metrics import f1_score, accuracy_score, roc_auc_score
import torch
import torch.nn as nn
import torch.nn.functional as F

import datasets, transformers
from datasets import Dataset
from transformers import (
    AutoTokenizer,
    AutoConfig,
    AutoModel,
    AutoModelForSequenceClassification,
    EvalPrediction,
    PreTrainedTokenizer,
    PretrainedConfig,
    PreTrainedModel,
    DataCollatorWithPadding,
    Trainer,
    TrainingArguments,
    set_seed,
)
from transformers.modeling_outputs import ModelOutput

# SYSTEM SETTINGS
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
os.environ["WANDB_SILENT"] = "true"
set_seed(cfg.seed)
if not cfg.debug:
    warnings.simplefilter("ignore")
    logging.disable(logging.WARNING)

  from pandas.core.computation.check import NUMEXPR_INSTALLED


# Helper Functions

In [6]:
def delete_checkpoints(dir):
    for file in glob.glob(f"{dir}/checkpoint-*"):
        shutil.rmtree(file, ignore_errors=True)


def clear_memory():
    gc.collect()
    torch.cuda.empty_cache()

def delete_file(path: str):
    if os.exists(path):
        os.remove(path)

def save_pickle(obj, filepath):
    with open(filepath, 'wb') as handle:
        pickle.dump(obj, handle, protocol=pickle.HIGHEST_PROTOCOL)
    
def process_config_for_wandb(cfg: OmegaConf):
    """
    Only keep relevant part of config for logging
    """
    tmp_cfg = copy.deepcopy(cfg)
    cfg_dict = OmegaConf.to_container(tmp_cfg, resolve=True, throw_on_missing=True)
    del cfg_dict["training_args"]
    return cfg_dict

# Read and process data

In [7]:
# READ DATA
df = pd.read_csv(ROOT_DIR/'input'/cfg.train_csv)
if cfg.debug:
    df = df.sample(100, random_state=42).reset_index(drop=True)
LABEL_COLS = [col for col in df.columns if col not in ["id", "name", "document_text", "fold"]]
print(len(LABEL_COLS))
df.head()

50


Unnamed: 0,id,name,document_text,Accounting and Finance,Antitrust,Banking,Broker Dealer,Commodities Trading,Compliance Management,Consumer protection,...,Research,Risk Management,Securities Clearing,Securities Issuing,Securities Management,Securities Sales,Securities Settlement,Trade Pricing,Trade Settlement,fold
0,4772,Consent Order in the Matter of Solium Financia...,"Solium Financial Services LLC (""SFS"") is a bro...",0,0,0,1,0,1,0,...,0,0,0,0,0,0,0,0,0,0
1,4774,Alberta Securities Commission Warns Investors ...,A new year brings new investment opportunities...,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
2,4775,Exempt Market Dealer Agrees to Settlement,The Alberta Securities Commission (ASC) has co...,0,0,0,1,0,1,0,...,0,0,0,0,0,1,1,0,1,2
3,4776,Canadian Securities Regulators Announces Consu...,The Canadian Securities Administrators (CSA) p...,0,0,0,0,0,0,0,...,0,0,0,0,0,0,1,0,0,4
4,4778,CSA Consultation Paper 51-405 Consideration of...,"On April 6, 2017, the Canadian Securities Admi...",0,0,0,0,0,0,1,...,0,0,0,0,0,0,1,0,0,2


In [8]:
tokenizer = AutoTokenizer.from_pretrained(cfg.model.model_name_or_path)

def tokenize_func(example: pd.Series, tokenizer: PreTrainedTokenizer, max_length: int = 512, truncation: bool = True, mode: str="train"):
    tokenized = tokenizer(
    example["text"],
    truncation=truncation,
    max_length=max_length,
    add_special_tokens=True,
)
    if mode == "train":
        tokenized["labels"] = [example[i] for i in LABEL_COLS]
    tokenized["length"] = len(tokenized["input_ids"])
    return tokenized

def preprocess_data(df_: pd.DataFrame, mode:str="train"):
    df_["text"] = tokenizer.cls_token + df_["name"] + tokenizer.sep_token + df_["document_text"] + tokenizer.sep_token
    tok_ds = Dataset.from_pandas(df_)
    tok_ds = tok_ds.map(lambda x: tokenize_func(x, tokenizer, max_length=cfg.data.max_length, truncation=cfg.data.truncation, mode=mode), num_proc=2)
    return tok_ds

tok_ds = preprocess_data(df, mode="train")

   

#0:   0%|          | 0/4930 [00:00<?, ?ex/s]

 

#1:   0%|          | 0/4929 [00:00<?, ?ex/s]

In [9]:
tok_ds[0]

{'id': 4772,
 'name': 'Consent Order in the Matter of Solium Financial Services LLC',
 'document_text': 'Solium Financial Services LLC ("SFS") is a broker-dealer with a principal place of business at 50 Tice Boulevard, Suite A-18 Woodcliff Lake, New Jersey 07677, and is registered as a broker-dealer with the Alabama Securities Commission ("Commission"). During the period from at least January 2009 to June 6, 2019, SFS acted as broker-dealer in Alabama as the term broker-dealer is defined by Title 8, Chapter 6, 8-6-2 of the Act. Code of Alabama, 8-6-3(a) states that it is unlawful for a person to transact business in Alabama as a broker-dealer or agent unless such person is registered under the Act. By engaging in the conduct set forth above, SFS acted as an unregistered broker-dealer in Alabama in violation of 8-6-3(a) of the Act. This Order concludes the investigation by the Commission and any other action that the Commission could commence under applicable Alabama law as it relates t

# Metrics

In [10]:
def post_process_logits(logits: np.ndarray, threshold=0.5):
    # first, apply sigmoid on logits which are of shape (batch_size, num_labels)
    sigmoid = torch.nn.Sigmoid()
    probs = sigmoid(torch.Tensor(logits))
    # next, use threshold to turn them into integer predictions
    preds = np.zeros(probs.shape)
    preds[np.where(probs >= threshold)] = 1
    preds = preds.flatten().astype(int)
    return preds

def compute_metrics(p: EvalPrediction):
    # `predictions` might return last_hidden_state or pooled_embeds
    # In that case, take the first element (array) of the tuple for logits
    logits = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
    preds = post_process_logits(logits)
    labels = p.label_ids.flatten()
    f1_macro_average = f1_score(labels, preds, average='macro')
    roc_auc = roc_auc_score(labels, preds, average = 'macro')
    accuracy = accuracy_score(labels, preds)
    # return as dictionary
    return {
        'f1': f1_macro_average,
        'roc_auc': roc_auc,
        'accuracy': accuracy
    }

# Custom Trainer

In [11]:
# class MultilabelTrainer(Trainer):
#     def compute_loss(self, model, inputs, return_outputs=False):
#         labels = inputs.pop("labels")
#         outputs = model(**inputs)
#         logits = outputs.logits
#         loss_fct = torch.nn.BCEWithLogitsLoss()
#         loss = loss_fct(logits.view(-1, self.model.config.num_labels), 
#                         labels.float().view(-1, self.model.config.num_labels))
#         return (loss, outputs) if return_outputs else loss

# Model

In [12]:
@dataclass
class CustomModelOutput(ModelOutput):
    loss: Optional[torch.FloatTensor] = None
    logits: torch.FloatTensor = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    last_hidden_state: Optional[torch.FloatTensor] = None
    pooled_embeds: Optional[torch.FloatTensor] = None


@dataclass
class ModelConfig:
    model_name_or_path: str
    gradient_checkpointing: Optional[bool] = False
    reinit_last_layers: Optional[int] = 0
    output_hidden_states: Optional[bool] = False
    output_last_hidden_state: Optional[bool] = False
    output_pooled_embeds: Optional[bool] = False

def reinit_last_layers(model: Union[nn.Module, PreTrainedModel], num_layers: int):
    """Re-initialize the last-k transformer layers"""
    if num_layers > 0:
        model.encoder.layer[-num_layers:].apply(model._init_weights)
    
class DataSolveModel(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.model_config = config
        self.hf_config = AutoConfig.from_pretrained(self.model_config.model_name_or_path)
        
        self.backbone = AutoModel.from_pretrained(self.model_config.model_name_or_path, config=self.hf_config)
        
        if self.model_config.gradient_checkpointing:
            self.backbone.gradient_checkpointing_enable()

        # Initialize last-k transformer (backbone) layers
        reinit_last_layers(self.backbone, self.model_config.reinit_last_layers)
        
        self.output = nn.Linear(self.hf_config.hidden_size, 50)
        
    def forward(self, input_ids, attention_mask, labels=None):
        trans_out = self.backbone(input_ids, attention_mask=attention_mask)
        last_hidden_state = trans_out.last_hidden_state
        pooled_embeds = last_hidden_state[:, 0] # CLS Token
        
        logits = self.output(pooled_embeds)

        loss = None
        if labels is not None:
            loss_fct = torch.nn.BCEWithLogitsLoss()
            loss = loss_fct(logits.view(-1, 50), labels.float().view(-1, 50))
            
        return CustomModelOutput(
            loss=loss,
            logits=logits,
            hidden_states=trans_out.hidden_states if self.model_config.output_hidden_states else None,
            last_hidden_state=last_hidden_state if self.model_config.output_last_hidden_state else None,
            pooled_embeds=pooled_embeds if self.model_config.output_pooled_embeds else None,
        )

# Train

In [13]:
def train_fold(cfg, fold):
    set_seed(cfg.seed)
    cfg.fold = fold
    if cfg.wandb:
        # init wandb run
        wandb.init(
            project="DataSolve-2022",
            group=cfg.exp_name,
            name=f"fold_{cfg.fold}",
            tags=cfg.tags,
            notes=cfg.notes,
            config=process_config_for_wandb(cfg),
            save_code=True,
        )
        # send slack notification
        wandb.alert(
            title=f"Experiment {cfg.exp_name}",
            text=f"🚀 Starting experiment {cfg.exp_name} (fold_{cfg.fold}), Description: {cfg.notes}",
            level=AlertLevel.INFO,
            wait_duration=0,
        )

    OUT_DIR = OUTPUT_DIR/f'fold_{cfg.fold}'
    os.makedirs(OUT_DIR, exist_ok=True)
    
    # filter train and val data
    train_ds = tok_ds.filter(lambda x: x["fold"] != cfg.fold, desc="Filtering train idxs")
    val_ds = tok_ds.filter(lambda x: x["fold"] == cfg.fold, desc="Filtering valid idxs")

    # sort by length to have similar length samples in each batch for speeding up evaluation
    val_ds = val_ds.sort("length")
    val_ids = val_ds['id']
    # remove unwanted columns
    keep_cols = {"input_ids", "attention_mask", "labels", "token_type_ids"}
    remove_cols = [c for c in train_ds.column_names if c not in keep_cols]
    train_ds = train_ds.remove_columns(remove_cols)
    val_ds = val_ds.remove_columns(remove_cols)
    train_ds.set_format("torch")
    val_ds.set_format("torch")

    # init model
    model_config = ModelConfig(**cfg.model)
    model = DataSolveModel(model_config)

    # init trainer
    training_args = TrainingArguments(output_dir=OUT_DIR, **cfg.training_args)
    trainer = Trainer(
                model,
                args=training_args,
                data_collator=DataCollatorWithPadding(tokenizer, pad_to_multiple_of=cfg.data.pad_to_multiple_of),
                train_dataset=train_ds,
                eval_dataset=val_ds,
                tokenizer=tokenizer,
                compute_metrics=compute_metrics,
    )
    # train
    trainer.train()

    # ---------------------------------- Save, log, cleanup and upload to W&B ------------------------

    # Save model
    delete_checkpoints(OUT_DIR)
    trainer.save_model()
    clear_memory()

    # Infer on validation set and extract logits and labels
    eval_out = trainer.predict(val_ds)
    logits = eval_out.predictions[0] if isinstance(eval_out.predictions, tuple) else eval_out.predictions
    labels = eval_out.label_ids
    
    oof_dict = {"id": val_ids, "logits": logits, "labels": labels}
    save_pickle(oof_dict, OUT_DIR/f"oof_{cfg.fold}.pkl")
    
    fin_f1_score = np.round(eval_out.metrics["test_f1"], 6)

    print("*" * 30)
    print(f"  Experiment {cfg.exp_name}, Fold {cfg.fold}, F1-SCORE: {fin_f1_score}")
    print("*" * 30)

    # save experiment config file
    config_file_save_path = OUT_DIR / f"{cfg.exp_name}_config.yaml"
    with open(config_file_save_path, "w") as fp:
        OmegaConf.save(config=cfg, f=fp.name)

    # log artifacts to wandb
    if cfg.wandb:
        wandb.log({"cv": fin_f1_score})
        if cfg.upload_artifacts_to_wandb:
            shutil.copyfile(config_file_save_path, os.path.join(wandb.run.dir, f"{cfg.exp_name}_config.yaml"))
            model_artifact = wandb.Artifact(name=cfg.exp_name, type="model")
            model_artifact.add_dir(OUT_DIR)
            wandb.log_artifact(model_artifact, aliases=f"fold_{cfg.fold}")

        wandb.alert(
            title=f"Experiment {cfg.exp_name}",
            text=f"🎉 Finished experiment {cfg.exp_name} (fold_{cfg.fold}), Score: {fin_f1_score}",
            level=AlertLevel.INFO,
            wait_duration=0,
        )
        wandb.finish()
            
    del model, trainer, eval_out, train_ds, val_ds; clear_memory();

In [None]:
for fold in range(5):
    train_fold(cfg, fold)

Filtering train idxs:   0%|          | 0/10 [00:00<?, ?ba/s]

Filtering valid idxs:   0%|          | 0/10 [00:00<?, ?ba/s]

Epoch,Training Loss,Validation Loss,F1,Roc Auc,Accuracy
1,No log,0.235258,0.648483,0.605215,0.919256
2,0.372900,0.190268,0.718239,0.657466,0.931961
3,0.206300,0.14781,0.820948,0.76785,0.948467
4,0.160500,0.124416,0.860252,0.818393,0.957229
5,0.126800,0.106454,0.886075,0.852261,0.963953
6,0.097800,0.09811,0.897677,0.868193,0.967132
7,0.074200,0.093254,0.90498,0.876075,0.969384
8,0.055900,0.087396,0.917118,0.910615,0.971666
9,0.040800,0.081806,0.923885,0.916153,0.97405
10,0.029100,0.080604,0.928175,0.917849,0.97566


******************************
  Experiment dbv3l-15ep, Fold 0, F1-SCORE: 0.930153
******************************


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01666862536667395, max=1.0)…

Filtering train idxs:   0%|          | 0/10 [00:00<?, ?ba/s]

Filtering valid idxs:   0%|          | 0/10 [00:00<?, ?ba/s]

Epoch,Training Loss,Validation Loss,F1,Roc Auc,Accuracy
1,No log,0.23458,0.595673,0.567181,0.91625
2,0.372100,0.17683,0.772997,0.718693,0.93874
3,0.205000,0.152576,0.802578,0.745916,0.9458
4,0.159900,0.121493,0.867104,0.836034,0.95857
5,0.124800,0.10648,0.881656,0.844629,0.96362
6,0.096100,0.099585,0.891668,0.856098,0.96646
7,0.073400,0.088934,0.911911,0.894989,0.9712
8,0.055200,0.084161,0.919702,0.909699,0.97327
9,0.039400,0.084855,0.92089,0.918086,0.9732
10,0.027700,0.080498,0.9281,0.920766,0.9759


******************************
  Experiment dbv3l-15ep, Fold 1, F1-SCORE: 0.930679
******************************


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016669114000008752, max=1.0…

Filtering train idxs:   0%|          | 0/10 [00:00<?, ?ba/s]

Filtering valid idxs:   0%|          | 0/10 [00:00<?, ?ba/s]

Epoch,Training Loss,Validation Loss,F1,Roc Auc,Accuracy
1,No log,0.232452,0.636522,0.595654,0.918852
2,0.371000,0.181365,0.762802,0.702579,0.937815
3,0.205500,0.147619,0.827435,0.78124,0.948902
4,0.157900,0.128414,0.843465,0.790447,0.9544
5,0.127300,0.110345,0.873745,0.828483,0.961677
6,0.097900,0.10082,0.890632,0.85109,0.966016
7,0.075300,0.092039,0.909216,0.892213,0.969868
8,0.056700,0.086202,0.919435,0.909831,0.972734
9,0.040700,0.081471,0.92629,0.909682,0.975467
10,0.028700,0.081438,0.927258,0.911583,0.975732


# Inference

In [None]:
# init a new wandb run for storing submission artifacts
wandb.init(
    project="DataSolve-2022",
    group=cfg.exp_name,
    name="inference",
    tags=cfg.tags,
    notes=cfg.notes,
    config=process_config_for_wandb(cfg)
)
SUB_OUT_DIR = OUTPUT_DIR/'submission'
os.makedirs(SUB_OUT_DIR, exist_ok=True)

In [None]:
test_df = pd.read_csv(ROOT_DIR/'input'/'test.csv')
sub_df = pd.read_csv(ROOT_DIR/'input'/'sample_submission.csv')
test_df.head()

In [None]:
test_ds = preprocess_data(test_df, mode="inference")
# sort test dataset to have similar length samples in a batch to speed up inference
test_ds = test_ds.sort("length")
test_ds

In [None]:
all_logits = []
for fold in range(5):
    print(f"{'='*10} Inferring FOLD-{fold} {'='*10}")
    model_dir = OUTPUT_DIR/f'fold_{fold}'
    model_config = ModelConfig(**cfg.model)
    model = DataSolveModel(model_config)
    model.load_state_dict(torch.load(model_dir/'pytorch_model.bin'))
    trainer_args = TrainingArguments("./tmp", per_device_eval_batch_size = 16)
    trainer = Trainer(model, trainer_args, data_collator=DataCollatorWithPadding(tokenizer))
    out = trainer.predict(test_ds)
    logits = out.predictions[0] if isinstance(out.predictions, tuple) else out.predictions
    all_logits.append(logits)
    
    del model, trainer; clear_memory();

In [None]:
fin_logits = np.mean(all_logits, axis=0)
# save test logits to a file
test_dict = {"id": test_ds['id'], "logits": fin_logits}
save_pickle(test_dict, SUB_OUT_DIR/f"{cfg.exp_name}_test_logits.pkl")

# Create submission

In [None]:
ids = []
for id_ in test_ds['id']:
    for col in LABEL_COLS:
        ids.append(f"{id_}_{col}")

predictions = post_process_logits(fin_logits, threshold=0.5)
print(predictions.shape)
sub_df['id'] = ids
sub_df['predictions'] = predictions
sub_df.to_csv(SUB_OUT_DIR/f"{cfg.exp_name}_sub.csv", index=False)
sub_df.head()

# Upload artifacts to W&B

In [None]:
infer_artifact = wandb.Artifact(name=cfg.exp_name, type="model")
infer_artifact.add_dir(SUB_OUT_DIR)
wandb.log_artifact(infer_artifact, aliases="submission")
wandb.finish()

# Clean-up

In [None]:
if ENVIRON == "kaggle":
    shutil.rmtree("./tmp", ignore_errors=True)
    shutil.rmtree(ROOT_DIR, ignore_errors=True)