<a href="https://colab.research.google.com/github/PavlosPo/nlp-optimizers/blob/pavlos-playground/pytorch-experiments-1-fosi-adam/playground.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## FOSI Classes

In [14]:
! rm -rf ./fosi/
!mkdir ./fosi/
!unzip fosi.zip -d ./fosi/

Archive:  fosi.zip
 extracting: ./fosi/__init__.py      
   creating: ./fosi/jax_optim/
 extracting: ./fosi/jax_optim/__init__.py  
 extracting: ./fosi/jax_optim/extreme_spectrum_estimation.py  
 extracting: ./fosi/jax_optim/fosi_optimizer.py  
 extracting: ./fosi/jax_optim/lanczos_algorithm.py  
 extracting: ./fosi/jax_optim/lanczos_algorithm_sanity.py  
   creating: ./fosi/torch_optim/
 extracting: ./fosi/torch_optim/__init__.py  
 extracting: ./fosi/torch_optim/extreme_spectrum_estimation.py  
 extracting: ./fosi/torch_optim/fosi_optimizer.py  
 extracting: ./fosi/torch_optim/lanczos_algorithm.py  
 extracting: ./fosi/torch_optim/lanczos_algorithm_sanity.py  
 extracting: ./fosi/version.py       


## Working Example

In [1]:
!pip install torchopt
!pip install datasets
!pip install evaluate



In [24]:
import torch
import torchvision.transforms as transforms
from torch.utils.data.dataloader import default_collate
from transformers import AutoModelForSequenceClassification, AutoTokenizer, DataCollatorWithPadding, Trainer, TrainingArguments
from torch.optim import Adam
import torch.nn as nn
import torchopt
import functorch
import evaluate
import torch.nn.functional as F

from datasets import load_dataset
from fosi import fosi_adam_torch

# Set device
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# Load pre-trained DistilBERT model and tokenizer
base_model = AutoModelForSequenceClassification.from_pretrained('distilbert-base-uncased', num_labels=2)
base_model.to(device)
tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [54]:

# Define a function to preprocess the dataset
def prepare_dataset(example):
    return tokenizer(example['sentence'], add_special_tokens=True, truncation=True, padding=True, return_tensors='pt')

data_collator = DataCollatorWithPadding(tokenizer=tokenizer, padding=True)

dataset = load_dataset('glue', 'sst2').map(prepare_dataset, batched=True)
metric = evaluate.load("glue", "sst2")

# Split dataset into train and test sets, we use the train category because the test one has labels -1 only.
train_dataset = dataset['train'].remove_columns(['sentence', 'idx']).rename_column('label', 'labels')
test_dataset = dataset['train'].remove_columns(['sentence', 'idx']).rename_column('label', 'labels')


Downloading data:   0%|          | 0.00/3.11M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/72.8k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/148k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/67349 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/872 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1821 [00:00<?, ? examples/s]

Map:   0%|          | 0/67349 [00:00<?, ? examples/s]

Map:   0%|          | 0/872 [00:00<?, ? examples/s]

Map:   0%|          | 0/1821 [00:00<?, ? examples/s]

In [55]:
len(dataset['train']['label'])

67349

In [56]:
# Define data loaders
batch_size = 64
trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=data_collator)
testloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True, collate_fn=data_collator)

In [57]:
for i in (testloader):
  print(i['input_ids'])
  break

tensor([[  101,  1010,  2003,  ...,     0,     0,     0],
        [  101,  1996,  6579,  ...,     0,     0,     0],
        [  101,  7494,  4519,  ...,     0,     0,     0],
        ...,
        [  101,  3328,  2000,  ...,     0,     0,     0],
        [  101,  1010,  2471,  ...,     0,     0,     0],
        [  101, 18691,  6447,  ...,     0,     0,     0]])


In [58]:
# With buffers

def loss_fn(params, buffers, input_ids, attention_mask, labels):
    loss = model(params, buffers=buffers, input_ids=input_ids, attention_mask=attention_mask, labels=labels).loss
    return loss

def softmax_output_fn(params, buffers, input_ids, attention_mask, labels):
    logits = model(params, buffers=buffers, input_ids=input_ids, attention_mask=attention_mask, labels=labels).logits
    logits_tensor = torch.tensor(logits)
    softmax_output = F.softmax(logits_tensor, dim=1)
    return softmax_output

In [59]:
# Without buffers

# def loss_fn(params, input_ids, attention_mask, labels):
#     loss = model(params, input_ids=input_ids, attention_mask=attention_mask, labels=labels).loss
#     return loss

# def softmax_output_fn(params, input_ids, attention_mask, labels):
#     logits = model(params, input_ids=input_ids, attention_mask=attention_mask, labels=labels).logits
#     logits_tensor = torch.tensor(logits)
#     softmax_output = F.softmax(logits_tensor, dim=1)
#     return softmax_output

In [None]:
# Train the model
for epoch in range(3):
    for i, data in enumerate(trainloader, 0):
        print("\n")
        print("*"*100)
        print("\n")

        if i == 0: # Initialize optimizer and model parameters
            print(f"input_ids: {data['input_ids']}")
            print(f"attention_mask: {data['attention_mask']}")
            print(f"labels: {data['labels']}")
            # Define optimizer
            base_optimizer = torchopt.adam(lr=0.01)
            optimizer = fosi_adam_torch(base_optimizer, loss_fn, data, num_iters_to_approx_eigs=500, alpha=0.01)
            model, params, buffers = functorch.make_functional_with_buffers(model=base_model)
            # Initialize buffers based on initial parameters
            # buffers = initialize_buffers(model, params, buffers)
            opt_state = optimizer.init(params)

        input_ids = data['input_ids'].to(device)
        attention_mask = data['attention_mask'].to(device)
        labels = data['labels'].to(device)
        print(f"Labels: \n{labels}\n")

        loss = loss_fn(params, buffers, input_ids, attention_mask, labels)
        print(f"Step: {i}\n")
        print(f"Loss: {loss}\n")

        # Calculate gradients
        grads = torch.autograd.grad(loss, params)

        # Update model parameters
        updates, opt_state = optimizer.update(grads, opt_state, params)
        params = torchopt.apply_updates(params, updates, inplace=True)

        # Should this method exists? Are buffers neseccary to update?
        # Update buffers based on updated parameters
        # buffers = torchopt.update_buffers(model, params, buffers)

    # Evaluate the model
    with torch.no_grad():
        for data in testloader:
            input_ids = data['input_ids'].to(device)
            attention_mask = data['attention_mask'].to(device)
            labels = data['labels'].to(device)

            logits = model(params, buffers, input_ids, attention_mask=attention_mask).logits
            logits_tensor = torch.tensor(logits)
            softmax_output = F.softmax(logits_tensor, dim=1)

            predicted_labels = torch.argmax(softmax_output, dim=1)
            metric.add_batch(predictions=predicted_labels, references=labels)

    print(f'Epoch: {epoch}')
    results = metric.compute()
    print(f"Results: \n{results}\n")

print('Finished Training')




****************************************************************************************************


input_ids: tensor([[  101,  2019, 23824,  ...,     0,     0,     0],
        [  101,  4248,  1011,  ...,     0,     0,     0],
        [  101, 12075,  2571,  ...,     0,     0,     0],
        ...,
        [  101,  2431,  1011,  ...,     0,     0,     0],
        [  101,  2066,  1996,  ...,     0,     0,     0],
        [  101, 11969,  2003,  ...,     0,     0,     0]])
attention_mask: tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]])
labels: tensor([1, 1, 1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1,
        0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1,
        1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 0, 1])
Returned ESE function. Lanczos order (m) is 20 .
Labels: 
tenso

  warn_deprecated('make_functional_with_buffers', 'torch.func.functional_call')


[1;30;43mΗ έξοδος ροής περικόπηκε στις τελευταίες 5000 γραμμές.[0m
Loss: 0.6964955925941467



****************************************************************************************************


Labels: 
tensor([0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0,
        1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0,
        0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1], device='cuda:0')

Step: 143

Loss: 0.6856685280799866



****************************************************************************************************


Labels: 
tensor([1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 0,
        1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0,
        1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0], device='cuda:0')

Step: 144

Loss: 0.6722356677055359



****************************************************************************************************


Labels: 
tensor([1, 1, 0, 1,

In [None]:
input_ids = data['input_ids']
attention_mask = data['attention_mask']
labels = data['labels']
# Define optimizer
results = model(params, buffer, input_ids, attention_mask, labels).logits
print(results)
# Initialize buffers based on initial parameters
# buffers = initialize_buffers(model, params, buffers)
# opt_state = optimizer.init(params)

In [None]:
# # Train the model
# for epoch in range(3):
#     for i, data in enumerate(trainloader, 0):
#         print("\n")
#         print("*"*100)
#         print("\n")

#         if i == 0: # Initialize optimizer and model parameters
#             print(f"input_ids: {data['input_ids']}")
#             print(f"attention_mask: {data['attention_mask']}")
#             print(f"labels: {data['labels']}")
#             # Define optimizer
#             base_optimizer = torchopt.adam(lr=0.01)
#             optimizer = fosi_adam_torch(base_optimizer, loss_fn, data, num_iters_to_approx_eigs=500, alpha=0.01)
#             model, params, buffers = functorch.make_functional_with_buffers(model=base_model)
#             # model, params = functorch.make_functional(model=base_model)
#             opt_state = optimizer.init(params)
#             # model.train()


#         input_ids = data['input_ids']
#         attention_mask = data['attention_mask']
#         labels = data['labels']
#         print(f"Labels: \n{labels}\n")

#         loss = loss_fn(params, buffers, input_ids, attention_mask, labels)
#         # loss = loss_fn(params, input_ids, attention_mask, labels)
#         print(f"Step: {i}\n")
#         print(f"Loss: {loss}\n")

#         print(f"Calculating Gradients\n")
#         grads = torch.autograd.grad(loss, params)
#         # print(f"Grads: \n{grads}\n")
#         print(f"Calculating updates in the model...\n")
#         updates, opt_state = optimizer.update(grads, opt_state, params)
#         print("Applying updates\n")
#         params = torchopt.apply_updates(params, updates, inplace=True)

#     # Evaluate the model
#     # model.eval()
#     with torch.no_grad():
#         for data in testloader:
#             input_ids = data['input_ids'].to(device)
#             attention_mask = data['attention_mask'].to(device)
#             labels = data['labels'].to(device)

#             logits = model(params, buffers, input_ids, attention_mask=attention_mask).logits
#             # logits = model(params, input_ids, attention_mask=attention_mask).logits
#             print(f"Logits: {logits}")
#             logits_tensor = torch.tensor(logits)
#             softmax_output = F.softmax(logits_tensor, dim=1)

#             print(f"SoftMax of Logits: {softmax_output}")
#             predicted_labels = torch.argmax(softmax_output, dim=1)
#             print(f"Argmax Predictions: {predicted_labels}")

#             metric.add_batch(predictions=predicted_labels, references=labels)


#     print(f'Epoch: {epoch}')
#     results = metric.compute()
#     print(f"Results: \n{results}\n")
#     # print(f'Correct: {correct}\n')
#     # print(f'Total: {total}\n')
#     # print(f'Predicted: {predicted}\n')
#     # print(f'outputs: {outputs}\n')
#     # print(f'labels: {labels}\n')
#     # print(f'Accuracy of the network on the test data: {100 * correct / total}%')
#     # model.train() # for the next epoch

# print('Finished Training')


In [None]:
model

In [None]:
# Evaluate the model
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        input_ids = data['input_ids'].to(device)
        attention_mask = data['attention_mask'].to(device)
        labels = data['labels'].to(device)
        outputs = model(params, buffers, input_ids, attention_mask=attention_mask)
        _, predicted = torch.max(outputs.logits, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

In [None]:
labels

In [None]:
outputs.logits

In [None]:
_

In [None]:
predicted

In [None]:
correct

In [None]:
total

In [None]:
print(f'Accuracy of the network on the test data: {100 * correct / total}%')

In [None]:
# # Define training loop
# for epoch in range(3):  # Adjust number of epochs as needed
#     model.train()
#     for i, batch in enumerate(train_dataset):
#         input_ids = batch['input_ids']
#         attention_mask = batch['attention_mask']
#         labels = batch['label']

#         #This is taking care automatically
#         # optimizer.zero_grad() This is taking care automatically

#         outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
#         loss = outputs.loss
#         loss.backward()
#         optimizer.step()

#         # if (i + 1) % 100 == 0:
#         #     print(f'[{epoch + 1}, {i + 1:5d}] loss: {loss.item():.3f}')

#     # Evaluate on test set
#     model.eval()
#     total = 0
#     correct = 0
#     with torch.no_grad():
#         for batch in test_dataset:
#             input_ids = batch['input_ids']
#             attention_mask = batch['attention_mask']
#             labels = batch['label']
#             print("Labels: \n", labels)

#             outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
#             logits = outputs.logits
#             _, predicted = torch.max(logits, 1)
#             total += labels.size(0)
#             correct += (predicted == labels).sum().item()

#     accuracy = correct / total
#     print(f'Test accuracy: {accuracy:.4f}')

# print('Finished Training')
