Install Libraries

In [None]:
%%capture

!pip install openai-clip
!pip install datasets
!pip install torch
!pip install tqdm

- Model: openai-clip to define our base CLIP model
- Dataset from Huggin Face
- Torch: Modeling code


**Zero-shot classifcation performance of CLIP**

In [None]:
from dataclasses import dataclass
from typing import List, Optional
import numpy as np

import clip
import torch


@dataclass
class ModelConfig:
    model_name: str
    enable_jit: bool = False

@dataclass
class InferenceConfig:
    model_config: ModelConfig
    labels: List[str]
    top_k: int = 1
    num_of_inf_samples: Optional[int] = None

@dataclass
class EvalConfig:
  inference_config: InferenceConfig
  metric_name: str


@dataclass
class DataConfig:
  dataset_name: str

In [None]:
from datasets import load_dataset
from collections import Counter


class DatasetInfo:
  def __init__(self, dataset_name: str):
    self.dataset_name = dataset_name
    ds = load_dataset(self.dataset_name)
    self.dataset = ds['train']

  def get_dataset(self):
    ds = load_dataset(self.dataset_name)
    self.dataset = ds['train']
    return self.dataset

  def get_labels(self):
    ds = load_dataset(self.dataset_name)
    self.dataset = ds['train']
    self.labels = list(set(self.dataset['subCategory']))
    return self.labels

  def get_dataset_stats(self):
    num_of_samples = len(set(self.dataset['id']))
    print(f"num of samples: {num_of_samples}")
    print(f"masterCatergory {Counter(self.dataset['masterCategory'])}")
    print(f"subCatergory {Counter(self.dataset['subCategory'])}")

  def display(self, idx):
    # Example image
    image = self.dataset[idx]['image']
    display(image)

    # Example data
    print(self.dataset[idx])



In [None]:
class InferenceModel:
    def __init__(self, model_name: str, enable_jit: bool, labels: List[str], top_k: int):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model, self.preprocess = clip.load(model_name, jit=enable_jit)
        self.model.eval()  # Set the model to evaluation mode
        self.top_k = top_k
        self.labels = labels

    def preprocess_data(self, image_data):
        """Preprocess the input data for inference."""
        return self.preprocess(image_data).unsqueeze(0).to(self.device)

    def precomoute_text_features(self, text_data) -> torch.Tensor:
        """Precompute text features for all the labels"""
        text_inputs = torch.cat([clip.tokenize(f"a photo of {c}") for c in text_data]).to(self.device)

        with torch.no_grad():
            text_features = self.model.encode_text(text_inputs)
            print(text_features.shape)
        text_features /= text_features.norm(dim=-1, keepdim=True)
        self.text_features =  text_features
        return self.text_features

    def predict(self, data) -> List[str]:
        """Perform inference on the preprocessed data."""
        image_input = self.preprocess_data(data)

        # Calculate image features
        with torch.no_grad():
            image_features = self.model.encode_image(image_input)

        # Normalize the image features
        image_features /= image_features.norm(dim=-1, keepdim=True)

        # Calculate similarity between image and text features
        similarity = (100.0 * image_features @ self.text_features.T).softmax(dim=-1)
        values, indices = similarity[0].topk(self.top_k)

        if self.top_k > 1:
          pred_label = []
          for ii in range(len(indices)):
            pred_label.append(self.labels[indices[ii]])
        else:
          pred_label = [self.labels[indices[0]]]

        return pred_label

**Configs**

In [None]:
# dataset config
dataset_name = 'ceyda/fashion-products-small'
dataset_config = DataConfig(dataset_name)
dataset_config.dataset_name

# model config
model_name = "ViT-B/32"
model_config = ModelConfig(model_name)
model_config.model_name

# dataset object init
dataset_obj = DatasetInfo(dataset_config.dataset_name)
labels = dataset_obj.get_labels()

# inference config
inference_config = InferenceConfig(model_config,
                                   dataset_obj.labels,
                                   top_k = 1,
                                   num_of_inf_samples = 200)


In [None]:
dataset_obj.get_dataset_stats()

In [None]:
# Inference run init
inference_model = InferenceModel(model_config.model_name,
                                 model_config.enable_jit,
                                 inference_config.labels,
                                 inference_config.top_k)

#ToDo: Make it part of init run with a flag
# Computes embedding for the all the classes (ie ther text descriptions)
text_features = inference_model.precomoute_text_features(dataset_obj.labels)

In [None]:
# Execute inference
# Inference Loop
predict_label_list = []
true_label_list = []

for idx in range(inference_config.num_of_inf_samples):
  example = dataset_obj.dataset[idx]
  image_data = example['image']
  true_label = example['subCategory']
  predict_label = inference_model.predict(image_data)
  true_label_list.append(true_label)
  predict_label_list.append(predict_label)
  if (idx % 10) == 0:
    print(f"Predicted: {predict_label}, Actual: {true_label}, for top_k = {inference_config.top_k}")

In [None]:
def eval_precision(true_label_list, predict_label_list):
  """Compute precsiion top_k precision"""
  eval_decision = []
  for idx in range(len(true_label_list)):
    if true_label_list[idx] in predict_label_list[idx]:
      eval_decision.append(1)
    else:
      eval_decision.append(0)

  print(f"Precision of Clip for (top_k = {inference_config.top_k}) is {np.sum(np.array(eval_decision))/inference_config.num_of_inf_samples}")
  return eval_decision

eval_metric = eval_precision(true_label_list, predict_label_list)

**Fine Tuning CLIP**

In [None]:
from torch.utils.data import random_split

# Split dataset into training and validation sets
train_size = int(0.8 * len(dataset_obj.dataset))
val_size = len(dataset_obj.dataset) - train_size
train_dataset, val_dataset = random_split(dataset_obj.dataset, [train_size, val_size])

In [None]:
from torchvision import transforms
from torch.utils.data import Dataset

sub_categories = list(set(dataset_obj.dataset['subCategory']))

# Define a custom dataset class
class TransformDataset(Dataset):

    def __init__(self, data):
        self.data = data
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
        ])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):

        item = self.data[idx]
        image = item['image']
        subcategory = item['subCategory']
        label = sub_categories.index(subcategory)
        return self.transform(image), label

In [None]:
from torch.utils.data import DataLoader

# Create DataLoader for training and validation sets
train_loader = DataLoader(TransformDataset(train_dataset), batch_size=32, shuffle=True)
val_loader = DataLoader(TransformDataset(val_dataset), batch_size=32, shuffle=False)

In [None]:
import torch.nn as nn

# Modify the model to include a classifier for subcategories
class FineTuneClip(nn.Module):
    def __init__(self, model, num_classes):
        super(FineTuneClip, self).__init__()
        self.model = model
        self.classifier = nn.Linear(model.visual.output_dim, num_classes)

    def forward(self, x):
        with torch.no_grad():
            features = self.model.encode_image(x).float()  # Convert to float32
        return self.classifier(features)

In [None]:
num_classes = len(sub_categories)
model_finetune = FineTuneClip(inference_model.model, num_classes).to(inference_model.device)

In [None]:
import torch.optim as optim

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model_finetune.classifier.parameters(), lr=1e-4)

In [None]:
from tqdm import tqdm

# Number of epochs for training
num_epochs = 2

# Training loop
for epoch in range(num_epochs):
    model_finetune.train()  # Set the model to training mode
    running_loss = 0.0  # Initialize running loss for the current epoch
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}, Loss: 0.0000")  # Initialize progress bar

    for images, labels in pbar:
        images, labels = images.to(inference_model.device), labels.to(inference_model.device)  # Move images and labels to the device (GPU/CPU)
        optimizer.zero_grad()  # Clear the gradients of all optimized variables
        outputs = model_finetune(images)  # Forward pass: compute predicted outputs by passing inputs to the model
        loss = criterion(outputs, labels)  # Calculate the loss
        loss.backward()  # Backward pass: compute gradient of the loss with respect to model parameters
        optimizer.step()  # Perform a single optimization step (Parameter update)

        running_loss += loss.item()  # Update running loss
        pbar.set_description(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(train_loader):.4f}")  # Update progress bar with current loss

    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}')  #Average loss for the epoch

    # Validation
    model_finetune.eval()  # Set the model to evaluation mode
    correct = 0  # Initialize correct predictions counter
    total = 0  # Initialize total samples counter

    pred_list = []
    label_list = []
    #values, indices = torch.topk(x, 2)


    with torch.no_grad():  # Disable gradient calculation for validation
        for images, labels in val_loader:
            images, labels = images.to(inference_model.device), labels.to(inference_model.device)  # Move images and labels to the device
            outputs = model_finetune(images)  # Forward pass: compute predicted outputs by passing inputs to the model
            probabilities = torch.nn.functional.softmax(outputs, dim=1) # Compute probabilities
            _, predicted = torch.max(probabilities.data, 1)  # Get the class label with the highest probability
            total += labels.size(0)  # Update total samples
            label_list.append(labels)
            correct += (predicted == labels).sum().item()  # Update correct predictions

    print(f'Validation Accuracy: {100 * correct / total}%')  # Print validation accuracy for the epoch

# Save the fine-tuned model
torch.save(model_finetune.state_dict(), 'clip_finetuned.pth')  # Save the model's state dictionary
