In [1]:
import torch
import torchvision.transforms as transforms
from torch.utils.data.dataloader import default_collate
from transformers import AutoModelForSequenceClassification, AutoTokenizer, DataCollatorWithPadding, Trainer, TrainingArguments
from torch.optim import Adam
import torch.nn as nn
import torchopt
import functorch

from datasets import load_dataset
from fosi import fosi_adam_torch

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load pre-trained DistilBERT model and tokenizer
model = AutoModelForSequenceClassification.from_pretrained('distilbert-base-uncased', num_labels=2)
tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')

  from .autonotebook import tqdm as notebook_tqdm
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [2]:

# Define a function to preprocess the dataset
def prepare_dataset(example):
    return tokenizer(example['sentence'], truncation=True, padding="max_length", return_tensors='pt')

data_collator = DataCollatorWithPadding(tokenizer=tokenizer, padding=True)

# Define a sample dataset (replace this with your custom dataset)
# Example: IMDB movie review dataset
dataset = load_dataset('glue', 'cola').map(prepare_dataset, batched=True)

# Split dataset into train and test sets
train_dataset = dataset['train'].remove_columns(['sentence', 'idx']).rename_column('label', 'labels')

test_dataset = dataset['test'].remove_columns(['sentence', 'idx']).rename_column('label', 'labels')


Map: 100%|██████████| 1043/1043 [00:00<00:00, 5588.21 examples/s]


In [3]:
train_dataset

Dataset({
    features: ['labels', 'input_ids', 'attention_mask'],
    num_rows: 8551
})

In [4]:
# Define data loaders
batch_size = 8
trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=data_collator)
testloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=data_collator)

# Define loss function
loss_fn = torch.nn.CrossEntropyLoss()

# Define optimizer
base_optimizer = torchopt.adam(lr=0.001)
optimizer = fosi_adam_torch(base_optimizer, loss_fn, next(iter(trainloader)), num_iters_to_approx_eigs=500, alpha=0.01)
model, params, buffers = functorch.make_functional_with_buffers(model=model)


  warn_deprecated('make_functional_with_buffers', 'torch.func.functional_call')


Returned ESE function. Lanczos order (m) is 20 .


In [13]:
model.param_names

('distilbert.embeddings.word_embeddings.weight',
 'distilbert.embeddings.position_embeddings.weight',
 'distilbert.embeddings.LayerNorm.weight',
 'distilbert.embeddings.LayerNorm.bias',
 'distilbert.transformer.layer.0.attention.q_lin.weight',
 'distilbert.transformer.layer.0.attention.q_lin.bias',
 'distilbert.transformer.layer.0.attention.k_lin.weight',
 'distilbert.transformer.layer.0.attention.k_lin.bias',
 'distilbert.transformer.layer.0.attention.v_lin.weight',
 'distilbert.transformer.layer.0.attention.v_lin.bias',
 'distilbert.transformer.layer.0.attention.out_lin.weight',
 'distilbert.transformer.layer.0.attention.out_lin.bias',
 'distilbert.transformer.layer.0.sa_layer_norm.weight',
 'distilbert.transformer.layer.0.sa_layer_norm.bias',
 'distilbert.transformer.layer.0.ffn.lin1.weight',
 'distilbert.transformer.layer.0.ffn.lin1.bias',
 'distilbert.transformer.layer.0.ffn.lin2.weight',
 'distilbert.transformer.layer.0.ffn.lin2.bias',
 'distilbert.transformer.layer.0.output_laye

In [6]:
opt_state = optimizer.init(params)

In [25]:
# def loss_fn(params, batch):
#     preds = model(params, batch['input_ids'], batch['attention_mask']).logits
#     loss = nn.CrossEntropyLoss()(preds, batch)
#     return loss

def loss_fn(params, buffers, input_ids, attention_mask, labels):
    logits = model(params, buffers=buffers, input_ids=input_ids, attention_mask=attention_mask).logits
    loss = nn.CrossEntropyLoss()(logits, labels)
    return loss

In [26]:
def accuracy(params, buffers, input_ids, attention_mask, labels):
    preds = model(params, buffers=buffers, input_ids=input_ids, attention_mask=attention_mask).logits
    predicted_class = torch.argmax(preds, dim=1)
    return torch.sum(predicted_class == batch[1])

In [28]:
# Get first batch of data
data = next(iter(train_dataset))
input_ids = data['input_ids']
attention_mask = data['attention_mask']

# Train the model
model.train()
for epoch in range(3):
    
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        input_ids = data['input_ids']
        attention_mask = data['attention_mask']
        label = data['labels']
        # optimizer.zero_grad()
        # outputs = model(input_ids, attention_mask=attention_mask, labels=label)
        
        loss = loss_fn(params=params, buffers=buffers, input_ids=input_ids, attention_mask=attention_mask, labels=label)
        
        grads = torch.autograd.grad(loss, params)
        updates, opt_state = optimizer.update(grads, opt_state, params)
        params = torchopt.apply_updates(params, updates, inplace=True)

        # print statistics
        running_loss += loss.item()
        if (i + 1) % 100 == 0:
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 100:.3f}')
            running_loss = 0.0

    acc = 0.0
    num_samples = 0
    for i, batch in enumerate(testloader, 0):
        acc += accuracy(params, batch)
        num_samples += batch[0].shape[0]
    acc /= num_samples
    print(f'Test accuracy: {acc}')

print('Finished Training')


KeyboardInterrupt: 

In [None]:
# # Define training loop
# for epoch in range(3):  # Adjust number of epochs as needed
#     model.train()
#     for i, batch in enumerate(train_dataset):
#         input_ids = batch['input_ids']
#         attention_mask = batch['attention_mask']
#         labels = batch['label']
        
#         #This is taking care automatically
#         # optimizer.zero_grad() This is taking care automatically

#         outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
#         loss = outputs.loss
#         loss.backward()
#         optimizer.step()

#         # if (i + 1) % 100 == 0:
#         #     print(f'[{epoch + 1}, {i + 1:5d}] loss: {loss.item():.3f}')

#     # Evaluate on test set
#     model.eval()
#     total = 0
#     correct = 0
#     with torch.no_grad():
#         for batch in test_dataset:
#             input_ids = batch['input_ids']
#             attention_mask = batch['attention_mask']
#             labels = batch['label']
#             print("Labels: \n", labels)
            
#             outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
#             logits = outputs.logits
#             _, predicted = torch.max(logits, 1)
#             total += labels.size(0)
#             correct += (predicted == labels).sum().item()

#     accuracy = correct / total
#     print(f'Test accuracy: {accuracy:.4f}')

# print('Finished Training')
