<a href="https://colab.research.google.com/github/an-eve/nlp-nli-idioms/blob/main/Fine_Tuning_BERT_on_IMPLI.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Fine-tuning BERT on IMPLI dataset and Exploring the dependence on the amount of data used

In [1]:
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [2]:
! pip install datasets
! pip install -U accelerate
! pip install -U transformers

In [3]:
import torch
import numpy as np
import pandas as pd
import os
import copy
import datetime
import csv

from transformers import (BertTokenizer,
                          AutoModelForSequenceClassification,
                          Trainer,
                          TrainingArguments)
from datasets import (Dataset,
                      load_dataset,
                      concatenate_datasets,
                      load_metric,
                      ClassLabel,
                      Features)

In [4]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True, timeout_ms = 0)

In [5]:
base_dir = '/content/drive/My Drive/'

In [6]:
BATCH_SIZE = 32

In [7]:
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(DEVICE)

## Model

In [8]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)

In [9]:
def tokenize_function(examples):
    return tokenizer(examples["premise"], examples["hypothesis"], padding="max_length", truncation=True)

In [10]:
model = AutoModelForSequenceClassification.from_pretrained("an-eve/bert-base-uncased-mnli-2-labels")

In [11]:
metric = load_metric('glue', "mnli")
metric_name = "accuracy"

  metric = load_metric('glue', "mnli")
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


In [12]:
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return metric.compute(predictions=predictions, references=labels)

## Cleaning the files

In [14]:
def transfer_lines(input_file_path, output_file_path, search_expression):
    with open(input_file_path, 'r', newline='') as input_file, open(output_file_path, 'a', newline='') as output_file:
        reader = csv.reader(input_file, delimiter='\t')
        writer = csv.writer(output_file, delimiter='\t')

        lines_to_remove = []

        for row in reader:

            if sum(column.lower().count(search_expression.lower()) for column in row) == 1:
                writer.writerow(row)
                lines_to_remove.append(reader.line_num)


        input_file.seek(0)


        updated_rows = [row for i, row in enumerate(reader, start=1) if i not in lines_to_remove]


    with open(input_file_path, 'w', newline='') as input_file:
        writer = csv.writer(input_file, delimiter='\t')
        writer.writerows(updated_rows)

In [None]:
input_file_path = '/content/train_ne.tsv'
output_file_path = '/content/test_ne_wo_idioms_col.tsv'


input_file_path1 = '/content/train_e.tsv'
output_file_path1 = '/content/test_e_wo_idioms_col.tsv'

search_expression = 'piss off'

transfer_lines(input_file_path, output_file_path, search_expression)
transfer_lines(input_file_path1, output_file_path1, search_expression)

In [None]:
def count_lines(file_path):
    with open(file_path, 'r', newline='') as file:
        line_count = sum(1 for line in file)
    return line_count

In [None]:
file_path = '/content/test_ne_wo_idioms_col.tsv'
num_lines = count_lines(file_path)
print(f'The number of lines in {file_path} is: {num_lines}')
file_path = '/content/train_ne.tsv'
num_lines = count_lines(file_path)
print(f'The number of lines in {file_path} is: {num_lines}')

The number of lines in /content/test_ne_wo_idioms_col.tsv is: 963
The number of lines in /content/train_ne.tsv is: 6584


In [None]:
file_path = '/content/test_e_wo_idioms_col.tsv'
num_lines = count_lines(file_path)
print(f'The number of lines in {file_path} is: {num_lines}')
file_path = '/content/train_e.tsv'
num_lines = count_lines(file_path)
print(f'The number of lines in {file_path} is: {num_lines}')

The number of lines in /content/test_e_wo_idioms_col.tsv is: 2387
The number of lines in /content/train_e.tsv is: 13785


In [None]:
def remove_identical_rows(file_path):
    with open(file_path, 'r', newline='') as file:
        reader = csv.reader(file, delimiter='\t')
        rows_set = set(tuple(row) for row in reader)

    # Write unique rows back to the original file
    with open(file_path, 'w', newline='') as file:
        writer = csv.writer(file, delimiter='\t')
        writer.writerows(rows_set)


In [None]:
file_path = '/content/test_ne_wo_idioms_col.tsv'
remove_identical_rows(file_path)

In [None]:
file_path ='/content/train_ne.tsv'
remove_identical_rows(file_path)

In [None]:
file_path = '/content/test_e_wo_idioms_col.tsv'
remove_identical_rows(file_path)

In [None]:
file_path = '/content/train_e.tsv'
remove_identical_rows(file_path)

## Uploading and Arranging IMPLI data

Preprocessing

In [None]:
url_train_ne = "https://github.com/an-eve/nlp-nli-idioms/raw/main/dataset/train_ne.tsv"
url_train_e = "https://github.com/an-eve/nlp-nli-idioms/raw/main/dataset/train_e.tsv"
url_test_ne = "https://github.com/an-eve/nlp-nli-idioms/raw/main/dataset/test_ne_wo_idioms_col.tsv"
url_test_e = "https://github.com/an-eve/nlp-nli-idioms/raw/main/dataset/test_e_wo_idioms_col.tsv"

In [None]:
train_ne_data = load_dataset('csv', data_files=url_train_ne, delimiter='\t', column_names = ['premise', 'hypothesis', 'label'], split='train')
train_e_data = load_dataset('csv', data_files=url_train_e, delimiter='\t', column_names = ['premise', 'hypothesis', 'label'], split='train')
test_ne_data = load_dataset('csv', data_files=url_test_ne, delimiter='\t', column_names = ['premise', 'hypothesis', 'label'], split='train')
test_e_data = load_dataset('csv', data_files=url_test_e, delimiter='\t', column_names = ['premise', 'hypothesis', 'label'], split='train')

Downloading data:   0%|          | 0.00/458k [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]

Downloading data:   0%|          | 0.00/793k [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]

Downloading data:   0%|          | 0.00/61.9k [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]

Downloading data:   0%|          | 0.00/113k [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]

In [None]:
#test_ne_data = test_ne_data.remove_columns("idiom")
#test_e_data = test_e_data.remove_columns("idiom")

In [None]:
def modify_label_ne(example):
    example['label'] = 1
    return example

def modify_label_e(example):
    example['label'] = 0
    return example

train_ne_data = train_ne_data.map(modify_label_ne)
train_e_data = train_e_data.map(modify_label_e)
test_ne_data = test_ne_data.map(modify_label_ne)
test_e_data = test_e_data.map(modify_label_e)

Map:   0%|          | 0/6584 [00:00<?, ? examples/s]

Map:   0%|          | 0/13785 [00:00<?, ? examples/s]

Map:   0%|          | 0/963 [00:00<?, ? examples/s]

Map:   0%|          | 0/2387 [00:00<?, ? examples/s]

In [None]:
new_features = train_ne_data.features.copy()
new_features['label'] = ClassLabel(num_classes = 2, names=["entailment", "non-entailment"])

train_ne_data = train_ne_data.cast(new_features)
train_e_data = train_e_data.cast(new_features)
test_ne_data = test_ne_data.cast(new_features)
test_e_data = test_e_data.cast(new_features)

Casting the dataset:   0%|          | 0/6584 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/13785 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/963 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/2387 [00:00<?, ? examples/s]

In [None]:
print(train_ne_data)
print(train_e_data)
print(test_ne_data)
print(test_e_data)

Dataset({
    features: ['premise', 'hypothesis', 'label'],
    num_rows: 6584
})
Dataset({
    features: ['premise', 'hypothesis', 'label'],
    num_rows: 13785
})
Dataset({
    features: ['premise', 'hypothesis', 'label'],
    num_rows: 963
})
Dataset({
    features: ['premise', 'hypothesis', 'label'],
    num_rows: 2387
})


In [None]:
print(train_ne_data[:2], '\n')
print(train_e_data[:2], '\n')
print(test_ne_data[:2], '\n')
print(test_e_data[:2])

{'premise': ['13 , 1991, in the fourth week of the U.S.-led air war against Iraq, a British Tornado warplane dropped a bomb that was intended to take out a key river bridge at Fallujah.', "16) of the evacuation of Attica that the Athenians took it so badly because it was like leaving one's polis; this is on the face of it a paradox because they were going from their country demes to the polis."], 'hypothesis': ['13 , 1991, in the fourth week of the U.S.-led air war against Iraq, a British Tornado warplane gave decisive news that was intended to take out a key river bridge at Fallujah.', "16) of the evacuation of Attica that the Athenians took it so badly because it was like leaving one's polis; this is Encountering a paradox because they were going from their country demes to the polis."], 'label': [1, 1]} 

{'premise': ["( 11–12 February 1778) as if to add insult to injury, Leopold received Mozart's letter telling him that he had not yet finished his commissions for the Dutchman :", '

Combining entailed and non-entailed data

In [None]:
train_data = concatenate_datasets([train_ne_data, train_e_data])
train_data = train_data.shuffle(seed=128)

In [None]:
print(train_data, '\n')
train_data[:4]

Dataset({
    features: ['premise', 'hypothesis', 'label'],
    num_rows: 20369
}) 



{'premise': ['I used to play with them for hours on end and they tore a big hole in the back of my coat during one rough and tumble.',
  '( I know that in wildlife programmes on TV the professional naturalists never interfere with natural savagery, but since the size of the moggy population in the area is unnatural, I felt it was right to lend a hand . )',
  "‘ He's the type of player who blows hot and cold.",
  'A NINE - year - old boy who went on a six - month crime spree and carried out more than 50 thefts, went back to school yesterday after police were forced to let him walk free.'],
 'hypothesis': ['I used to play with them for hours on end and they tore a big hole in the back of my coat during one rough activity.',
  '( I know that in wildlife programmes on TV the professional naturalists never interfere with natural savagery, but since the size of the moggy population in the area is unnatural, I felt it was right to shake hands . )',
  "‘ He's the type of player who behaves inc

In [None]:
test_data = concatenate_datasets([test_ne_data, test_e_data])
test_data = test_data.shuffle(seed=128)

In [None]:
print(test_data, '\n')
test_data[:4]

Dataset({
    features: ['premise', 'hypothesis', 'label'],
    num_rows: 3350
}) 



{'premise': ['The two went hand in hand until the later nineteenth century.',
  'More and more researchers in natural language processing are investing their efforts in dictionaries and lexicons, and efforts are being made to use machine - readable dictionaries instead of constructing lexicons from scratch’.',
  "Articles in this October 2005 issue: Question And Answer Travel and Advertising Scam European E-Marketplaces Libyan Relations With Europe Trading Problems Report 2006 - the European Year of Workers' mobility Consultations Information Roundup Libyan Relations With Europe Libya is coming in from the cold.",
  'Since the SADS - L is normally used in a face - to - face interview situation these items are mostly in the form of questions and here the original wording was retained.'],
 'hypothesis': ['The two held hands until the later nineteenth century.',
  'More and more researchers in natural language processing are investing their efforts in dictionaries and lexicons, and effort

Dividing the training set into several folds

In [None]:
num_shards = 9

train_sets = [train_data.shard(num_shards=num_shards, index=i, contiguous=True) for i in range(num_shards)]

In [None]:
print(train_sets[0])

Dataset({
    features: ['premise', 'hypothesis', 'label'],
    num_rows: 2264
})


Uploading on Hub

In [None]:
test_data.push_to_hub("an-eve/test_idioms", private=True)
test_ne_data.push_to_hub("an-eve/test_non_entailment_idioms", private=True)
test_e_data.push_to_hub("an-eve/test_entailment_idioms", private=True)

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/4 [00:00<?, ?ba/s]

README.md:   0%|          | 0.00/434 [00:00<?, ?B/s]

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

README.md:   0%|          | 0.00/429 [00:00<?, ?B/s]

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/3 [00:00<?, ?ba/s]

README.md:   0%|          | 0.00/430 [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/datasets/an-eve/test_entailment_idioms/commit/ee4ed6dc4c091e501968c5d9484db3e699842055', commit_message='Upload dataset', commit_description='', oid='ee4ed6dc4c091e501968c5d9484db3e699842055', pr_url=None, pr_revision=None, pr_num=None)

In [None]:
for i in range(num_shards):
    concatenated_data = concatenate_datasets([train_sets[j] for j in range(i+1)])
    #print(concatenated_data, '\n')
    concatenated_data.push_to_hub(f"an-eve/train_sets_{i+1}_idioms", private=True)


Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/3 [00:00<?, ?ba/s]

README.md:   0%|          | 0.00/452 [00:00<?, ?B/s]

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/5 [00:00<?, ?ba/s]

README.md:   0%|          | 0.00/454 [00:00<?, ?B/s]

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/7 [00:00<?, ?ba/s]

README.md:   0%|          | 0.00/453 [00:00<?, ?B/s]

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/10 [00:00<?, ?ba/s]

README.md:   0%|          | 0.00/455 [00:00<?, ?B/s]

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/12 [00:00<?, ?ba/s]

README.md:   0%|          | 0.00/456 [00:00<?, ?B/s]

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/14 [00:00<?, ?ba/s]

README.md:   0%|          | 0.00/456 [00:00<?, ?B/s]

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/16 [00:00<?, ?ba/s]

README.md:   0%|          | 0.00/454 [00:00<?, ?B/s]

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/19 [00:00<?, ?ba/s]

README.md:   0%|          | 0.00/454 [00:00<?, ?B/s]

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/21 [00:00<?, ?ba/s]

README.md:   0%|          | 0.00/438 [00:00<?, ?B/s]

## Uploading datsets and Tokenization

In [13]:
test_data = load_dataset("an-eve/test_idioms", split='train')
tokenized_test_dataset = test_data.map(tokenize_function, batched=True)

Using the latest cached version of the dataset since an-eve/test_idioms couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'default' at /root/.cache/huggingface/datasets/an-eve___test_idioms/default/0.0.0/5dfb32c27635d48fca6d50b41308aa4189a3700f (last modified on Thu Jan 11 23:57:22 2024).


In [14]:
test_ne_data = load_dataset("an-eve/test_non_entailment_idioms", split='train')
tokenized_test_ne_dataset = test_ne_data.map(tokenize_function, batched=True)

Using the latest cached version of the dataset since an-eve/test_non_entailment_idioms couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'default' at /root/.cache/huggingface/datasets/an-eve___test_non_entailment_idioms/default/0.0.0/223bb56bc60297f983567884ef898f456efbdafa (last modified on Thu Jan 11 23:57:30 2024).


In [15]:
test_e_data = load_dataset("an-eve/test_entailment_idioms", split='train')
tokenized_test_e_dataset = test_e_data.map(tokenize_function, batched=True)

Using the latest cached version of the dataset since an-eve/test_entailment_idioms couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'default' at /root/.cache/huggingface/datasets/an-eve___test_entailment_idioms/default/0.0.0/ee4ed6dc4c091e501968c5d9484db3e699842055 (last modified on Thu Jan 11 23:57:39 2024).


In [16]:
num_shards = 9
train_sets = []

for i in range(num_shards):
     train_set= load_dataset(f"an-eve/train_sets_{i+1}_idioms", split='train')
     train_sets.append(train_set)

Using the latest cached version of the dataset since an-eve/train_sets_1_idioms couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'default' at /root/.cache/huggingface/datasets/an-eve___train_sets_1_idioms/default/0.0.0/aa59f7f5549d67e09dbe46575ccd468bb9c8b32e (last modified on Thu Jan 11 23:58:37 2024).
Using the latest cached version of the dataset since an-eve/train_sets_2_idioms couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'default' at /root/.cache/huggingface/datasets/an-eve___train_sets_2_idioms/default/0.0.0/a50cf50cdca7cc9c00db2df7f3ed5a626c8fb0a4 (last modified on Thu Jan 11 23:58:43 2024).
Using the latest cached version of the dataset since an-eve/train_sets_3_idioms couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'default' at /root/.cache/huggingface/datasets/an-eve___train_sets_3_idioms/default/0.0.0/b0f1498885bc4bccdf46cb8811674d794469eaec (last modifie

In [17]:
tokenized_train_sets = []

for i in range(num_shards):
    tokenized_data = train_sets[i].map(tokenize_function, batched=True)
    tokenized_train_sets.append(tokenized_data)

In [18]:
tokenized_train_sets[2]

Dataset({
    features: ['premise', 'hypothesis', 'label', 'input_ids', 'token_type_ids', 'attention_mask'],
    num_rows: 6791
})

## Fine-tuning BERT on IMPLI gradually increasing the amount of data

Evaluation on test set without fine-tuning

In [21]:
eval_folder = base_dir + "BERT-test-idioms-full-" + str(datetime.datetime.now().timestamp())

if os.path.exists(eval_folder) == False:
    os.mkdir(eval_folder)


args_test = TrainingArguments(eval_folder,
         per_device_eval_batch_size=BATCH_SIZE)

eval_trainer = Trainer(
    model=model,
    args=args_test,
    train_dataset=tokenized_test_dataset,
    eval_dataset=tokenized_test_dataset,
    compute_metrics=compute_metrics)

metrics = eval_trainer.evaluate()

eval_trainer.log_metrics("eval", metrics)
eval_trainer.save_metrics("eval", metrics)

***** eval metrics *****
  eval_accuracy           =     0.5991
  eval_loss               =     1.7657
  eval_runtime            = 0:00:25.93
  eval_samples_per_second =    129.155
  eval_steps_per_second   =      2.043


In [22]:
eval_folder = base_dir + "BERT-test-idioms-full-ne-" + str(datetime.datetime.now().timestamp())

if os.path.exists(eval_folder) == False:
    os.mkdir(eval_folder)


args_test = TrainingArguments(eval_folder,
         per_device_eval_batch_size=BATCH_SIZE)

eval_trainer = Trainer(
    model=model,
    args=args_test,
    train_dataset=tokenized_test_ne_dataset,
    eval_dataset=tokenized_test_ne_dataset,
    compute_metrics=compute_metrics)

metrics = eval_trainer.evaluate()

eval_trainer.log_metrics("eval", metrics)
eval_trainer.save_metrics("eval", metrics)

***** eval metrics *****
  eval_accuracy           =     0.2233
  eval_loss               =     3.4711
  eval_runtime            = 0:00:07.28
  eval_samples_per_second =    132.184
  eval_steps_per_second   =      2.196


In [23]:
eval_folder = base_dir + "BERT-test-idioms-full-e-" + str(datetime.datetime.now().timestamp())

if os.path.exists(eval_folder) == False:
    os.mkdir(eval_folder)


args_test = TrainingArguments(eval_folder,
         per_device_eval_batch_size=BATCH_SIZE)

eval_trainer = Trainer(
    model=model,
    args=args_test,
    train_dataset=tokenized_test_e_dataset,
    eval_dataset=tokenized_test_e_dataset,
    compute_metrics=compute_metrics)

metrics = eval_trainer.evaluate()

eval_trainer.log_metrics("eval", metrics)
eval_trainer.save_metrics("eval", metrics)

***** eval metrics *****
  eval_accuracy           =     0.7507
  eval_loss               =     1.0776
  eval_runtime            = 0:00:17.91
  eval_samples_per_second =    133.248
  eval_steps_per_second   =      2.121


Evaluation on test set with incremental fine-tuning on train sets

In [19]:
model_mnli = copy.deepcopy(model)

In [24]:
# Training function for convenience

def train_eval_func(model, ind, type=None):

    folder = base_dir + f"BERT-test-idioms-{type}-{ind+1}-" + str(datetime.datetime.now().timestamp())

    if os.path.exists(folder) == False:
        os.mkdir(folder)

    if type == "full":
        eval_set = tokenized_test_dataset
    elif type == "ne":
        eval_set = tokenized_test_ne_dataset
    else:
        eval_set = tokenized_test_e_dataset


    args = TrainingArguments(
    output_dir=folder,
    evaluation_strategy = "epoch",
    save_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    num_train_epochs=1,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model=metric_name,
    resume_from_checkpoint=True,
    push_to_hub=False)

    trainer = Trainer(
    model,
    args,
    train_dataset=tokenized_train_sets[ind],
    eval_dataset=eval_set,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics)

    trainer.train()

    metrics = trainer.evaluate()

    trainer.log_metrics("eval", metrics)
    trainer.save_metrics("eval", metrics)

In [25]:
# Evaluation function for convenience

def eval_func(model, ind, type=None):

    folder = base_dir + f"BERT-test-idioms-{type}-{ind+1}-" + str(datetime.datetime.now().timestamp())

    if os.path.exists(folder) == False:
        os.mkdir(folder)

    if type == "full":
        eval_set = tokenized_test_dataset
    elif type == "ne":
        eval_set = tokenized_test_ne_dataset
    else:
        eval_set = tokenized_test_e_dataset

    args_test = TrainingArguments(folder,
         per_device_eval_batch_size=BATCH_SIZE)

    eval_trainer = Trainer(
        model=model,
        args=args_test,
        train_dataset=eval_set,
        eval_dataset=eval_set,
        compute_metrics=compute_metrics)

    metrics = eval_trainer.evaluate()

    eval_trainer.log_metrics("eval", metrics)
    eval_trainer.save_metrics("eval", metrics)

In [26]:
num_shards = 9

for i in range(num_shards):

    model = copy.deepcopy(model_mnli)

    train_eval_func(model, i, type="full")
    eval_func(model, i, type="ne")
    eval_func(model, i, type="e")

Epoch,Training Loss,Validation Loss,Accuracy
1,No log,0.602985,0.693731


Epoch,Training Loss,Validation Loss,Accuracy
1,No log,0.586979,0.735821


Epoch,Training Loss,Validation Loss,Accuracy
1,No log,0.615916,0.757612


Epoch,Training Loss,Validation Loss,Accuracy
1,No log,0.606347,0.779104


Epoch,Training Loss,Validation Loss,Accuracy
1,No log,0.588173,0.8


Epoch,Training Loss,Validation Loss,Accuracy
1,No log,0.498443,0.818209


Epoch,Training Loss,Validation Loss,Accuracy
1,No log,0.469387,0.82597


Epoch,Training Loss,Validation Loss,Accuracy
1,0.2481,0.462853,0.84


Epoch,Training Loss,Validation Loss,Accuracy
1,0.2567,0.414177,0.854925


In [34]:
torch.cuda.empty_cache()

## Evaluating on MNLI after fine-tuning on IMPLI