In [1]:
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

import open_clip

from Utils.dataset import CustomDataset

In [2]:
device = torch.device('cpu')

In [4]:
# pretrained also accepts local paths
model, _, preprocess = open_clip.create_model_and_transforms(
    model_name='ViT-B/32',
    pretrained='laion2b_s34b_b79k'
) 

tokenizer = open_clip.get_tokenizer('ViT-B-32')
model.to(device)
None

In [5]:
train_dataset = CustomDataset('Dataset/train/', preprocess)
test_dataset = CustomDataset('Dataset/test/', preprocess)

In [6]:
train_dataloader = DataLoader(train_dataset, 
                              batch_size=4,
                              shuffle=True)
test_dataloader = DataLoader(test_dataset,
                             batch_size=4,
                             shuffle=True)

In [7]:
optimizer = torch.optim.Adam(model.parameters(), 
                             lr = 5e-5,
                             betas = (0.9, 0.98),
                             eps = 1e-6,
                             weight_decay=0.2)
img_loss = nn.CrossEntropyLoss()
txt_loss = nn.CrossEntropyLoss()

In [8]:
num_epochs = 5

In [None]:
for epoch in range(num_epochs):
    pbar = tqdm(train_dataloader, total= len(train_dataloader))
    for batch in pbar:
        optimizer.zero_grad()

        images, texts, labels = batch
        texts = torch.cat([tokenizer(text) for text in texts])

        images = images.to(device)
        texts = texts.to(device)

        logits_per_image, logits_per_text, _ = model(images, texts)
        
        ground_truth = labels.to(device)
        total_loss = (img_loss(logits_per_image,ground_truth) + txt_loss(logits_per_text,ground_truth))/2

        total_loss.backward()
        optimizer.step()

        pbar.set_description(f"Epoch {epoch}/{num_epochs}, Loss: {total_loss.item():.4f}")
    torch.save(model, f'CLIP/Results/RN_Model_{epoch}.pt')