In [1]:
from datasets import load_dataset, concatenate_datasets
from transformers import AutoTokenizer, DataCollatorWithPadding
import torch
import evaluate
from torch.utils.data import DataLoader
from typing import Tuple

class CustomDataLoader:
    def __init__(self, dataset_from: str = 'glue', model_name: str = 'bert-based-uncased', dataset_task: str = 'sst2', seed_num: int = 42, range_to_select: int = 500, batch_size: int = 8):
        """
        Custom DataLoader class for preparing and loading datasets.

        Args:
            dataset_from (str, optional): Name of the dataset. Defaults to 'glue'.
            model_name (str, optional): Name of the pretrained model. Defaults to 'bert-based-uncased'.
            dataset_task (str, optional): Name of the dataset task. Defaults to 'sst2'.
            seed_num (int, optional): Random seed number. Defaults to 42.
            range_to_select (int, optional): Range of data to select. Defaults to 500.
            batch_size (int, optional): Batch size. Defaults to 8.
        """
        self.dataset_from = dataset_from
        self.model_name = model_name
        self.dataset_task = dataset_task
        self.seed_num = seed_num
        self.range_to_select = range_to_select
        self.batch_size = batch_size

        self.GLUE_TASKS = ["cola", "mnli", "mnli-mm", "mrpc", "qnli", "qqp", "rte", "sst2", "stsb", "wnli"]
        self.task_to_keys = {
            "cola": ("sentence", None),
            "mnli": ("premise", "hypothesis"),
            "mnli-mm": ("premise", "hypothesis"),
            "mrpc": ("sentence1", "sentence2"),
            "qnli": ("question", "sentence"),
            "qqp": ("question1", "question2"),
            "rte": ("sentence1", "sentence2"),
            "sst2": ("sentence", None),
            "stsb": ("sentence1", "sentence2"),
            "wnli": ("sentence1", "sentence2"),
        }
        self.sentence1_key, self.sentence2_key = self.task_to_keys[self.dataset_task]

        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
        self.data_collator = DataCollatorWithPadding(tokenizer=self.tokenizer, padding=True)
        self.metric = evaluate.load(self.dataset_from, self.dataset_task)

    def get_custom_data_loaders(self) -> Tuple[DataLoader, DataLoader, DataLoader, evaluate.Metric]:
        """
        Get custom data loaders for training and testing.

        Returns:
            Tuple[DataLoader, DataLoader, DataLoader, evaluate.Metric]: Tuple containing train loader, val_loader, test loader, and metric.
        """
        dataset = load_dataset(self.dataset_from, self.dataset_task).map(self._prepare_dataset, batched=True)
        dataset = concatenate_datasets([dataset["train"], dataset["validation"]]).train_test_split(test_size=0.1666666666666, seed=self.seed_num, stratify_by_column='label')

        # train_dataset = dataset['train'].select(range(self.range_to_select)).remove_columns(['sentence', 'idx']).rename_column('label', 'labels')
        # test_dataset = dataset['test'].select(range(self.range_to_select)).remove_columns(['sentence', 'idx']).rename_column('label', 'labels')

        train_dataset = dataset['train'].select(range(self.range_to_select)).remove_columns(['idx'] + [col for col in dataset["train"].column_names if col in self.task_to_keys[self.dataset_task]]).rename_column('label', 'labels')
        val_dataset = dataset['train'].select(range(self.range_to_select, 2*self.range_to_select)).remove_columns(['idx'] + [col for col in dataset["test"].column_names if col in self.task_to_keys[self.dataset_task]]).rename_column('label', 'labels')
        test_dataset = dataset['test'].select(range(self.range_to_select)).remove_columns(['idx'] + [col for col in dataset["test"].column_names if col in self.task_to_keys[self.dataset_task]]).rename_column('label', 'labels')


        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True, collate_fn=self.data_collator)
        val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=self.batch_size, shuffle=True, collate_fn=self.data_collator)
        test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=self.batch_size, shuffle=True, collate_fn=self.data_collator)
        return train_loader, val_loader, test_loader, self.metric

    def _prepare_dataset(self, examples) -> dict:
        """
        Prepare dataset for training and testing.

        Args:
            examples: Input examples.

        Returns:
            dict: Tokenized examples.
        """
        if self.sentence2_key is None:
            return self.tokenizer(examples[self.sentence1_key], truncation=True)
        return self.tokenizer(examples[self.sentence1_key], examples[self.sentence2_key], truncation=True)


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import matplotlib.pyplot as plt
from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_score, mean_absolute_error, roc_auc_score, matthews_corrcoef
from torch.utils.tensorboard import SummaryWriter  # Import SummaryWriter for TensorBoard logging

class Logger:
    def __init__(self):
        self.train_losses = []
        self.train_f1_scores = []
        self.train_accuracies = []
        self.train_precisions = []
        self.train_recalls = []
        self.train_maes = []
        self.train_auc_roc = []
        self.train_mcc = []  # Added list for MCC
        self.validation_losses = []  # Store validation losses
        self.validation_f1_scores = []  # Store validation F1 scores
        self.writer = SummaryWriter()  # Initialize SummaryWriter for TensorBoard logging

    def log_train_epoch(self, epoch, loss, preds, labels):
        self.train_losses.append(loss)

        train_f1 = f1_score(labels, preds)
        self.train_f1_scores.append(train_f1)

        train_accuracy = accuracy_score(labels, preds)
        self.train_accuracies.append(train_accuracy)

        train_precision = precision_score(labels, preds)
        self.train_precisions.append(train_precision)

        train_recall = recall_score(labels, preds)
        self.train_recalls.append(train_recall)

        train_mae = mean_absolute_error(labels, preds)
        self.train_maes.append(train_mae)

        train_auc_roc = roc_auc_score(labels, preds)
        self.train_auc_roc.append(train_auc_roc)

        train_mcc = matthews_corrcoef(labels, preds)  # Calculate MCC
        self.train_mcc.append(train_mcc)  # Append MCC to the list

        # Log metrics to TensorBoard
        self.writer.add_scalar('Loss/train', loss, epoch)
        self.writer.add_scalar('F1 Score/train', train_f1, epoch)
        self.writer.add_scalar('Accuracy/train', train_accuracy, epoch)
        self.writer.add_scalar('Precision/train', train_precision, epoch)
        self.writer.add_scalar('Recall/train', train_recall, epoch)
        self.writer.add_scalar('MAE/train', train_mae, epoch)
        self.writer.add_scalar('AUC ROC/train', train_auc_roc, epoch)
        self.writer.add_scalar('MCC/train', train_mcc, epoch)  # Log MCC to TensorBoard

        print(f"Epoch {epoch}: Loss: {loss}, F1 Score: {train_f1}, Accuracy: {train_accuracy}, Precision: {train_precision}, Recall: {train_recall}, MAE: {train_mae}, AUC ROC: {train_auc_roc}, MCC: {train_mcc}")

    def log_validation_epoch(self, epoch, loss, preds, labels):
        self.validation_losses.append(loss)

        validation_f1 = f1_score(labels, preds)
        self.validation_f1_scores.append(validation_f1)

        validation_accuracy = accuracy_score(labels, preds)
        validation_precision = precision_score(labels, preds)
        validation_recall = recall_score(labels, preds)
        validation_mae = mean_absolute_error(labels, preds)
        validation_auc_roc = roc_auc_score(labels, preds)
        validation_mcc = matthews_corrcoef(labels, preds)

        # Log metrics to TensorBoard
        self.writer.add_scalar('Loss/validation', loss, epoch)
        self.writer.add_scalar('F1 Score/validation', validation_f1, epoch)
        self.writer.add_scalar('Accuracy/validation', validation_accuracy, epoch)
        self.writer.add_scalar('Precision/validation', validation_precision, epoch)
        self.writer.add_scalar('Recall/validation', validation_recall, epoch)
        self.writer.add_scalar('MAE/validation', validation_mae, epoch)
        self.writer.add_scalar('AUC ROC/validation', validation_auc_roc, epoch)
        self.writer.add_scalar('MCC/validation', validation_mcc, epoch)

        print(f"Validation at epoch {epoch}: Loss: {loss}, F1 Score: {validation_f1}, Accuracy: {validation_accuracy}, Precision: {validation_precision}, Recall: {validation_recall}, MAE: {validation_mae}, AUC ROC: {validation_auc_roc}, MCC: {validation_mcc}")

    def log_test_metrics(self, test_loss, test_preds, test_labels):
        test_f1 = f1_score(test_labels, test_preds)
        test_accuracy = accuracy_score(test_labels, test_preds)
        test_precision = precision_score(test_labels, test_preds)
        test_recall = recall_score(test_labels, test_preds)
        test_mae = mean_absolute_error(test_labels, test_preds)
        test_auc_roc = roc_auc_score(test_labels, test_preds)
        test_mcc = matthews_corrcoef(test_labels, test_preds)

        # Log metrics to TensorBoard
        self.writer.add_scalar('Loss/test', test_loss)
        self.writer.add_scalar('F1 Score/test', test_f1)
        self.writer.add_scalar('Accuracy/test', test_accuracy)
        self.writer.add_scalar('Precision/test', test_precision)
        self.writer.add_scalar('Recall/test', test_recall)
        self.writer.add_scalar('MAE/test', test_mae)
        self.writer.add_scalar('AUC ROC/test', test_auc_roc)
        self.writer.add_scalar('MCC/test', test_mcc)

        print(f"Test Metrics: Loss: {test_loss}, F1 Score: {test_f1}, Accuracy: {test_accuracy}, Precision: {test_precision}, Recall: {test_recall}, MAE: {test_mae}, AUC ROC: {test_auc_roc}, MCC: {test_mcc}")

    def plot_metrics(self):
        epochs = range(1, len(self.train_losses) + 1)

        plt.figure(figsize=(15, 5))
        plt.subplot(1, 5, 1)
        plt.plot(epochs, self.train_losses, label='Training Loss')
        plt.plot(epochs, self.validation_losses, label='Validation Loss')  # Add validation loss plot
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.title('Training and Validation Loss')
        plt.legend()

        plt.subplot(1, 5, 2)
        plt.plot(epochs, self.train_f1_scores, label='Training F1 Score')
        plt.plot(epochs, self.validation_f1_scores, label='Validation F1 Score')  # Add validation F1 score plot
        plt.xlabel('Epoch')
        plt.ylabel('F1 Score')
        plt.title('Training and Validation F1 Score')
        plt.legend()

        plt.show()

    def close(self):
        self.writer.close()  # Close the SummaryWriter when logging is complete


In [3]:
import torch
from tqdm import tqdm
from fosi import fosi_adam_torch
import copy
import torchopt
import functorch
from torch.utils.data import DataLoader
from torch import Tensor
from typing import Tuple
from logger import Logger  # Import the modified Logger class for logging

class CustomTrainer:
    def __init__(self, original_model: torch.nn.Module, train_loader: DataLoader, val_loader: DataLoader = None,epochs: int = 1):
        self.original_model = original_model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.epochs = epochs
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.logger = Logger()  # Initialize the modified Logger class for logging

    def train(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        self.original_model.train()
        self.original_model.to(self.device)

        base_optimizer = torchopt.adam(lr=0.01)
        data = next(iter(self.train_loader))
        optimizer = fosi_adam_torch(base_optimizer, self.loss_fn, data, num_iters_to_approx_eigs=500, alpha=0.01)
        self.functional_model, self.params, self.buffers = self.make_functional_with_buffers(self.original_model)
        # self.functional_model, self.params, self.buffers = torch.func.functional_call(self.original_model, dict(self.original_model.named_parameters()))
        
        opt_state = optimizer.init(self.params)   

        self.original_model.train()
        # self.functional_model.train()
        for epoch in range(self.epochs):
            epoch_loss = 0.0  # Reset epoch loss for each epoch
            epoch_preds = []
            epoch_labels = []
            progress_bar = tqdm(enumerate(self.train_loader, 1), total=len(self.train_loader))
            for i, data in progress_bar:
                self.original_model.train()
                progress_bar.set_description(f'Epoch {epoch+1}/{self.epochs}, Step {i}/{len(self.train_loader)}')

                input_ids = data['input_ids'].squeeze().to(self.device)
                attention_mask = data['attention_mask'].squeeze().to(self.device)
                labels = data['labels'].squeeze().to(self.device)

                # Calculate loss, with params from previous iteration
                loss, _ = self.loss_fn(self.params, self.buffers, input_ids, attention_mask, labels)
                epoch_loss += loss.item()  # Accumulate loss for each batch

                # Calculate gradients based on loss value
                grads = torch.autograd.grad(loss, self.params)
                updates, opt_state = optimizer.update(grads, opt_state, self.params)
                self.params = torchopt.apply_updates(self.params, updates, inplace=True)

                # Bar responsible
                progress_bar.set_postfix(loss=loss.item())

                # Get predictions with updated params
                self.functional_model, self.params, self.buffers = self.make_functional_with_buffers(mod=self.original_model, new_params_values=self.params, new_buffers_values=self.buffers)
                preds = self.functional_model(input_ids=input_ids, attention_mask = attention_mask)
                predictions = torch.round(preds).to(torch.float32)

                epoch_preds.extend(predictions.detach().cpu().numpy())
                epoch_labels.extend(labels.detach().cpu().numpy())

            epoch_loss /= self.train_loader.__len__()  # Calculate average loss per epoch
            self.logger.log_train_epoch(epoch + 1, epoch_loss, epoch_preds, epoch_labels)  # Log epoch metrics using the modified Logger class

            # Perform validation check here and log validation metrics
            if self.val_loader != None:
                validation_loss, validation_preds, validation_labels = self.validate()  # Implement validate() method
                self.logger.log_validation_epoch(epoch + 1, validation_loss, validation_preds, validation_labels)

        self.logger.close()  # Close the SummaryWriter when logging is complete

        return self.functional_model, self.params, self.buffers

    # def loss_fn(self, functional_model: callable, params: Tuple[Tensor], buffers: Tuple[Tensor], input_ids: Tensor, attention_mask: Tensor, labels: Tensor) -> Tensor:
    #     preds = functional_model(params=params, buffers=buffers, input_ids=input_ids, attention_mask=attention_mask)
    #     loss = torch.nn.functional.binary_cross_entropy(preds.squeeze().to(torch.float32), labels.squeeze().to(torch.float32))
    #     return loss

    def loss_fn(self, params: Tuple[Tensor], buffers: Tuple[Tensor], input_ids: Tensor, attention_mask: Tensor, labels: Tensor) -> Tensor:
        fmodel, _, __ = self.make_functional_with_buffers(mod=self.original_model, new_params_values=params, new_buffers_values=buffers)
        preds = fmodel(input_ids=input_ids, attention_mask = attention_mask)
        loss = torch.nn.functional.binary_cross_entropy(preds.squeeze().to(torch.float32), labels.squeeze().to(torch.float32))
        return loss, preds

    def validate(self) -> Tuple[float, list, list]:
        # Implement validation check here, this will run per epoch, it is NOT a test functionality.
        # This function should return validation loss, predictions, and labels for validation set
        # self.original_model.eval()  # Set model to evaluation mode
        val_loss = 0.0
        val_preds = []
        val_labels = []
        progress_bar = tqdm(enumerate(self.val_loader, 0), total=len(self.val_loader))
        with torch.no_grad():
            for i, data in progress_bar:
                progress_bar.set_postfix(val_loss=val_loss)

                input_ids = data['input_ids'].squeeze().to(self.device)
                attention_mask = data['attention_mask'].squeeze().to(self.device)
                labels = data['labels'].squeeze().to(self.device)

                # fmodel, _, __ = self.make_functional_with_buffers(mod=self.original_model, new_params_values=self.params, new_buffers_values=self.buffers)
                # preds = fmodel(input_ids=input_ids, attention_mask = attention_mask)
                loss, preds = self.loss_fn(self.params, self.buffers, input_ids, attention_mask, labels)
                val_loss += loss.item()  # Accumulate validation loss

                predictions = torch.round(preds).to(torch.float32)
                val_preds.extend(predictions.detach().cpu().numpy())
                val_labels.extend(labels.detach().cpu().numpy())

        val_loss /= self.val_loader.__len__()  # Calculate average validation loss
        return val_loss, val_preds, val_labels

    def test(self, test_loader: DataLoader):
        # Implement test method here
        # This function should log test metrics using the modified Logger class
        self.test_loader = test_loader    
        self.original_model.eval()  # Set model to evaluation mode
        test_loss = 0.0
        test_preds = []
        test_labels = []
        progress_bar = tqdm(enumerate(self.test_loader, 0), total=len(self.test_loader))
        with torch.no_grad():
            for i, data in progress_bar:
                progress_bar.set_description(f'Testing {i}/{len(self.test_loader)}')

                input_ids = data['input_ids'].squeeze().to(self.device)
                attention_mask = data['attention_mask'].squeeze().to(self.device)
                labels = data['labels'].squeeze().to(self.device)

                # fmodel, _, __ = self.make_functional_with_buffers(mod=self.original_model, new_params_values=self.params, new_buffers_values=self.buffers)
                # preds = fmodel(input_ids=input_ids, attention_mask = attention_mask)
                loss, preds = self.loss_fn(self.params, self.buffers, input_ids, attention_mask, labels)
                test_loss += loss.item()  # Accumulate test loss

                predictions = torch.round(preds).to(torch.float32)
                test_preds.extend(predictions.detach().cpu().numpy())
                test_labels.extend(labels.detach().cpu().numpy())

        test_loss /= len(self.test_loader)  # Calculate average test loss
        # self.original_model.train() # Make train mode again for the next loop, if there is any

        # Log test metrics
        self.logger.log_test_metrics(test_loss, test_preds, test_labels)

        return self.functional_model, self.params, self.buffers

    # def make_functional(self, mod, new_params_values=None, disable_autograd_tracking=False):
    #     params_dict = dict(mod.named_parameters())
    #     params_names = params_dict.keys()
    #     params_values = tuple(params_dict.values())
        
    #     stateless_mod = copy.deepcopy(mod)
    #     stateless_mod.to('meta')

    #     # This remains Unchanged and not used in the code
    #     def fmodel(new_params_values=new_params_values, *args, **kwargs):
    #         if new_params_values is None:
    #             # This is the first call to the functional model
    #             new_params_values = params_values
    #         new_params_dict = {name: value for name, value in zip(params_names, new_params_values)}
    #         return torch.func.functional_call(stateless_mod, new_params_dict, args, kwargs)
    
    #     if disable_autograd_tracking:
    #         params_values = torch.utils._pytree.tree_map(torch.Tensor.detach, params_values)
    #     return fmodel, params_values

    def make_functional_with_buffers(self, mod, new_params_values=None, new_buffers_values=None, disable_autograd_tracking=False):

        """
        Given a module, return a functional version of the module that can be called with
        the parameters and buffers as arguments. This is useful for optimization libraries
        that require a functional interface to the model.

        Args:
            mod: A PyTorch module.
            disable_autograd_tracking: If True, the parameters will be detached from the computation graph.

        Returns:
            A tuple (fmodel, params, buffers), where:
            - fmodel is a functional version of the module.
            - params is a tuple of the parameters of the module.
            - buffers is a tuple of the buffers of the module.
        
        This was taken from the official PyTorch library.
        Repo Link: https://gist.github.com/zou3519/7769506acc899d83ef1464e28f22e6cf
        Original Docs: https://pytorch.org/docs/stable/func.migrating.html#function-transforms
        """
        params_dict = dict(mod.named_parameters())
        params_names = params_dict.keys()
        params_values = tuple(params_dict.values())

        buffers_dict = dict(mod.named_buffers())
        buffers_names = buffers_dict.keys()
        buffers_values = tuple(buffers_dict.values())
        
        stateless_mod = copy.deepcopy(mod)
        stateless_mod.to('meta')

        # def fmodel(new_params_values=new_buffers_values, new_buffers_values=new_buffers_values, *args, **kwargs):
        #     if new_params_values is None:
        #         # This is the first call to the functional model
        #         new_params_values = params_values
        #     if new_buffers_values is None:
        #         # This is the first call to the functional model
        #         new_buffers_values = buffers_values
        #     new_params_dict = {name: value for name, value in zip(params_names, new_params_values)}
        #     new_buffers_dict = {name: value for name, value in zip(buffers_names, new_buffers_values)}
        #     return torch.func.functional_call(stateless_mod, (new_params_dict, new_buffers_dict), args, kwargs)
        
        # Inner function
        def fmodel(new_params_values=new_params_values, new_buffers_values=new_buffers_values, *args, **kwargs):
            if new_params_values is None:
                # This is the first call to the functional model
                new_params_values = params_values
            if new_buffers_values is None:
                # This is the first call to the functional model
                new_buffers_values = buffers_values
            new_params_dict = {name: value for name, value in zip(params_names, new_params_values)}
            new_buffers_dict = {name: value for name, value in zip(buffers_names, new_buffers_values)}
            return torch.func.functional_call(stateless_mod, (new_params_dict, new_buffers_dict), args=args, kwargs=kwargs)

        if disable_autograd_tracking:
            params_values = torch.utils._pytree.tree_map(torch.Tensor.detach, params_values)

        # del stateless_mod
        return fmodel, params_values, buffers_values



In [4]:
from model import BertClassifier
from dataset import CustomDataLoader
from training import CustomTrainer
from evaluation import CustomEvaluator
from utils import set_seed

# Prompt user for dataset choice
dataset_from = input("Enter the dataset you want to use (e.g., 'glue'): ") or 'glue'

# Prompt user for model name
model_name = input("Enter the model name (e.g., 'bert-base-uncased'): ") or 'bert-base-uncased'

# Prompt user for dataset task
dataset_task = input("Enter the dataset task (e.g., 'sst2'): ") or 'sst2'

# Prompt user for seed number
seed_num = int(input("Enter the seed number (default is 42): ") or '42')

# Prompt user for number of epochs
epochs = int(input("Enter the number of epochs (default is 2): ") or '2')

# Set seed for reproducibility
set_seed(seed_num)

# Load model
original_model = BertClassifier(
    model_name=model_name,
    num_classes=3 if dataset_task.startswith("mnli") else 1 if dataset_task == "stsb" else 2
)

# Prepare dataset
custom_dataloader = CustomDataLoader(
    dataset_from=dataset_from,
    model_name=model_name,
    dataset_task=dataset_task,
    seed_num=seed_num,
    range_to_select=100,  # Default value for now, you can prompt the user for this too if needed
    batch_size=8  # Default value for now, you can prompt the user for this too if needed
)
train_loader, val_loader, test_loader, metric = custom_dataloader.get_custom_data_loaders()

# Train model
trainer = CustomTrainer(original_model, 
    train_loader, 
    val_loader, 
    epochs=epochs)
functional_model, params, buffers = trainer.train()  # Get functional model, params, and buffers

trainer.test(test_loader=test_loader)


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Returned ESE function. Lanczos order (m) is 20 .


Epoch 1/10, Step 1/13:   0%|          | 0/13 [00:00<?, ?it/s]


TypeError: CustomTrainer.loss_fn() takes 6 positional arguments but 7 were given