In [1]:
import json
from PIL import Image

from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

import clip
from transformers import CLIPProcessor, CLIPModel

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
json_path = '../data/data_train.json'
image_path = '../data/images/train/'

with open(json_path, 'r') as f:
    input_data = []
    for line in f:
        obj = json.loads(line)
        input_data.append(obj)

In [2]:
# # Load the CLIP model and processor
# model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
# processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

device = "cuda:0" if torch.cuda.is_available() else "cpu" 

# Load pre-trained CLIP model
model, preprocess = clip.load("ViT-B/32", device=device, jit=False)

In [3]:
class image_label_dataset():
    def __init__(self, list_image_path,list_txt):
        # Initialize image paths and corresponding texts
        self.image_path = list_image_path
        # Tokenize text using CLIP's tokenizer
        self.label  = clip.tokenize(list_txt)

    def __len__(self):
        return len(self.label)

    def __getitem__(self, idx):
        # Preprocess image using CLIP's preprocessing function
        image = preprocess(Image.open(self.image_path[idx]))
        label = self.label[idx]
        return image, label

In [4]:
list_image_path = []
list_txt = []
for item in input_data:
  img_path = image_path + item['image_path'].split('/')[-1]
  label = item['label']
  list_image_path.append(img_path)
  list_txt.append(label)

In [5]:
train_dataset = image_label_dataset(list_image_path, list_txt)
train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True) 

In [6]:
# Function to convert model's parameters to FP32 format
def convert_models_to_fp32(model): 
    for p in model.parameters(): 
        p.data = p.data.float() 
        p.grad.data = p.grad.data.float() 

In [7]:
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5, betas=(0.9,0.98),eps=1e-6,weight_decay=0.2) # the lr is smaller, more safe for fine tuning to new dataset

# Specify the loss function
loss_img = nn.CrossEntropyLoss()
loss_txt = nn.CrossEntropyLoss()


In [8]:
torch.cuda.empty_cache()

# Train the model
num_epochs = 30
for epoch in range(num_epochs):
    pbar = tqdm(train_dataloader, total=len(train_dataloader))
    for batch in pbar:
        optimizer.zero_grad()

        images,texts = batch 
        
        images= images.to(device)
        texts = texts.to(device)

        # Forward pass
        logits_per_image, logits_per_text = model(images, texts)

        # Compute loss
        ground_truth = torch.arange(len(images),dtype=torch.long,device=device)
        total_loss = (loss_img(logits_per_image,ground_truth) + loss_txt(logits_per_text,ground_truth))/2

        # Backward pass
        total_loss.backward()
        if device == "cpu":
            optimizer.step()
        else : 
            convert_models_to_fp32(model)
            optimizer.step()
            clip.model.convert_weights(model)

        pbar.set_description(f"Epoch {epoch}/{num_epochs}, Loss: {total_loss.item():.4f}")

Epoch 0/30, Loss: 0.0000: 100%|██████████| 170/170 [00:29<00:00,  5.80it/s]
Epoch 1/30, Loss: 0.0000: 100%|██████████| 170/170 [00:28<00:00,  5.88it/s]
Epoch 2/30, Loss: 0.0000: 100%|██████████| 170/170 [00:29<00:00,  5.86it/s]
Epoch 3/30, Loss: 0.0000: 100%|██████████| 170/170 [00:28<00:00,  5.89it/s]
Epoch 4/30, Loss: 0.0000: 100%|██████████| 170/170 [00:29<00:00,  5.83it/s]
Epoch 5/30, Loss: 0.0000: 100%|██████████| 170/170 [00:29<00:00,  5.86it/s]
Epoch 6/30, Loss: 0.0000: 100%|██████████| 170/170 [00:28<00:00,  5.87it/s]
Epoch 7/30, Loss: 0.0000: 100%|██████████| 170/170 [00:29<00:00,  5.73it/s]
Epoch 8/30, Loss: 0.0000: 100%|██████████| 170/170 [00:29<00:00,  5.84it/s]
Epoch 9/30, Loss: 0.0000: 100%|██████████| 170/170 [00:29<00:00,  5.79it/s]
Epoch 10/30, Loss: 0.0000: 100%|██████████| 170/170 [00:29<00:00,  5.84it/s]
Epoch 11/30, Loss: 0.0000: 100%|██████████| 170/170 [00:29<00:00,  5.81it/s]
Epoch 12/30, Loss: 0.0000: 100%|██████████| 170/170 [00:29<00:00,  5.77it/s]
Epoch 13/

In [9]:
import torch

# Save the model weights
torch.save(model.state_dict(), "vit_model.pth")