In [2]:
from tqdm import tqdm

import clip
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import CLIPProcessor, CLIPModel, CLIPTokenizerFast

from Utils.dataset import CustomDataset

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

In [4]:
train_dataset = CustomDataset('Dataset/train/')
test_dataset = CustomDataset('Dataset/test/')

In [5]:
train_dataloader = DataLoader(train_dataset, 
                              batch_size=4,
                              shuffle=True)
test_dataloader = DataLoader(test_dataset,
                             batch_size=4,
                             shuffle=True)

In [None]:
model_name = "openai/clip-vit-base-patch32"

model = CLIPModel.from_pretrained(model_name)
processor = CLIPProcessor.from_pretrained(model_name)
tokenizer = CLIPTokenizerFast.from_pretrained(model_name)

In [6]:
optimizer = torch.optim.Adam(model.parameters(), 
                             lr = 5e-5,
                             betas = (0.9, 0.98),
                             eps = 1e-6,
                             weight_decay=0.2)
img_loss = nn.CrossEntropyLoss()
txt_loss = nn.CrossEntropyLoss()

In [7]:
num_epochs = 5
device = torch.device('cpu')

In [13]:
for epoch in range(num_epochs):
    pbar = tqdm(train_dataloader, total= len(train_dataloader))
    for batch in pbar:
        optimizer.zero_grad()

        images, texts, labels = batch
        texts = torch.cat([clip.tokenize(text) for text in texts])

        images = images.to(device)
        texts = texts.to(device)

        logits_per_image, logits_per_text = model(images, texts)
        
        ground_truth = torch.arange(len(images),dtype=torch.long,device=device)
        total_loss = (img_loss(logits_per_image,ground_truth) + txt_loss(logits_per_text,ground_truth))/2

        total_loss.backward()
        optimizer.step()

        pbar.set_description(f"Epoch {epoch}/{num_epochs}, Loss: {total_loss.item():.4f}")

Epoch 0/5, Loss: 1.3863:   2%|▏         | 208/10500 [14:06<10:53:00,  3.81s/it]