# Fine tuning, for the n=2 case

In [1]:
import torch
import torch.nn as nn

# Needed for parallel 
from collections import OrderedDict

# For training 
from network_architecture_v2 import MyBertForSequenceClassification

# For fine tuning
from datasets import load_dataset #, load_metric
from transformers import BertTokenizer
from transformers import Trainer, TrainingArguments
import numpy as np

In [2]:
# Load dataset
dataset = load_dataset('glue', 'sst2')

# I believe this is the tokenizer I used... 
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

def tokenize_function(examples):
    return tokenizer(examples['sentence'], padding="max_length", 
                     max_length=128, truncation=True)

tokenized_datasets = dataset.map(tokenize_function, batched=True)



Map:   0%|          | 0/67349 [00:00<?, ? examples/s]

Map:   0%|          | 0/872 [00:00<?, ? examples/s]

Map:   0%|          | 0/1821 [00:00<?, ? examples/s]

# Load the parallel model

This involves a bit more code

In [3]:
checkpoint_0 = torch.load('bert-save-2/model_checkpoint_0_batch_idx=80000')
checkpoint_1 = torch.load('bert-save-2/model_checkpoint_1_batch_idx=80000')

In [4]:
keys_0 = checkpoint_0['model_state'].keys()
keys_1 = checkpoint_1['model_state'].keys()

In [5]:
# Ugh, this is 
new_dict = OrderedDict()
keys_0 = checkpoint_0['model_state'].keys()
counter = 0
for key in keys_0:
    if 'parallel_nn' in key:
        split = key.split('.')
        split[1] = 'serial_nn'
        if int(split[2]) > counter:
            counter = int(split[2])
            
        split.insert(3, 'layer')
        new_key = '.'.join(split[1:])
        new_dict[new_key] = checkpoint_0['model_state'][key]
    else:
        new_key = key
        if 'close_nsp' in key:
            # print(key)
            split = key.split('.')
            split[0] = 'close_nn_nsp'
            new_key = '.'.join(split)
        if 'close_mlm' in key:
            # print(key)
            split = key.split('.')
            split[0] = 'close_nn_mlm'
            new_key = '.'.join(split)
        
        new_dict[new_key] = checkpoint_0['model_state'][key]
print(counter)

# Now for the remaining parts? 
keys_1 = checkpoint_1['model_state'].keys()
for key in keys_1:
    if 'parallel_nn' in key:
        split = key.split('.')
        split[1] = 'serial_nn'
        split[2] = str(int(split[2]) + counter + 1)
        split.insert(3, 'layer')
        
        new_key = '.'.join(split[1:])
        new_dict[new_key] = checkpoint_1['model_state'][key]
    else:
        new_dict[key] = checkpoint_1['model_state'][key]

16


In [6]:
model_parallel = torch.load('serialnet_bert_32')
model_parallel.load_state_dict(new_dict)

<All keys matched successfully>

In [7]:
training_parallel = MyBertForSequenceClassification(model_parallel)

# Now let's train

https://proceedings.neurips.cc/paper_files/paper/2023/file/095a6917768712b7ccc61acbeecad1d8-Supplemental-Conference.pdf for hyperparameters


In [8]:
training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=3,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    learning_rate=1e-4,
    adam_beta1=0.9,
    adam_beta2=0.988,
    adam_epsilon=1e-6,
    dataloader_drop_last=True,
    warmup_steps=100,
    weight_decay=1e-4,
    logging_dir='./logs',
    logging_steps=10,
    evaluation_strategy="epoch",
)


In [9]:
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    accuracy = (predictions == labels).astype(np.float32).mean().item()
    return {"accuracy": accuracy}

In [10]:
# Initialize the Trainer
trainer = Trainer(
    model=training_parallel,
    args=training_args,
    train_dataset=tokenized_datasets['train'],
    eval_dataset=tokenized_datasets['validation'],
    compute_metrics=compute_metrics
)


Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [11]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
1,0.279,0.473321,0.8125
2,0.1471,0.490394,0.831019
3,0.0986,0.569,0.834491


TrainOutput(global_step=6312, training_loss=0.2275288417057166, metrics={'train_runtime': 6447.1197, 'train_samples_per_second': 31.339, 'train_steps_per_second': 0.979, 'total_flos': 0.0, 'train_loss': 0.2275288417057166, 'epoch': 3.0})

In [12]:
sum(p.numel() for p in training_parallel.parameters() if p.requires_grad)

251241218

In [13]:
training_parallel = MyBertForSequenceClassification(model_parallel)
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
1,0.1704,0.482051,0.822917
2,0.062,0.536374,0.826389
3,0.0384,0.645903,0.83912


TrainOutput(global_step=6312, training_loss=0.08186612888392773, metrics={'train_runtime': 6501.9192, 'train_samples_per_second': 31.075, 'train_steps_per_second': 0.971, 'total_flos': 0.0, 'train_loss': 0.08186612888392773, 'epoch': 3.0})

In [14]:
# Load dataset
dataset = load_dataset('glue', 'cola')

# I believe this is the tokenizer I used... 
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

def tokenize_function(examples):
    return tokenizer(examples['sentence'], padding="max_length", 
                     max_length=128, truncation=True)

tokenized_datasets = dataset.map(tokenize_function, batched=True)



In [15]:
training_parallel = MyBertForSequenceClassification(model_parallel)


In [17]:
training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=3,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    learning_rate=1e-4,
    adam_beta1=0.9,
    adam_beta2=0.988,
    adam_epsilon=1e-6,
    dataloader_drop_last=True,
    warmup_steps=100,
    weight_decay=1e-4,
    logging_dir='./logs',
    logging_steps=10,
    evaluation_strategy="epoch",
)

# Initialize the Trainer
trainer = Trainer(
    model=training_parallel,
    args=training_args,
    train_dataset=tokenized_datasets['train'],
    eval_dataset=tokenized_datasets['validation'],
    compute_metrics=compute_metrics
)


Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [18]:
# For COLA
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
1,0.6187,0.6259,0.688477
2,0.553,0.616845,0.683594
3,0.4166,0.719184,0.676758


TrainOutput(global_step=801, training_loss=0.5481547378794234, metrics={'train_runtime': 613.5173, 'train_samples_per_second': 41.813, 'train_steps_per_second': 1.306, 'total_flos': 0.0, 'train_loss': 0.5481547378794234, 'epoch': 3.0})

# MRPC 

Reload models, then retrain 

In [19]:
model_parallel = torch.load('serialnet_bert_32')
model_parallel.load_state_dict(new_dict)

<All keys matched successfully>

In [20]:
training_parallel = MyBertForSequenceClassification(model_parallel)

In [22]:
# Load dataset
dataset = load_dataset('glue', 'mrpc')

# Tokenization function
def tokenize_function(examples):
    return tokenizer(
        examples["sentence1"], 
        examples["sentence2"], 
        padding="max_length", 
        truncation=True,
        max_length=256
    )
    
tokenized_datasets = dataset.map(tokenize_function, batched=True)

tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
tokenized_datasets.set_format("torch", columns=["input_ids", "attention_mask", "labels"])


In [23]:
training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    learning_rate=2e-5,
    adam_beta1=0.9,
    adam_beta2=0.988,
    adam_epsilon=1e-8,
    dataloader_drop_last=True,
    warmup_steps=5,
    weight_decay=1e-4,
    logging_dir='./logs',
    logging_steps=10,
    evaluation_strategy="epoch",
)

# Initialize the Trainer
trainer = Trainer(
    model=training_parallel,
    args=training_args,
    train_dataset=tokenized_datasets['train'],
    eval_dataset=tokenized_datasets['validation'],
    compute_metrics=compute_metrics
)


Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [24]:
# For MRPC
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
1,0.5666,0.586236,0.7025
2,0.6053,0.578408,0.7125
3,0.4867,0.647816,0.7125
4,0.346,0.716277,0.72
5,0.2461,0.817811,0.7275


TrainOutput(global_step=1145, training_loss=0.44530555676164585, metrics={'train_runtime': 1163.2257, 'train_samples_per_second': 15.767, 'train_steps_per_second': 0.984, 'total_flos': 0.0, 'train_loss': 0.44530555676164585, 'epoch': 5.0})