In [None]:
import transformers
import accelerate
import peft

print(f"Transformers version: {transformers.__version__}")
print(f"Accelerate version: {accelerate.__version__}")
print(f"PEFT version: {peft.__version__}")

In [None]:
model_checkpoint = "google/vit-base-patch16-224-in21k"

In [None]:
from transformers import AutoImageProcessor

image_processor = AutoImageProcessor.from_pretrained(model_checkpoint)

In [None]:
from torchvision.transforms import (
    CenterCrop,
    Compose,
    Normalize,
    RandomHorizontalFlip,
    RandomResizedCrop,
    Resize,
    ToTensor,
)

ImageNet_mean = [0.485, 0.456, 0.406]
ImageNet_std = [0.229, 0.224, 0.225]
normalize = Normalize(mean=ImageNet_mean, std=ImageNet_std)
train_transforms = Compose(
    [
        RandomResizedCrop(224),
        RandomHorizontalFlip(),
        ToTensor(),
        normalize,
    ]
)

val_transforms = Compose(
    [
        # Resize(image_processor.size["height"]),
        Resize(224),
        CenterCrop(224),
        ToTensor(),
        normalize,
    ]
)


def preprocess_train(example_batch):
    """Apply train_transforms across a batch."""
    example_batch["pixel_values"] = [train_transforms(image.convert("RGB")) for image in example_batch["image"]]
    return example_batch


def preprocess_val(example_batch):
    """Apply val_transforms across a batch."""
    example_batch["pixel_values"] = [val_transforms(image.convert("RGB")) for image in example_batch["image"]]
    return example_batch

In [None]:
from loaders.ImageData import ImageDataset
data_dir = "/bucket/npss/CottonPestClassification_v3a_os_lora/"

train_loader = ImageDataset(data_dir, split='train', transform=train_transforms)
val_loader = ImageDataset(data_dir, split='val', transform=val_transforms)
# test_loader = ImageDataset(data_dir, split='reporting', transform=val_transforms)

label2id = train_loader.labels2id
id2label = train_loader.id2label

In [None]:
# from timm.data import create_dataset

# dataset_train = create_dataset(
#         '',
#         root=data_dir,
#         split='train',
#         is_training=True,
#         batch_size=128,
#         seed=42,
#     )

In [None]:
def print_trainable_parameters(model):
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param:.2f}"
    )

In [None]:
from transformers import AutoModelForImageClassification, TrainingArguments, Trainer

model_og = AutoModelForImageClassification.from_pretrained(
    model_checkpoint,
    label2id=label2id,
    id2label=id2label,
    ignore_mismatched_sizes=True,  # provide this in case you're planning to fine-tune an already fine-tuned checkpoint
)

In [None]:
from timm.models import create_model, load_checkpoint
import timm
model = timm.create_model('vit_base_patch16_224.orig_in21k', 
                                pretrained=True, num_classes=3)


load_checkpoint(model, '/home/ashishpapanai/timm-classify/output/train/vit_base_patch16_224.orig_in21k-timm-050524-OS/model_best.pth.tar',
                strict=False)

In [None]:
print_trainable_parameters(model)

In [None]:
from peft import LoraConfig, get_peft_model

config = LoraConfig(
    r=16,
    lora_alpha=16,
    target_modules=["qkv"],
    lora_dropout=0.1,
    bias="none",
    modules_to_save=["classifier"],
)
lora_model = get_peft_model(model, config)
print_trainable_parameters(lora_model)

In [None]:
from transformers import TrainingArguments, Trainer


model_name = model_checkpoint.split("/")[-1]
batch_size = 128

args = TrainingArguments(
    f"{model_name}-finetuned-lora-food101",
    remove_unused_columns=False,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-3,
    per_device_train_batch_size=batch_size,
    gradient_accumulation_steps=4,
    per_device_eval_batch_size=batch_size,
    fp16=True,
    num_train_epochs=5,
    logging_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    push_to_hub=True,
    label_names=["labels"],
    report_to="none"
)

In [None]:
import numpy as np
import evaluate

metric = evaluate.load("accuracy")


def compute_metrics(eval_pred):
    """Computes accuracy on a batch of predictions"""
    predictions = np.argmax(eval_pred.predictions, axis=1)
    return metric.compute(predictions=predictions, references=eval_pred.label_ids)

In [None]:
import torch

def collate_fn(examples):
    pixel_values = torch.stack([example["image"] for example in examples])
    labels = torch.tensor([example["labels"] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}

In [None]:
trainer = Trainer(
    lora_model,
    # model_og,
    args,
    train_dataset=train_loader,
    eval_dataset=val_loader,
    # tokenizer=image_processor,
    compute_metrics=compute_metrics,
    data_collator=collate_fn,
)
train_results = trainer.train()

In [None]:
repo_name = f"ashishp-wiai/{model_name}-finetuned-lora-CottonPestClassification_v3a_npss"
lora_model.push_to_hub(repo_name)