In [None]:
from datasets import load_dataset 
from datasets import load_metric

from transformers import AutoImageProcessor
from transformers import AutoModelForImageClassification, TrainingArguments, Trainer

from torchvision.transforms import (
    CenterCrop,
    Compose,
    Normalize,
    RandomHorizontalFlip,
    RandomResizedCrop,
    Resize,
    ToTensor,
)

from peft import LoraConfig, get_peft_model
from runlora.modeling import RunLoRAModel
from runlora import RunLoRACollection

import pandas as pd
import numpy as np
import torch
import torch.utils.benchmark as benchmark
import matplotlib.pyplot as plt
import gc
from functools import partial
import logging
logging.basicConfig(level=logging.INFO, filemode='w', format='%(asctime)s - %(levelname)s - %(message)s')

In [None]:
device = torch.device("cuda")
torch.set_default_device(device)

In [None]:
from transformers.models.vit.modeling_vit import ViTSelfAttention, ViTSelfOutput, ViTIntermediate, ViTOutput, ViTEmbeddings
from torch.nn import Linear
from runlora.modeling import RunLoRALinear

def report_hook(idx, module, input, output):
    if isinstance(input, tuple):
        print(idx, input[0].shape)
        print(input[0].dtype)
    else:
        print(idx, input.shape)
        print(input.dtype)
    if isinstance(output, tuple):
        print(idx, output[0].shape)
        print(output[0].dtype)
    else:
        print(idx, output.shape)
        print(output.dtype)
    print()

def hook_model(model, hook_func, target_classes):

    handles = []
    j = 0
    for module in model.modules():
        if isinstance(module, target_classes):
        # if isinstance(module, (ViTEmbeddings)):
            handle = module.register_forward_hook(partial(hook_func, j))
            handles.append(handle)
            j+=1
    
    with torch.no_grad():
        _ = model(**input_batch)
    
    for handle in handles:
        handle.remove()

In [None]:
min_run_time = 40
type_string = 'fp32'
dtype = torch.float

In [None]:
def reset_memory(reset_stats=True):
    gc.collect()
    torch.cuda.empty_cache()
    if reset_stats:
        torch.cuda.reset_peak_memory_stats()

def bench_model(model, input_batch, min_run_time):
    reset_memory()

    bench = benchmark.Timer(
        stmt='model(**input_batch).loss.backward()',
        globals={'input_batch': input_batch, 'model': model})

    # warmup
    warmup_measure = bench.blocked_autorange(min_run_time=min_run_time)
    assert len(warmup_measure.times) >= 10, \
        'Number of measurements is less than 10, increase min_run_time!'
    
    reset_memory()
    max_mem_prev = torch.cuda.max_memory_allocated()
    max_res_prev = torch.cuda.max_memory_reserved()

    # benchmarking
    measure = bench.blocked_autorange(min_run_time=min_run_time)
    logging.info("Computing mean with {} measurments, {} runs per measurment".format(
        len(measure.times), measure.number_per_run))

    max_mem = torch.cuda.max_memory_allocated()
    max_res = torch.cuda.max_memory_reserved()

    del bench
    reset_memory()

    logging.info("Mean time: {} us".format(measure.mean * 1000000))
    logging.info("Max Allocated Overhead: {} MB".format((max_mem - max_mem_prev) / 2**20))
    logging.info("Max Reserved Overhead:{} MB".format((max_res - max_res_prev) / 2**20))
    logging.info("")

    return {'mean_time_us': measure.mean * 1000000,
            'max_mem_overhead_MB': (max_mem - max_mem_prev) / 2**20,
            'max_mem_res_overhead_MB': (max_res - max_res_prev) / 2**20,
            'msrs/runs': f'{len(measure.times)}/{measure.number_per_run}',
           }

In [None]:
def report_params(model):
    params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f'Params: {params}, Trainable Params: {trainable_params}')
    return params, trainable_params

In [None]:
dataset = load_dataset("food101", split="validation[:1000]")

In [None]:
dataset

In [None]:
dataset[1]['image'].resize((200, 200))

In [None]:
labels = dataset.features["label"].names
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
    label2id[label] = i
    id2label[i] = label

id2label[2]

In [None]:
model_checkpoint = "google/vit-base-patch16-224-in21k"
image_processor = AutoImageProcessor.from_pretrained(model_checkpoint, cache_dir=cache_dir)

In [None]:
# the compute_metrics function takes a Named Tuple as input:
# predictions, which are the logits of the model as Numpy arrays,
# and label_ids, which are the ground-truth labels as Numpy arrays.
def compute_metrics(eval_pred):
    """Computes accuracy on a batch of predictions"""
    predictions = np.argmax(eval_pred.predictions, axis=1)
    return metric.compute(predictions=predictions, references=eval_pred.label_ids)

def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example["label"] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}

In [None]:
normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
if "height" in image_processor.size:
    size = (image_processor.size["height"], image_processor.size["width"])
    crop_size = size
    max_size = None
elif "shortest_edge" in image_processor.size:
    size = image_processor.size["shortest_edge"]
    crop_size = (size, size)
    max_size = image_processor.size.get("longest_edge")

val_transforms = Compose(
        [
            Resize(size),
            CenterCrop(crop_size),
            ToTensor(),
            normalize,
        ]
    )

def preprocess_val(example_batch):
    """Apply val_transforms across a batch."""
    example_batch["pixel_values"] = [val_transforms(image.convert("RGB")) for image in example_batch["image"]]
    return example_batch

In [None]:
dataset.set_transform(preprocess_val)

In [None]:
batch_size = 100
labels = torch.tensor(dataset[:batch_size]['label'])
# pixel_values = torch.stack(dataset[:batch_size]['pixel_values']).cuda()
pixel_values = torch.stack(dataset[:batch_size]['pixel_values']).to(dtype).cuda()
input_batch = {'pixel_values': pixel_values, 'labels': labels}

In [None]:
pixel_values.dtype

In [None]:
reset_memory()
model = AutoModelForImageClassification.from_pretrained(
    model_checkpoint,
    label2id=label2id,
    id2label=id2label,
    ignore_mismatched_sizes = True, # provide this in case you're planning to fine-tune an already fine-tuned checkpoint
)
model = model.to(dtype)

In [None]:
for name, param in model.named_parameters():
    print(name, param.dtype)

In [None]:
# target_classes = (ViTEmbeddings, ViTSelfAttention, ViTSelfOutput, ViTIntermediate, ViTOutput)
target_classes = (Linear)
hook_model(model, report_hook, target_classes)

In [None]:
_ = report_params(model)

In [None]:
rows = []
reset_memory()
stats = bench_model(model, input_batch, min_run_time=min_run_time)
rows.append(stats)

In [None]:
stats

In [None]:
del model
reset_memory()

In [None]:
logging.info(f'Allocated GPU Memory: {torch.cuda.memory_allocated() / 2**20}MB')

In [None]:
model = AutoModelForImageClassification.from_pretrained(
    model_checkpoint,
    label2id=label2id,
    id2label=id2label,
    ignore_mismatched_sizes = True, # provide this in case you're planning to fine-tune an already fine-tuned checkpoint
)

In [None]:
random_seed = 42
lora_r = 32
lora_alpha = 32
lora_dropout = 0.
target_modules = ['query', 'key', 'value', 'dense']

In [None]:
assert 768 * 768 > 2 * 768 * lora_r

In [None]:
torch.manual_seed(random_seed)

lora_config = LoraConfig(
    r=lora_r,
    lora_alpha=lora_alpha,
    target_modules=target_modules,
    lora_dropout=lora_dropout,
    bias="none",
    modules_to_save=["classifier"],
)
model = get_peft_model(model, lora_config)
model = model.to(dtype)

In [None]:
model

In [None]:
for name, param in model.named_parameters():
    print(name, param.dtype)

In [None]:
_, lora_tr_params = report_params(model)

In [None]:
# target_classes = (ViTEmbeddings, ViTSelfAttention, ViTSelfOutput, ViTIntermediate, ViTOutput)
target_classes = (Linear)
hook_model(model, report_hook, target_classes)

In [None]:
reset_memory()
stats = bench_model(model, input_batch, min_run_time=min_run_time)
rows.append({**stats, 'lora': 'peft'})

In [None]:
stats

In [None]:
del model
reset_memory()

In [None]:
logging.info(f'Allocated GPU Memory: {torch.cuda.memory_allocated() / 2**20}MB')

In [None]:
model = AutoModelForImageClassification.from_pretrained(
    model_checkpoint,
    label2id=label2id,
    id2label=id2label,
    ignore_mismatched_sizes = True, # provide this in case you're planning to fine-tune an already fine-tuned checkpoint
)

In [None]:
run_lora_collection = RunLoRACollection(min_run_time=min_run_time/2)
run_lora_mapping = \
    run_lora_collection.optimize_for_model(
        model,
        n_batch=batch_size,
        lora_r=lora_r,
        target_modules=target_modules,
        criterions=['flops'],
        # 224 x 224 -> conv kernel_size=(16, 16), stride=(16, 16)
        sequence_length=197)

In [None]:
del model, run_lora_collection
reset_memory()

In [None]:
model = AutoModelForImageClassification.from_pretrained(
    model_checkpoint,
    label2id=label2id,
    id2label=id2label,
    ignore_mismatched_sizes = True, # provide this in case you're planning to fine-tune an already fine-tuned checkpoint
)
model = model.to(dtype)
model = RunLoRAModel(model,
                     run_lora_mapping['flops'],
                     lora_r=lora_r,
                     lora_alpha=lora_alpha,
                     lora_dtype=dtype,
                     target_modules=target_modules)

In [None]:
run_lora_mapping['flops']

In [None]:
# Every parameter except for lora adapters is set to requires_grad=False
model.prepare_for_finetuning(modules_to_save=['classifier'])

In [None]:
for name, param in model.named_parameters():
    print(name, param.dtype)

In [None]:
_, runlora_tr_params = report_params(model)

In [None]:
target_classes = (Linear, RunLoRALinear)
# target_classes = (ViTEmbeddings, ViTSelfAttention, ViTSelfOutput, ViTIntermediate, ViTOutput)
hook_model(model, report_hook, target_classes)

In [None]:
model

In [None]:
assert lora_tr_params == runlora_tr_params

In [None]:
reset_memory()
stats = bench_model(model, input_batch, min_run_time=min_run_time)
rows.append({**stats, 'lora': 'runlora'})

In [None]:
del model
logging.info(f'Max GPU Memory Reserved: {torch.cuda.max_memory_reserved() / 2**20} MB')
reset_memory()

In [None]:
df = pd.DataFrame.from_records(rows)
df.sort_values(['mean_time_us', 'max_mem_overhead_MB'],
               ascending=[True, True], inplace=True)
df

In [None]:
1 - df.mean_time_us.iloc[0] / df.mean_time_us.iloc[1], 1 - df.mean_time_us.iloc[0] / df.mean_time_us.iloc[2]

In [None]:
model_name = model_checkpoint.split('/')[-1].split('-patch')[0]
df.to_csv(f'{model_name}_r{lora_r}b{batch_size}_{type_string}'+'.csv')