In [None]:
from datasets import load_dataset 
from datasets import load_metric
import evaluate

from transformers import AutoImageProcessor
from transformers import AutoModelForImageClassification, TrainingArguments, Trainer

from torchvision.transforms import (
    CenterCrop,
    Compose,
    Normalize,
    RandomHorizontalFlip,
    RandomResizedCrop,
    Resize,
    ToTensor,
)

from runlora.modeling import RunLoRAModel
from runlora import RunLoRACollection

import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import gc

In [None]:
def report_params(model):
    params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f'Params: {params}, Trainable Params: {trainable_params}')
    return params, trainable_params

In [None]:
dataset = load_dataset("food101")

In [None]:
dataset

In [None]:
dataset["train"][10]['image'].resize((200, 200))

In [None]:
metric = evaluate.load("accuracy")

In [None]:
labels = dataset["train"].features["label"].names
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
    label2id[label] = i
    id2label[i] = label

id2label[2]

In [None]:
model_checkpoint = "google/vit-base-patch16-224-in21k"
image_processor = AutoImageProcessor.from_pretrained(model_checkpoint, cache_dir=cache_dir)

In [None]:
# the compute_metrics function takes a Named Tuple as input:
# predictions, which are the logits of the model as Numpy arrays,
# and label_ids, which are the ground-truth labels as Numpy arrays.
def compute_metrics(eval_pred):
    """Computes accuracy on a batch of predictions"""
    predictions = np.argmax(eval_pred.predictions, axis=1)
    return metric.compute(predictions=predictions, references=eval_pred.label_ids)

def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example["label"] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}

In [None]:
normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
if "height" in image_processor.size:
    size = (image_processor.size["height"], image_processor.size["width"])
    crop_size = size
    max_size = None
elif "shortest_edge" in image_processor.size:
    size = image_processor.size["shortest_edge"]
    crop_size = (size, size)
    max_size = image_processor.size.get("longest_edge")

train_transforms = Compose(
        [
            RandomResizedCrop(crop_size),
            RandomHorizontalFlip(),
            ToTensor(),
            normalize,
        ]
    )

val_transforms = Compose(
        [
            Resize(size),
            CenterCrop(crop_size),
            ToTensor(),
            normalize,
        ]
    )

def preprocess_train(example_batch):
    """Apply train_transforms across a batch."""
    example_batch["pixel_values"] = [
        train_transforms(image.convert("RGB")) for image in example_batch["image"]
    ]
    return example_batch

def preprocess_val(example_batch):
    """Apply val_transforms across a batch."""
    example_batch["pixel_values"] = [val_transforms(image.convert("RGB")) for image in example_batch["image"]]
    return example_batch

In [None]:
# split up training into training + validation
# splits = dataset["train"].train_test_split(test_size=0.1)
# train_ds = splits['train']
# val_ds = splits['test']

In [None]:
dataset

In [None]:
train_ds = dataset['train']
val_ds = dataset['validation']

In [None]:
train_ds.set_transform(preprocess_train)
val_ds.set_transform(preprocess_val)

In [None]:
model = AutoModelForImageClassification.from_pretrained(
    model_checkpoint,
    label2id=label2id,
    id2label=id2label,
    ignore_mismatched_sizes = True, # provide this in case you're planning to fine-tune an already fine-tuned checkpoint
)

In [None]:
_ = report_params(model)

In [None]:
model

In [None]:
random_seed = 42
lora_r = 32
lora_alpha = 32
lora_dropout = 0.
target_modules = ['query', 'key', 'value', 'dense']

In [None]:
run_lora_mapping = {}
run_lora_collection = RunLoRACollection()
for module_name, module in model.named_modules():
    if isinstance(module, nn.Linear) and any(trgt in module_name for trgt in target_modules):
        # modify this part if optimal forward and backward functions are not the same for all layers
        run_lora_mapping[module_name] = run_lora_collection[('forward2', 'backward5')]

In [None]:
torch.manual_seed(random_seed)

runlora_model = RunLoRAModel(model,
                     run_lora_mapping,
                     lora_r=lora_r,
                     lora_alpha=lora_alpha,
                     lora_dropout=lora_dropout,
                     lora_dtype=torch.float,
                     target_modules=target_modules)
runlora_model.prepare_for_finetuning(modules_to_save=['classifier'])

In [None]:
run_lora_mapping

In [None]:
len(run_lora_mapping)

In [None]:
model

In [None]:
# model.base_model.model.classifier.weight.requires_grad, id(model.base_model.model.classifier.weight)

In [None]:
# model.base_model.model.classifier.original_module.weight.requires_grad, \
# model.base_model.model.classifier.modules_to_save.default.weight.requires_grad

In [None]:
# id(model.base_model.model.classifier.original_module.weight), id(model.base_model.model.classifier.modules_to_save.default.weight)

In [None]:
_ = report_params(model)

In [None]:
# from transformers.models.vit.modeling_vit import ViTSelfAttention, ViTSelfOutput, ViTIntermediate, ViTOutput, ViTEmbeddings
# from transformers.activations import GELUActivation
# from runlora.modeling import RunLoRALinear
# from functools import partial

# def report_hook(idx, module, input, output):
#     if isinstance(input, tuple):
#         print(idx, input[0].shape)
#         print(input[0].dtype)
#     else:
#         print(idx, input.shape)
#         print(input.dtype)
#     if isinstance(output, tuple):
#         print(idx, output[0].shape)
#         print(output[0].dtype)
#     else:
#         print(idx, output.shape)
#         print(output.dtype)
#     print()

# def hook_model(model, hook_func, target_classes):

#     handles = []
#     j = 0
#     for module in model.modules():
#         if isinstance(module, target_classes):
#         # if isinstance(module, (ViTEmbeddings)):
#             handle = module.register_forward_hook(partial(hook_func, j))
#             handles.append(handle)
#             j+=1
    
#     return handles

In [None]:
# model = model.to(torch.half)

In [None]:
# # target_classes = (RunLoRALinear, GELUActivation)
# target_classes = (nn.Linear)
# handles = hook_model(model, report_hook, target_classes)

In [None]:
batch_size = 100
model_name = model_checkpoint.split('/')[-1]
training_arguments = TrainingArguments(
    output_dir=f"./checkpoints/{model_name}_RunLoRA_r{lora_r}b{batch_size}_fp32/",
    remove_unused_columns=False,
    # evaluation_strategy="epoch", # uncomment when fine-tuning
    # save_strategy="epoch", # uncomment when fine-tuning
    learning_rate=5e-3,
    per_device_train_batch_size=batch_size,
    # per_device_eval_batch_size=batch_size, # uncomment when fine-tuning
    gradient_accumulation_steps=4,
    num_train_epochs=8,
    logging_steps=10,
    # load_best_model_at_end=True, # uncomment when fine-tuning
    # metric_for_best_model="accuracy", # uncomment when fine-tuning
    # label_names=["labels"], # uncomment when fine-tuning
    max_steps=10 # for testing, comment when fine-tuning
)

In [None]:
trainer = Trainer(
    model,
    training_arguments,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    tokenizer=image_processor,
    compute_metrics=compute_metrics,
    data_collator=collate_fn,
)

In [None]:
train_results = trainer.train()

In [None]:
# trainer.train(resume_from_checkpoint=True)

In [None]:
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()