<a href="https://colab.research.google.com/github/PavlosPo/nlp-optimizers/blob/pavlos-playground/pytorch-experiments-1-fosi-adam/playground.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## FOSI Classes

In [1]:
# ! rm -rf ./fosi/
# !mkdir ./fosi/
# !unzip fosi.zip -d ./fosi/

Archive:  fosi.zip
 extracting: ./fosi/__init__.py      
   creating: ./fosi/jax_optim/
 extracting: ./fosi/jax_optim/__init__.py  
 extracting: ./fosi/jax_optim/extreme_spectrum_estimation.py  
 extracting: ./fosi/jax_optim/fosi_optimizer.py  
 extracting: ./fosi/jax_optim/lanczos_algorithm.py  
 extracting: ./fosi/jax_optim/lanczos_algorithm_sanity.py  
   creating: ./fosi/torch_optim/
 extracting: ./fosi/torch_optim/__init__.py  
 extracting: ./fosi/torch_optim/extreme_spectrum_estimation.py  
 extracting: ./fosi/torch_optim/fosi_optimizer.py  
 extracting: ./fosi/torch_optim/lanczos_algorithm.py  
 extracting: ./fosi/torch_optim/lanczos_algorithm_sanity.py  
 extracting: ./fosi/version.py       


## Working Example

In [None]:
# !pip install torchopt
# !pip install datasets
# !pip install evaluate



In [2]:
import torch
import torchvision.transforms as transforms
from torch.utils.data.dataloader import default_collate
from transformers import AutoModelForSequenceClassification, AutoTokenizer, DataCollatorWithPadding, Trainer, TrainingArguments
from torch.optim import Adam
import torch.nn as nn
import torchopt
import functorch
import evaluate
import torch.nn.functional as F

from datasets import load_dataset
from fosi import fosi_adam_torch

# # Set device
# device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# # Load pre-trained DistilBERT model and tokenizer
# base_model = AutoModelForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)
# base_model.to(device)
# tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

In [4]:
import torch
import torch.nn as nn
from transformers import AutoModelForSequenceClassification, AutoTokenizer, DataCollatorWithPadding

# Load pre-trained BERT model and tokenizer
bert_model_name = "bert-base-uncased"
num_classes = 1
bert_model = AutoModelForSequenceClassification.from_pretrained(bert_model_name, num_labels=num_classes)
tokenizer = AutoTokenizer.from_pretrained(bert_model_name)

# # Freeze all parameters of the BERT model except for the last layer (classifier)
# for name, param in bert_model.named_parameters():
#     if 'classifier' not in name:  # Exclude parameters of the classifier layer
#         param.requires_grad = False

# Define a dense layer on top of the BERT model for classification
class BertClassifier(nn.Module):
    def __init__(self, bert_model, num_classes):
        super(BertClassifier, self).__init__()
        self.bert = bert_model
        self.sigmoid = nn.Sigmoid()

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask).logits
        probability = self.sigmoid(outputs)
        return probability

# Instantiate the classifier
classifier = BertClassifier(bert_model, num_classes)

# Optionally, move the model to a GPU device if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
classifier.to(device)

# Now you can train the entire model (BERT + dense layer) on your task-specific data
# Make sure to prepare your data (input_ids, attention_mask, labels) using the BERT tokenizer
# and use an appropriate loss function and optimizer for your task


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  return torch._C._cuda_getDeviceCount() > 0


BertClassifier(
  (bert): BertForSequenceClassification(
    (bert): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(30522, 768, padding_idx=0)
        (position_embeddings): Embedding(512, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0-11): 12 x BertLayer(
            (attention): BertAttention(
              (self): BertSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): BertSelfOutput(
                (dense): Linear(in_features=768, out_features=768, 

In [6]:

# Define a function to preprocess the dataset
def prepare_dataset(example):
    return tokenizer(example['sentence'], add_special_tokens=True, truncation=True, padding=True, return_tensors='pt')

data_collator = DataCollatorWithPadding(tokenizer=tokenizer, padding=True)

dataset = load_dataset('glue', 'sst2').map(prepare_dataset, batched=True)
metric = evaluate.load("glue", "sst2")

# Split dataset into train and test sets, we use the train category because the test one has labels -1 only.
train_dataset = dataset['train'].select(range(0,500)).remove_columns(['sentence', 'idx']).rename_column('label', 'labels')
test_dataset = dataset['train'].select(range(500, 1000)).remove_columns(['sentence', 'idx']).rename_column('label', 'labels')


Map: 100%|██████████| 67349/67349 [00:06<00:00, 9627.95 examples/s] 


In [7]:
len(dataset['train']['label'])

67349

In [8]:
# Define data loaders
batch_size = 128
trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=data_collator)
testloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True, collate_fn=data_collator)

In [9]:
for i in (testloader):
  print(i['input_ids'])
  break

tensor([[  101,  2054,  2017,  ...,     0,     0,     0],
        [  101,  1045,  2031,  ...,     0,     0,     0],
        [  101,  1037,  4438,  ...,     0,     0,     0],
        ...,
        [  101,  2659,  2006,  ...,     0,     0,     0],
        [  101,  8991,  4818,  ...,     0,     0,     0],
        [  101, 14153, 12369,  ...,     0,     0,     0]])


In [17]:
# With buffers

def loss_fn(functional_model, params, buffers, input_ids, attention_mask, labels):
    preds = functional_model(params, buffers=buffers, input_ids=input_ids, attention_mask=attention_mask)
    loss = nn.functional.binary_cross_entropy(preds.squeeze().to(torch.float32), labels.squeeze().to(torch.float32))
    return loss

# def softmax_output_fn(params, buffers, input_ids, attention_mask, labels):
#     logits = model(params, buffers=buffers, input_ids=input_ids, attention_mask=attention_mask, labels=labels).logits
#     logits_tensor = torch.tensor(logits)
#     softmax_output = F.softmax(logits_tensor, dim=1)
#     return softmax_output

In [None]:
# Without buffers

# def loss_fn(params, input_ids, attention_mask, labels):
#     loss = model(params, input_ids=input_ids, attention_mask=attention_mask, labels=labels).loss
#     return loss

# def softmax_output_fn(params, input_ids, attention_mask, labels):
#     logits = model(params, input_ids=input_ids, attention_mask=attention_mask, labels=labels).logits
#     logits_tensor = torch.tensor(logits)
#     softmax_output = F.softmax(logits_tensor, dim=1)
#     return softmax_output

In [18]:
# Train the model
classifier.train()
for epoch in range(2):
    for i, data in enumerate(trainloader, 0):
        print("\n")
        print("*"*100)
        print("\n")

        if i == 0 and epoch == 0: # Initialize optimizer and model parameters
            print(f"input_ids: {data['input_ids']}")
            print(f"attention_mask: {data['attention_mask']}")
            print(f"labels: {data['labels']}")
            # Define optimizer
            base_optimizer = torchopt.adam(lr=0.01)
            optimizer = fosi_adam_torch(base_optimizer, loss_fn, data, num_iters_to_approx_eigs=500, alpha=0.01)
            # func_optimizer = torchopt.FuncOptimizer(fosi_adam_torch(base_optimizer, loss_fn, data, num_iters_to_approx_eigs=500, alpha=0.01))
            model, params, buffers = functorch.make_functional_with_buffers(model=classifier)
            opt_state = optimizer.init(params)

        input_ids = data['input_ids'].squeeze().to(device)
        attention_mask = data['attention_mask'].squeeze().to(device)
        labels = data['labels'].squeeze().to(device)
        print(f"Labels: \n{labels}\n")

        loss = loss_fn(functional_model=model, 
                       params=params, buffers=buffers, input_ids=input_ids,attention_mask=attention_mask, labels=labels)
        print(f"Step: {i}\n")
        print(f"Loss: {loss}\n")

        # func_optimizer.step(loss, params, inplace=True)

        # Calculate gradients
        grads = torch.autograd.grad(loss, params)

        # Update model parameters
        updates, opt_state = optimizer.update(grads, opt_state, params)
        params = torchopt.apply_updates(params, updates, inplace=True)

        # Should this method exists? Are buffers neseccary to update?
        # Update buffers based on updated parameters
        # buffers = torchopt.update_buffers(model, params, buffers)
    evaluation_results = []

    # Evaluate the model
    with torch.no_grad():
        for i, data in enumerate(testloader):
            evaluation_result = {}

            input_ids = data['input_ids'].squeeze().to(device)
            attention_mask = data['attention_mask'].squeeze().to(device)
            labels = data['labels'].squeeze().to(device)

            preds = model(params, buffers, input_ids, attention_mask=attention_mask)

            predicted_labels = torch.round(preds)

            # Save the evaluation results
            evaluation_result['input_ids'] = input_ids.cpu().tolist()
            evaluation_result['attention_mask'] = attention_mask.cpu().tolist()
            evaluation_result['labels'] = labels.cpu().tolist()
            evaluation_result['preds'] = preds.cpu().tolist()
            evaluation_result['predicted_label'] = predicted_labels.cpu().tolist()

            evaluation_results.append(evaluation_result)
            metric.add_batch(predictions=predicted_labels, references=labels)

print(f'Epoch: {epoch}')
results = metric.compute()
print(f"Results: \n{results}\n")

print('Finished Training')




****************************************************************************************************


input_ids: tensor([[  101,  2892, 10689,  ...,     0,     0,     0],
        [  101, 26202,  2015,  ...,     0,     0,     0],
        [  101,  2804,  2000,  ...,     0,     0,     0],
        ...,
        [  101,  1997,  1996,  ...,     0,     0,     0],
        [  101,  3084,  2026,  ...,     0,     0,     0],
        [  101,  3073,  2049,  ...,     0,     0,     0]])
attention_mask: tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]])
labels: tensor([1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 0,
        0, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 1,
        0, 1, 1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1,
        1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1,

  warn_deprecated('make_functional_with_buffers', 'torch.func.functional_call')


Labels: 
tensor([1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 0,
        0, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 1,
        0, 1, 1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1,
        1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0,
        1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1,
        0, 0, 0, 1, 1, 1, 0, 1])

Step: 0

Loss: 0.694150984287262



****************************************************************************************************


Labels: 
tensor([0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0,
        0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0,
        1, 1, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1,
        1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 1, 0,
        1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0,
        1

In [21]:
print(evaluation_results[0]['labels'])

[0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 1]


In [22]:
print(evaluation_results[0].keys())

dict_keys(['input_ids', 'attention_mask', 'labels', 'preds', 'predicted_label'])


In [23]:
print(evaluation_results[0]['preds'])

[[0.32067814469337463], [0.27826035022735596], [0.2859336733818054], [0.38033556938171387], [0.3939129412174225], [0.2985920011997223], [0.40012797713279724], [0.3530847132205963], [0.30622410774230957], [0.3813495934009552], [0.33078932762145996], [0.42219704389572144], [0.34536466002464294], [0.4196912944316864], [0.4567902684211731], [0.2431776076555252], [0.3137512505054474], [0.42258739471435547], [0.3382858335971832], [0.3237937390804291], [0.29999300837516785], [0.35315102338790894], [0.326948344707489], [0.3300442397594452], [0.41341814398765564], [0.39553001523017883], [0.4028570353984833], [0.34860578179359436], [0.2644515335559845], [0.27199968695640564], [0.2460271269083023], [0.33638569712638855], [0.35194525122642517], [0.3956696093082428], [0.28749218583106995], [0.313027560710907], [0.25715571641921997], [0.3575066030025482], [0.4142695963382721], [0.33259639143943787], [0.4077953100204468], [0.43278172612190247], [0.26865363121032715], [0.35169243812561035], [0.4152421

In [24]:
print(evaluation_results[0]['predicted_label'])

[[0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0]]


In [25]:
import gc
gc.collect() # Python thing
torch.cuda.empty_cache() # PyTorch thing

In [None]:
# # Train the model
# for epoch in range(3):
#     for i, data in enumerate(trainloader, 0):
#         print("\n")
#         print("*"*100)
#         print("\n")

#         if i == 0: # Initialize optimizer and model parameters
#             print(f"input_ids: {data['input_ids']}")
#             print(f"attention_mask: {data['attention_mask']}")
#             print(f"labels: {data['labels']}")
#             # Define optimizer
#             base_optimizer = torchopt.adam(lr=0.01)
#             optimizer = fosi_adam_torch(base_optimizer, loss_fn, data, num_iters_to_approx_eigs=500, alpha=0.01)
#             model, params, buffers = functorch.make_functional_with_buffers(model=base_model)
#             # model, params = functorch.make_functional(model=base_model)
#             opt_state = optimizer.init(params)
#             # model.train()


#         input_ids = data['input_ids']
#         attention_mask = data['attention_mask']
#         labels = data['labels']
#         print(f"Labels: \n{labels}\n")

#         loss = loss_fn(params, buffers, input_ids, attention_mask, labels)
#         # loss = loss_fn(params, input_ids, attention_mask, labels)
#         print(f"Step: {i}\n")
#         print(f"Loss: {loss}\n")

#         print(f"Calculating Gradients\n")
#         grads = torch.autograd.grad(loss, params)
#         # print(f"Grads: \n{grads}\n")
#         print(f"Calculating updates in the model...\n")
#         updates, opt_state = optimizer.update(grads, opt_state, params)
#         print("Applying updates\n")
#         params = torchopt.apply_updates(params, updates, inplace=True)

#     # Evaluate the model
#     # model.eval()
#     with torch.no_grad():
#         for data in testloader:
#             input_ids = data['input_ids'].to(device)
#             attention_mask = data['attention_mask'].to(device)
#             labels = data['labels'].to(device)

#             logits = model(params, buffers, input_ids, attention_mask=attention_mask).logits
#             # logits = model(params, input_ids, attention_mask=attention_mask).logits
#             print(f"Logits: {logits}")
#             logits_tensor = torch.tensor(logits)
#             softmax_output = F.softmax(logits_tensor, dim=1)

#             print(f"SoftMax of Logits: {softmax_output}")
#             predicted_labels = torch.argmax(softmax_output, dim=1)
#             print(f"Argmax Predictions: {predicted_labels}")

#             metric.add_batch(predictions=predicted_labels, references=labels)


#     print(f'Epoch: {epoch}')
#     results = metric.compute()
#     print(f"Results: \n{results}\n")
#     # print(f'Correct: {correct}\n')
#     # print(f'Total: {total}\n')
#     # print(f'Predicted: {predicted}\n')
#     # print(f'outputs: {outputs}\n')
#     # print(f'labels: {labels}\n')
#     # print(f'Accuracy of the network on the test data: {100 * correct / total}%')
#     # model.train() # for the next epoch

# print('Finished Training')


In [None]:
# # Evaluate the model
# model.eval()
# correct = 0
# total = 0
# with torch.no_grad():
#     for data in testloader:
#         input_ids = data['input_ids'].to(device)
#         attention_mask = data['attention_mask'].to(device)
#         labels = data['labels'].to(device)
#         outputs = model(params, buffers, input_ids, attention_mask=attention_mask)
#         _, predicted = torch.max(outputs.logits, 1)
#         total += labels.size(0)
#         correct += (predicted == labels).sum().item()

In [None]:
# # Define training loop
# for epoch in range(3):  # Adjust number of epochs as needed
#     model.train()
#     for i, batch in enumerate(train_dataset):
#         input_ids = batch['input_ids']
#         attention_mask = batch['attention_mask']
#         labels = batch['label']

#         #This is taking care automatically
#         # optimizer.zero_grad() This is taking care automatically

#         outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
#         loss = outputs.loss
#         loss.backward()
#         optimizer.step()

#         # if (i + 1) % 100 == 0:
#         #     print(f'[{epoch + 1}, {i + 1:5d}] loss: {loss.item():.3f}')

#     # Evaluate on test set
#     model.eval()
#     total = 0
#     correct = 0
#     with torch.no_grad():
#         for batch in test_dataset:
#             input_ids = batch['input_ids']
#             attention_mask = batch['attention_mask']
#             labels = batch['label']
#             print("Labels: \n", labels)

#             outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
#             logits = outputs.logits
#             _, predicted = torch.max(logits, 1)
#             total += labels.size(0)
#             correct += (predicted == labels).sum().item()

#     accuracy = correct / total
#     print(f'Test accuracy: {accuracy:.4f}')

# print('Finished Training')
