In [1]:
import os
import argparse
import torch
import random
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as T
import torch.nn.functional as F
from torch.utils.data import random_split
import clip
import loraclip
from torchvision import datasets, transforms
from transformers import ViTImageProcessor, ViTForImageClassification
from datasets import load_dataset
from encoder_utils import build_faiss_index, predict_with_faiss, compute_topk_accuracy, CLIPClassifier
from tqdm import tqdm
from dotenv import load_dotenv
from gym import Trainer
import wandb

  from .autonotebook import tqdm as notebook_tqdm
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mandriy-suh[0m ([33mandriy-suh-private[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /Users/asukh/.netrc


In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
clip_vit16_model, clip_preprocess = clip.load("ViT-B/16", device=device)

In [3]:
path_data = "./data"
os.makedirs(path_data, exist_ok=True)

In [4]:
cifar100_dataset_train = datasets.CIFAR100(root=path_data, train=True, download=True, transform=clip_preprocess)
cifar100_dataset_test = datasets.CIFAR100(root=path_data, train=False, download=True, transform=clip_preprocess)

train_size = int(len(cifar100_dataset_train) * 0.8)
val_size = len(cifar100_dataset_train) - train_size
cifar100_train_x, cifar100_val_x = random_split(cifar100_dataset_train, [train_size, val_size])

cifar100_loader_train = DataLoader(cifar100_train_x, batch_size=64, shuffle=True)
cifar100_loader_val = DataLoader(cifar100_val_x, batch_size=64, shuffle=False)
cifar100_loader_test = DataLoader(cifar100_dataset_test, batch_size=64, shuffle=False)
print(len(cifar100_train_x))
print(len(cifar100_val_x))
print(len(cifar100_dataset_test))

Files already downloaded and verified
Files already downloaded and verified
40000
10000
10000


## ViT B/16: Zero-shot prediction

In [None]:
faiss_labels, faiss_index = build_faiss_index(
    dataloader=cifar100_loader_train,
    model=clip_vit16_model,
    device=device
)

Building FAISS Index: 100%|██████████| 782/782 [41:41<00:00,  3.20s/it]

FAISS index built with 50000 entries.





In [None]:
ground_truth, predictions = predict_with_faiss(
    dataloader=cifar100_loader_test,
    model=clip_vit16_model,
    faiss_index=faiss_index,
    faiss_labels=faiss_labels,
    device=device,
    top_k=5,
    distractor_classes=None
)

In [14]:
accuracy_top1 = compute_topk_accuracy(ground_truth, predictions, top_k=1)
print('top 1 accuracy', accuracy_top1)

top 1 accuracy 0.9354


In [15]:
accuracy_top2 = compute_topk_accuracy(ground_truth, predictions, top_k=2)
print('top 2 accuracy', accuracy_top2)

top 2 accuracy 0.9747


## ViT B/16: Train Clip Backbone

In [5]:
clip_vit16_cl_model = CLIPClassifier(clip_vit16_model, fine_tune=True).to(device)

In [6]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(clip_vit16_cl_model.parameters(), lr=1e-4)

trainer = Trainer(
    model=clip_vit16_cl_model,
    train_loader=cifar100_loader_train,
    val_loader=cifar100_loader_val,
    criterion=criterion,
    optimizer=optimizer,
    device=device,
    num_epochs=10,
    hf_repo_name="ansu0122",
    experiment_name="clip_vit16_full",
    project_name="lora-project"
)

In [None]:
trainer.train()

In [8]:
wandb.finish()

0,1
Train/Accuracy,▁█
Train/Loss,█▁

0,1
Train/Accuracy,0.01847
Train/Loss,4.60496
