In [263]:
! pip install -U scikit-learn
! pip install wandb
! pip install tqdm
! pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu117
! pip install -U git+https://github.com/huggingface/transformers.git
! pip install -U git+https://github.com/huggingface/accelerate.git

Looking in indexes: https://download.pytorch.org/whl/cu117
Collecting git+https://github.com/huggingface/transformers.git
  Cloning https://github.com/huggingface/transformers.git to /tmp/pip-req-build-fjri10qj
  Running command git clone --filter=blob:none --quiet https://github.com/huggingface/transformers.git /tmp/pip-req-build-fjri10qj
  Resolved https://github.com/huggingface/transformers.git to commit 9dc965bb404c2bb8e3c02eaa5eea6502af1aee1a
  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
Collecting git+https://github.com/huggingface/accelerate.git
  Cloning https://github.com/huggingface/accelerate.git to /tmp/pip-req-build-oz0zl5aj
  Running command git clone --filter=blob:none --quiet https://github.com/huggingface/accelerate.git /tmp/pip-req-build-oz0zl5aj
  Resolved https://github.com/huggingface/accelerate.git to commit 653ba110d31c86d3527bb88bf6209441c176ce1

In [264]:
import pandas as pd
import numpy as np
import os
import gc
import random
import time
from tqdm import tqdm, trange

from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# set a seed value
torch.manual_seed(42)

from datasets import load_dataset

import wandb

import transformers
from transformers import TrainingArguments, Trainer
from transformers import AdamW, EarlyStoppingCallback
from transformers import PreTrainedModel, PretrainedConfig
from transformers import XLMRobertaModel, XLMRobertaForSequenceClassification, XLMRobertaTokenizer
from huggingface_hub import login

In [265]:
TOKENIZER_TYPE = 'xlm-roberta-base'
MBERT_TYPE = 'xlm-roberta-base'
MODEL_TEACHER_TYPE = 'jalaluddin94/xlmr-nli-indoindo'
MODEL_PATH = '/kaggle/working/ResearchedModels/'
HF_MODEL_NAME = 'jalaluddin94/trf-learning-indojavanesenli-xlmr'

STUDENT_LRATE = 2e-5
LAMBDA_KLD = 0.5 # between 0.01 - 0.5
MAX_LEN = 512
NUM_EPOCHS = 5
BATCH_SIZE = 1
BATCH_NORM_EPSILON = 1e-5
LAMBDA_L2 = 3e-5

HF_TOKEN = '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'

NUM_CORES = os.cpu_count() - 2

In [None]:
login(token=HF_TOKEN)

In [266]:
# %env WANDB_API_KEY=97b170d223eb55f86fe1fbf9640831ad76381a74
# wandb.login()

In [267]:
# %env WANDB_LOG_MODEL='end'
# run = wandb.init(
#   project="javanese_nli",
#   notes="Experiment transfer learning on Bandyopadhyay's paper using XLMR",
#   name="trf-lrn-experiment-xlmr-epoch5-lamdakld0.5",
#   tags=["transferlearning", "bandyopadhyay", "xlmr"]
# )

In [268]:
os.environ["WANDB_AGENT_MAX_INITIAL_FAILURES"]="1024"
os.environ["WANDB_AGENT_DISABLE_FLAPPING"]="true"

In [269]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


## Data Preparation

Prepare Dataset for Student

In [270]:
df_train = pd.read_csv("/kaggle/input/dataset-indojavanesenli/indojavanesenli-train.csv", sep='\t')
df_train = df_train.sample(frac=1).reset_index(drop=True) #shuffle the data

df_train_student = pd.DataFrame()
df_train_student["premise"] = df_train["premise"]
df_train_student["hypothesis"] = df_train["jv_hypothesis_mongo"]
df_train_student["label"] = df_train["label"]
df_train_student.head()

Unnamed: 0,premise,hypothesis,label
0,"Esai ini, yang diterbitkan sebagai Undersea, m...",esai iki ngrupakne narasi babagan dalan neng d...,0
1,"""Pada tahun 2001, Komite Olimpiade Internasion...",semenjak 2001 tuan surup olimpiade uga dadi tu...,0
2,Hargeisa adalah kota terbesar kedua di Somalia...,hargesia dhekea kutha paling gedhe neng somali...,0
3,Tentunya Tiongkok akan menyajikan banyak peran...,durung ana rencana saka tiongkok kanggo unjuk ...,2
4,Leher nya bergerak maju dalam konfrontasi yang...,dheweke keweden.,2


In [271]:
df_valid = pd.read_csv("/kaggle/input/dataset-indojavanesenli/indojavanesenli-valid.csv", sep='\t')
df_valid = df_valid.sample(frac=1).reset_index(drop=True) #shuffle the data

df_valid_student = pd.DataFrame()
df_valid_student["premise"] = df_valid["premise"]
df_valid_student["hypothesis"] = df_valid["jv_hypothesis_mongo"]
df_valid_student["label"] = df_valid["label"]
df_valid_student.head()

Unnamed: 0,premise,hypothesis,label
0,Beliau adalah orang yang paling baik akhlaknya...,wong kuwi nduweni akhlak becik.,0
1,"Liga eSport Amerika Serikat (AS), Collegiate S...",turnamen tiktok cup arep dianakne ing udhar 20...,1
2,Ibu tiga anak ini juga membenarkan jika bagian...,kecelakaan tol entas kedadean.,2
3,Sekitar 23 juta orang di pesisir Indonesia dip...,taun 2050 diprediksi arep dadi taun kebecikan ...,2
4,Brivio yang juga pernah bekerja sama dengan Ro...,rossi nduweni kabisan adaptasi sing jaba biyasa.,0


In [272]:
df_test = pd.read_csv("/kaggle/input/dataset-indojavanesenli/indojavanesenli-test.csv", sep='\t')
df_test = df_test.sample(frac=1).reset_index(drop=True) #shuffle the data

df_test_student = pd.DataFrame()
df_test_student["premise"] = df_test["premise"]
df_test_student["premise"] = df_test_student["premise"].astype(str)
df_test_student["hypothesis"] = df_test["jv_hypothesis_mongo"]
df_test_student["hypothesis"] = df_test_student["hypothesis"].astype(str)
df_test_student["label"] = df_test["label"]
df_test_student.head()

Unnamed: 0,premise,hypothesis,label
0,Sebagian besar pembicara menganggap ini menjad...,sakanggonan gedhe pangomong nganggep iki dadi ...,0
1,Middlesbrough dan Bournemouth juga gagal mempe...,bournemouth tau rumangsakne kemumpangan.,1
2,GERD (Gastroesophageal Reflux Disease) adalah ...,gerd marakake alangan pernapasan.,0
3,Kalau dengan posisi ini suami mampu menahan ej...,laki ora bisa nglakoke hubungan seksual karo p...,2
4,Uni Soviet adalah negara sosialis yang pernah ...,uni soviet yaiku nagara komunis.,2


Prepare Dataset for Teacher

Dataset from teacher will be from "IndoNLI", and using Indonesian only.

In [273]:
df_train_t = pd.DataFrame()
df_train_t["premise"] = df_train["premise"]
df_train_t["hypothesis"] = df_train["hypothesis"]
df_train_t["label"] = df_train["label"]
df_train_t = df_train_t.sample(frac=1).reset_index(drop=True)
display(df_train_t)

Unnamed: 0,premise,hypothesis,label
0,"Jonan menyampaikan pernyataan itu, menanggapi ...",Tidak ada perubahan status kontrak karya ke iz...,2
1,"Pada awal tahun 2006, mantan koordinator Kampa...",Ryaas Rasyid adalah ekonom.,2
2,Vinales menjadi yang tercepat sepanjang dua ha...,"Pada tes hari kedua, catatan waktu Vinales di ...",2
3,Invasi Irak ke Kuwait disebabkan oleh kemeroso...,Irak banyak mengalami kesulitan.,1
4,"Dengan saran dari Alexander, ia pun bisa berko...",Ia tidak mempertimbangkan anjuran dari Alexander.,2
...,...,...,...
10325,Beragam penduduk asli mendiami Alaska selama r...,Orang Eropa telah tinggal ribuan tahun di daer...,2
10326,Selera Tiongkok yang tak pernah terpuaskan ter...,Produk KFC tidak sesuai dengan selera Tiongkok.,2
10327,"Pada tahun 1271, setelah sebulan pertempuran, ...",Baibar adalah seorang sultan.,1
10328,Malaysia mampu membuka skor pada menit ke-11 m...,Malaysia belum mendapat skor pada menit ke-30 ...,2


In [274]:
print("Count per class train:") 
print(df_train_t['label'].value_counts())

Count per class train:
0    3476
2    3439
1    3415
Name: label, dtype: int64


In [275]:
df_valid_t = pd.DataFrame()
df_valid_t["premise"] = df_valid["premise"]
df_valid_t["hypothesis"] = df_valid["hypothesis"]
df_valid_t["label"] = df_valid["label"]
df_valid_t = df_valid_t.sample(frac=1).reset_index(drop=True)
display(df_valid_t)

Unnamed: 0,premise,hypothesis,label
0,Lari sambung atau lari estafet adalah salah sa...,Lari estafet dilaksanakan dengan minimal 3 orang.,1
1,"Bagi Anda yang ingin melakukan wisata edukasi,...",Banyak pilihahn wisata Edukasi di Bogor.,1
2,"Pada 1865, Kapal Uap Sultana yang mengangkut 2...",Kapal Uap Sultana hanya beroperasi pada tahun ...,1
3,Saya menulis hal ini kini untuk memberitahu An...,Saya tidak mengingatkan diri saya.,2
4,"Selama perang kemerdekaan RI dari 1945-1949, h...",Pejabat Belanda NICA-KNIL sangat banyak.,1
...,...,...,...
2192,Kemudian kedua orang tua itu mencoba mengubah ...,Ia tidak mempunyai 2 orang tua.,2
2193,Bangunan ini digunakan untuk penjualan berbaga...,Pemilik dari bangunan ini adalah penguasa wila...,1
2194,"""Biarlah masyarakat bahasa memiliki kebebasan ...",Masyarakat tidak bebas untuk memilih menurut P...,2
2195,Zeltweg adalah kota yang terletak di Aichfeld ...,Zeltweg berada pada ketinggian 659 m.,0


In [276]:
print("Count per class valid:") 
print(df_valid_t['label'].value_counts())

Count per class valid:
0    807
2    749
1    641
Name: label, dtype: int64


In [277]:
df_test_t = pd.DataFrame()
df_test_t["premise"] = df_test["premise"]
df_test_t["hypothesis"] = df_test["hypothesis"]
df_test_t["label"] = df_test["label"]
df_test_t = df_test_t.sample(frac=1).reset_index(drop=True)
display(df_test_t)

Unnamed: 0,premise,hypothesis,label
0,Santa Fe adalah sebuah kotamadya di Vega de Gr...,Sungai Genil tidak mengaliri Granada.,2
1,"Pameran bertajuk ""Titanic—The Promise of Moder...","""Titanic-The Promise of Modernity"" adalah pame...",0
2,"Misalnya, dengan dicantumkannya Hak Asasi Manu...",Hak Asasi Manusia (HAM) sangat penting untuk o...,1
3,Konser bertajuk Cross Genre Music ini adalah g...,Terdapat 2 grup musisi untuk genre pop pada ko...,1
4,Pakaian formal yang dikenakan pejabat sipil (b...,Ketiak pejabat sipil bau.,1
...,...,...,...
2196,Mariah segera memeriksakan dirinya ke rumah sa...,Mariah memeriksakan dirinya ke rumah sakit.,0
2197,"Selama curah hujan, tetesan air menyerap dan m...",Tetesan air yang terserap ke dalam tanah berba...,1
2198,Berpetualang bersama teman di Gunung Batu Jong...,Banyak kegiatan asyik selain berpetualang.,1
2199,Purwokerto Selatan adalah sebuah kecamatan di ...,Purwokerto Utara adalah sebuah kecamatan di Ka...,1


In [278]:
print("Count per class test:") 
print(df_test_t['label'].value_counts())

Count per class test:
0    808
2    764
1    629
Name: label, dtype: int64


## Preprocessing

In [279]:
tokenizer = XLMRobertaTokenizer.from_pretrained(TOKENIZER_TYPE)

In [280]:
class CompDataset(Dataset):
    def __init__(self, df_teacher, df_student):
        self.df_data_teacher = df_teacher
        self.df_data_student = df_student
        
    def __getitem__(self, index):
        # Teacher
        sentence_teacher_1 = self.df_data_teacher.loc[index, 'premise']
        sentence_teacher_2 = self.df_data_teacher.loc[index, 'hypothesis']
        
        encoded_dict_teacher = tokenizer.encode_plus(
            sentence_teacher_1,
            sentence_teacher_2,
            add_special_tokens = True,
            max_length = MAX_LEN,
            truncation='longest_first',
            padding = 'max_length',
            return_attention_mask = True,
            return_tensors = 'pt'
        )
        
        padded_token_list_teacher = encoded_dict_teacher['input_ids'][0]
        att_mask_teacher = encoded_dict_teacher['attention_mask'][0]
        
        target_teacher = torch.tensor([self.df_data_teacher.loc[index, 'label']])
        lt_target_teacher = torch.LongTensor(target_teacher)
        onehot_encoded_lbl_teacher = F.one_hot(lt_target_teacher, num_classes=3) # 3 classes: entails, neutral, contradict
        
        # Student
        sentence_student_1 = self.df_data_student.loc[index, 'premise']
        sentence_student_2 = self.df_data_student.loc[index, 'hypothesis']
        
        encoded_dict_student = tokenizer.encode_plus(
            sentence_student_1,
            sentence_student_2,
            add_special_tokens = True,
            max_length = MAX_LEN,
            truncation='longest_first',
            padding = 'max_length',
            return_attention_mask = True,
            return_tensors = 'pt'
        )
        
        padded_token_list_student = encoded_dict_student['input_ids'][0]
        att_mask_student = encoded_dict_student['attention_mask'][0]
        
        target_student = torch.tensor([self.df_data_student.loc[index, 'label']])
        lt_target_student = torch.LongTensor(target_student)
        onehot_encoded_lbl_student = F.one_hot(lt_target_student, num_classes=3) # 3 classes: entails, neutral, contradict
        
        output = {
            "input_ids_teacher": padded_token_list_teacher, 
            "attention_mask_teacher": att_mask_teacher,
            "lbl_teacher": onehot_encoded_lbl_teacher,
            "input_ids_student": padded_token_list_student, 
            "attention_mask_student": att_mask_student,
            "lbl_student": onehot_encoded_lbl_student
        }
        
        return output
    
    def __len__(self):
        return len(self.df_data_teacher)

In [281]:
train_data_cmp = CompDataset(df_train_t, df_train_student)
valid_data_cmp = CompDataset(df_valid_t, df_valid_student)
test_data_cmp = CompDataset(df_test_t, df_test_student)

In [282]:
train_dataloader = DataLoader(train_data_cmp, batch_size = BATCH_SIZE)
valid_dataloader = DataLoader(valid_data_cmp, batch_size = BATCH_SIZE)
test_dataloader = DataLoader(test_data_cmp, batch_size = BATCH_SIZE)

## Model

Transfer Learning model as per Bandyopadhyay, D., et al (2022) paper, but using XLMR instead of mBERT

In [283]:
class TransferLearningPaper(PreTrainedModel):
    def __init__(self, config, lambda_kld, learningrate_student, batchnorm_epsilon = 1e-5):
        super(TransferLearningPaper, self).__init__(config)
        
        self.xlmr_model_teacher = XLMRobertaModel.from_pretrained(
            MODEL_TEACHER_TYPE, # using pretrained mBERT in INA language
            num_labels = 3,
            output_hidden_states=True
        )
        
        # Freeze teacher mBERT parameters
        for params_teacher in self.xlmr_model_teacher.parameters():
            params_teacher.requires_grad = False
    
        self.xlmr_model_student = XLMRobertaModel.from_pretrained(
            MBERT_TYPE,
            num_labels = 3,
            output_hidden_states=True
        )
        
        # Unfreeze student mBERT parameters
        for params_student in self.xlmr_model_student.parameters():
            params_student.requires_grad = True
        
        self.optimizer_student = AdamW(
            self.xlmr_model_student.parameters(), 
            lr=learningrate_student
        )
        
        self.linear = nn.Linear(config.hidden_size, 3)  # Linear layer
        self.batchnorm = nn.BatchNorm1d(config.hidden_size, eps=batchnorm_epsilon)
        self.softmax = nn.Softmax(dim=1)  # Softmax activation
        
        self.cross_entropy = nn.CrossEntropyLoss()
        self.kld = nn.KLDivLoss(reduction='batchmean')
        
        # Initialize the weights of the linear layer
        self.linear.weight.data.normal_(mean=0.0, std=0.02)
        self.linear.bias.data.zero_()
        
        self.lambda_kld = lambda_kld
    
    def forward(self, input_ids_teacher, attention_mask_teacher, lbl_teacher, input_ids_student, attention_mask_student, lbl_student):
        # the label is already one-hot encoded 
        self.xlmr_model_teacher.eval()
        self.xlmr_model_student.eval()
        
        lbl_teacher = lbl_teacher[:, 0, :]
        lbl_student = lbl_student[:, 0, :]
        
        with torch.no_grad():
            # Taking CLS token out of XLMR last hidden state
            outputs_teacher = self.xlmr_model_teacher(
                input_ids=input_ids_teacher, 
                attention_mask=attention_mask_teacher #, 
                #labels=lbl_teacher
            )
        
            # take CLS token of the last hidden state
            pooled_output_teacher = outputs_teacher.last_hidden_state[:, 0, :]
        
        # taking CLS token out of the student data without deleting the gradient
        outputs_student = self.xlmr_model_student(
            input_ids=input_ids_student, 
            attention_mask=attention_mask_student #, 
            #labels=lbl_student
        )
        
        pooled_output_student = outputs_student.last_hidden_state[:, 0, :]
        
        # FFNN
        batchnormed_logits = self.batchnorm(pooled_output_student)
        linear_output = self.linear(batchnormed_logits) # the output's logits
        softmax_linear_output = F.log_softmax(linear_output, dim=1)
        
        lbl_student = lbl_student.float()
        softmax_linear_output = softmax_linear_output.float()
        
        # Loss Computation
        cross_entropy_loss = self.cross_entropy(softmax_linear_output, lbl_student)
        total_kld = self.kld(F.log_softmax(pooled_output_student, dim=1), F.softmax(pooled_output_teacher, dim=1))
        joint_loss = cross_entropy_loss + (self.lambda_kld * total_kld )
        
        return {"loss": joint_loss, "logits": softmax_linear_output}
    
    def clear_grad(self):
        self.xlmr_model_student.train()
        self.optimizer_student.zero_grad()
    
    def backpro_compute(self, loss):
        loss.backward()
        
    def update_std_weights_and_clear_grad(self):
        self.optimizer_student.step()
        self.optimizer_student.zero_grad()
    
    def update_std_weights(self):
        self.optimizer_student.step()
    
    def update_param_student_model(self, loss):
        # Doing customized backpropagation for student's model
        self.xlmr_model_student.train()
        
        self.optimizer_student.zero_grad()
        loss.backward()
        self.optimizer_student.step()
        
    def upload_to_huggingface(self):
        self.xlmr_model_student.push_to_hub(HF_MODEL_NAME)
        tokenizer.push_to_hub(HF_MODEL_NAME)

In [284]:
config = PretrainedConfig(
    problem_type = "single_label_classification",
    id2label = {
        "0": "ENTAIL",
        "1": "NEUTRAL",
        "2": "CONTRADICTION"
    },
    label2id = {
        "ENTAIL": 0,
        "NEUTRAL": 1,
        "CONTRADICTION": 2
    },
    num_labels = 3,
    hidden_size = 768,
    name_or_path = "indojavanesenli-transfer-learning",
    finetuning_task = "indonesian-javanese natural language inference"
)
print(config)
transferlearning_model = TransferLearningPaper(
    config = config,
    lambda_kld = LAMBDA_KLD, # antara 0.01-0.5
    learningrate_student = STUDENT_LRATE,
    batchnorm_epsilon = BATCH_NORM_EPSILON
)
transferlearning_model = transferlearning_model.to(device)

PretrainedConfig {
  "_name_or_path": "indojavanesenli-transfer-learning",
  "finetuning_task": "indonesian-javanese natural language inference",
  "hidden_size": 768,
  "id2label": {
    "0": "ENTAIL",
    "1": "NEUTRAL",
    "2": "CONTRADICTION"
  },
  "label2id": {
    "CONTRADICTION": 2,
    "ENTAIL": 0,
    "NEUTRAL": 1
  },
  "problem_type": "single_label_classification",
  "transformers_version": "4.30.2"
}



Some weights of the model checkpoint at jalaluddin94/xlmr-nli-indoindo were not used when initializing XLMRobertaModel: ['classifier.out_proj.weight', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.dense.bias']
- This IS expected if you are initializing XLMRobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLMRobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of XLMRobertaModel were not initialized from the model checkpoint at jalaluddin94/xlmr-nli-indoindo and are newly initialized: ['roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inf

## Training

In [None]:
gc.collect()

Function to compute metrics

In [None]:
def compute_metrics(p):
    pred, labels = p
    pred = np.argmax(pred, axis=1)
    labels = np.argmax(labels[:,0,:], axis=1)

    accuracy = accuracy_score(y_true=labels, y_pred=pred)
    recall = recall_score(y_true=labels, y_pred=pred, average='micro')
    precision = precision_score(y_true=labels, y_pred=pred, average='micro')
    f1 = f1_score(y_true=labels, y_pred=pred, average='weighted')

    return {"accuracy": accuracy, "precision": precision, "recall": recall, "f1_score": f1}

Manual training function

In [None]:
def train(the_model, train_data, pgb):
    the_model.train()
    
    batch_loss = 0
    
    for batch, data in enumerate(train_data):
        # Clear accumulated gradients
        the_model.clear_grad()
        
        input_ids_teacher = data["input_ids_teacher"].to(device)
        attention_mask_teacher = data["attention_mask_teacher"].to(device)
        lbl_teacher = data["lbl_teacher"].to(device)
        input_ids_student = data["input_ids_student"].to(device)
        attention_mask_student = data["attention_mask_student"].to(device)
        lbl_student = data["lbl_student"].to(device)
        
        output = the_model(
            input_ids_teacher = input_ids_teacher, 
            attention_mask_teacher = attention_mask_teacher,
            lbl_teacher = lbl_teacher,
            input_ids_student = input_ids_student, 
            attention_mask_student = attention_mask_student, 
            lbl_student = lbl_student
        )
        
        loss_model = output["loss"]
        batch_loss += loss_model
        
        # Backpropagation
        # the_model.update_param_student_model(loss_model) # uncomment to use ordinary backpro
        ## now using gradient accumulation technique
        the_model.backpro_compute(loss_model) # backward pass and gradient accumulation
        
        # Accumulate gradients for the desired number of mini-batches
        if(batch+1) % BATCH_SIZE == 0:
            # update weights
            the_model.update_std_weights_and_clear_grad()
        
        pgb.update(1 / len(train_data))
    
    # Make sure to update the weights for any remaining accumulated gradients
    if (batch+1) % BATCH_SIZE != 0:
        the_model.update_std_weights()
        
    training_loss = batch_loss / BATCH_SIZE
#     wandb.log({"train/loss": training_loss})
    
    return training_loss

In [None]:
def validate(the_model, valid_data):
    the_model.eval()
    
    batch_loss = 0
    
    eval_f1 = []
    eval_accuracy = []
    eval_precision = []
    eval_recall = []
    
    with torch.no_grad():
        for batch, data in enumerate(valid_data):
            input_ids_teacher = data["input_ids_teacher"].to(device)
            attention_mask_teacher = data["attention_mask_teacher"].to(device)
            lbl_teacher = data["lbl_teacher"].to(device)
            input_ids_student = data["input_ids_student"].to(device)
            attention_mask_student = data["attention_mask_student"].to(device)
            lbl_student = data["lbl_student"].to(device)

            output = the_model(
                input_ids_teacher = input_ids_teacher, 
                attention_mask_teacher = attention_mask_teacher, 
                lbl_teacher = lbl_teacher,
                input_ids_student = input_ids_student, 
                attention_mask_student = attention_mask_student,
                lbl_student = lbl_student
            )

            logits = output["logits"].cpu().detach().numpy()
            packed_val = logits, lbl_student.cpu().detach().numpy()
            metrics = compute_metrics(packed_val)
            
            eval_f1.append(metrics["f1_score"])
            eval_accuracy.append(metrics["accuracy"])
            eval_precision.append(metrics["precision"])
            eval_recall.append(metrics["recall"])
            
            loss_model = output["loss"]
            batch_loss += loss_model

            # t.update(1 / len(valid_data))
    
        eval_loss = batch_loss / BATCH_SIZE
#         wandb.log({
#             "eval/loss": eval_loss, 
#             "eval/f1_score": np.average(eval_f1), 
#             "eval/accuracy": np.average(eval_accuracy),
#             "eval/precision": np.average(eval_precision),
#             "eval/recall": np.average(eval_recall)
#         })
    
    out_metrics = {
        "eval/loss": eval_loss, 
        "eval/f1_score": np.average(eval_f1), 
        "eval/accuracy": np.average(eval_accuracy),
        "eval/precision": np.average(eval_precision),
        "eval/recall": np.average(eval_recall)
    }
    
    return eval_loss, out_metrics

In [None]:
def training_sequence(the_model, train_data, valid_data, epochs):
    track_train_loss = []
    track_val_loss = []
    
    pbar_format = "{l_bar}{bar} | Epoch: {n:.2f}/{total_fmt} [{elapsed}<{remaining}]"
    with tqdm(total=epochs, colour="blue", leave=True, position=0, bar_format=pbar_format) as t:
        for ep in range(epochs):
            training_loss = train(the_model, train_data, t)
            t.set_description(f"Evaluating... Train loss: {training_loss:.3f}")
            valid_loss, _ = validate(the_model, valid_data)

            track_train_loss.append(training_loss)
            track_val_loss.append(valid_loss)

            t.set_description(f"Train loss: {training_loss:.3f} Valid loss: {valid_loss:.3f}")

            if valid_loss < min(track_val_loss) or ep + 1 == 1:
                the_model.save_pretrained(
                    save_directory = MODEL_PATH + "indojavanesenli-transfer-learning"
                )

#             wandb.log({
#                 "train_loss/epoch": training_loss,
#                 "validation_loss/epoch": valid_loss
#             })
        
    return {
        "training_loss": track_train_loss,
        "validation_loss": track_val_loss
    }

In [None]:
training_result = training_sequence(transferlearning_model, train_dataloader, valid_dataloader, NUM_EPOCHS)

In [None]:
# wandb.finish()

In [None]:
transferlearning_model.upload_to_huggingface()