# 🦆 Training

✅ **needed to be checked**
- mixed precision (minor code refactoring done) ✅
- loss function (cross entropy loss and rce loss)  ✅
- collate (does it properly implemented?)
    - assumed torch padding function but was not used
- entity_extraction function
- does deverta accepted long sequence length even the config setted by 512 seq len?

✅ **needed to be added**
- training depend on device ✅
- model output to prediction string converter
- f1macro calculation by prediction string
- training mode without gradient accumulation

---

- Environment Setting
- Argument Setting

## Environment Setting

In [1]:
import os
import os.path as osp
import sys

DATASET_PATH = ('../../feedback-prize-2021')

sys.path.insert(0, './codes')
sys.path.append('longformer/tvm/python/')
sys.path.append('longformer/')

In [2]:
import re
import random
import easydict
import argparse

from random import shuffle
from tqdm import tqdm
from glob import glob

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import h5py
import ftfy
import dill as pickle
import wandb

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader

from torch.cuda.amp import autocast, GradScaler

from transformers import DebertaV2Model

# torch.use_deterministic_algorithms(True)
# from longformer.longformer import Longformer, LongformerConfig, RobertaModel
# from longformer.sliding_chunks import pad_to_window_size

In [3]:
def seed_everything(seed=0):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8"
    
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

**why using `os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8"`?**
- [torch.use_deterministic_algorithms](https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html)
> **A handful of CUDA operations are nondeterministic if the CUDA version is 10.2 or greater**, unless the environment variable `CUBLAS_WORKSPACE_CONFIG=:4096:8` or `CUBLAS_WORKSPACE_CONFIG=:16:8` is set. See the CUDA documentation for more details: https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility If one of these environment variable configurations is not set, a RuntimeError will be raised from these operations when called with CUDA tensors:



## Argument Setting

In [4]:
def get_config():
    parser = argparse.ArgumentParser(description="use huggingface models")
    parser.add_argument("--wandb_user", default='ducky', type=str)
    parser.add_argument("--wandb_project", default='feedback_deberta_large', type=str)
    parser.add_argument("--dataset_path", default='../../feedback-prize-2021', type=str)
    parser.add_argument("--save_path", default='result', type=str)
    parser.add_argument("--seed", default=0, type=int)
    parser.add_argument("--min_len", default=0, type=int)
    parser.add_argument("--use_groupped_weights", default=False, type=bool)
    parser.add_argument("--global_attn", default=False, type=int)
    parser.add_argument("--label_smoothing", default= 0.1, type=float)
    parser.add_argument("--epochs", default=9, type=int)
    parser.add_argument("--batch_size", default=4, type=int)
    parser.add_argument("--grad_acc_steps", default=2, type=int)
    parser.add_argument("--grad_checkpt", default=True, type=bool)
    parser.add_argument("--data_prefix", default='', type=str)
    parser.add_argument("--max_grad_norm", default=35 * 8, type=int)
    parser.add_argument("--start_eval_at", default=0, type=int)
    parser.add_argument("--lr", default=3e-5, type=float)
    parser.add_argument("--min_lr", default=3e-5, type=float)
    parser.add_argument("--weight_decay", default=1e-2, type=float)
    parser.add_argument("--weights_pow", default=0.1, type=float)
    parser.add_argument("--dataset_version", default=2, type=int)
    parser.add_argument("--warmup_steps", default=500, type=int)
    parser.add_argument("--decay_bias", default=False, type=bool)
    parser.add_argument("--val_fold", default=0, type=int)
    parser.add_argument("--num_worker", default=8, type=int)
    parser.add_argument("--local_rank", type=int, default=-1, help="do not modify!")
    parser.add_argument("--device", type=int, default=0, help="select the gpu device to train")

    # optimizer
    parser.add_argument("--rce_weight", default=0.1, type=float)
    parser.add_argument("--ce_weight", default=0.9, type=float)

    # model related arguments
    parser.add_argument("--model", default="microsoft/deberta-v3-large", type=str)
    parser.add_argument("--cnn1d", default=False, type=bool)
    parser.add_argument("--extra_dense", default= False, type=bool)
    parser.add_argument("--dropout_ratio", default=0.0, type=float)
    args = parser.parse_args(args=[])

    if args.local_rank !=-1:
        print('[ DDP ] local rank', args.local_rank)
        torch.cuda.set_device(args.local_rank)
        dist.init_process_group(backend='nccl')
        args.device = torch.device("cuda", args.local_rank)
        args.rank = torch.distributed.get_rank()
        args.world_size = torch.distributed.get_world_size()  

        # checking settings for distributed training
        assert args.batch_size % args.world_size == 0, f'--batch_size {args.batch_size} must be multiple of world size'
        assert torch.cuda.device_count() > args.local_rank, 'insufficient CUDA devices for DDP command'

        args.ddp = True
    else:
        args.device = torch.device("cuda")
        args.rank = 0
        args.ddp = False

    return args

## Load data

In [5]:
from module.utils import get_token_weights
from module.utils import get_prepare_data
from module.utils import get_all_texts
from module.utils import get_id_to_ix_map
from module.utils import get_fold_data

In [6]:
def get_data_files(args):
    token_weights = get_token_weights(args.use_groupped_weights, args.weights_pow)
    data = get_prepare_data()
    csv = pd.read_csv(osp.join(args.dataset_path, 'train.csv'))
    all_texts = get_all_texts(args)
    id_to_ix_map = get_id_to_ix_map()
    data_splits = get_fold_data()

    # text_id example `16585724607E`
    train_text_ids = [text_id for fold in range(5) if fold != args.val_fold for text_id in data_splits[args.seed][250]['normed'][fold]]
    val_text_ids = data_splits[args.seed][250]['normed'][args.val_fold]

    train_ids = [id_to_ix_map[text_id] for text_id in train_text_ids]
    val_ids = [id_to_ix_map[text_id] for text_id in val_text_ids]

    return all_texts, token_weights, data, csv, train_ids, val_ids, train_text_ids, val_text_ids

## Wandb

In [7]:
def wandb_setting(args):
    wandb.login()
    run = wandb.init(entity=args.wandb_user, project=args.wandb_project)
    run.name = f'v3_fold{args.val_fold}_minlr{args.min_lr}_maxlr{args.lr}_wd{args.weight_decay}_warmup{args.warmup_steps}_gradnorm{args.max_grad_norm}_biasdecay{args.decay_bias}_ls{args.label_smoothing}_wp{args.weights_pow}_data{args.dataset_version}_rce{args.rce_weight}'

## 🐣 Dataset & Dataloader

In [8]:
class TrainDataset(torch.utils.data.Dataset):
    def __init__(self, ids, data, label_smoothing, token_weights, data_prefix):
        self.ids = ids
        self.data = data
        self.label_smoothing = label_smoothing
        self.token_weights = token_weights
        self.data_prefix = data_prefix

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, idx):
        i = self.ids[idx]

        # load train data
        tokens = self.data['tokens'][i]
        attention_mask = self.data['attention_masks'][i]
        num_tokens = self.data['num_tokens'][i, 0]

        # label smoothing
        cbio_labels = self.data[f'{self.data_prefix}cbio_labels'][i]
        cbio_labels *= (1 - self.label_smoothing)
        cbio_labels += self.label_smoothing / 15

        # class weight per token
        class_weight = np.zeros_like(attention_mask)
        argmax_labels = cbio_labels.argmax(-1)

        for class_i in range(1, 15):
            class_weight[argmax_labels == class_i] = self.token_weights[class_i]

        class_none_index = argmax_labels == 0      # 0 is the text that is not entity
        class_none_index[num_tokens - 1:] = False  # special token & padding
        class_weight[class_none_index] = self.token_weights[0]
        class_weight[0] = 0

        return tokens, attention_mask, cbio_labels, class_weight, num_tokens
    

In [9]:
class ValDataset(torch.utils.data.Dataset):
    def __init__(self, ids, data, csv, all_texts, val_text_ids, class_names, token_weights):
        self.ids = ids
        self.data = data
        self.csv = csv
        self.space_regex = re.compile('[\s\n]')
        self.all_texts = all_texts
        self.val_text_ids = val_text_ids
        self.class_names = class_names
        self.token_weights = token_weights

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, idx):
        i = self.ids[idx]

        # load text data & text dataframe
        text_id = self.val_text_ids[idx]
        text = self.all_texts[text_id]
        sample_df = self.csv.query('id == @text_id')

        # load ground truth prediction string for f1macro metric
        gt_dict = {}
        for class_i in range(1, 8):
            class_name = self.class_names[class_i]
            class_df = sample_df.query('discourse_type == @class_name')   
            if len(class_df):
                gt_dict[class_i] = [(x[0], x[1]) for x in class_df.predictionstring.map(split_predstring)]
        
        # load valid data
        tokens = self.data['tokens'][i]
        attention_mask = self.data['attention_masks'][i]
        num_tokens = self.data['num_tokens'][i, 0]
        token_bounds = self.data['token_offsets'][i]
        cbio_labels = self.data['cbio_labels'][i]
        
        # class weight per token
        class_weight = np.zeros_like(attention_mask)
        argmax_labels = cbio_labels.argmax(-1)

        for class_i in range(1, 15):
            class_weight[argmax_labels == class_i] = self.token_weights[class_i]

        class_none_index = argmax_labels == 0
        class_none_index[num_tokens - 1:] = False
        class_weight[class_none_index] = self.token_weights[0]
        class_weight[0] = 0
        
        # ???
        index_map = []
        current_word = 0
        blank = False
        for char_ix in range(text.index(text.strip()[0]), len(text)):
            if self.space_regex.match(text[char_ix]) is not None:
                blank = True
            elif blank:
                current_word += 1
                blank = False
            index_map.append(current_word)
        
        return tokens, attention_mask, cbio_labels, class_weight, token_bounds, gt_dict, index_map, num_tokens

In [10]:
def get_dataloader(train_ids, val_ids, data, csv, all_texts, val_text_ids, class_names, token_weights, args):
    train_dataset = TrainDataset(train_ids, data, args.label_smoothing, token_weights, args.data_prefix)
    val_dataset = ValDataset(val_ids, data, csv, all_texts, val_text_ids, class_names, token_weights)

    train_dataloader = DataLoader(train_dataset, collate_fn=train_collate_fn, batch_size=args.batch_size, num_workers=args.num_worker)
    val_dataloader = DataLoader(val_dataset, collate_fn=val_collate_fn, batch_size=args.batch_size, num_workers=8, persistent_workers=True)

    return train_dataloader, val_dataloader

## Collate

In [11]:
first_batch = True
def train_collate_fn(ins):
    global first_batch
    if first_batch:
        max_len = 2048
        first_batch = False
    else:
        max_len = (max(x[-1] for x in ins) + 7) // 8 * 8
        
    return tuple(torch.from_numpy(np.concatenate([ins[z][x][None, :max_len] for z in range(len(ins))])) for x in range(len(ins[0]) - 1))
    
def val_collate_fn(ins):
    max_len = (max(x[-1] for x in ins) + 7) // 8 * 8
    return tuple(torch.from_numpy(np.concatenate([ins[z][x][None, :max_len] for z in range(len(ins))])) for x in range(len(ins[0]) - 3)) + ([x[-3] for x in ins], [x[-2] for x in ins], np.array([x[-1] for x in ins]),)

## Loss

In [12]:
def custom_ce(preds, gts, class_weight):
    preds = torch.log_softmax(preds, -1)
    loss = -(((preds * gts).sum(-1) * class_weight).sum(-1) / class_weight.sum(-1)).mean()

    return loss

def custom_rce(preds, gts, class_weight):
    preds = torch.log_softmax(preds, -1)
    loss = -(((torch.exp(preds) * torch.log_softmax(gts, -1)).sum(-1) * class_weight).sum(-1) / class_weight.sum(-1)).mean()

    return loss

class Criterion():
    """Wrapper for multi criterion calculation"""
    def __init__(self, args):
        self.args = args
        self.criterions = self.get_criterions()
        self.criterion_names = self.args.criterion_list
        self.criterion_ratios = args.criterion_ratio

    def get_criterions(self):
        criterions = []
        for criterion in self.args.criterion_list:
            if criterion == "crossentropy":
                criterions.append(nn.CrossEntropyLoss(weight=self.args.class_weight,
                                                      label_smoothing=self.args.label_smoothing))
            elif criterion == "custom_ce":
                criterions.append(custom_ce)
            elif criterion == "custom_rce":
                criterions.append(custom_rce)
                
        return criterions

    def reshape(self, preds, gts):
        """TODO: fix the dataloader to argmax label version and change this code"""
        return preds.view(-1, 15), gts.argmax(-1).view(-1)
    
    def calculate_loss(self, preds, gts, class_weight=None):
        total_loss = 0
        for criterion_name, criterion, ratio in zip(self.criterion_names, self.criterions, self.criterion_ratios):
            if criterion_name in ['custom_ce', 'custom_rce']:
                current_loss = criterion(preds, gts, class_weight)
            else:
                # TODO: tailored for current dataloader
                preds_, gts_ = self.reshape(preds, gts)
                current_loss = criterion(preds_, gts_)
            total_loss = total_loss + current_loss * ratio

        return total_loss

    def __call__(self, preds, gts, class_weight=None):
        return self.calculate_loss(preds, gts, class_weight)



def get_criterion(args):
    criterion = Criterion(args)

    return criterion

## 🧣 Model

In [13]:
class DebertaV3Large(torch.nn.Module):
    """microsoft/deberta-v3-large"""
    def __init__(self, args):
        super().__init__()
        self.args = args

        self.feats = DebertaV2Model.from_pretrained('microsoft/deberta-v3-large')
        self.feats.pooler = None

        if args.grad_checkpt:
            self.feats.gradient_checkpointing_enable()

        if args.cnn1d:
            self.conv1d_layer1 = torch.nn.Conv1d(1024, 1024, kernel_size=1)
            self.conv1d_layer3 = torch.nn.Conv1d(1024, 1024, kernel_size=3, padding=1)
            self.conv1d_layer5 = torch.nn.Conv1d(1024, 1024, kernel_size=5, padding=2)

            self.output_length = 1024 * 3
        else:
            self.output_length = 1024

        if args.extra_dense:
            self.class_projector = torch.nn.Sequential(
                torch.nn.LayerNorm(self.output_length),
                torch.nn.Linear(self.output_length, 256),
                torch.nn.GELU(),
                torch.nn.Linear(256, 15)
            )
        else:
            self.class_projector = torch.nn.Sequential(
                torch.nn.LayerNorm(self.output_length),
                torch.nn.Linear(self.output_length, 15)
            )

    def forward(self, tokens, mask):
        transformer_output = self.feats(tokens, mask, return_dict=False)[0]
        
        if self.args.cnn1d:
            conv_input = transformer_output.transpose(1, 2) # batch, hidden, seq

            conv_output1 = F.relu(self.conv1d_layer1(conv_input)) 
            conv_output3 = F.relu(self.conv1d_layer3(conv_input)) 
            conv_output5 = F.relu(self.conv1d_layer5(conv_input)) 

            concat_output = torch.cat((conv_output1, conv_output3, conv_output5), dim=1).transpose(1, 2)
            output = self.class_projector(concat_output)
        else:
            output = self.class_projector(transformer_output) # batch, seq, hidden

        return output

In [14]:
def get_model(args):
    if args.model == 'microsoft/deberta-v3-large':
        model = DebertaV3Large(args).to(args.device)

        # dropout layer
        for m in model.modules():
            if isinstance(m, torch.nn.Dropout):
                m.p = args.dropout_ratio

    # distributed training
    if args.ddp:
        model = DDP(model, device_ids=[args.rank], output_device=args.rank)
        model.to(args.device)

    return model


## 🚵🏻 Main

In [15]:
%load_ext autoreload
%autoreload 2

In [16]:
from module.metric import calc_acc, process_sample, make_match_dict

from module.utils import get_data_files
from module.dataset import get_dataloader
from module.loss import get_criterion
from module.optimizer import get_optimizer
from module.scheduler import get_scheduler
from model.model import get_model

In [17]:
seed_everything(42)
args = get_config()
# wandb_setting(args)

class_names = ['None',
               'Lead',
               'Position',
               'Evidence',
               'Claim',
               'Concluding Statement',
               'Counterclaim',
               'Rebuttal']

# create directory to save model
if not osp.exists(args.save_path):
    os.makedirs(args.save_path)

args.batch_size = 2

In [18]:
args.device = 2

In [19]:
args.model = 'microsoft/deberta-v3-large'
model = get_model(args)

Some weights of the model checkpoint at microsoft/deberta-v3-large were not used when initializing DebertaV2Model: ['mask_predictions.classifier.bias', 'lm_predictions.lm_head.bias', 'mask_predictions.LayerNorm.bias', 'mask_predictions.classifier.weight', 'mask_predictions.LayerNorm.weight', 'mask_predictions.dense.weight', 'lm_predictions.lm_head.dense.weight', 'lm_predictions.lm_head.LayerNorm.weight', 'lm_predictions.lm_head.LayerNorm.bias', 'mask_predictions.dense.bias', 'lm_predictions.lm_head.dense.bias']
- This IS expected if you are initializing DebertaV2Model from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DebertaV2Model from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [20]:
all_texts, token_weights, data, csv, train_ids, val_ids, train_text_ids, val_text_ids = get_data_files(args)
train_dataloader, val_dataloader = get_dataloader(args, train_ids, val_ids, data, csv, all_texts, val_text_ids, class_names, token_weights)

args.criterion_list = ["custom_ce", "custom_rce"]
args.criterion_ratio = [args.ce_weight, args.rce_weight]
criterion = get_criterion(args)

args.optimizer = "adam"
optimizer = get_optimizer(args, model)

# scheduler
args.steps_per_epoch = len(train_dataloader)
args.scheduler = "plateau"
scheduler = get_scheduler(args, optimizer)

In [21]:
model.train()

DebertaV3Large(
  (feats): DebertaV2Model(
    (embeddings): DebertaV2Embeddings(
      (word_embeddings): Embedding(128100, 1024, padding_idx=0)
      (LayerNorm): LayerNorm((1024,), eps=1e-07, elementwise_affine=True)
      (dropout): StableDropout()
    )
    (encoder): DebertaV2Encoder(
      (layer): ModuleList(
        (0): DebertaV2Layer(
          (attention): DebertaV2Attention(
            (self): DisentangledSelfAttention(
              (query_proj): Linear(in_features=1024, out_features=1024, bias=True)
              (key_proj): Linear(in_features=1024, out_features=1024, bias=True)
              (value_proj): Linear(in_features=1024, out_features=1024, bias=True)
              (pos_dropout): StableDropout()
              (dropout): StableDropout()
            )
            (output): DebertaV2SelfOutput(
              (dense): Linear(in_features=1024, out_features=1024, bias=True)
              (LayerNorm): LayerNorm((1024,), eps=1e-07, elementwise_affine=True)
            

## SWA (Stochastic Weighted Average)

In [23]:
from torch.optim.swa_utils import AveragedModel, SWALR
from torch.optim.lr_scheduler import CosineAnnealingLR

swa_model = AveragedModel(model)

In [24]:
args.steps_per_epoch = len(train_dataloader)

In [25]:
args.steps_per_epoch

6227

In [41]:
for i in model.named_parameters():
    print(i)
    break

('feats.embeddings.word_embeddings.weight', Parameter containing:
tensor([[-0.0481, -0.0435, -0.0472,  ..., -0.0504, -0.0499, -0.0345],
        [-0.0016, -0.0074,  0.0130,  ..., -0.0136, -0.0066, -0.0036],
        [-0.0035, -0.0076,  0.0116,  ..., -0.0137, -0.0063, -0.0046],
        ...,
        [-0.0476, -0.0571, -0.0349,  ..., -0.0569, -0.0481, -0.0336],
        [-0.0425, -0.0474, -0.0440,  ..., -0.0526, -0.0543, -0.0435],
        [-0.0376, -0.0404, -0.0396,  ..., -0.0565, -0.0464, -0.0312]],
       device='cuda:2', requires_grad=True))


In [46]:
for i in swa_model.named_parameters():
    print(i)
    break

('module.feats.embeddings.word_embeddings.weight', Parameter containing:
tensor([[-0.0493, -0.0447, -0.0484,  ..., -0.0516, -0.0511, -0.0357],
        [-0.0016, -0.0078,  0.0130,  ..., -0.0137, -0.0064, -0.0036],
        [-0.0034, -0.0079,  0.0116,  ..., -0.0139, -0.0061, -0.0046],
        ...,
        [-0.0488, -0.0583, -0.0361,  ..., -0.0581, -0.0493, -0.0347],
        [-0.0437, -0.0486, -0.0452,  ..., -0.0538, -0.0555, -0.0447],
        [-0.0388, -0.0416, -0.0408,  ..., -0.0577, -0.0476, -0.0324]],
       device='cuda:2', requires_grad=True))


In [37]:
for i in swa_model.named_parameters():
    print(i)
    break

('module.feats.embeddings.word_embeddings.weight', Parameter containing:
tensor([[-0.0502, -0.0456, -0.0493,  ..., -0.0525, -0.0520, -0.0366],
        [-0.0016, -0.0079,  0.0130,  ..., -0.0137, -0.0063, -0.0036],
        [-0.0034, -0.0080,  0.0117,  ..., -0.0141, -0.0059, -0.0046],
        ...,
        [-0.0497, -0.0593, -0.0371,  ..., -0.0591, -0.0503, -0.0357],
        [-0.0446, -0.0495, -0.0461,  ..., -0.0548, -0.0565, -0.0456],
        [-0.0397, -0.0425, -0.0417,  ..., -0.0586, -0.0485, -0.0334]],
       device='cuda:2', requires_grad=True))


In [30]:
step_half_cycle = 5
scheduler = CosineAnnealingLR(optimizer, T_max=step_half_cycle, eta_min=1e-7)
swa_start = step_half_cycle * 3
swa_save_per_steps = int(args.steps_per_epoch / 3)

# SWA was intended to be used by epoch but here it will be used by steps
# the main reason is the model is slow and heavy so cannot afford so much epochs
swa_scheduler = SWALR(optimizer, swa_lr=1e-5, anneal_epochs=args.steps_per_epoch, anneal_strategy='cos')

In [32]:
swa_start = 15
swa_save_per_step = 2

In [None]:
for step, batch in enumerate(train_dataloader):
    print(f'current step is {step}')
    tokens, mask, label, class_weight = (x.to(args.device) for x in batch)
    
    optimizer.zero_grad()
    outs = model(tokens, mask)
    loss = criterion(outs, label, class_weight=class_weight)
    loss.backward()
    optimizer.step()
    
    if (step + 1) > swa_start:
        if (step + 1) % swa_save_per_step == 0:
            print('working on swa!!!')
            swa_model.update_parameters(model)
        swa_scheduler.step()
    else:
        scheduler.step()
        
    for param_group in optimizer.param_groups:
        print('lr', param_group['lr'])

In [None]:
# Update bn statistics for the swa_model at the end
torch.optim.swa_utils.update_bn(loader, swa_model)
# Use swa_model to make predictions on test data 
preds = swa_model(test_input)

# Test Laboratory

In [None]:
STOP!!!! YOU SHALL NOT PASS!!!

## scheduler

In [None]:
args.min_lr = 1e-07

lr_schedule = np.r_[np.linspace(0, args.lr, args.warmup_steps),
                    (np.cos(np.linspace(0, np.pi, len(train_dataloader)*args.epochs - args.warmup_steps)) * .5 + .5) * (args.lr - args.min_lr)
                    + args.min_lr]

plt.plot(np.arange(len(lr_schedule)), lr_schedule)
plt.show()

## cross entropy check

In [None]:
args.criterion_list = ["crossentropy", "custom_rce"]
args.criterion_ratio = [args.ce_weight, args.rce_weight]
args.class_weight = torch.Tensor(token_weights).to(args.device).half()

criterion = get_criterion(args)            

In [None]:
loss = criterion(outs, label, class_weight=class_weight)

In [None]:
loss

In [None]:
loss