<a href="https://colab.research.google.com/github/PavlosPo/nlp-optimizers/blob/pavlos-playground/pytorch-experiments-1-fosi-adam/playground.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## FOSI Classes

In [46]:
# ! rm -rf ./fosi/
# !mkdir ./fosi/
# !unzip fosi.zip -d ./fosi/

## Working Example

In [47]:
# !pip install torchopt
# !pip install datasets
# !pip install evaluate

In [48]:
import torch
import torchvision.transforms as transforms
from torch.utils.data.dataloader import default_collate
from transformers import AutoModelForSequenceClassification, AutoTokenizer, DataCollatorWithPadding, Trainer, TrainingArguments
from torch.optim import Adam
import torch.nn as nn
import torchopt
import functorch
import evaluate
import torch.nn.functional as F

from datasets import load_dataset
from fosi import fosi_adam_torch

# # Set device
# device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# # Load pre-trained DistilBERT model and tokenizer
# base_model = AutoModelForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)
# base_model.to(device)
# tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

In [49]:
import torch
import torch.nn as nn
from transformers import AutoModelForSequenceClassification, AutoTokenizer, DataCollatorWithPadding

# Load pre-trained BERT model and tokenizer
bert_model_name = "bert-base-uncased"
num_classes = 1
bert_model = AutoModelForSequenceClassification.from_pretrained(bert_model_name, num_labels=num_classes)
tokenizer = AutoTokenizer.from_pretrained(bert_model_name)

# # Freeze all parameters of the BERT model except for the last layer (classifier)
# for name, param in bert_model.named_parameters():
#     if 'classifier' not in name:  # Exclude parameters of the classifier layer
#         param.requires_grad = False

# Define a dense layer on top of the BERT model for classification
class BertClassifier(nn.Module):
    def __init__(self, bert_model, num_classes):
        super(BertClassifier, self).__init__()
        self.bert = bert_model
        self.sigmoid = nn.Sigmoid()

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask).logits
        probability = self.sigmoid(outputs.squeeze())
        return probability

# Instantiate the classifier
classifier = BertClassifier(bert_model, num_classes)

# Optionally, move the model to a GPU device if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
classifier.to(device)

# Now you can train the entire model (BERT + dense layer) on your task-specific data
# Make sure to prepare your data (input_ids, attention_mask, labels) using the BERT tokenizer
# and use an appropriate loss function and optimizer for your task


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


BertClassifier(
  (bert): BertForSequenceClassification(
    (bert): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(30522, 768, padding_idx=0)
        (position_embeddings): Embedding(512, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0-11): 12 x BertLayer(
            (attention): BertAttention(
              (self): BertSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): BertSelfOutput(
                (dense): Linear(in_features=768, out_features=768, 

In [50]:

# Define a function to preprocess the dataset
def prepare_dataset(example):
    return tokenizer(example['sentence'], add_special_tokens=True, truncation=True, padding=True, return_tensors='pt')

data_collator = DataCollatorWithPadding(tokenizer=tokenizer, padding=True)

dataset = load_dataset('glue', 'sst2').map(prepare_dataset, batched=True)
metric = evaluate.load("glue", "sst2")

# Split dataset into train and test sets, we use the train category because the test one has labels -1 only.
train_dataset = dataset['train'].select(range(0,500)).remove_columns(['sentence', 'idx']).rename_column('label', 'labels')
test_dataset = dataset['train'].select(range(500, 1000)).remove_columns(['sentence', 'idx']).rename_column('label', 'labels')


In [51]:
len(dataset['train']['label'])

67349

In [52]:
# Define data loaders
batch_size = 128
trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=data_collator)
testloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True, collate_fn=data_collator)

In [53]:
for i in (testloader):
  print(i['input_ids'])
  break

tensor([[  101,  2130, 25591,  ...,     0,     0,     0],
        [  101,  2003,  1037,  ...,     0,     0,     0],
        [  101,  2437, 10556,  ...,     0,     0,     0],
        ...,
        [  101,  8562,  3238,  ...,     0,     0,     0],
        [  101,  2563,  1005,  ...,     0,     0,     0],
        [  101,  2009,  5078,  ...,     0,     0,     0]])


In [54]:
# With buffers

def loss_fn(functional_model, params, buffers, input_ids, attention_mask, labels):
    preds = functional_model(params, buffers=buffers, input_ids=input_ids, attention_mask=attention_mask)
    loss = nn.functional.binary_cross_entropy(preds.squeeze().to(torch.float32), labels.squeeze().to(torch.float32))
    return loss

# def softmax_output_fn(params, buffers, input_ids, attention_mask, labels):
#     logits = model(params, buffers=buffers, input_ids=input_ids, attention_mask=attention_mask, labels=labels).logits
#     logits_tensor = torch.tensor(logits)
#     softmax_output = F.softmax(logits_tensor, dim=1)
#     return softmax_output

In [55]:
# Without buffers

# def loss_fn(params, input_ids, attention_mask, labels):
#     loss = model(params, input_ids=input_ids, attention_mask=attention_mask, labels=labels).loss
#     return loss

# def softmax_output_fn(params, input_ids, attention_mask, labels):
#     logits = model(params, input_ids=input_ids, attention_mask=attention_mask, labels=labels).logits
#     logits_tensor = torch.tensor(logits)
#     softmax_output = F.softmax(logits_tensor, dim=1)
#     return softmax_output

In [66]:

# Initialize optimizer and model parameters
# Those are needed in order for fosi_adam_to run in the loop below

data = next(iter(trainloader))  # get first batch of data

print(f"input_ids: {data['input_ids']}")
print(f"attention_mask: {data['attention_mask']}")
print(f"labels: {data['labels']}")

# Define optimizer
classifier.train()
base_optimizer = torchopt.adam(lr=0.01)
optimizer = fosi_adam_torch(base_optimizer, loss_fn, data, num_iters_to_approx_eigs=500, alpha=0.01)
model, params, buffers = functorch.make_functional_with_buffers(model=classifier)
opt_state = optimizer.init(params)

input_ids: tensor([[  101, 16514,  2135,  ...,     0,     0,     0],
        [  101,  2053,  2047,  ...,     0,     0,     0],
        [  101,  2521, 24763,  ...,     0,     0,     0],
        ...,
        [  101,  2009, 11014,  ...,     0,     0,     0],
        [  101,  2003,  3432,  ...,     0,     0,     0],
        [  101,  2202,  2729,  ...,     0,     0,     0]])
attention_mask: tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]])
labels: tensor([1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        1, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1,
        0, 1, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0,
        0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1,
        1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 

In [67]:
# Train the model
for epoch in range(2):
    for i, data in enumerate(trainloader, 1):
        print("\n")
        print("*"*100)
        print("\n")

        input_ids = data['input_ids'].squeeze().to(device)
        attention_mask = data['attention_mask'].squeeze().to(device)
        labels = data['labels'].squeeze().to(device)
        print(f"Labels: \n{labels}\n")

        loss = loss_fn(functional_model=model, 
                       params=params, buffers=buffers, input_ids=input_ids,attention_mask=attention_mask, labels=labels)
        print(f"Step: {i}\n")
        print(f"Loss: {loss}\n")

        # func_optimizer.step(loss, params, inplace=True)

        # Calculate gradients
        grads = torch.autograd.grad(loss, params)

        # Update model parameters
        updates, opt_state = optimizer.update(grads, opt_state, params)
        params = torchopt.apply_updates(params, updates, inplace=True)

        # Should this method exists? Are buffers neseccary to update?
        # Update buffers based on updated parameters
        # buffers = torchopt.update_buffers(model, params, buffers)
    evaluation_results = []

    # Evaluate the model
    classifier.eval()
    with torch.no_grad():
        for i, data in enumerate(testloader):
            evaluation_result = {}

            input_ids = data['input_ids'].squeeze().to(device)
            attention_mask = data['attention_mask'].squeeze().to(device)
            labels = data['labels'].squeeze().to(device)

            preds = model(params, buffers, input_ids, attention_mask=attention_mask)

            predicted_labels = torch.round(preds)

            # Save the evaluation results
            evaluation_result['input_ids'] = input_ids.cpu().tolist()
            evaluation_result['attention_mask'] = attention_mask.cpu().tolist()
            evaluation_result['labels'] = labels.cpu().tolist()
            evaluation_result['preds'] = preds.cpu().tolist()
            evaluation_result['predicted_label'] = predicted_labels.cpu().tolist()

            evaluation_results.append(evaluation_result)
            metric.add_batch(predictions=predicted_labels, references=labels)

print(f'Epoch: {epoch}')
results = metric.compute()
print(f"Results: \n{results}\n")

print('Finished Training')




****************************************************************************************************


Labels: 
tensor([0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 0,
        0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 1, 0, 0, 1, 0,
        0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 1,
        0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0,
        0, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 0, 0, 1, 1, 1])

Step: 1

Loss: 0.6880813837051392



****************************************************************************************************


Labels: 
tensor([0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 0, 0, 0, 1, 0, 1,
        1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1,
        0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1,
        0, 1, 1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0

In [68]:
print(evaluation_results[0]['labels'])

[1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1]


In [69]:
print(evaluation_results[0].keys())

dict_keys(['input_ids', 'attention_mask', 'labels', 'preds', 'predicted_label'])


In [70]:
print(evaluation_results[0]['preds'])

[0.36563795804977417, 0.3529967665672302, 0.3821791708469391, 0.4057554602622986, 0.3534340560436249, 0.3860921263694763, 0.31445321440696716, 0.35191407799720764, 0.36207184195518494, 0.4616522789001465, 0.46333950757980347, 0.40192654728889465, 0.3814569115638733, 0.3836885094642639, 0.31321561336517334, 0.4732728600502014, 0.5138583183288574, 0.37511923909187317, 0.47060903906822205, 0.3644796907901764, 0.31288063526153564, 0.3747831881046295, 0.34744465351104736, 0.4050697088241577, 0.45658013224601746, 0.35638412833213806, 0.406047523021698, 0.46350979804992676, 0.372139573097229, 0.42184123396873474, 0.4497975707054138, 0.3508889675140381, 0.29819008708000183, 0.37678325176239014, 0.31483420729637146, 0.40384426712989807, 0.4338630139827728, 0.3745567202568054, 0.41909345984458923, 0.41355180740356445, 0.4067532420158386, 0.35739511251449585, 0.36872386932373047, 0.3335361182689667, 0.36576971411705017, 0.4096563160419464, 0.350468248128891, 0.35859355330467224, 0.364814192056655

In [71]:
print(evaluation_results[0]['predicted_label'])

[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]


In [62]:
import gc
gc.collect() # Python thing
torch.cuda.empty_cache() # PyTorch thing

In [63]:
# # Train the model
# for epoch in range(3):
#     for i, data in enumerate(trainloader, 0):
#         print("\n")
#         print("*"*100)
#         print("\n")

#         if i == 0: # Initialize optimizer and model parameters
#             print(f"input_ids: {data['input_ids']}")
#             print(f"attention_mask: {data['attention_mask']}")
#             print(f"labels: {data['labels']}")
#             # Define optimizer
#             base_optimizer = torchopt.adam(lr=0.01)
#             optimizer = fosi_adam_torch(base_optimizer, loss_fn, data, num_iters_to_approx_eigs=500, alpha=0.01)
#             model, params, buffers = functorch.make_functional_with_buffers(model=base_model)
#             # model, params = functorch.make_functional(model=base_model)
#             opt_state = optimizer.init(params)
#             # model.train()


#         input_ids = data['input_ids']
#         attention_mask = data['attention_mask']
#         labels = data['labels']
#         print(f"Labels: \n{labels}\n")

#         loss = loss_fn(params, buffers, input_ids, attention_mask, labels)
#         # loss = loss_fn(params, input_ids, attention_mask, labels)
#         print(f"Step: {i}\n")
#         print(f"Loss: {loss}\n")

#         print(f"Calculating Gradients\n")
#         grads = torch.autograd.grad(loss, params)
#         # print(f"Grads: \n{grads}\n")
#         print(f"Calculating updates in the model...\n")
#         updates, opt_state = optimizer.update(grads, opt_state, params)
#         print("Applying updates\n")
#         params = torchopt.apply_updates(params, updates, inplace=True)

#     # Evaluate the model
#     # model.eval()
#     with torch.no_grad():
#         for data in testloader:
#             input_ids = data['input_ids'].to(device)
#             attention_mask = data['attention_mask'].to(device)
#             labels = data['labels'].to(device)

#             logits = model(params, buffers, input_ids, attention_mask=attention_mask).logits
#             # logits = model(params, input_ids, attention_mask=attention_mask).logits
#             print(f"Logits: {logits}")
#             logits_tensor = torch.tensor(logits)
#             softmax_output = F.softmax(logits_tensor, dim=1)

#             print(f"SoftMax of Logits: {softmax_output}")
#             predicted_labels = torch.argmax(softmax_output, dim=1)
#             print(f"Argmax Predictions: {predicted_labels}")

#             metric.add_batch(predictions=predicted_labels, references=labels)


#     print(f'Epoch: {epoch}')
#     results = metric.compute()
#     print(f"Results: \n{results}\n")
#     # print(f'Correct: {correct}\n')
#     # print(f'Total: {total}\n')
#     # print(f'Predicted: {predicted}\n')
#     # print(f'outputs: {outputs}\n')
#     # print(f'labels: {labels}\n')
#     # print(f'Accuracy of the network on the test data: {100 * correct / total}%')
#     # model.train() # for the next epoch

# print('Finished Training')


In [64]:
# # Evaluate the model
# model.eval()
# correct = 0
# total = 0
# with torch.no_grad():
#     for data in testloader:
#         input_ids = data['input_ids'].to(device)
#         attention_mask = data['attention_mask'].to(device)
#         labels = data['labels'].to(device)
#         outputs = model(params, buffers, input_ids, attention_mask=attention_mask)
#         _, predicted = torch.max(outputs.logits, 1)
#         total += labels.size(0)
#         correct += (predicted == labels).sum().item()

In [65]:
# # Define training loop
# for epoch in range(3):  # Adjust number of epochs as needed
#     model.train()
#     for i, batch in enumerate(train_dataset):
#         input_ids = batch['input_ids']
#         attention_mask = batch['attention_mask']
#         labels = batch['label']

#         #This is taking care automatically
#         # optimizer.zero_grad() This is taking care automatically

#         outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
#         loss = outputs.loss
#         loss.backward()
#         optimizer.step()

#         # if (i + 1) % 100 == 0:
#         #     print(f'[{epoch + 1}, {i + 1:5d}] loss: {loss.item():.3f}')

#     # Evaluate on test set
#     model.eval()
#     total = 0
#     correct = 0
#     with torch.no_grad():
#         for batch in test_dataset:
#             input_ids = batch['input_ids']
#             attention_mask = batch['attention_mask']
#             labels = batch['label']
#             print("Labels: \n", labels)

#             outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
#             logits = outputs.logits
#             _, predicted = torch.max(logits, 1)
#             total += labels.size(0)
#             correct += (predicted == labels).sum().item()

#     accuracy = correct / total
#     print(f'Test accuracy: {accuracy:.4f}')

# print('Finished Training')
