<a href="https://colab.research.google.com/github/PavlosPo/nlp-optimizers/blob/pavlos-playground/pytorch-experiments-1-fosi-adam/playground.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## FOSI Classes

In [1]:
! rm -rf ./fosi/
!mkdir ./fosi/
!unzip fosi.zip -d ./fosi/

Archive:  fosi.zip
 extracting: ./fosi/__init__.py      
   creating: ./fosi/jax_optim/
 extracting: ./fosi/jax_optim/__init__.py  
 extracting: ./fosi/jax_optim/extreme_spectrum_estimation.py  
 extracting: ./fosi/jax_optim/fosi_optimizer.py  
 extracting: ./fosi/jax_optim/lanczos_algorithm.py  
 extracting: ./fosi/jax_optim/lanczos_algorithm_sanity.py  
   creating: ./fosi/torch_optim/
 extracting: ./fosi/torch_optim/__init__.py  
 extracting: ./fosi/torch_optim/extreme_spectrum_estimation.py  
 extracting: ./fosi/torch_optim/fosi_optimizer.py  
 extracting: ./fosi/torch_optim/lanczos_algorithm.py  
 extracting: ./fosi/torch_optim/lanczos_algorithm_sanity.py  
 extracting: ./fosi/version.py       


## Working Example

In [2]:
!pip install torchopt
!pip install datasets
!pip install evaluate



In [50]:
import torch
import torchvision.transforms as transforms
from torch.utils.data.dataloader import default_collate
from transformers import AutoModelForSequenceClassification, AutoTokenizer, DataCollatorWithPadding, Trainer, TrainingArguments
from torch.optim import Adam
import torch.nn as nn
import torchopt
import functorch
import evaluate
import torch.nn.functional as F

from datasets import load_dataset
from fosi import fosi_adam_torch

# Set device
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# Load pre-trained DistilBERT model and tokenizer
base_model = AutoModelForSequenceClassification.from_pretrained('distilbert-base-uncased', num_labels=2)
base_model.to(device)
tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [62]:

# Define a function to preprocess the dataset
def prepare_dataset(example):
    return tokenizer(example['sentence'], add_special_tokens=True, truncation=True, padding=True, return_tensors='pt')

data_collator = DataCollatorWithPadding(tokenizer=tokenizer, padding=True)

# Define a sample dataset (replace this with your custom dataset)
# Example: IMDB movie review dataset
dataset = load_dataset('glue', 'cola').map(prepare_dataset, batched=True)
metric = evaluate.load("glue", "cola")

# Split dataset into train and test sets, we use the train category because the test one has labels -1 only.
train_dataset = dataset['train'].select(range(100)).remove_columns(['sentence', 'idx']).rename_column('label', 'labels')
test_dataset = dataset['train'].select(range(100, 200)).remove_columns(['sentence', 'idx']).rename_column('label', 'labels')


Map:   0%|          | 0/8551 [00:00<?, ? examples/s]

Map:   0%|          | 0/1043 [00:00<?, ? examples/s]

Map:   0%|          | 0/1063 [00:00<?, ? examples/s]

In [63]:
len(dataset['train']['label'])

8551

In [73]:
# Define data loaders
batch_size = 1
trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=data_collator)
testloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True, collate_fn=data_collator)

In [74]:
for i in (testloader):
  print(i['input_ids'])
  break

tensor([[  101,  2562,  2115,  2677,  3844, 12347,  1010,  1996,  2062,  2198,
         20323,  1010,  7929,  1029,   102,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0]])


In [75]:
# With buffers

def loss_fn(params, buffers, input_ids, attention_mask, labels):
    loss = model(params, buffers=buffers, input_ids=input_ids, attention_mask=attention_mask, labels=labels).loss
    return loss

def softmax_output_fn(params, buffers, input_ids, attention_mask, labels):
    logits = model(params, buffers=buffers, input_ids=input_ids, attention_mask=attention_mask, labels=labels).logits
    logits_tensor = torch.tensor(logits)
    softmax_output = F.softmax(logits_tensor, dim=1)
    return softmax_output

In [70]:
# Without buffers

# def loss_fn(params, input_ids, attention_mask, labels):
#     loss = model(params, input_ids=input_ids, attention_mask=attention_mask, labels=labels).loss
#     return loss

# def softmax_output_fn(params, input_ids, attention_mask, labels):
#     logits = model(params, input_ids=input_ids, attention_mask=attention_mask, labels=labels).logits
#     logits_tensor = torch.tensor(logits)
#     softmax_output = F.softmax(logits_tensor, dim=1)
#     return softmax_output

In [76]:
# Train the model
for epoch in range(3):
    for i, data in enumerate(trainloader, 0):
        print("\n")
        print("*"*100)
        print("\n")

        if i == 0: # Initialize optimizer and model parameters
            print(f"input_ids: {data['input_ids']}")
            print(f"attention_mask: {data['attention_mask']}")
            print(f"labels: {data['labels']}")
            # Define optimizer
            base_optimizer = torchopt.adam(lr=0.01)
            optimizer = fosi_adam_torch(base_optimizer, loss_fn, data, num_iters_to_approx_eigs=500, alpha=0.01)
            model, params, buffers = functorch.make_functional_with_buffers(model=base_model)
            # model, params = functorch.make_functional(model=base_model)
            opt_state = optimizer.init(params)
            model.train()


        input_ids = data['input_ids']
        attention_mask = data['attention_mask']
        labels = data['labels']
        print(f"Labels: \n{labels}\n")

        loss = loss_fn(params, buffers, input_ids, attention_mask, labels)
        # loss = loss_fn(params, input_ids, attention_mask, labels)
        print(f"Step: {i}\n")
        print(f"Loss: {loss}\n")

        print(f"Calculating Gradients\n")
        grads = torch.autograd.grad(loss, params)
        # print(f"Grads: \n{grads}\n")
        print(f"Calculating updates in the model...\n")
        updates, opt_state = optimizer.update(grads, opt_state, params)
        print("Applying updates\n")
        params = torchopt.apply_updates(params, updates, inplace=False)

    # Evaluate the model
    model.eval()
    with torch.no_grad():
        for data in testloader:
            input_ids = data['input_ids'].to(device)
            attention_mask = data['attention_mask'].to(device)
            labels = data['labels'].to(device)

            logits = model(params, buffers, input_ids, attention_mask=attention_mask).logits
            # logits = model(params, input_ids, attention_mask=attention_mask).logits
            print(f"Logits: {logits}")
            logits_tensor = torch.tensor(logits)
            softmax_output = F.softmax(logits_tensor, dim=1)

            print(f"SoftMax of Logits: {softmax_output}")
            predicted_labels = torch.argmax(softmax_output, dim=1)
            print(f"Argmax Predictions: {predicted_labels}")

            metric.add_batch(predictions=predicted_labels, references=labels)


    print(f'Epoch: {epoch}')
    results = metric.compute()
    print(f"Results: \n{results}\n")
    # print(f'Correct: {correct}\n')
    # print(f'Total: {total}\n')
    # print(f'Predicted: {predicted}\n')
    # print(f'outputs: {outputs}\n')
    # print(f'labels: {labels}\n')
    # print(f'Accuracy of the network on the test data: {100 * correct / total}%')
    model.train() # for the next epoch

print('Finished Training')




****************************************************************************************************


input_ids: tensor([[  101,  1996, 15871,  2921,  1996,  8164,  7121,  2058,  1996,  4139,
          3240,  1012,   102,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0]])
attention_mask: tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0]])
labels: tensor([1])
Returned ESE function. Lanczos order (m) is 20 .


  warn_deprecated('make_functional_with_buffers', 'torch.func.functional_call')


Labels: 
tensor([1])

Step: 0

Loss: 0.6636227965354919

Calculating Gradients

Calculating updates in the model...

Applying updates



****************************************************************************************************


Labels: 
tensor([1])

Step: 1

Loss: 0.11219566315412521

Calculating Gradients

Calculating updates in the model...

Applying updates



****************************************************************************************************


Labels: 
tensor([1])

Step: 2

Loss: 7.986990567587782e-06

Calculating Gradients

Calculating updates in the model...

Applying updates



****************************************************************************************************


Labels: 
tensor([1])

Step: 3

Loss: 0.0

Calculating Gradients

Calculating updates in the model...

Applying updates



****************************************************************************************************


Labels: 
tensor([1])

Step: 4

Loss: 0.0

Calculati

  logits_tensor = torch.tensor(logits)


Logits: tensor([[-2.1857,  2.6897]])
SoftMax of Logits: tensor([[0.0076, 0.9924]])
Argmax Predictions: tensor([1])
Logits: tensor([[-2.1857,  2.6897]])
SoftMax of Logits: tensor([[0.0076, 0.9924]])
Argmax Predictions: tensor([1])
Logits: tensor([[-2.1857,  2.6897]])
SoftMax of Logits: tensor([[0.0076, 0.9924]])
Argmax Predictions: tensor([1])
Logits: tensor([[-2.1857,  2.6897]])
SoftMax of Logits: tensor([[0.0076, 0.9924]])
Argmax Predictions: tensor([1])
Logits: tensor([[-2.1857,  2.6897]])
SoftMax of Logits: tensor([[0.0076, 0.9924]])
Argmax Predictions: tensor([1])
Logits: tensor([[-2.1857,  2.6897]])
SoftMax of Logits: tensor([[0.0076, 0.9924]])
Argmax Predictions: tensor([1])
Logits: tensor([[-2.1857,  2.6897]])
SoftMax of Logits: tensor([[0.0076, 0.9924]])
Argmax Predictions: tensor([1])
Logits: tensor([[-2.1857,  2.6897]])
SoftMax of Logits: tensor([[0.0076, 0.9924]])
Argmax Predictions: tensor([1])
Logits: tensor([[-2.1857,  2.6897]])
SoftMax of Logits: tensor([[0.0076, 0.9924]

  warn_deprecated('make_functional_with_buffers', 'torch.func.functional_call')


Labels: 
tensor([1])

Step: 0

Loss: 0.5944719910621643

Calculating Gradients

Calculating updates in the model...

Applying updates



****************************************************************************************************


Labels: 
tensor([1])

Step: 1

Loss: 0.3595026433467865

Calculating Gradients

Calculating updates in the model...

Applying updates



****************************************************************************************************


Labels: 
tensor([1])

Step: 2

Loss: 0.0

Calculating Gradients

Calculating updates in the model...

Applying updates



****************************************************************************************************


Labels: 
tensor([1])

Step: 3

Loss: 0.0

Calculating Gradients

Calculating updates in the model...

Applying updates



****************************************************************************************************


Labels: 
tensor([0])

Step: 4

Loss: 68.18515014648438

Calculating Gr

  logits_tensor = torch.tensor(logits)


Logits: tensor([[-3.5832,  4.5755]])
SoftMax of Logits: tensor([[2.8614e-04, 9.9971e-01]])
Argmax Predictions: tensor([1])
Logits: tensor([[-3.5832,  4.5755]])
SoftMax of Logits: tensor([[2.8614e-04, 9.9971e-01]])
Argmax Predictions: tensor([1])
Logits: tensor([[-3.5832,  4.5755]])
SoftMax of Logits: tensor([[2.8614e-04, 9.9971e-01]])
Argmax Predictions: tensor([1])
Logits: tensor([[-3.5832,  4.5755]])
SoftMax of Logits: tensor([[2.8614e-04, 9.9971e-01]])
Argmax Predictions: tensor([1])
Logits: tensor([[-3.5832,  4.5755]])
SoftMax of Logits: tensor([[2.8614e-04, 9.9971e-01]])
Argmax Predictions: tensor([1])
Logits: tensor([[-3.5832,  4.5755]])
SoftMax of Logits: tensor([[2.8614e-04, 9.9971e-01]])
Argmax Predictions: tensor([1])
Logits: tensor([[-3.5832,  4.5755]])
SoftMax of Logits: tensor([[2.8614e-04, 9.9971e-01]])
Argmax Predictions: tensor([1])
Logits: tensor([[-3.5832,  4.5755]])
SoftMax of Logits: tensor([[2.8614e-04, 9.9971e-01]])
Argmax Predictions: tensor([1])
Logits: tensor([

  warn_deprecated('make_functional_with_buffers', 'torch.func.functional_call')


Labels: 
tensor([1])

Step: 0

Loss: 0.5761680603027344

Calculating Gradients

Calculating updates in the model...

Applying updates



****************************************************************************************************


Labels: 
tensor([1])

Step: 1

Loss: 0.3184877038002014

Calculating Gradients

Calculating updates in the model...

Applying updates



****************************************************************************************************


Labels: 
tensor([1])

Step: 2

Loss: 0.0

Calculating Gradients

Calculating updates in the model...

Applying updates



****************************************************************************************************


Labels: 
tensor([0])

Step: 3

Loss: 78.21977233886719

Calculating Gradients

Calculating updates in the model...

Applying updates



****************************************************************************************************


Labels: 
tensor([1])

Step: 4

Loss: 0.0

Calculating Gr

  logits_tensor = torch.tensor(logits)


Logits: tensor([[-4.0940,  3.4010]])
SoftMax of Logits: tensor([[5.5555e-04, 9.9944e-01]])
Argmax Predictions: tensor([1])
Logits: tensor([[-4.0940,  3.4010]])
SoftMax of Logits: tensor([[5.5555e-04, 9.9944e-01]])
Argmax Predictions: tensor([1])
Logits: tensor([[-4.0940,  3.4010]])
SoftMax of Logits: tensor([[5.5555e-04, 9.9944e-01]])
Argmax Predictions: tensor([1])
Logits: tensor([[-4.0940,  3.4010]])
SoftMax of Logits: tensor([[5.5555e-04, 9.9944e-01]])
Argmax Predictions: tensor([1])
Logits: tensor([[-4.0940,  3.4010]])
SoftMax of Logits: tensor([[5.5555e-04, 9.9944e-01]])
Argmax Predictions: tensor([1])
Logits: tensor([[-4.0940,  3.4010]])
SoftMax of Logits: tensor([[5.5555e-04, 9.9944e-01]])
Argmax Predictions: tensor([1])
Logits: tensor([[-4.0940,  3.4010]])
SoftMax of Logits: tensor([[5.5555e-04, 9.9944e-01]])
Argmax Predictions: tensor([1])
Logits: tensor([[-4.0940,  3.4010]])
SoftMax of Logits: tensor([[5.5555e-04, 9.9944e-01]])
Argmax Predictions: tensor([1])
Logits: tensor([

In [45]:
model

FunctionalModuleWithBuffers(
  (stateless_model): DistilBertForSequenceClassification(
    (distilbert): DistilBertModel(
      (embeddings): Embeddings(
        (word_embeddings): Embedding(30522, 768, padding_idx=0)
        (position_embeddings): Embedding(512, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (transformer): Transformer(
        (layer): ModuleList(
          (0-5): 6 x TransformerBlock(
            (attention): MultiHeadSelfAttention(
              (dropout): Dropout(p=0.1, inplace=False)
              (q_lin): Linear(in_features=768, out_features=768, bias=True)
              (k_lin): Linear(in_features=768, out_features=768, bias=True)
              (v_lin): Linear(in_features=768, out_features=768, bias=True)
              (out_lin): Linear(in_features=768, out_features=768, bias=True)
            )
            (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_aff

In [10]:
# Evaluate the model
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        input_ids = data['input_ids'].to(device)
        attention_mask = data['attention_mask'].to(device)
        labels = data['labels'].to(device)
        outputs = model(params, buffers, input_ids, attention_mask=attention_mask)
        _, predicted = torch.max(outputs.logits, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

In [11]:
labels

tensor([1, 0, 0, 0])

In [12]:
outputs.logits

tensor([[-0.0085,  0.1657],
        [-0.0104,  0.1667],
        [-0.0045,  0.1649],
        [-0.0055,  0.1650]])

In [13]:
_

tensor([0.1657, 0.1667, 0.1649, 0.1650])

In [14]:
predicted

tensor([1, 1, 1, 1])

In [15]:
correct

22

In [16]:
total

32

In [17]:
print(f'Accuracy of the network on the test data: {100 * correct / total}%')

Accuracy of the network on the test data: 68.75%


In [18]:
# # Define training loop
# for epoch in range(3):  # Adjust number of epochs as needed
#     model.train()
#     for i, batch in enumerate(train_dataset):
#         input_ids = batch['input_ids']
#         attention_mask = batch['attention_mask']
#         labels = batch['label']

#         #This is taking care automatically
#         # optimizer.zero_grad() This is taking care automatically

#         outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
#         loss = outputs.loss
#         loss.backward()
#         optimizer.step()

#         # if (i + 1) % 100 == 0:
#         #     print(f'[{epoch + 1}, {i + 1:5d}] loss: {loss.item():.3f}')

#     # Evaluate on test set
#     model.eval()
#     total = 0
#     correct = 0
#     with torch.no_grad():
#         for batch in test_dataset:
#             input_ids = batch['input_ids']
#             attention_mask = batch['attention_mask']
#             labels = batch['label']
#             print("Labels: \n", labels)

#             outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
#             logits = outputs.logits
#             _, predicted = torch.max(logits, 1)
#             total += labels.size(0)
#             correct += (predicted == labels).sum().item()

#     accuracy = correct / total
#     print(f'Test accuracy: {accuracy:.4f}')

# print('Finished Training')
