In [1]:
import numpy as np
from PIL import Image
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from sentence_transformers import SentenceTransformer, util

import clip
import open_clip

from Utils.dataset import CustomDataset

  from tqdm.autonotebook import tqdm, trange


In [2]:
batch_size = 16
device = torch.device('cuda')

In [3]:
model, preprocess = clip.load("ViT-B/32", device=device)

In [4]:
train_dataset = CustomDataset('Dataset/train/')
test_dataset = CustomDataset('Dataset/test/')

In [5]:
train_dataloader = DataLoader(train_dataset, 
                              batch_size=batch_size,
                              shuffle=True)
test_dataloader = DataLoader(test_dataset,
                             batch_size=batch_size,
                             shuffle=True)

In [7]:
num_epochs = 10
temperature = 0.1

optimizer = torch.optim.Adam(model.parameters(), 
                             lr = 0.00001,
                             weight_decay=0.2,
                             betas=(0.9, 0.98),
                             eps = 0.001)

criterion = nn.CrossEntropyLoss()

In [13]:
model.train()
for epoch in range(num_epochs):
    pbar = tqdm(train_dataloader, total=len(train_dataloader))
    for batch in pbar:
        optimizer.zero_grad()

        images, texts = batch

        imgs = torch.stack([preprocess(Image.open(image)) for image in images]).to(device)
        texs = clip.tokenize(texts).to(device)
        
        image_emb = model.encode_image(imgs)
        text_emb = model.encode_text(texs)

        logits = torch.matmul(image_emb, text_emb.T) * torch.exp(torch.tensor(temperature))
        ground_truth = torch.arange(len(batch[0]), device=device, dtype=torch.long)

        image_loss = criterion(logits, ground_truth)
        text_loss = criterion(logits.T, ground_truth)
        loss = (image_loss + text_loss) / 2
        
        loss.backward()
        optimizer.step()

        pbar.set_description(f"Epoch {epoch}/{num_epochs}, Loss: {loss.item():.4f}")
    
    torch.save(model, f'Results/RN_Model_{epoch}.pt')

Epoch 0/10, Loss: 0.0022: 100%|██████████| 878/878 [36:40<00:00,  2.51s/it]
Epoch 1/10, Loss: 0.0006: 100%|██████████| 878/878 [36:05<00:00,  2.47s/it]
Epoch 2/10, Loss: 0.8809: 100%|██████████| 878/878 [36:32<00:00,  2.50s/it]
Epoch 3/10, Loss: 0.0166: 100%|██████████| 878/878 [35:21<00:00,  2.42s/it]
Epoch 4/10, Loss: 0.8491: 100%|██████████| 878/878 [35:23<00:00,  2.42s/it]
Epoch 5/10, Loss: 0.0002: 100%|██████████| 878/878 [35:12<00:00,  2.41s/it]
Epoch 6/10, Loss: 0.6509: 100%|██████████| 878/878 [35:09<00:00,  2.40s/it]
Epoch 7/10, Loss: 0.7090: 100%|██████████| 878/878 [35:18<00:00,  2.41s/it]
Epoch 8/10, Loss: 0.1095: 100%|██████████| 878/878 [35:17<00:00,  2.41s/it]
Epoch 9/10, Loss: 0.7095: 100%|██████████| 878/878 [35:09<00:00,  2.40s/it]
