In [None]:
# ✅ Install all required libraries
!pip install torch torchvision torchaudio --upgrade
!pip install transformers timm gradio tqdm Pillow -q




In [None]:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from transformers import DistilBertTokenizer, DistilBertModel
import timm
from PIL import Image
import requests
import os
import zipfile
from tqdm.notebook import tqdm
import gradio as gr

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

Using device: cuda


In [None]:

config = {
    "batch_size": 128,
    "epochs": 15, 
    "lr": 1e-4,
    "image_encoder": "resnet18",
    "text_encoder": "distilbert-base-uncased",
    "image_embedding_dim": 512, 
    "text_embedding_dim": 768,  
    "projection_dim": 256,      
    "temperature": 1.0          
}

In [None]:

# Download and unzip the Flickr8k dataset
!kaggle datasets download -d adityajn105/flickr8k -p ./data
with zipfile.ZipFile('./data/flickr8k.zip', 'r') as zip_ref:
    zip_ref.extractall('./data/flickr8k')

# Define paths
IMAGE_PATH = "./data/flickr8k/Images"
CAPTIONS_PATH = "./data/flickr8k/captions.txt"

Dataset URL: https://www.kaggle.com/datasets/adityajn105/flickr8k
License(s): CC0-1.0
flickr8k.zip: Skipping, found more recently modified local copy (use --force to force download)


In [None]:
class FlickrDataset(Dataset):
    def __init__(self, image_path, captions_path, tokenizer, transform):
        self.image_path = image_path
        self.tokenizer = tokenizer
        self.transform = transform

        # Load and process captions
        with open(captions_path, 'r') as f:
            lines = f.readlines()

        self.captions = {}
        for line in lines[1:]: # Skip header
            parts = line.strip().split(',')
            img_name, caption = parts[0], ','.join(parts[1:])
            if img_name not in self.captions:
                self.captions[img_name] = []
            self.captions[img_name].append(caption)

        self.image_files = list(self.captions.keys())

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_name = self.image_files[idx]
        img_path = os.path.join(self.image_path, img_name)

        # Load image
        image = Image.open(img_path).convert("RGB")
        image = self.transform(image)

        # Get one of the captions for the image
        caption = self.captions[img_name][0] # Using the first caption for simplicity

        # Tokenize text
        encoded_caption = self.tokenizer(
            caption,
            padding='max_length',
            truncation=True,
            max_length=64, # Max caption length
            return_tensors='pt'
        )

        return image, encoded_caption['input_ids'].squeeze(), encoded_caption['attention_mask'].squeeze()

# Define image transformations and tokenizer
image_transform = Compose([
    Resize(256),
    CenterCrop(224),
    ToTensor(),
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

tokenizer = DistilBertTokenizer.from_pretrained(config["text_encoder"])

# Create dataset and dataloader
dataset = FlickrDataset(IMAGE_PATH, CAPTIONS_PATH, tokenizer, image_transform)
dataloader = DataLoader(dataset, batch_size=config["batch_size"], shuffle=True)

In [None]:
class ImageEncoder(nn.Module):
    def __init__(self, model_name=config["image_encoder"], pretrained=True, trainable=True):
        super().__init__()
        self.model = timm.create_model(model_name, pretrained=pretrained, num_classes=0)
        # Freeze the model if not trainable
        for p in self.model.parameters():
            p.requires_grad = trainable

    def forward(self, x):
        return self.model(x)

class TextEncoder(nn.Module):
    def __init__(self, model_name=config["text_encoder"], trainable=True):
        super().__init__()
        self.model = DistilBertModel.from_pretrained(model_name)
        for p in self.model.parameters():
            p.requires_grad = trainable

    def forward(self, input_ids, attention_mask):
        # We only care about the [CLS] token's output
        output = self.model(input_ids=input_ids, attention_mask=attention_mask)
        return output.last_hidden_state[:, 0, :]

class ProjectionHead(nn.Module):
    def __init__(self, embedding_dim, projection_dim):
        super().__init__()
        self.projection = nn.Linear(embedding_dim, projection_dim)
        self.gelu = nn.GELU()
        self.fc = nn.Linear(projection_dim, projection_dim)
        self.dropout = nn.Dropout(0.1)
        self.layer_norm = nn.LayerNorm(projection_dim)

    def forward(self, x):
        projected = self.projection(x)
        x = self.gelu(projected)
        x = self.fc(x)
        x = self.dropout(x)
        x = x + projected # Add residual connection
        x = self.layer_norm(x)
        return x

class CLIPModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.image_encoder = ImageEncoder()
        self.text_encoder = TextEncoder()
        self.image_projection = ProjectionHead(config["image_embedding_dim"], config["projection_dim"])
        self.text_projection = ProjectionHead(config["text_embedding_dim"], config["projection_dim"])

    def forward(self, image, text_ids, text_mask):
        image_features = self.image_encoder(image)
        text_features = self.text_encoder(text_ids, text_mask)

        image_embedding = self.image_projection(image_features)
        text_embedding = self.text_projection(text_features)

        return image_embedding, text_embedding

In [None]:
import torch.nn.functional as F

In [None]:
def contrastive_loss(image_embeddings, text_embeddings, temperature=config["temperature"]):
    # Normalize the embeddings
    image_embeddings_norm = F.normalize(image_embeddings, p=2, dim=1)
    text_embeddings_norm = F.normalize(text_embeddings, p=2, dim=1)

    # Calculate cosine similarity
    similarity_matrix = torch.matmul(image_embeddings_norm, text_embeddings_norm.T) * torch.exp(torch.tensor(temperature))

    # The labels are the indices of the correct pairs (i.e., the diagonal)
    labels = torch.arange(image_embeddings.shape[0]).to(device)

    # Calculate loss for image-to-text and text-to-image
    loss_img = F.cross_entropy(similarity_matrix, labels)
    loss_text = F.cross_entropy(similarity_matrix.T, labels)

    # Average the two losses
    return (loss_img + loss_text) / 2

In [None]:
# Instantiate model and optimizer
model = CLIPModel().to(device)
optimizer = optim.Adam(model.parameters(), lr=config["lr"])

# --- Training Loop ---
for epoch in range(config["epochs"]):
    model.train()
    train_loss = 0.0

    pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{config['epochs']}")
    for images, input_ids, attention_mask in pbar:
        images = images.to(device)
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)

        optimizer.zero_grad()

        # Forward pass
        image_embed, text_embed = model(images, input_ids, attention_mask)

        # Calculate loss
        loss = contrastive_loss(image_embed, text_embed)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        pbar.set_postfix({"loss": loss.item()})

    avg_loss = train_loss / len(dataloader)
    print(f"End of Epoch {epoch+1}, Average Loss: {avg_loss:.4f}")

Epoch 1/15:   0%|          | 0/64 [00:00<?, ?it/s]

End of Epoch 1, Average Loss: 4.0868


Epoch 2/15:   0%|          | 0/64 [00:00<?, ?it/s]

End of Epoch 2, Average Loss: 3.5535


Epoch 3/15:   0%|          | 0/64 [00:00<?, ?it/s]

End of Epoch 3, Average Loss: 3.3417


Epoch 4/15:   0%|          | 0/64 [00:00<?, ?it/s]

End of Epoch 4, Average Loss: 3.1981


Epoch 5/15:   0%|          | 0/64 [00:00<?, ?it/s]

End of Epoch 5, Average Loss: 3.0894


Epoch 6/15:   0%|          | 0/64 [00:00<?, ?it/s]

End of Epoch 6, Average Loss: 3.0101


Epoch 7/15:   0%|          | 0/64 [00:00<?, ?it/s]

End of Epoch 7, Average Loss: 2.9559


Epoch 8/15:   0%|          | 0/64 [00:00<?, ?it/s]

End of Epoch 8, Average Loss: 2.9173


Epoch 9/15:   0%|          | 0/64 [00:00<?, ?it/s]

End of Epoch 9, Average Loss: 2.8694


Epoch 10/15:   0%|          | 0/64 [00:00<?, ?it/s]

End of Epoch 10, Average Loss: 2.8370


Epoch 11/15:   0%|          | 0/64 [00:00<?, ?it/s]

End of Epoch 11, Average Loss: 2.8068


Epoch 12/15:   0%|          | 0/64 [00:00<?, ?it/s]

End of Epoch 12, Average Loss: 2.7882


Epoch 13/15:   0%|          | 0/64 [00:00<?, ?it/s]

End of Epoch 13, Average Loss: 2.7729


Epoch 14/15:   0%|          | 0/64 [00:00<?, ?it/s]

End of Epoch 14, Average Loss: 2.7489


Epoch 15/15:   0%|          | 0/64 [00:00<?, ?it/s]

End of Epoch 15, Average Loss: 2.7289


In [None]:
# Clear GPU memory after training
import gc

# Move model to CPU temporarily
model = model.cpu()

# Clear CUDA cache
torch.cuda.empty_cache()
gc.collect()

# Move model back to GPU
model = model.to(device)
model.eval()

print(f"GPU Memory allocated: {torch.cuda.memory_allocated()/1024**2:.2f} MB")
print(f"GPU Memory cached: {torch.cuda.memory_reserved()/1024**2:.2f} MB")

GPU Memory allocated: 1251.46 MB
GPU Memory cached: 1378.00 MB


In [None]:
def zero_shot_classify(image, text_labels):
    model.eval()

    # Preprocess the image
    image = image_transform(image).unsqueeze(0).to(device)

    # Tokenize the text labels
    text_labels = text_labels.split(',')
    text_tokens = tokenizer(text_labels, padding=True, truncation=True, return_tensors="pt").to(device)

    with torch.no_grad():
        # Get embeddings
        image_embed, text_embed = model(image, text_tokens["input_ids"], text_tokens["attention_mask"])

        # Normalize embeddings
        image_embed_norm = F.normalize(image_embed, p=2, dim=-1)
        text_embed_norm = F.normalize(text_embed, p=2, dim=-1)

        # Calculate similarity
        similarities = (image_embed_norm @ text_embed_norm.T).squeeze(0)

        # Convert to probabilities
        probs = F.softmax(similarities, dim=-1)

    # Return a dictionary of labels and their probabilities
    return {label: prob.item() for label, prob in zip(text_labels, probs)}


interface = gr.Interface(
    fn=zero_shot_classify,
    inputs=[
        gr.Image(type="pil", label="Upload Image"),
        gr.Textbox(lines=2, label="Enter Text Labels (comma-separated)")
    ],
    outputs=gr.Label(num_top_classes=3, label="Predictions"),
    title="Zero-Shot Image Classifier (Mini-CLIP)",
    description="Upload an image and provide a list of potential categories (e.g., 'a dog, a cat, a car'). The model will predict the most likely category without having been explicitly trained on it.",
    examples=[
        ["./data/flickr8k/Images/1000268201_693b08cb0e.jpg", "a child playing, a dog running, a city street"],
        ["./data/flickr8k/Images/1001773457_577c3a7d70.jpg", "a man on a mountain, a woman on a beach, a group of people indoors"]
    ]
)

# Launch the demo
interface.launch(debug=True)

It looks like you are running Gradio on a hosted Jupyter notebook, which requires `share=True`. Automatically setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
* Running on public URL: https://e107ba445aa8cc01c1.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


Keyboard interruption in main thread... closing server.
Killing tunnel 127.0.0.1:7860 <> https://e107ba445aa8cc01c1.gradio.live


