## Installs

In [None]:
!pip install -U torch torchvision 
!pip install transformers datasets tqdm ipywidgets hf_transfer

In [None]:
import torch
from torch.utils.data import DataLoader
from transformers import CLIPModel
from datasets import load_from_disk
from tqdm.notebook import tqdm
import copy
import os

## Instantiate

In [None]:
device = "cuda"

reference = CLIPModel.from_pretrained("openai/clip-vit-base-patch32", dtype=torch.float32)
reference.to(device)

dataset_name = ["DTD", "EuroSAT", "GTSRB", "MNIST", "RESISC45", "Stanford_Cars", "SUN397", "SVHN"]

In [None]:
EPOCHS = 500

patience = 30

## Fine-Tune Loop

In [None]:
for name in dataset_name:
    train_dataset = load_from_disk(f"data/{name}/train")
    val_dataset = load_from_disk(f"data/{name}/val")

    train_dataset.set_format(type="torch", columns=["pixel_values", "input_ids", "attention_mask"])
    val_dataset.set_format(type="torch", columns=["pixel_values", "input_ids", "attention_mask"])

    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=12, pin_memory=True, persistent_workers=True)
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=12, pin_memory=True, persistent_workers=True)

    os.makedirs(f"fine_tuned/{name}", exist_ok=True)
    for indice in range(5):
        model = copy.deepcopy(reference)

        for param in model.parameters():
            param.requires_grad = False
        for param in model.vision_model.parameters():
            param.requires_grad = True

        optimizer = torch.optim.AdamW(model.vision_model.parameters(), lr=1e-5, weight_decay=0.01)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS) # remove if not helping

        loss_stop_counter = 0
        best_val_loss = float("inf")

        for epoch in range(EPOCHS):
            train_loss = 0
            model.train()
            train_progress = tqdm(train_loader, total=len(train_loader))

            for batch in train_progress:
                optimizer.zero_grad()

                pixel_values = batch["pixel_values"].to(device, non_blocking=True)
                input_ids = batch["input_ids"].to(device, non_blocking=True)
                attention_mask = batch["attention_mask"].to(device, non_blocking=True)

                outputs = model(pixel_values=pixel_values, input_ids=input_ids, attention_mask=attention_mask, return_loss=True)
                loss = outputs.loss
                
                loss.backward()
                optimizer.step()

                train_loss += loss.item()
            
            avg_train_loss = train_loss  / len(train_loader)
            
            val_loss = 0
            model.eval()
            val_progress = tqdm(val_loader, total=len(val_loader))
            with torch.no_grad():
                for batch in val_progress:
                    pixel_values = batch["pixel_values"].to(device, non_blocking=True)
                    input_ids = batch["input_ids"].to(device, non_blocking=True)
                    attention_mask = batch["attention_mask"].to(device, non_blocking=True)

                    outputs = model(pixel_values=pixel_values, input_ids=input_ids, attention_mask=attention_mask, return_loss=True)
                    loss = outputs.loss
                    
                    val_loss += loss.item()

            avg_val_loss = val_loss / len(val_loader)

            scheduler.step() # remove if not helping
            
            print(f"EPOCH: {epoch}/{EPOCHS} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")

            if avg_val_loss < best_val_loss:
                best_val_loss = avg_val_loss
                loss_stop_counter = 0
                torch.save(model.state_dict(), f"fine_tuned/{name}/CLIP_{name}_{indice}.pt")
                print(f"Current Best EPOCH: {epoch} | Loss: {best_val_loss}")
            else:
                loss_stop_counter += 1
                print(f"Best EPOCH: {epoch - loss_stop_counter} | Loss: {best_val_loss}")

            if loss_stop_counter >= patience:
                print(f"Early stoppage on epoch {epoch}!")
                break