# Transfer Learning Approach for Cross-Lingual NLI

## Import Libraries and Setup Environment Variables

In [1]:
! pip install -U scikit-learn
! pip install wandb
! pip install tqdm
! pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu117
! pip install -U git+https://github.com/huggingface/transformers.git
! pip install -U git+https://github.com/huggingface/accelerate.git

Collecting scikit-learn
  Downloading scikit_learn-1.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (10.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.8/10.8 MB[0m [31m77.1 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: scikit-learn
  Attempting uninstall: scikit-learn
    Found existing installation: scikit-learn 1.2.2
    Uninstalling scikit-learn-1.2.2:
      Successfully uninstalled scikit-learn-1.2.2
Successfully installed scikit-learn-1.3.0
[0mLooking in indexes: https://download.pytorch.org/whl/cu117
[0mCollecting git+https://github.com/huggingface/transformers.git
  Cloning https://github.com/huggingface/transformers.git to /tmp/pip-req-build-u1to7gfg
  Running command git clone --filter=blob:none --quiet https://github.com/huggingface/transformers.git /tmp/pip-req-build-u1to7gfg
  Resolved https://github.com/huggingface/transformers.git to commit 2bd7a27a671fd1d98059124024f580f8f5c0f3b5
  Installing buil

In [2]:
import pandas as pd
import numpy as np
import os
import gc
import random
import time
from tqdm import tqdm, trange

from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# set a seed value
torch.manual_seed(42)

from datasets import load_dataset

# import wandb

import transformers
from transformers import TrainingArguments, Trainer
from transformers import AdamW, EarlyStoppingCallback
from transformers import PreTrainedModel, PretrainedConfig
from transformers import BertTokenizer, BertModel
from huggingface_hub import login

caused by: ['/opt/conda/lib/python3.10/site-packages/tensorflow_io/python/ops/libtensorflow_io_plugins.so: undefined symbol: _ZN3tsl6StatusC1EN10tensorflow5error4CodeESt17basic_string_viewIcSt11char_traitsIcEENS_14SourceLocationE']
caused by: ['/opt/conda/lib/python3.10/site-packages/tensorflow_io/python/ops/libtensorflow_io.so: undefined symbol: _ZTVN10tensorflow13GcsFileSystemE']


In [3]:
TOKENIZER_TYPE = 'bert-base-multilingual-cased'
MBERT_TYPE = 'bert-base-multilingual-cased'
MODEL_TEACHER_TYPE = 'jalaluddin94/nli_mbert'
MODEL_PATH = '/kaggle/working/ResearchedModels/'
HF_MODEL_NAME = 'jalaluddin94/trf-learning-indojavanesenli-mbert'

STUDENT_LRATE = 2e-5
LAMBDA_KLD = 0.5 # between 0.01 - 0.5
MAX_LEN = 512
NUM_EPOCHS = 10
BATCH_SIZE = 22
BATCH_NORM_EPSILON = 1e-5
LAMBDA_L2 = 3e-5

HF_TOKEN = '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'

NUM_CORES = os.cpu_count() - 2

In [4]:
login(token=HF_TOKEN)

Token will not been saved to git credential helper. Pass `add_to_git_credential=True` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to /root/.cache/huggingface/token
Login successful


In [5]:
# %env WANDB_API_KEY=97b170d223eb55f86fe1fbf9640831ad76381a74
# wandb.login()

In [6]:
# %env WANDB_LOG_MODEL='end'
# run = wandb.init(
#   project="javanese_nli",
#   notes="Experiment transfer learning on Bandyopadhyay's paper",
#   name="trf-learning-experiment-epoch5-lamdakld0.5",
#   tags=["transferlearning", "bandyopadhyay"]
# )

In [7]:
os.environ["WANDB_AGENT_MAX_INITIAL_FAILURES"]="1024"

In [8]:
os.environ["WANDB_AGENT_DISABLE_FLAPPING"]="true"

In [9]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


## Download and Prepare Dataset

### Download Dataset

In [10]:
# uri = "https://drive.google.com/uc?id=1aE9w2rqgW-j3PTgjnmHDjulNwp-Znb6i"
# output = "dataset/indo_java_nli_training.csv"
# if not os.path.exists("dataset/"):
#   os.makedirs("dataset/")
# gdown.download(url=uri, output=output, quiet=False, fuzzy=True)

In [11]:
# uri = "https://drive.google.com/uc?id=1YlQ9_8CvQbTSb5-2BjIfiYT-cy7pe6YM"
# output = "dataset/indo_java_nli_validation.csv"
# if not os.path.exists("dataset/"):
#   os.makedirs("dataset/")
# gdown.download(url=uri, output=output, quiet=False, fuzzy=True)

In [12]:
# uri = "https://drive.google.com/uc?id=1Zz_rHeI7fPUuA04zt9gCWyl5RYhrYPn0"
# output = "dataset/indo_java_nli_testing.csv"
# if not os.path.exists("dataset/"):
#   os.makedirs("dataset/")
# gdown.download(url=uri, output=output, quiet=False, fuzzy=True)

### Prepare Dataset for Student 

In [13]:
df_train = pd.read_csv("/kaggle/input/dataset-indojavanesenli/indojavanesenli-train.csv", sep='\t')
df_train = df_train.sample(frac=1).reset_index(drop=True) #shuffle the data

In [14]:
df_train_student = pd.DataFrame()
df_train_student["premise"] = df_train["premise"]
df_train_student["hypothesis"] = df_train["jv_hypothesis_mongo"]
df_train_student["label"] = df_train["label"]
df_train_student.head()

Unnamed: 0,premise,hypothesis,label
0,Bukit Bendera merupakan sebuah kawasan parleme...,akeh negara kanggonan neng malaysia.,1
1,"Sebelumnya, rangkaian perayaan HUT ke-261 Kota...",kutha jakarta luwih tuwa dibandingna kutha yog...,1
2,Adveksi memungkinkan awan akan menyebar dan be...,mendung-mendung entuk ngobah saka segaran menu...,0
3,"Negara ini belum diakui oleh satu negarapun, d...",pamerentah mali membenci azawad.,1
4,"Pada penghujung 1970-an, bandar udara ini dipe...",pesawat jet bisa nganggo landasan pacu sadurun...,2


In [15]:
df_valid = pd.read_csv("/kaggle/input/dataset-indojavanesenli/indojavanesenli-valid.csv", sep='\t')
df_valid = df_valid.sample(frac=1).reset_index(drop=True) #shuffle the data

In [16]:
df_valid_student = pd.DataFrame()
df_valid_student["premise"] = df_valid["premise"]
df_valid_student["hypothesis"] = df_valid["jv_hypothesis_mongo"]
df_valid_student["label"] = df_valid["label"]
df_valid_student.head()

Unnamed: 0,premise,hypothesis,label
0,Kreator Rooms bisa memutuskan apakah meeting a...,rooms memungkinkan penyelenggaraan meeting sin...,0
1,"Dari tahun 1930 hingga 1970, Trofi Jules Rimet...",trofi jules rimet nduweni bot 7 kg.,2
2,"Ketika orang mendengar kata perempuan saja, sa...",wong tau krungu tembung wadon.,0
3,Mereka akan menyediakan kereta otomatis untuk ...,richmond ana ing kanada.,0
4,Lima pengedar sabu jaringan Malaysia digulung ...,kancan iki memasukkan sabu karo bungkus teh ci...,0


In [17]:
df_test = pd.read_csv("/kaggle/input/dataset-indojavanesenli/indojavanesenli-test.csv", sep='\t')
df_test = df_test.sample(frac=1).reset_index(drop=True) #shuffle the data

In [18]:
df_test_student = pd.DataFrame()
df_test_student["premise"] = df_test["premise"]
df_test_student["premise"] = df_test_student["premise"].astype(str)
df_test_student["hypothesis"] = df_test["jv_hypothesis_mongo"]
df_test_student["hypothesis"] = df_test_student["hypothesis"].astype(str)
df_test_student["label"] = df_test["label"]
df_test_student.head()

Unnamed: 0,premise,hypothesis,label
0,"Selama curah hujan, tetesan air menyerap dan m...",tetesan banyu sing terserap menyang ing lemah ...,1
1,Tata cara pengibaran Bendera Pusaka disusunnya...,pasukan pengibaran bendera pusaka mung awak sa...,2
2,"Seperti diberitakan Ace Showbiz, Joo Hyuk meni...",joo hyuk ngalami kecelakaan ing dina selasa.,2
3,Dua puluh turnamen Piala Dunia telah dimenangk...,brasil arep memenangkan piala donya sabanjure.,1
4,Desa Mopait adalah salah satu Desa di Kecamata...,desa paling amba neng kecamatan loyalan dudu d...,2


### Prepare Dataset for Teacher

Dataset from teacher will be from "IndoNLI", and using Indonesian only.

In [19]:
df_train_t = pd.DataFrame()
df_train_t["premise"] = df_train["premise"]
df_train_t["hypothesis"] = df_train["hypothesis"]
df_train_t["label"] = df_train["label"]
df_train_t = df_train_t.sample(frac=1).reset_index(drop=True)
display(df_train_t)

Unnamed: 0,premise,hypothesis,label
0,Ketinggian yang tinggi itu berfungsi untuk men...,Iklim yang tidak didingkan akan sangat panas.,0
1,Di sudut ada seorang gadis dengan jilbab dan j...,Gadis dengan jilban dan jin tersebut terlihat ...,2
2,Peluncuran layanan video call terbaru Facebook...,Zoom merasa tersaingi oleh Messenger Rooms.,1
3,"Jadi Pansuslah yang berhak memotong, menambah,...",Sudah banyak pasal-pasal dalam RUU itu yang di...,1
4,"Saat ini, katak bertopi bisa ditemukan di Aust...",Katak bertopi bisa ditemukan di Indonesia.,2
...,...,...,...
10325,Presiden Joko Widodo dijadwalkan akan melantik...,Ada lebih dari 5 orang yang akan Presiden Joko...,1
10326,Kemaharajaan Majapahit adalah sebuah kerajaan ...,Hayam Wuruk adalah pemimpin Kemaharajaan Majap...,1
10327,Menurut dr. Esther Kristiningrum dari Departem...,Pembentukan vitamin D3 tidak membutuhkan katal...,2
10328,"George Meredith, OM (12 Februari 1828—18 Mei 1...",George Meredith adalah seorang novelis berbang...,2


In [20]:
print("Count per class train:") 
print(df_train_t['label'].value_counts())

Count per class train:
0    3476
2    3439
1    3415
Name: label, dtype: int64


In [21]:
df_valid_t = pd.DataFrame()
df_valid_t["premise"] = df_valid["premise"]
df_valid_t["hypothesis"] = df_valid["hypothesis"]
df_valid_t["label"] = df_valid["label"]
df_valid_t = df_valid_t.sample(frac=1).reset_index(drop=True)
display(df_valid_t)

Unnamed: 0,premise,hypothesis,label
0,Perang Dunia I membuat perubahan besar pada pe...,Perang dunia I membuat perubahan kecil pada pe...,2
1,"Chris Anderson, penulis The Long Tail, dan Guy...",Chris Anderson tidak menulis buku The Long Tail.,2
2,"Nama kota ini diambil dari Tarif bin Malik, pe...",Pemimpin ekspedisi umat Islam pertama ke Spany...,2
3,"Sebelumnya, Barisan Relawan Jalan Perubahan (B...",Fahri terlibat aksi makar.,0
4,Bola bumi pertama kali dibuat oleh astronomi Y...,Astronomi Yunani lah yang pertama kali membuat...,0
...,...,...,...
2192,"Sampai saat ini, foto tersebut sudah mendapat ...",Beragam komentar diberikan oleh netizen pada f...,0
2193,Berkunjung ke Bukittinggi terasa belum lengkap...,Mengunjungi Jam Gadang memberi rasa senang.,1
2194,"Sebaliknya, hard rock lebih berasal dari blues...",Pop dimainkan lebih keras dari hard rock.,1
2195,"Perry dan Snoop Dogg menampilkan ""California G...",Perry dan Snoop Dogg tidak tampil di acara MTV...,2


In [22]:
print("Count per class valid:") 
print(df_valid_t['label'].value_counts())

Count per class valid:
0    807
2    749
1    641
Name: label, dtype: int64


In [23]:
df_test_t = pd.DataFrame()
df_test_t["premise"] = df_test["premise"]
df_test_t["hypothesis"] = df_test["hypothesis"]
df_test_t["label"] = df_test["label"]
df_test_t = df_test_t.sample(frac=1).reset_index(drop=True)
display(df_test_t)

Unnamed: 0,premise,hypothesis,label
0,"Setelah ekdisis, dalam waktu satu atau dua jam...",Pertumbuhan hewan dibatasi oleh kekakuan eksos...,0
1,Warna parhelion akhirnya menjadi putih saat be...,Parhelion tidak berwarna.,2
2,Sirkuit ini memiliki panjang 0.533 mi (0.858 k...,Sirkuit ini digunakan untuk lomba MotoGP dan F1.,1
3,Ia menyebut Franco tidak memberi kesempatan pa...,Franco merupakan orang yang buruk perilakunya.,0
4,"Mereka akan pulang, bukan sebagai pemenang ata...",Mereka akan pulang.,0
...,...,...,...
2196,Kendati vegetasi laut hanya memiliki proporsi ...,Kemampuan menyimpan karbon vegetasi laut lebih...,0
2197,Tanaman yang tak banyak memerlukan air justru ...,Tanaman yang memerlukan air populer.,2
2198,"Ketika Millican meninggal, kami kira bahwa itu...",Kami tidak bertemu dengan Odi.,2
2199,"Namun, mereka tidak melakukan ekspansi hingga ...",Mereka adalah negara Inggris yang melakukan ek...,1


In [24]:
print("Count per class test:") 
print(df_test_t['label'].value_counts())

Count per class test:
0    808
2    764
1    629
Name: label, dtype: int64


## Preprocessing

### Tokenization

In [25]:
tokenizer = BertTokenizer.from_pretrained(TOKENIZER_TYPE)

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/996k [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/29.0 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/625 [00:00<?, ?B/s]

In [26]:
class CompDataset(Dataset):
    def __init__(self, df_teacher, df_student):
        self.df_data_teacher = df_teacher
        self.df_data_student = df_student
        
    def __getitem__(self, index):
        # Teacher
        sentence_teacher_1 = self.df_data_teacher.loc[index, 'premise']
        sentence_teacher_2 = self.df_data_teacher.loc[index, 'hypothesis']
        
        encoded_dict_teacher = tokenizer.encode_plus(
            sentence_teacher_1,
            sentence_teacher_2,
            add_special_tokens = True,
            max_length = MAX_LEN,
            truncation='longest_first',
            padding = 'max_length',
            return_attention_mask = True,
            return_tensors = 'pt'
        )
        
        padded_token_list_teacher = encoded_dict_teacher['input_ids'][0]
        att_mask_teacher = encoded_dict_teacher['attention_mask'][0]
        tok_type_id_teacher = encoded_dict_teacher['token_type_ids'][0]
        
        target_teacher = torch.tensor([self.df_data_teacher.loc[index, 'label']])
        lt_target_teacher = torch.LongTensor(target_teacher)
        onehot_encoded_lbl_teacher = F.one_hot(lt_target_teacher, num_classes=3) # 3 classes: entails, neutral, contradict
        
        # Student
        sentence_student_1 = self.df_data_student.loc[index, 'premise']
        sentence_student_2 = self.df_data_student.loc[index, 'hypothesis']
        
        encoded_dict_student = tokenizer.encode_plus(
            sentence_student_1,
            sentence_student_2,
            add_special_tokens = True,
            max_length = MAX_LEN,
            truncation='longest_first',
            padding = 'max_length',
            return_attention_mask = True,
            return_tensors = 'pt'
        )
        
        padded_token_list_student = encoded_dict_student['input_ids'][0]
        att_mask_student = encoded_dict_student['attention_mask'][0]
        tok_type_id_student = encoded_dict_student['token_type_ids'][0]
        
        target_student = torch.tensor([self.df_data_student.loc[index, 'label']])
        lt_target_student = torch.LongTensor(target_student)
        onehot_encoded_lbl_student = F.one_hot(lt_target_student, num_classes=3) # 3 classes: entails, neutral, contradict
        
        output = {
            "input_ids_teacher": padded_token_list_teacher, 
            "attention_mask_teacher": att_mask_teacher,
            "token_type_ids_teacher": tok_type_id_teacher,
            "lbl_teacher": onehot_encoded_lbl_teacher,
            "input_ids_student": padded_token_list_student, 
            "attention_mask_student": att_mask_student,
            "token_type_ids_student": tok_type_id_student,
            "lbl_student": onehot_encoded_lbl_student
        }
        
        return output
    
    def __len__(self):
        return len(self.df_data_teacher)

Tokenize dataset

In [27]:
train_data_cmp = CompDataset(df_train_t, df_train_student)
valid_data_cmp = CompDataset(df_valid_t, df_valid_student)
test_data_cmp = CompDataset(df_test_t, df_test_student)

Create dataloader

In [28]:
train_dataloader = DataLoader(train_data_cmp, batch_size = BATCH_SIZE)
valid_dataloader = DataLoader(valid_data_cmp, batch_size = BATCH_SIZE)
test_dataloader = DataLoader(test_data_cmp, batch_size = BATCH_SIZE)

## Model

Transfer Learning model as per Bandyopadhyay, D., et al (2022) paper

In [29]:
class TransferLearningPaper(PreTrainedModel):
    def __init__(self, config, lambda_kld, learningrate_student, batchnorm_epsilon = 1e-5):
        super(TransferLearningPaper, self).__init__(config)
        
        self.bert_model_teacher = BertModel.from_pretrained(
            MODEL_TEACHER_TYPE, # using pretrained mBERT in INA language
            num_labels = 3,
            output_hidden_states=True
        )
        
        # Freeze teacher mBERT parameters
        for params_teacher in self.bert_model_teacher.parameters():
            params_teacher.requires_grad = False
    
        self.bert_model_student = BertModel.from_pretrained(
            MBERT_TYPE,
            num_labels = 3,
            output_hidden_states=True
        )
        
        # Unfreeze student mBERT parameters
        for params_student in self.bert_model_student.parameters():
            params_student.requires_grad = True
        
        self.optimizer_student = AdamW(
            self.bert_model_student.parameters(), 
            lr=learningrate_student
        )
        
        self.linear = nn.Linear(config.hidden_size, 3)  # Linear layer
        self.batchnorm = nn.BatchNorm1d(config.hidden_size, eps=batchnorm_epsilon)
        self.softmax = nn.Softmax(dim=1)  # Softmax activation
        
        self.cross_entropy = nn.CrossEntropyLoss()
        self.kld = nn.KLDivLoss(reduction='batchmean')
        
        # Initialize the weights of the linear layer
        self.linear.weight.data.normal_(mean=0.0, std=0.02)
        self.linear.bias.data.zero_()
        
        self.lambda_kld = lambda_kld
    
    def forward(self, input_ids_teacher, attention_mask_teacher, token_type_ids_teacher, input_ids_student, attention_mask_student, token_type_ids_student, lbl_student):
        # the label is already one-hot encoded 
        self.bert_model_teacher.eval()
        self.bert_model_student.eval()
        
        with torch.no_grad():
            # Taking CLS token out of mBERT last hidden state
            outputs_teacher = self.bert_model_teacher(
                input_ids=input_ids_teacher, 
                attention_mask=attention_mask_teacher, 
                token_type_ids=token_type_ids_teacher
            )
        
            # take CLS token of the last hidden state
            pooled_output_teacher = outputs_teacher[0][:, 0, :]
        
        # taking CLS token out of the student data without deleting the gradient
        outputs_student = self.bert_model_student(
            input_ids=input_ids_student, 
            attention_mask=attention_mask_student, 
            token_type_ids=token_type_ids_student
        )
        
        pooled_output_student = outputs_student[0][:, 0, :]
        
        # FFNN
        batchnormed_logits = self.batchnorm(pooled_output_student)
        linear_output = self.linear(batchnormed_logits) # the output's logits
        softmax_linear_output = F.log_softmax(linear_output, dim=1)
        
        lbl_student = lbl_student[:,0,:].float()
        softmax_linear_output = softmax_linear_output.float()
        
        # Loss Computation
        cross_entropy_loss = self.cross_entropy(softmax_linear_output, lbl_student)
        total_kld = self.kld(F.log_softmax(pooled_output_student, dim=1), F.softmax(pooled_output_teacher, dim=1))
        joint_loss = cross_entropy_loss + (self.lambda_kld * total_kld )
        
        return {"loss": joint_loss, "logits": softmax_linear_output}
    
    def clear_grad(self):
        self.bert_model_student.train()
        self.optimizer_student.zero_grad()
    
    def backpro_compute(self, loss):
        loss.backward()
        
    def update_std_weights_and_clear_grad(self):
        self.optimizer_student.step()
        self.optimizer_student.zero_grad()
    
    def update_std_weights(self):
        self.optimizer_student.step()
    
    def update_param_student_model(self, loss):
        # Doing customized backpropagation for student's model
        self.bert_model_student.train()
        
        self.optimizer_student.zero_grad()
        loss.backward()
        self.optimizer_student.step()
        
    def upload_to_huggingface(self):
        self.bert_model_student.push_to_hub(HF_MODEL_NAME)
        tokenizer.push_to_hub(HF_MODEL_NAME)

In [30]:
config = PretrainedConfig(
    problem_type = "single_label_classification",
    id2label = {
        "0": "ENTAIL",
        "1": "NEUTRAL",
        "2": "CONTRADICTION"
    },
    label2id = {
        "ENTAIL": 0,
        "NEUTRAL": 1,
        "CONTRADICTION": 2
    },
    num_labels = 3,
    hidden_size = 768,
    name_or_path = "indojavanesenli-transfer-learning",
    finetuning_task = "indonesian-javanese natural language inference"
)
print(config)
transferlearning_model = TransferLearningPaper(
    config = config,
    lambda_kld = LAMBDA_KLD, # antara 0.01-0.5
    learningrate_student = STUDENT_LRATE,
    batchnorm_epsilon = BATCH_NORM_EPSILON
)
transferlearning_model = transferlearning_model.to(device)

PretrainedConfig {
  "_name_or_path": "indojavanesenli-transfer-learning",
  "finetuning_task": "indonesian-javanese natural language inference",
  "hidden_size": 768,
  "id2label": {
    "0": "ENTAIL",
    "1": "NEUTRAL",
    "2": "CONTRADICTION"
  },
  "label2id": {
    "CONTRADICTION": 2,
    "ENTAIL": 0,
    "NEUTRAL": 1
  },
  "problem_type": "single_label_classification",
  "transformers_version": "4.32.0.dev0"
}



Downloading (…)lve/main/config.json:   0%|          | 0.00/1.05k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/711M [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/714M [00:00<?, ?B/s]



## Training

Collect garbage

In [31]:
gc.collect()

173

Function to compute metrics

In [32]:
def compute_metrics(p):
    pred, labels = p
    pred = np.argmax(pred, axis=1)
    labels = np.argmax(labels[:,0,:], axis=1)

    accuracy = accuracy_score(y_true=labels, y_pred=pred)
    recall = recall_score(y_true=labels, y_pred=pred, average='micro')
    precision = precision_score(y_true=labels, y_pred=pred, average='micro')
    f1 = f1_score(y_true=labels, y_pred=pred, average='weighted')

    return {"accuracy": accuracy, "precision": precision, "recall": recall, "f1_score": f1}

Manual training function

In [33]:
def train(the_model, train_data, pgb):
    the_model.train()
    
    batch_loss = 0
    
    for batch, data in enumerate(train_data):
        # Clear accumulated gradients
        the_model.clear_grad()
        
        input_ids_teacher = data["input_ids_teacher"].to(device)
        attention_mask_teacher = data["attention_mask_teacher"].to(device)
        token_type_ids_teacher = data["token_type_ids_teacher"].to(device)
        input_ids_student = data["input_ids_student"].to(device)
        attention_mask_student = data["attention_mask_student"].to(device)
        token_type_ids_student = data["token_type_ids_student"].to(device)
        lbl_student = data["lbl_student"].to(device)
        
        output = the_model(
            input_ids_teacher = input_ids_teacher, 
            attention_mask_teacher = attention_mask_teacher, 
            token_type_ids_teacher = token_type_ids_teacher, 
            input_ids_student = input_ids_student, 
            attention_mask_student = attention_mask_student, 
            token_type_ids_student = token_type_ids_student, 
            lbl_student = lbl_student
        )
        
        loss_model = output["loss"]
        batch_loss += loss_model
        
        # Backpropagation
        # the_model.update_param_student_model(loss_model) # uncomment to use ordinary backpro
        ## now using gradient accumulation technique
        the_model.backpro_compute(loss_model) # backward pass and gradient accumulation
        
        # Accumulate gradients for the desired number of mini-batches
        if(batch+1) % BATCH_SIZE == 0:
            # update weights
            the_model.update_std_weights_and_clear_grad()
        
        pgb.update(1 / len(train_data))
    
    # Make sure to update the weights for any remaining accumulated gradients
    if (batch+1) % BATCH_SIZE != 0:
        the_model.update_std_weights()
        
    training_loss = batch_loss / BATCH_SIZE
#     wandb.log({"train/loss": training_loss})
    
    return training_loss

In [34]:
def validate(the_model, valid_data):
    the_model.eval()
    
    batch_loss = 0
    
    eval_f1 = []
    eval_accuracy = []
    eval_precision = []
    eval_recall = []
    
    with torch.no_grad():
        for batch, data in enumerate(valid_data):
            input_ids_teacher = data["input_ids_teacher"].to(device)
            attention_mask_teacher = data["attention_mask_teacher"].to(device)
            token_type_ids_teacher = data["token_type_ids_teacher"].to(device)
            input_ids_student = data["input_ids_student"].to(device)
            attention_mask_student = data["attention_mask_student"].to(device)
            token_type_ids_student = data["token_type_ids_student"].to(device)
            lbl_student = data["lbl_student"].to(device)

            output = the_model(
                input_ids_teacher = input_ids_teacher, 
                attention_mask_teacher = attention_mask_teacher, 
                token_type_ids_teacher = token_type_ids_teacher, 
                input_ids_student = input_ids_student, 
                attention_mask_student = attention_mask_student, 
                token_type_ids_student = token_type_ids_student, 
                lbl_student = lbl_student
            )

            logits = output["logits"].cpu().detach().numpy()
            packed_val = logits, lbl_student.cpu().detach().numpy()
            metrics = compute_metrics(packed_val)
            
            eval_f1.append(metrics["f1_score"])
            eval_accuracy.append(metrics["accuracy"])
            eval_precision.append(metrics["precision"])
            eval_recall.append(metrics["recall"])
            
            loss_model = output["loss"]
            batch_loss += loss_model
    
        eval_loss = batch_loss / BATCH_SIZE
#         wandb.log({
#             "eval/loss": eval_loss, 
#             "eval/f1_score": np.average(eval_f1), 
#             "eval/accuracy": np.average(eval_accuracy),
#             "eval/precision": np.average(eval_precision),
#             "eval/recall": np.average(eval_recall)
#         })
    
    out_metrics = {
        "eval/loss": eval_loss, 
        "eval/f1_score": np.average(eval_f1), 
        "eval/accuracy": np.average(eval_accuracy),
        "eval/precision": np.average(eval_precision),
        "eval/recall": np.average(eval_recall)
    }
    
    return eval_loss, out_metrics

In [35]:
def training_sequence(the_model, train_data, valid_data, epochs):
    track_train_loss = []
    track_val_loss = []
    
    pbar_format = "{l_bar}{bar} | Epoch: {n:.2f}/{total_fmt} [{elapsed}<{remaining}]"
    with tqdm(total=epochs, colour="blue", leave=True, position=0, bar_format=pbar_format) as t:
        for ep in range(epochs):
            training_loss = train(the_model, train_data, t)
            t.set_description(f"Evaluating... Train loss: {training_loss:.3f}")
            valid_loss, _ = validate(the_model, valid_data)

            track_train_loss.append(training_loss)
            track_val_loss.append(valid_loss)

            t.set_description(f"Train loss: {training_loss:.3f} Valid loss: {valid_loss:.3f}")

            if valid_loss < min(track_val_loss) or ep + 1 == 1:
                the_model.save_pretrained(
                    save_directory = MODEL_PATH + "indojavanesenli-transfer-learning"
                )

#             wandb.log({
#                 "train_loss/epoch": training_loss,
#                 "validation_loss/epoch": valid_loss
#             })
        
    return {
        "training_loss": track_train_loss,
        "validation_loss": track_val_loss
    }

In [36]:
training_result = training_sequence(transferlearning_model, train_dataloader, valid_dataloader, NUM_EPOCHS)

  0%|[34m          [0m | Epoch: 0.04/10 [00:32<2:00:29]Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
  2%|[34m▏         [0m | Epoch: 0.15/10 [01:53<1:59:36]Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
  2%|[34m▏         [0m | Epoch: 0.25/10 [03:01<1:58:04]Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
  3%|[34m▎         [0m | Epoch: 0.28/10 [03:22<1:57:57]Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence p

In [37]:
# wandb.finish()

In [38]:
transferlearning_model.upload_to_huggingface()

Upload 1 LFS files:   0%|          | 0/1 [00:00<?, ?it/s]

pytorch_model.bin:   0%|          | 0.00/711M [00:00<?, ?B/s]