### Fine-tuning CLIP

#### Imports

In [106]:
from PIL import Image
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms, datasets

import clip
from transformers import CLIPProcessor, CLIPModel, BertTokenizer

#### Configs

In [107]:
class CFG:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    batch_size = 100
    num_epochs = 1
    learning_rate = 5e-5
    weight_decay = 0.0
    base_model = "openai/clip-vit-base-patch32"
    loss_func = nn.CrossEntropyLoss()

#### Load and Prep Data

In [108]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
     transforms.Lambda(lambda x: x / 2 + 0.5)]
)

dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
dataloader = DataLoader(dataset, batch_size = CFG.batch_size, shuffle=True)
class_names = dataset.classes

Files already downloaded and verified


#### Initialize CLIP model

In [109]:
model = CLIPModel.from_pretrained(CFG.base_model)
processor = CLIPProcessor.from_pretrained(CFG.base_model)

#### Train

In [None]:

model.to(CFG.device)
model.train()
optimizer = torch.optim.Adam(model.parameters(), lr = CFG.learning_rate,
                             weight_decay=CFG.weight_decay)
rokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

iter_losses = []
for epoch in range(1, CFG.num_epochs+1):
    
    pbar = tqdm(dataloader, total = len(dataloader))
    for idx, batch in enumerate(pbar):
        optimizer.zero_grad()
        
        images, texts = batch
        texts = [class_names[i] for i in texts.tolist()]
        
        inputs = processor(text=texts, images=images, return_tensors='pt', padding=True)
        inputs.to(CFG.device)
        
        outputs = model(**inputs, return_loss=True)
        loss = outputs.loss

        loss.backward()
        optimizer.step()
        
        iter_losses.append(loss.item())