# Transfer Learning Approach for Cross-Lingual NLI

## Import Libraries and Setup Environment Variables

In [1]:
import pandas as pd
import numpy as np
import os
import gc
import random
import gdown
import time
from tqdm import tqdm, trange

from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# set a seed value
torch.manual_seed(205)

from datasets import load_dataset

import wandb

import transformers
from transformers import TrainingArguments, Trainer
from transformers import EarlyStoppingCallback
from transformers import BertTokenizer, BertForSequenceClassification #, XLMRobertaTokenizer, XLMRobertaForSequenceClassification
from transformers import AdamW

In [2]:
# MODEL_TYPE = 'xlm-roberta-base'
MODEL_TYPE = 'bert-base-multilingual-cased'
MODEL_PATH = 'D:/Training/Machine Learning/NLP/NLI/saved_models/Indo-Javanese-NLI/ResearchedModels/'

L_RATE = 3e-6
MAX_LEN = 512
NUM_EPOCHS = 6
BATCH_SIZE = 2
BATCH_NORM_EPSILON = 1e-5
LAMBDA_L2 = 2e-5

NUM_CORES = os.cpu_count() - 2

In [3]:
# %env WANDB_NOTEBOOK_NAME=/home/sagemaker-user/PPT/BERT_BiLSTM_Game_Review.ipynb
%env WANDB_API_KEY=97b170d223eb55f86fe1fbf9640831ad76381a74
wandb.login()

env: WANDB_API_KEY=97b170d223eb55f86fe1fbf9640831ad76381a74


[34m[1mwandb[0m: Currently logged in as: [33mjalaluddin-94[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [4]:
%env WANDB_PROJECT=javanese_nli
%env WANDB_LOG_MODEL='end'

env: WANDB_PROJECT=javanese_nli
env: WANDB_LOG_MODEL='end'


In [5]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


## Download and Prepare Dataset

### Download Dataset

In [6]:
# uri = "https://drive.google.com/uc?id=1aE9w2rqgW-j3PTgjnmHDjulNwp-Znb6i"
# output = "dataset/indo_java_nli_training.csv"
# if not os.path.exists("dataset/"):
#   os.makedirs("dataset/")
# gdown.download(url=uri, output=output, quiet=False, fuzzy=True)

In [7]:
# uri = "https://drive.google.com/uc?id=1YlQ9_8CvQbTSb5-2BjIfiYT-cy7pe6YM"
# output = "dataset/indo_java_nli_validation.csv"
# if not os.path.exists("dataset/"):
#   os.makedirs("dataset/")
# gdown.download(url=uri, output=output, quiet=False, fuzzy=True)

In [8]:
# uri = "https://drive.google.com/uc?id=1Zz_rHeI7fPUuA04zt9gCWyl5RYhrYPn0"
# output = "dataset/indo_java_nli_testing.csv"
# if not os.path.exists("dataset/"):
#   os.makedirs("dataset/")
# gdown.download(url=uri, output=output, quiet=False, fuzzy=True)

### Prepare Dataset for Student 

In [9]:
df_train = pd.read_csv("D:/Training/Machine Learning/Datasets/NLI/IndoJavaNLI/indonli-with-java-chatgpt-training-data.csv", sep='\t')
df_train = df_train.sample(frac=1).reset_index(drop=True) #shuffle the data

In [10]:
df_train_student = pd.DataFrame()
df_train_student["premise"] = df_train["premise"]
df_train_student["hypothesis"] = df_train["jv_hypothesis"]
df_train_student["label"] = df_train["label"]
df_train_student.head()

Unnamed: 0,premise,hypothesis,label
0,Dia membuat penampilan terakhirnya untuk klub ...,"""Isih lunga main kanggo klubé ing tanggal 20 N...",1
1,Pengelola Nama Domain Internet Indonesia (Pand...,"""Pandi ora rencana mlebu domain.""",2
2,Mayweather menepis anggapan bahwa McGregor yan...,Mayweather nganggukno McGregor minangka sumber...,2
3,Tinggal bagaimana Real Madrid menjaga keutuhan...,Alvaro Morata duwe jatah tampilan sing banyuak...,2
4,Sekte-sekte lain yang terpecah termasuk kelomp...,Grup Syiah lan sekte-sekte liyane sing penggal...,0


In [11]:
df_valid = pd.read_csv("D:/Training/Machine Learning/Datasets/NLI/IndoJavaNLI/indonli-with-java-chatgpt-validation-data.csv", sep='\t')
df_valid = df_valid.sample(frac=1).reset_index(drop=True) #shuffle the data

In [12]:
df_valid_student = pd.DataFrame()
df_valid_student["premise"] = df_valid["premise"]
df_valid_student["hypothesis"] = df_valid["jv_hypothesis"]
df_valid_student["label"] = df_valid["label"]
df_valid_student.head()

Unnamed: 0,premise,hypothesis,label
0,Lagu tersebut merupakan lagu latar untuk film ...,Lagu the Star diwenehi Carey.,1
1,"Hingga pertengahan tahun 1960-an, kimono masih...",Nanging ayeuna pakaian sehari-hari wadon Jepan...,1
2,Tim identifikasi Polres Kediri Kota masih mela...,Ratusan wong sing nggeleh menyang warung korba...,1
3,Dusun ini dikelilingi jalan semi aspal yang lu...,Dhukuh iki sisih ing wilayah angetan.,1
4,Reger belajar musik di Muenchen dan Wiesbaden ...,Reger mbelajar musik nang Wiesbaden.,0


In [13]:
df_test = pd.read_csv("D:/Training/Machine Learning/Datasets/NLI/IndoJavaNLI/indonli-with-java-chatgpt-testing-data.csv", sep='\t')
df_test = df_test.sample(frac=1).reset_index(drop=True) #shuffle the data

In [14]:
df_test_student = pd.DataFrame()
df_test_student["premise"] = df_test["premise"]
df_test_student["hypothesis"] = df_test["jv_hypothesis"]
df_test_student["label"] = df_test["label"]
df_test_student.head()

Unnamed: 0,premise,hypothesis,label
0,Saya mengunjungi acara gelar di Northampton Un...,Ora tau ana acara gelar ing Universitas Northa...,2
1,"Tayuban adalah desa di kecamatan Panjatan, Kul...",Tayuban iku sawijining desa.,0
2,Berita tentang penipuan perusahaan yang telah ...,"Warta babagan penipuan perusahaan, sing wis pi...",2
3,Kemunduran ini mengejutkan bagi tentara Austro...,Tentara Austro-Prusia yaiku tentara Denmark.,1
4,Fokus terlalu lama pada gadget saat bekerja bi...,Gadget bisa ngobati stres.,1


### Prepare Dataset for Teacher

Dataset from teacher will be from "IndoNLI", and using Indonesian only.

In [15]:
df_train_ds = load_dataset('indonli', split="train")
list_all_train = [df_train_ds['premise'], df_train_ds['hypothesis'], df_train_ds['label']]
df_train_t = pd.DataFrame(list_all_train).transpose()
df_train_t.columns = ['premise', 'hypothesis', 'label']
df_train_t = df_train_t.sample(frac=1).reset_index(drop=True)
display(df_train_t)

Found cached dataset indonli (C:/Users/sufin/.cache/huggingface/datasets/indonli/indonli/1.1.0/d34041bd1d1a555a4bcb4ffdb9fe904778da6f7c5343209fc1485dd68121cb62)


Unnamed: 0,premise,hypothesis,label
0,"Sebelumnya, Barisan Relawan Jalan Perubahan (B...",Bara JP berencana melaporkan Fahri Hamzah ke M...,0
1,Sebuah penelitian yang dilakukan oleh Intensiv...,Semua pasien Corona dirawat tanpa alat bantu p...,2
2,Dalam pengaturan smartphone ada sebuah setting...,Ada setting dimana pengguna tidak dapat melaku...,2
3,Grup musik ini beranggotakan 3 orang yaitu Cit...,Grup musik ini beranggotakan 3 orang.,0
4,"Totok mengatakan, perseroannya berkomitmen unt...",Perseroan toktok berkomitmen untuk memuaskan p...,0
...,...,...,...
10325,AeroConnect sendiri dapat diunduh melalui Appl...,Google PlayStore tidak menyediakan AeroConnect.,2
10326,Susy Susanti sukses mempersembahkan medali ema...,Susy Susanti tidak pernah mengikut Olimpiade.,2
10327,Jakarta (ANTARA News)-Setelah disibukkan denga...,Jennifer Lawrence samasekali tidak punya target.,1
10328,Sengketa proses bisa terjadi salah satunya seb...,Panwas kabupaten/kota juga disebut sebagai mah...,0


In [16]:
print("Count per class train:", df_train_t['label'].value_counts())

Count per class train: 0    3476
2    3439
1    3415
Name: label, dtype: int64


In [17]:
df_valid_ds = load_dataset('indonli', split="validation")
list_all_valid = [df_valid_ds['premise'], df_valid_ds['hypothesis'], df_valid_ds['label']]
df_valid_t = pd.DataFrame(list_all_valid).transpose()
df_valid_t.columns = ['premise', 'hypothesis', 'label']
df_valid_t = df_valid_t.sample(frac=1).reset_index(drop=True)
display(df_valid_t)

Found cached dataset indonli (C:/Users/sufin/.cache/huggingface/datasets/indonli/indonli/1.1.0/d34041bd1d1a555a4bcb4ffdb9fe904778da6f7c5343209fc1485dd68121cb62)


Unnamed: 0,premise,hypothesis,label
0,"Beberapa kali, anggota keluarganya yang dulu—K...",Anggota Keluarga Naito tidak mengenal Hyogo.,2
1,Semula SMMA ditetapkan sebagai cagar alam oleh...,Luas cagar alam SMMA sekarang melebihi 20 ha.,1
2,Alessandria adalah sebuah provinsi di regione ...,Alessandria terletak di Italia.,0
3,"""Ini karena adanya peningkatan pada tarif angk...",Peningkatan tarif kereta api memengaruhi harga...,1
4,"Selama di sekolah, dikisahkan Gomez kerap mena...",Gomez memperhatikan gurunya di sekolah pada ja...,1
...,...,...,...
2192,"Apalagi sekarang, visa tidak diperlukan untuk ...",Banyak turis menyalahgunakan keabsenan visa se...,1
2193,Dibangun dengan tinggi 36 meter dengan 4 tingk...,Bangunan monumen Jam Gadang ini dibangun denga...,0
2194,"""Kalau bunga acuan turun, 3-4 bulan ke depan b...","Bank yang menyesuaikan hanya Bank Indonesia,.",1
2195,Alasan yang diungkapkan PSSI tak lebih karena ...,Dua pemain tetap bisa bergabung dalam persiapa...,1


In [18]:
print("Count per class valid:", df_valid_t['label'].value_counts())

Count per class valid: 0    807
2    749
1    641
Name: label, dtype: int64


In [19]:
df_test_ds = load_dataset('indonli', split="test_lay")
list_all_test = [df_test_ds['premise'], df_test_ds['hypothesis'], df_test_ds['label']]
df_test_t = pd.DataFrame(list_all_test).transpose()
df_test_t.columns = ['premise', 'hypothesis', 'label']
df_test_t = df_test_t.sample(frac=1).reset_index(drop=True)
display(df_test_t)

Found cached dataset indonli (C:/Users/sufin/.cache/huggingface/datasets/indonli/indonli/1.1.0/d34041bd1d1a555a4bcb4ffdb9fe904778da6f7c5343209fc1485dd68121cb62)


Unnamed: 0,premise,hypothesis,label
0,Jordan Nagai adalah seorang aktor pengisi suar...,Jordan Nagai terkenal di kalangan anak-anak.,1
1,Sejak pemerintahan jatuh ke tangan China pada ...,China sangat kejam terutama pada Hong Kong.,1
2,Hak cipta karya yang dibuat oleh pekerja lepas...,Pekerja lepas AS tidak pernah khawatir karena ...,0
3,Sasando adalah sebuah alat musik dawai yang di...,"Dalam bahasa Rote, kata Sandu artinya bergetar.",0
4,Presiden Soekarno mengeluarkan Keputusan Presi...,Hari Kartini diperingati setiap tanggal 21 April.,0
...,...,...,...
2196,Apalagi sekarang sudah ada inisiatif masyaraka...,Terdapat ban yang disewakan oleh masyarakat.,0
2197,Kabupaten ini dibentuk berdasarkan Undang-Unda...,Undang-Undang Nomor 27 Tahun 2008 tidak memkar...,2
2198,Fly Emirates adalah sponsor AC Milan mulai dar...,Bwin.com adalah sponsor Milan sebelum Fly Emir...,0
2199,Sejarah sepak bola di Indonesia diawali dengan...,Sejarah sepak bola di Indonesia diawali oleh S...,1


In [20]:
print("Count per class test:", df_test_t['label'].value_counts())

Count per class test: 0    808
2    764
1    629
Name: label, dtype: int64


## Preprocessing

### Tokenization

In [21]:
# tokenizer = XLMRobertaTokenizer.from_pretrained(MODEL_TYPE)
tokenizer = BertTokenizer.from_pretrained(MODEL_TYPE)

In [22]:
class CompDataset(Dataset):
    def __init__(self, df):
        self.df_data = df
        
    def __getitem__(self, index):
        sentence1 = self.df_data.loc[index, 'premise']
        sentence2 = self.df_data.loc[index, 'hypothesis']
        
        encoded_dict = tokenizer.encode_plus(
            sentence1,
            sentence2,
            add_special_tokens = True,
            max_length = MAX_LEN,
            truncation='longest_first',
            padding = 'max_length',
            return_attention_mask = True,
            return_tensors = 'pt'
        )
        
        padded_token_list = encoded_dict['input_ids'][0]
        att_mask = encoded_dict['attention_mask'][0]
        
        target = torch.tensor(self.df_data.loc[index, 'label'])
        sample = {"input_ids": padded_token_list, "attention_mask": att_mask, "label": target}
        
        return sample
    
    def __len__(self):
        return len(self.df_data)
    

Tokenize student dataset

In [23]:
# train_data_student_cmp = CompDataset(df_train_student)
# valid_data_student_cmp = CompDataset(df_valid_student)
# test_data_student_cmp = CompDataset(df_test_student)

In [24]:
# train_student_loader = DataLoader(
#     train_data_student_cmp,
#     batch_size = BATCH_SIZE
# )

In [25]:
# validation_student_loader = DataLoader(
#     valid_data_student_cmp,
#     batch_size = BATCH_SIZE
# )

In [26]:
# test_student_loader = DataLoader(
#     test_data_student_cmp,
#     batch_size = BATCH_SIZE
# )

Tokenize teacher dataset

In [27]:
train_data_teacher_cmp = CompDataset(df_train_t)
valid_data_teacher_cmp = CompDataset(df_valid_t)
test_data_teacher_cmp = CompDataset(df_test_t)

In [28]:
train_teacher_loader = DataLoader(
    train_data_teacher_cmp,
    batch_size = BATCH_SIZE
)

In [29]:
validation_teacher_loader = DataLoader(
    valid_data_teacher_cmp,
    batch_size = BATCH_SIZE
)

In [30]:
test_teacher_loader = DataLoader(
    test_data_teacher_cmp,
    batch_size = BATCH_SIZE
)

## Hidden State Extraction

Define model for teacher and student

In [31]:
model_teacher = BertForSequenceClassification.from_pretrained(
    MODEL_TYPE, 
    num_labels = 3,
    output_hidden_states=True
)
# model_teacher = model_teacher.to(device)

Some weights of the model checkpoint at bert-base-multilingual-cased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model ch

In [32]:
# model_student = BertForSequenceClassification.from_pretrained(
#     MODEL_TYPE, 
#     num_labels = 3,
#     output_hidden_states=True
# )

In [33]:
def extract_final_hidden_state(the_model, the_dataloader):
    the_model.eval()
    
    context_vector = []
    labels = []
    
    with torch.no_grad():
        t = tqdm(enumerate(the_dataloader), total=len(the_dataloader), colour="green", leave=True, position=0)
        for batch, datas in t:
#             inpt_ids = datas["input_ids"].to(device)
#             att_mask = datas["attention_mask"].to(device)
#             label = datas["label"].to(device)
            inpt_ids = datas["input_ids"]
            att_mask = datas["attention_mask"]
            label = datas["label"]

            outputs = the_model(
                input_ids = inpt_ids,
                attention_mask = att_mask,
                labels = label
            )

            # Take last hidden state
    #         hidden_states = outputs.hidden_states[0] # torch.Size([1, 512, 768]) # this is the first layer
            last_hidden_state = outputs.hidden_states[-1:][0] # torch.Size([1, 512, 768]) # take the last hidden layer

            # Take CLS only
            cont_vec = last_hidden_state[0][0] # torch.Size([768])
            cont_vec = cont_vec[None, :] # torch.Size([1, 768])
            context_vector.append(cont_vec) 
            labels.append(label)

            t.set_description(f"Extracting CLS token vector [{batch+1}/{len(the_dataloader)}]...")
    
    return context_vector, labels

In [34]:
# context_vec_student, labels_student = extract_final_hidden_state(model_student, train_student_loader)

In [35]:
gc.collect()

1136

In [36]:
context_vec_train_teacher, labels_train_teacher = extract_final_hidden_state(model_teacher, train_teacher_loader)

Extracting CLS token vector [1381/5165]...:  27%|[32m████████▌                       [0m| 1381/5165 [51:40<2:21:18,  2.24s/it][0mBe aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Extracting CLS token vector [1504/5165]...:  29%|[32m█████████▎                      [0m| 1504/5165 [56:16<2:15:42,  2.22s/it][0mBe aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Extracting CLS token vector [2344/5165]...:  45%|[32m█████████████▌                [0m| 2344/5165 [1:28:15<1:49:07,  2.32s/it][0mBe aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the 

In [37]:
context_vec_valid_teacher, labels_valid_teacher = extract_final_hidden_state(model_teacher, validation_teacher_loader)

Extracting CLS token vector [1099/1099]...: 100%|[32m██████████████████████████████████[0m| 1099/1099 [42:59<00:00,  2.35s/it][0m


In [38]:
context_vec_test_teacher, labels_test_teacher = extract_final_hidden_state(model_teacher, test_teacher_loader)

Extracting CLS token vector [1101/1101]...: 100%|[32m██████████████████████████████████[0m| 1101/1101 [43:31<00:00,  2.37s/it][0m


In [39]:
class EmbeddingDataset(Dataset):
    def __init__(self, array_embed, array_lbl):
        self.embedding = array_embed
        self.label = array_lbl
    def __len__(self):
        return len(self.label)
    def __getitem__(self, idx):
        self.exact_data = self.embedding[idx]
        self.exact_label = self.label[idx]
        
        returned_dict = {
            "x": self.exact_data,
            "labels": self.exact_label
        }
        
        return returned_dict

In [40]:
dataset_train_teacher = EmbeddingDataset(context_vec_train_teacher, labels_train_teacher)

In [41]:
dataset_valid_teacher = EmbeddingDataset(context_vec_valid_teacher, labels_valid_teacher)

In [42]:
dataset_test_teacher = EmbeddingDataset(context_vec_test_teacher, labels_test_teacher)

In [43]:
# dataset_student = EmbeddingDataset(context_vec_student, labels_student)

## Classifier

Classifier for Teacher (scenario zero-shot from teacher directly to classify student data)

In [70]:
class ClassifierTeacher(nn.Module):
    def __init__(self, n_features, epsilon, n_labels):
        super(ClassifierTeacher, self).__init__()
        self.norm = nn.BatchNorm1d(
            num_features = n_features,
            eps = epsilon
        )
        self.fc = nn.Linear(
            in_features = n_features,
            out_features = n_labels
        )
        self.softm = nn.Softmax(
            dim=1
        )
    def forward(self, x):
        x = x[0,:,:]
        y = self.norm(x)
        y = self.fc(y)
        y = self.softm(y)
        
        return {"logits": y}

In [71]:
classifier_teacher = ClassifierTeacher(
    n_features = 768, # CLS token n_features
    epsilon = BATCH_NORM_EPSILON,
    n_labels = 3 # 3 labels: entails, neutral, and contradict
)
classifier_teacher = classifier_teacher.to(device)

## Training

Collect garbage

In [72]:
gc.collect()

2601

Function to compute metrics

In [73]:
def compute_metrics(p):
    pred, labels = p
    pred = np.argmax(pred, axis=1)

    accuracy = accuracy_score(y_true=labels, y_pred=pred)
    recall = recall_score(y_true=labels, y_pred=pred, average='micro')
    precision = precision_score(y_true=labels, y_pred=pred, average='micro')
    f1 = f1_score(y_true=labels, y_pred=pred, average='weighted')

    return {"accuracy": accuracy, "precision": precision, "recall": recall, "f1_score": f1}

Training using Trainer from Huggingface

In [79]:
training_args_scenario_1 = TrainingArguments(
    output_dir=MODEL_PATH + "teacher-zero-shot/",
    evaluation_strategy="epoch",
    weight_decay=LAMBDA_L2,
    save_strategy="no",
    overwrite_output_dir=True,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    num_train_epochs=NUM_EPOCHS,
    seed=101,
    learning_rate=L_RATE,
    report_to="wandb", # "none"
    run_name="scenario1-teacher-zero-shot"
)

Error in callback <function _WandbInit._resume_backend at 0x00000234748B5F70> (for pre_run_cell):


ConnectionResetError: [WinError 10054] An existing connection was forcibly closed by the remote host

Error in callback <function _WandbInit._pause_backend at 0x0000023420F4AA60> (for post_run_cell):


ConnectionResetError: [WinError 10054] An existing connection was forcibly closed by the remote host

In [78]:
class CustomTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.get("labels")
        data = inputs.get("x")
        # forward pass
        print("data", data)
        outputs = model(data)
        logits = outputs.get("logits")
        loss_func_teacher = nn.CrossEntropyLoss()
        loss = loss_func_teacher(logits, labels)
        # compute custom loss (suppose one has 3 labels with different weights)
#         loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 2.0, 3.0]))
#         loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
        return (loss, outputs) if return_outputs else loss

Error in callback <function _WandbInit._resume_backend at 0x00000234748B5F70> (for pre_run_cell):


ConnectionResetError: [WinError 10054] An existing connection was forcibly closed by the remote host

Error in callback <function _WandbInit._pause_backend at 0x0000023420F4AA60> (for post_run_cell):


ConnectionResetError: [WinError 10054] An existing connection was forcibly closed by the remote host

In [76]:
trainer_scenario_1 = CustomTrainer(
    model=classifier_teacher,
    args=training_args_scenario_1,
    train_dataset=dataset_train_teacher,
    eval_dataset=dataset_valid_teacher,
    compute_metrics=compute_metrics,
)

In [77]:
trainer_scenario_1.train()

inputs {'x': tensor([[[-0.0740,  0.0902,  0.0220,  ...,  0.0896, -0.0318,  0.0350]],

        [[-0.0879, -0.1221, -0.1729,  ...,  0.1188,  0.1163, -0.1308]]],
       device='cuda:0')}


ValueError: Expected more than 1 value per channel when training, got input size torch.Size([1, 768])

In [None]:
trainer_scenario_1.evaluate()

Save the model

In [None]:
trainer_scenario_1.save_model(MODEL_PATH + "teacher-zero-shot/")

## Test

In [None]:
prediction_scenario_1 = trainer_scenario_1.predict(dataset_test_teacher)

In [None]:
print("Testing metrics:", prediction_scenario_1[2])

In [None]:
wandb.finish()