In [1]:
import os
import argparse
import torch
import random
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as T
import torch.nn.functional as F
from torch.utils.data import random_split
from torchvision import datasets, transforms
from transformers import ViTImageProcessor, ViTForImageClassification
from datasets import load_dataset
from encoder_utils import build_faiss_index, predict_with_faiss, compute_topk_accuracy, CLIPClassifier
from tqdm import tqdm
from dotenv import load_dotenv
from gym import Trainer
import wandb
import clip
import lora_clip

  from .autonotebook import tqdm as notebook_tqdm
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mandriy-suh[0m ([33mandriy-suh-private[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /Users/asukh/.netrc


In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
clip_vit16_model, clip_preprocess = clip.load("ViT-B/16", device=device)

In [3]:
path_data = "./data"
os.makedirs(path_data, exist_ok=True)

In [4]:
cifar100_dataset_train = datasets.CIFAR100(root=path_data, train=True, download=True, transform=clip_preprocess)
cifar100_dataset_test = datasets.CIFAR100(root=path_data, train=False, download=True, transform=clip_preprocess)

train_size = int(len(cifar100_dataset_train) * 0.8)
val_size = len(cifar100_dataset_train) - train_size
cifar100_train_x, cifar100_val_x = random_split(cifar100_dataset_train, [train_size, val_size])

cifar100_loader_train = DataLoader(cifar100_train_x, batch_size=128, shuffle=True)
cifar100_loader_val = DataLoader(cifar100_val_x, batch_size=128, shuffle=False)
cifar100_loader_test = DataLoader(cifar100_dataset_test, batch_size=128, shuffle=False)

## ViT B/16: Zero-shot prediction

In [7]:
faiss_labels, faiss_index = build_faiss_index(
    dataloader=cifar100_loader_train,
    model=clip_vit16_model,
    device=device
)

Building FAISS Index: 100%|██████████| 625/625 [15:01<00:00,  1.44s/it]


FAISS index built with 40000 entries.


In [8]:
ground_truth, predictions = predict_with_faiss(
    dataloader=cifar100_loader_test,
    model=clip_vit16_model,
    faiss_index=faiss_index,
    faiss_labels=faiss_labels,
    device=device,
    top_k=5,
    distractor_classes=None
)

Predicting with FAISS: 100%|██████████| 157/157 [06:10<00:00,  2.36s/it]


In [9]:
accuracy_top1 = compute_topk_accuracy(ground_truth, predictions, top_k=1)
print('top 1 accuracy', accuracy_top1)

top 1 accuracy 0.6912


In [10]:
accuracy_top2 = compute_topk_accuracy(ground_truth, predictions, top_k=2)
print('top 2 accuracy', accuracy_top2)

top 2 accuracy 0.8046


## ViT B/16: Train LoRA Adapter R=4

In [3]:
# Setup the model
r = 4
clip_vit16_lora_r4 = lora_clip.build_LoRA_model(clip_vit16_model.state_dict(), r, 'vision').to(device)
clip_vit16_lora_r4_cls = CLIPClassifier(clip_vit16_lora_r4, fine_tune=True).float().to(device)
lora_clip.print_trainable_parameters(clip_vit16_lora_r4_cls)

Model loaded
Unexpected keys: ['lora_text_projection']
 
trainable params: 1448548 || all params: 150524005 || trainable%: 0.9623368711189952


In [None]:
# Setup the trainer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(clip_vit16_lora_r4_cls.parameters(), lr=1e-4)

trainer = Trainer(
    model=clip_vit16_lora_r4_cls,
    train_loader=cifar100_loader_train,
    val_loader=cifar100_loader_val,
    criterion=criterion,
    optimizer=optimizer,
    device=device,
    num_epochs=10,
    hf_repo_name="ansu0122/vit-lora",
    experiment_name="clip_vit16_lora_r4",
    project_name="lora-project"
)

In [None]:
# train the model
trainer.train()

In [None]:
wandb.finish()

## ViT B/16: Train LoRA Adapter R=16

In [None]:
# Setup the model
r = 8
clip_vit16_lora_r4 = lora_clip.build_LoRA_model(clip_vit16_model.state_dict(), r, 'vision').to(device)
clip_vit16_lora_r4_cls = CLIPClassifier(clip_vit16_lora_r4, fine_tune=True).float().to(device)
lora_clip.print_trainable_parameters(clip_vit16_lora_r4_cls)

In [None]:
# Setup the trainer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(clip_vit16_lora_r4_cls.parameters(), lr=1e-4)

trainer = Trainer(
    model=clip_vit16_lora_r4_cls,
    train_loader=cifar100_loader_train,
    val_loader=cifar100_loader_val,
    criterion=criterion,
    optimizer=optimizer,
    device=device,
    num_epochs=10,
    hf_repo_name="ansu0122/vit-lora",
    experiment_name="clip_vit16_lora_r4",
    project_name="lora-project"
)

In [None]:
# train the model
trainer.train()

In [None]:
wandb.finish()

## Release Memory

In [None]:
import torch
del clip_vit16_lora_r4_cls
del trainer
torch.cuda.empty_cache()

In [40]:
import torch
import gc
gc.collect()
torch.cuda.empty_cache()

In [41]:
for var in dir():
    if isinstance(eval(var), torch.Tensor):
        del globals()[var]
torch.cuda.empty_cache()

In [16]:
!nvidia-smi

Sun Feb 16 23:24:41 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.127.05             Driver Version: 550.127.05     CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 3090        On  |   00000000:62:00.0 Off |                  N/A |
| 30%   44C    P8             24W /  350W |    1548MiB /  24576MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [3]:
print(f"Memory allocated: {torch.cuda.memory_allocated() / 1024 / 1024:.2f} MB")
print(f"Memory reserved: {torch.cuda.memory_reserved() / 1024 / 1024:.2f} MB")

Memory allocated: 0.00 MB
Memory reserved: 0.00 MB
