In [2]:
import os
import json
from PIL import Image
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import clip
from transformers import CLIPProcessor, CLIPModel
from torchvision import transforms
from sklearn.metrics import accuracy_score, confusion_matrix
import numpy as np
from torchvision.datasets import ImageFolder
import matplotlib.pyplot as plt

In [3]:
#list containing the image path & class label of all train data images
json_path_train = '/Users/asmitasengupta/finproj/image_data_train.json'
json_path_val = '/Users/asmitasengupta/finproj/image_data_val.json'
#image_path = '/Users/asmitasengupta/Downloads/Birds_25/train'

input_data = []
val_data = []
with open(json_path_train, 'r') as json_file:
    input_data = json.load(json_file)

with open(json_path_val, 'r') as json_file:
    val_data = json.load(json_file)   


In [4]:
# Load the CLIP model and processor
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

In [5]:
# Choose computation device
device = "cuda:0" if torch.cuda.is_available() else "cpu" 

In [6]:
# Load pre-trained CLIP model
model, preprocess = clip.load("ViT-B/32", device=device, jit=False)

In [7]:
# Define a custom dataset
class image_title_dataset():
		def __init__(self, list_image_path, list_txt):
			
			self.image_path = list_image_path  # Initialize image paths and corresponding texts
			self.title = clip.tokenize(list_txt)  # Tokenize text using CLIP's tokenizer
		
		def __len__(self):
			return len(self.title)  # Define the length of the dataset

		def __getitem__(self, idx):     # Get an item from the dataset
			image = preprocess(Image.open(self.image_path[idx]))  # Preprocess image using CLIP's preprocessing function
			title = self.title[idx]
			return image, title

In [8]:
#for training
list_image_path = []
list_txt = []
for item in input_data:
  img_path = item['Image_path']
  caption = item['Class_label']
  list_image_path.append(img_path)
  list_txt.append(caption)

dataset = image_title_dataset(list_image_path, list_txt)
train_dataloader = DataLoader(dataset, batch_size=50, shuffle=True)

#for validation
list_image_path_v = []
list_txt_v = []
for item in val_data:
  img_path = item['Image_path']
  caption = item['Class_label']
  list_image_path_v.append(img_path)
  list_txt_v.append(caption)

val_dataset = image_title_dataset(list_image_path_v, list_txt_v)
val_dataloader = DataLoader(val_dataset, batch_size=50, shuffle=False)

In [9]:
# Function to convert model's parameters to FP32 format so that the model loads in the provided memory.
def convert_models_to_fp32(model): 
    for p in model.parameters(): 
        p.data = p.data.float() 
        p.grad.data = p.grad.data.float() 

In [10]:
# Check if the device is set to CPU
if device == "cpu":
    model.float()  # Convert the model's parameters to float if using CPU

# Prepare the optimizer with specific hyperparameters
optimizer = torch.optim.Adam(
    model.parameters(), lr=5e-5, betas=(0.9, 0.98), eps=1e-6 ,weight_decay=0.2) 

# Loss function for images
loss_img = nn.CrossEntropyLoss()

# Loss function for text
loss_txt = nn.CrossEntropyLoss()

In [None]:
##Train the model

num_epochs = 10 # No. of training epochs

for epoch in range(num_epochs):
    pbar = tqdm(train_dataloader, total=len(train_dataloader))
    
    # Iterate through the batches in the training data
    for batch in pbar:
        optimizer.zero_grad()  # Zero out gradients for the optimizer
        images, texts = batch  # Extract images and texts from the batch
        print(device)  # Print the current device (CPU or GPU)
        
        # Move images and texts to the specified device (CPU or GPU)
        images = images.to(device)
        texts = texts.to(device)

        # Forward pass through the model
        logits_per_image, logits_per_text = model(images, texts)

        # Compute the loss
        ground_truth = torch.arange(len(images), dtype=torch.long, device=device)
        total_loss = (loss_img(logits_per_image, ground_truth) + loss_txt(logits_per_text, ground_truth)) / 2

        # Backward pass and update the model's parameters
        total_loss.backward()
        
        # If the device is CPU, directly update the model
        if device == "cpu":
            optimizer.step()
        else:
            # Convert model's parameters to FP32 format, update, and convert back
            convert_models_to_fp32(model)
            optimizer.step()
            clip.model.convert_weights(model)

        # Update the progress bar with the current epoch and loss
        pbar.set_description(f"Epoch {epoch}/{num_epochs}, Loss: {total_loss.item():.4f}")

    
    torch.save(model.state_dict(), filepath = 'models/trained_model.pth')  #saving the trained model

In [None]:
##Evaluating model performance on validation data

# Check if CUDA (GPU) is available, otherwise use CPU
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load the CLIP model and preprocessing pipeline
model, preprocess = clip.load("ViT-B/32", device=device)

# Load the state dictionary of the trained model
#saved_model_state_dict = torch.load("models/trained_model.pth", map_location=device)

# Update the model's state dictionary with the saved state dictionary
#model.load_state_dict(saved_model_state_dict)

#All class labels of Birds_25 dataset
birds = [
"Asian-Green-Bee-Eater", "Common-Rosefinch", "Hoopoe", "Indian-Roller", "Rufous-Treepie",
"Brown-Headed-Barbet", "Common-Tailorbird", "House-Crow", "Jungle-Babbler", "Sarus-Crane",
"Cattle-Egret", "Coppersmith-Barbet", "Indian-Grey-Hornbill", "Northern-Lapwing", "White-Breasted-Kingfisher",
"Common-Kingfisher", "Forest-Wagtail", "Indian-Peacock", "Red-Wattled-Lapwing", "White-Breasted-Waterhen",
"Common-Myna", "Gray-Wagtail", "Indian-Pitta", "Ruddy-Shelduck", "White-Wagtail"]

birds.sort(reverse=True)
#print(birds)

class_dict = {}
for _, b in enumerate(birds):
    class_dict[b]= _

#Index of the input data you want to analyze
index_ = 500
predicted_labels, true_labels=[] , []
for index_ in range(7500):

    image_json = val_data[index_]

    #path to the image file
    image_path = image_json['Image_path']

    #Get the class label of the image
    image_class = image_json['Class_label']

    #Preprocess the image and move it to the appropriate device (CPU/GPU)
    image = preprocess(Image.open(image_path)).unsqueeze(0).to(device)

    #Tokenize and move the bird item names to the appropriate device
    text = torch.cat([clip.tokenize(c) for c in birds]).to(device)

    #Perform inference
    with torch.no_grad():
        #Encode image and text
        image_features = model.encode_image(image)
        text_features = model.encode_text(text)
        
        #Calculate similarity scores between image and text
        logits_per_image, logits_per_text = model(image, text)
        probs = logits_per_image.softmax(dim=-1).cpu().numpy()

    #Normalize image and text features
    image_features /= image_features.norm(dim=-1, keepdim=True)
    text_features /= text_features.norm(dim=-1, keepdim=True)

    #Calculate similarity scores
    similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
    values, indices = similarity[0].topk(1)

    true_labels.append(class_dict[image_class])
    predicted_labels.append(indices)
    
# Compute evaluation metrics
accuracy = accuracy_score(true_labels, predicted_labels)
conf_matrix = confusion_matrix(true_labels, predicted_labels)
class_wise_accuracy = np.diag(conf_matrix) / np.sum(conf_matrix, axis=1)

# Print evaluation metrics
print(f"Overall Accuracy: {accuracy:.4f}")
print("Class-wise Accuracy:")
for i, acc in enumerate(class_wise_accuracy):
    print(f"Class {i}: {acc:.4f}")

print("Confusion Matrix:")
print(conf_matrix)