In [1]:
import os
import torch
import numpy as np
import torch
import pickle
from torch.utils.data import DataLoader
from datasets import load_dataset
from transformers import DataCollatorWithPadding
from src.base_model import BaseModel
from src.KAN_model import KANModel
from src.Mixed_model import MIXEDModel
from src.Augement import Augment
from src.trainer import TrainerCustom
from src.dataset import TextDataset
from datasets import load_dataset
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    Trainer,
    TrainingArguments,
    set_seed,
)


### Model Parameters
# we will use with Distil-BERT
language_model_name = "distilbert-base-uncased"
length = 32

save_path = './data/processed_dataset.pkl'

### Training Argurments

# batch
batch_size = 128

# optim
learning_rate = 1e-7
weight_decay = 0.01 # we could use e.g. 0.01 in case of very low and very high amount of data for regularization

# training
epochs = 10
device = "cuda" if torch.cuda.is_available() else "cpu"


set_seed(42)

In [2]:
def append_update(d1, d2):
    for key, value in d2.items():
        if key in d1:
            if isinstance(d1[key], list):
                d1[key].append(value)
            else:
                d1[key] = [d1[key], value]
        else:
            d1[key] = value
    return d1


In [3]:
def save_data(filepath, dataloader):
    # Extract dataset from dataloader
    dataset = dataloader.dataset
    
    # Save input_ids, attention_masks, and labels as tensors
    torch.save({
        'input_ids': torch.stack(dataset.input_ids),
        'attention_masks': torch.stack(dataset.attention_masks),
        'labels': dataset.labels
    }, filepath)

def load_data(filepath, batch_size=32, shuffle=True, num_workers=0):
    # Load the data
    saved_data = torch.load(filepath)
    
    # Create a list of tokenized texts
    tokenized_texts = [{'input_ids': input_id, 'attention_mask': attention_mask, 'labels': label.squeeze()} 
                       for input_id, attention_mask, label in zip(saved_data['input_ids'], saved_data['attention_masks'], saved_data['labels'])]
    
    # Recreate the dataset
    dataset = TextDataset(tokenized_texts)
    
    # Recreate the dataloader
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
    
    return dataloader

In [4]:
# load our dataset
dataset = load_dataset("tommasobonomo/sem_augmented_fever_nli", cache_dir="../data/sem_augmented_fever", trust_remote_code=True)

In [5]:
## Let's see an example...
print(f"Sentence: {dataset['train'][12]}")

Sentence: {'id': '65960', 'premise': 'Whoopi Goldberg . From 1998 to 2002 , she was co-producer of the television game show Hollywood Squares .', 'hypothesis': 'Whoopi Goldberg co-produced an American dance tournament.', 'label': 'NEUTRAL', 'wsd': {'premise': [{'index': 0, 'text': 'Whoopi', 'pos': 'PROPN', 'lemma': 'Whoopi', 'bnSynsetId': 'O', 'wnSynsetOffset': 'O', 'nltkSynset': 'O'}, {'index': 1, 'text': 'Goldberg', 'pos': 'PROPN', 'lemma': 'Goldberg', 'bnSynsetId': 'O', 'wnSynsetOffset': 'O', 'nltkSynset': 'O'}, {'index': 2, 'text': '.', 'pos': 'PUNCT', 'lemma': '.', 'bnSynsetId': 'O', 'wnSynsetOffset': 'O', 'nltkSynset': 'O'}, {'index': 3, 'text': 'From', 'pos': 'ADP', 'lemma': 'from', 'bnSynsetId': 'O', 'wnSynsetOffset': 'O', 'nltkSynset': 'O'}, {'index': 4, 'text': '1998', 'pos': 'NUM', 'lemma': '1998', 'bnSynsetId': 'O', 'wnSynsetOffset': 'O', 'nltkSynset': 'O'}, {'index': 5, 'text': 'to', 'pos': 'ADP', 'lemma': 'to', 'bnSynsetId': 'O', 'wnSynsetOffset': 'O', 'nltkSynset': 'O'},

In [6]:
## The structure of the huggingface dataset.
dataset

DatasetDict({
    train: Dataset({
        features: ['id', 'premise', 'hypothesis', 'label', 'wsd', 'srl'],
        num_rows: 51086
    })
    validation: Dataset({
        features: ['id', 'premise', 'hypothesis', 'label', 'wsd', 'srl'],
        num_rows: 2288
    })
    test: Dataset({
        features: ['id', 'premise', 'hypothesis', 'label', 'wsd', 'srl'],
        num_rows: 2287
    })
})

In [7]:
'''
augemnter = Augment(dataset, "test")
augmented_dataset =  augemnter.apply()
augmented_dataset
'''

'\naugemnter = Augment(dataset, "test")\naugmented_dataset =  augemnter.apply()\naugmented_dataset\n'

### Metric Definition

Looking only at cross entropy loss cannot allow us to understand effectively the real capabilities of our NLP model. So let's define a standard method to compute:

- **Accuracy** metric
- **F1** metric

In [8]:
from datasets import load_metric

# Metrics

def compute_metrics(eval_pred):
   load_accuracy = load_metric("accuracy")
   load_f1 = load_metric("f1")

   logits, labels = eval_pred
   predictions = np.argmax(logits, axis=-1)
   accuracy = load_accuracy.compute(predictions=predictions, references=labels)["accuracy"]
   f1 = load_f1.compute(predictions=predictions, references=labels, average='macro')["f1"]
   return {"accuracy": accuracy, "f1": f1}

In [9]:
## Initialize the model

MIXED_model = MIXEDModel(length, language_model_name, device)
'''
KAN_model = KANModel(length, language_model_name, device)
auto_model = AutoModelForSequenceClassification.from_pretrained(language_model_name,
                                                                   ignore_mismatched_sizes=True,
                                                                   output_attentions=False, output_hidden_states=False,
                                                                   num_labels=3) # number of the classes
base_model = BaseModel(device, length, language_model_name)

'''
tokenizer = AutoTokenizer.from_pretrained(language_model_name)

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

def tokenize_function(examples):
    label_map = {
        'ENTAILMENT': 0,
        'CONTRADICTION': 1,
        'NEUTRAL': 2
    }
    # Map the labels
    examples['label'] = [label_map[label] for label in examples['label']]
    
    # Tokenize the premise and hypothesis
    
    tokenized = tokenizer(
        examples['premise'], 
        examples['hypothesis'], 
        truncation=True, 
        padding='max_length',
        max_length=length
    )
    
    # Add tokenized fields to the examples
    examples.update(tokenized)
    return examples

In [10]:
def tokenize_sense_function(examples):
    #TODO add word sense
    label_map = {
        'ENTAILMENT': 0,
        'CONTRADICTION': 1,
        'NEUTRAL': 2
    }
    # Map the labels
    examples['label'] = [label_map[label] for label in examples['label']]
    
    # Tokenize the premise and hypothesis
    tokenized = tokenizer(
        examples['premise'], 
        examples['hypothesis'], 
        truncation=True, 
        padding='max_length',
        max_length=length
    )
    
    # Add tokenized fields to the examples
    examples.update(tokenized)

In [11]:
# Tokenize the dataset put the second phrase as the second parameter to have it concatenated with a <SEP> token

'''
print("Tokenize the dataset ...")
tokenized_dataset = dataset.map(tokenize_function, batched=True)
tokenized_sense_dataset = dataset.map(tokenize_sense_function, batched=True)
tokenized_augmented_dataset = augmented_dataset.map(tokenize_function, batched=True)
tokenized_augmented_sense_dataset = augmented_dataset.map(tokenize_sense_function, batched=True)
'''
print(dataset)

DatasetDict({
    train: Dataset({
        features: ['id', 'premise', 'hypothesis', 'label', 'wsd', 'srl'],
        num_rows: 51086
    })
    validation: Dataset({
        features: ['id', 'premise', 'hypothesis', 'label', 'wsd', 'srl'],
        num_rows: 2288
    })
    test: Dataset({
        features: ['id', 'premise', 'hypothesis', 'label', 'wsd', 'srl'],
        num_rows: 2287
    })
})


In [12]:
'''
print(tokenized_dataset['train'][1].keys())
print(tokenized_dataset['train'][1]['label'])
print(tokenized_dataset['train'][1]['input_ids'])
'''

"\nprint(tokenized_dataset['train'][1].keys())\nprint(tokenized_dataset['train'][1]['label'])\nprint(tokenized_dataset['train'][1]['input_ids'])\n"

In [13]:
splits = ['train', 'validation', 'test']
KAN_dataset = {'train':[], 'validation': [], 'test':[]}

for split in splits:
    if os.path.isfile(save_path[:-4]+split+str(length)+save_path[-4:]):
        print('loading',split,'...')
        KAN_dataset[split] = load_data(save_path[:-4]+split+str(length)+save_path[-4:], batch_size=batch_size, num_workers=0)
    else:
        print('generating',split,'...')
        i = 0
        for data in dataset[split]:
            tokens = tokenizer(
                    data['premise'], 
                    data['hypothesis'], 
                    truncation=True, 
                    padding='max_length',
                    max_length=length,
                    return_tensors='pt'
                    ).to(device)
            label_map = {
            'ENTAILMENT': 0,
            'CONTRADICTION': 1,
            'NEUTRAL': 2
            }
            # Map the labels
            label = label_map[data['label']]

            tokens.update({'labels': torch.tensor(label).unsqueeze(0).to(device)})
            KAN_dataset[split].append(tokens)
            i += 1
            if not i%1000: print(split,"step:",i)
        print('saving',split,'...')
        tokenized_dataset = TextDataset(KAN_dataset[split])
        dataloader = DataLoader(tokenized_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
        save_data(save_path[:-4]+split+str(length)+save_path[-4:], dataloader)


loading train ...
loading validation ...
loading test ...


## Model Training

To train a transformer model you can rely on the **Trainer** class of Huggingface (https://huggingface.co/docs/transformers/main_classes/trainer).

The Trainer class allows you to save many lines of code, and makes your code much more readable.

To initialize the Trainer class you have to define a **TrainerArguments** object.

In [14]:
training_args = TrainingArguments(
    output_dir="training_dir",                    # output directory [Mandatory]
    num_train_epochs=epochs,                      # total number of training epochs
    per_device_train_batch_size=batch_size,       # batch size per device during training
    warmup_steps=1000,                             # number of warmup steps for learning rate scheduler
    weight_decay=weight_decay,                    # strength of weight decay
    save_strategy="no",
    learning_rate=learning_rate,                  # learning rate
    gradient_checkpointing=True
)

In [15]:
trainer_KAN = TrainerCustom(
model=MIXED_model,
optimizer=torch.optim.Adam(MIXED_model.parameters(), lr=learning_rate,weight_decay=weight_decay),
#loss_function=nn.CrossEntropyLoss(),
log_steps=10
)

'''
trainer_base = Trainer(
   model=base_model,
   args=training_args,
   train_dataset=tokenized_dataset["train"],
   eval_dataset=tokenized_dataset["validation"],
   tokenizer=tokenizer,
   data_collator=data_collator,
   compute_metrics=compute_metrics
)
trainer_auto = Trainer(
   model=auto_model,
   args=training_args,
   train_dataset=tokenized_dataset["train"],
   eval_dataset=tokenized_dataset["validation"],
   tokenizer=tokenizer,
   data_collator=data_collator,
   compute_metrics=compute_metrics
)


trainer_sense_auto = Trainer(
   model=auto_model,
   args=training_args,
   train_dataset=tokenized_sense_dataset["train"],
   eval_dataset=tokenized_sense_dataset["validation"],
   tokenizer=tokenizer,
   data_collator=data_collator,
   compute_metrics=compute_metrics
)

trainer_sense_base = Trainer(
   model=base_model,
   args=training_args,
   train_dataset=tokenized_sense_dataset["train"],
   eval_dataset=tokenized_sense_dataset["validation"],
   tokenizer=tokenizer,
   data_collator=data_collator,
   compute_metrics=compute_metrics
)

trainer_sense_KAN = Trainer(
   model=KAN_model,
   args=training_args,
   train_dataset=tokenized_sense_dataset["train"],
   eval_dataset=tokenized_sense_dataset["validation"],
   tokenizer=tokenizer,
   data_collator=data_collator,
   compute_metrics=compute_metrics
)
'''

'\ntrainer_base = Trainer(\n   model=base_model,\n   args=training_args,\n   train_dataset=tokenized_dataset["train"],\n   eval_dataset=tokenized_dataset["validation"],\n   tokenizer=tokenizer,\n   data_collator=data_collator,\n   compute_metrics=compute_metrics\n)\ntrainer_auto = Trainer(\n   model=auto_model,\n   args=training_args,\n   train_dataset=tokenized_dataset["train"],\n   eval_dataset=tokenized_dataset["validation"],\n   tokenizer=tokenizer,\n   data_collator=data_collator,\n   compute_metrics=compute_metrics\n)\n\n\ntrainer_sense_auto = Trainer(\n   model=auto_model,\n   args=training_args,\n   train_dataset=tokenized_sense_dataset["train"],\n   eval_dataset=tokenized_sense_dataset["validation"],\n   tokenizer=tokenizer,\n   data_collator=data_collator,\n   compute_metrics=compute_metrics\n)\n\ntrainer_sense_base = Trainer(\n   model=base_model,\n   args=training_args,\n   train_dataset=tokenized_sense_dataset["train"],\n   eval_dataset=tokenized_sense_dataset["validation"

In [16]:
# Let's Train ...
trainer_KAN.train(KAN_dataset['train'], KAN_dataset['validation'], epochs=epochs)
'''
trainer_auto.train()
trainer_base.train()

trainer_sense_auto.train()
trainer_sense_base.train()
trainer_sense_KAN.train()
'''

Training ...
 Epoch  1
	[E:  1 @ step 0] current avg loss = 1.3335 in 0:00:06.272876
	[E:  1 @ step 10] current avg loss = 1.4147 in 0:01:08.735379
	[E:  1 @ step 20] current avg loss = 1.3827 in 0:02:13.407141
	[E:  1 @ step 30] current avg loss = 1.3740 in 0:03:17.694704
	[E:  1 @ step 40] current avg loss = 1.3522 in 0:04:22.538791
	[E:  1 @ step 50] current avg loss = 1.3378 in 0:05:26.768466
	[E:  1 @ step 60] current avg loss = 1.3285 in 0:06:31.221627
	[E:  1 @ step 70] current avg loss = 1.3147 in 0:07:35.724878
	[E:  1 @ step 80] current avg loss = 1.3090 in 0:08:40.194726
	[E:  1 @ step 90] current avg loss = 1.3074 in 0:09:44.399234
	[E:  1 @ step 100] current avg loss = 1.3113 in 0:10:49.097708
	[E:  1 @ step 110] current avg loss = 1.3093 in 0:11:53.717488
	[E:  1 @ step 120] current avg loss = 1.3062 in 0:12:58.350711
	[E:  1 @ step 130] current avg loss = 1.3030 in 0:14:03.279121
	[E:  1 @ step 140] current avg loss = 1.3045 in 0:15:07.492848
	[E:  1 @ step 150] current 

In [None]:
# Evaluate the model ...
#metrics = trainerKAN.evaluate()
'''
metrics = trainer_base.evaluate()
metrics = trainer_auto.evaluate()#average='weighted')
trainer_sense_auto.evaluate()
trainer_sense_base.evaluate()
trainer_sense_KAN.evaluate()
'''
#print(metrics)

"\nmetrics = trainer_base.evaluate()\nmetrics = trainer_auto.evaluate()#average='weighted')\ntrainer_sense_auto.evaluate()\ntrainer_sense_base.evaluate()\ntrainer_sense_KAN.evaluate()\n"

In [None]:
'''
trainer_augmented_auto = Trainer(
   model=auto_model,
   args=training_args,
   train_dataset=tokenized_augmented_datasets["train"],
   eval_dataset=tokenized_augmented_datasets["validation"],
   tokenizer=tokenizer,
   data_collator=data_collator,
   compute_metrics=compute_metrics,
)

trainer_augmented_base = Trainer(
   model=base_model,
   args=training_args,
   train_dataset=tokenized_augmented_datasets["train"],
   eval_dataset=tokenized_augmented_datasets["validation"],
   tokenizer=tokenizer,
   data_collator=data_collator,
   compute_metrics=compute_metrics,
)

trainer_augmented_KAN = Trainer(
   model=KAN_model,
   args=training_args,
   train_dataset=tokenized_augmented_datasets["train"],
   eval_dataset=tokenized_augmented_datasets["validation"],
   tokenizer=tokenizer,
   data_collator=data_collator,
   compute_metrics=compute_metrics,
)

trainer_augmented_sense_auto = Trainer(
   model=auto_model,
   args=training_args,
   train_dataset=tokenized_augmented_sense_datasets["train"],
   eval_dataset=tokenized_augmented_sense_datasets["validation"],
   tokenizer=tokenizer,
   data_collator=data_collator,
   compute_metrics=compute_metrics,
)

trainer_augmented_sense_base = Trainer(
   model=base_model,
   args=training_args,
   train_dataset=tokenized_augmented_sense_datasets["train"],
   eval_dataset=tokenized_augmented_sense_datasets["validation"],
   tokenizer=tokenizer,
   data_collator=data_collator,
   compute_metrics=compute_metrics,
)

trainer_augmented_sense_KAN = Trainer(
   model=KAN_model,
   args=training_args,
   train_dataset=tokenized_augmented_sense_datasets["train"],
   eval_dataset=tokenized_augmented_sense_datasets["validation"],
   tokenizer=tokenizer,
   data_collator=data_collator,
   compute_metrics=compute_metrics,
)
'''

'\ntrainer_augmented_auto = Trainer(\n   model=auto_model,\n   args=training_args,\n   train_dataset=tokenized_augmented_datasets["train"],\n   eval_dataset=tokenized_augmented_datasets["validation"],\n   tokenizer=tokenizer,\n   data_collator=data_collator,\n   compute_metrics=compute_metrics,\n)\n\ntrainer_augmented_base = Trainer(\n   model=base_model,\n   args=training_args,\n   train_dataset=tokenized_augmented_datasets["train"],\n   eval_dataset=tokenized_augmented_datasets["validation"],\n   tokenizer=tokenizer,\n   data_collator=data_collator,\n   compute_metrics=compute_metrics,\n)\n\ntrainer_augmented_KAN = Trainer(\n   model=KAN_model,\n   args=training_args,\n   train_dataset=tokenized_augmented_datasets["train"],\n   eval_dataset=tokenized_augmented_datasets["validation"],\n   tokenizer=tokenizer,\n   data_collator=data_collator,\n   compute_metrics=compute_metrics,\n)\n\ntrainer_augmented_sense_auto = Trainer(\n   model=auto_model,\n   args=training_args,\n   train_datase

In [None]:
# Let's Train ...
'''
trainer_augmented_auto.train()
trainer_augmented_base.train()
trainer_augmented_KAN.train()

trainer_augmented_sense_auto.train()
trainer_augmented_sense_base.train()
trainer_augmented_sense_KAN.train()
'''

'\ntrainer_augmented_auto.train()\ntrainer_augmented_base.train()\ntrainer_augmented_KAN.train()\n\ntrainer_augmented_sense_auto.train()\ntrainer_augmented_sense_base.train()\ntrainer_augmented_sense_KAN.train()\n'

In [None]:
# Evaluate the model ...
'''
trainer_augmented_auto.evaluate()
trainer_augmented_base.evaluate()
trainer_augmented_KAN.evaluate()

trainer_augmented_sense_auto.evaluate()
trainer_augmented_sense_base.evaluate()
trainer_augmented_sense_KAN.evaluate()
'''

'\ntrainer_augmented_auto.evaluate()\ntrainer_augmented_base.evaluate()\ntrainer_augmented_KAN.evaluate()\n\ntrainer_augmented_sense_auto.evaluate()\ntrainer_augmented_sense_base.evaluate()\ntrainer_augmented_sense_KAN.evaluate()\n'

In [None]:
def save_kan_model(model, path, max_length, model_name, device):
    """
    Save the KANModel to the specified path.
    
    Args:
    - model (KANModel): The model instance to save.
    - path (str): The path to save the model.
    - max_length (int): The maximum length used in the model.
    - model_name (str): The name of the pretrained model.
    - device (str): The device on which the model is loaded.
    """
    # Save the state dictionary
    torch.save({
        'model_state_dict': model.state_dict(),
        'max_length': max_length,
        'model_name': model_name,
        'device': device
    }, path)

def load_kan_model(path):
    """
    Load the KANModel from the specified path.
    
    Args:
    - path (str): The path to load the model from.
    
    Returns:
    - KANModel: The loaded model instance.
    """
    checkpoint = torch.load(path)
    max_length = checkpoint['max_length']
    model_name = checkpoint['model_name']
    device = checkpoint['device']
    
    model = KANModel(max_length=max_length, model_name=model_name, device=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    return model

In [None]:
save_kan_model(MIXED_model, "mixed-lr.pk", length, language_model_name, device)