In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
!pip install datasets



In [3]:

# Step 2: Import necessary libraries
import os
import zipfile
from datasets import load_dataset
import pandas as pd
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as transforms
from transformers import AutoTokenizer, AutoModel, ViTModel, ViTFeatureExtractor
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm
import itertools

In [4]:

# Step 2: Import necessary libraries
import os
import zipfile
from datasets import load_dataset
import pandas as pd
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as transforms
from transformers import AutoTokenizer, AutoModel, ViTModel, ViTFeatureExtractor
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm
import itertools

In [5]:
# Step 1: Import necessary libraries
import os
import zipfile
import requests
from datasets import load_dataset
import pandas as pd
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as transforms
from transformers import AutoTokenizer, AutoModel, ViTModel, ViTFeatureExtractor
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm
import itertools

In [6]:
# Step 2: Download the dataset manually
dataset_url = "https://huggingface.co/datasets/Mansuba/Banglafinal/resolve/main/bangla_combined_image_caption_dataset.zip"
zip_path = "/content/Banglafinal.zip"
output_dir = "/content/Banglafinal_unzipped"
os.makedirs(output_dir, exist_ok=True)

In [7]:
# Download the dataset using requests
if not os.path.exists(zip_path):
    print("Downloading dataset...")
    response = requests.get(dataset_url, stream=True)
    with open(zip_path, "wb") as f:
        for chunk in response.iter_content(chunk_size=1024):
            f.write(chunk)
    print(f"Dataset downloaded to {zip_path}")


In [None]:
# Step 3: Unzip the dataset
print("Unzipping dataset...")
with zipfile.ZipFile(zip_path, "r") as zip_ref:
    zip_ref.extractall(output_dir)
print(f"Dataset extracted to {output_dir}")

Unzipping dataset...


In [None]:
# Step 2: Define dataset paths
output_dir = "/content/Banglafinal_unzipped"

# Step 3: Load dataset using Hugging Face's `load_from_disk`
print("Loading dataset from Arrow format...")
dataset = load_from_disk(output_dir)
print(f"Dataset loaded with {len(dataset)} records.")

In [None]:
print(dataset.column_names)

In [None]:
print(dataset[0])

In [None]:
import PIL
from PIL import Image
from io import BytesIO
import torchvision.transforms as transforms
from torch.utils.data import Dataset

class ImageTextDataset(Dataset):
    def __init__(self, dataset, image_size=224):
        self.dataset = dataset
        self.transform = transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.Lambda(lambda img: img.convert("RGB")),  # Convert images to RGB format
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        image_data = self.dataset[index]["image"]
        caption = self.dataset[index]["caption"]

        # Handle different types of image data
        if isinstance(image_data, PIL.Image.Image):  # If already a PIL image
            image = image_data
        elif isinstance(image_data, bytes):  # If the image is stored as bytes
            image = Image.open(BytesIO(image_data))
        elif isinstance(image_data, str):  # If the image is a file path
            image = Image.open(image_data)
        else:
            raise ValueError(f"Unsupported image format at index {index}: {type(image_data)}")

        # Apply transformations
        image = self.transform(image)
        return image, caption


In [None]:
# Step 5: Create PyTorch dataset
image_size = 224  # Set desired image size
full_dataset = ImageTextDataset(dataset, image_size)
print(f"Dataset size: {len(full_dataset)}")

# Step 6: Split dataset into training and validation sets
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

batch_size = 32
num_workers = os.cpu_count()

# Step 7: Create DataLoaders
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)

print(f"Training set size: {len(train_dataset)}, Validation set size: {len(val_dataset)}")

# Step 8: Example training loop (placeholder)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


In [None]:


vit_model_name = 'google/vit-base-patch16-224'
vit_feature_extractor = ViTFeatureExtractor.from_pretrained(vit_model_name)
vit_model = ViTModel.from_pretrained(vit_model_name).to(device)
image_embedding_size = vit_model.config.hidden_size

bangla_bert_name = 'sagorsarker/bangla-bert-base'
text_tokenizer = AutoTokenizer.from_pretrained(bangla_bert_name)
text_model = AutoModel.from_pretrained(bangla_bert_name).to(device)
text_embedding_size = text_model.config.hidden_size

In [None]:
# Step 9: Define Projection class
class Projection(nn.Module):
    def __init__(self, input_size, output_size):
        super(Projection, self).__init__()
        self.fc1 = nn.Linear(input_size, output_size)
        self.dropout = nn.Dropout(0.3)
        self.fc2 = nn.Linear(output_size, output_size)

    def forward(self, x):
        x = self.dropout(torch.relu(self.fc1(x)))
        return self.fc2(x)

shared_embedding_size = 512
image_projector = Projection(image_embedding_size, shared_embedding_size).to(device)
text_projector = Projection(text_embedding_size, shared_embedding_size).to(device)


In [None]:
# Step 10: Define contrastive loss
def contrastive_loss(image_proj, text_proj, margin=0.2):
    similarity = F.cosine_similarity(image_proj, text_proj)
    loss = 1 - similarity.mean() + margin
    return loss

In [None]:
# Step 11: Set up optimizer and learning rate scheduler
params = [
    {"params": vit_model.parameters(), "lr": 1e-4},
    {"params": text_model.parameters(), "lr": 1e-5},
    {"params": itertools.chain(image_projector.parameters(), text_projector.parameters()), "lr": 1e-3, "weight_decay": 1e-3}
]
optimizer = optim.AdamW(params, weight_decay=1e-4)
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", patience=2, factor=0.8)


In [None]:
# Step 12: Training and validation loop
import matplotlib.pyplot as plt

num_epochs = 10
best_val_loss = float("inf")
early_stopping_counter = 0
early_stopping_patience = 3

train_losses = []
val_losses = []
train_accuracies = []
val_accuracies = []
recall_at_1_scores = []

for epoch in range(num_epochs):
    vit_model.train()
    text_model.train()
    image_projector.train()
    text_projector.train()
    total_train_loss = 0.0
    correct_train_predictions = 0
    total_train_predictions = 0

    for images, captions in tqdm(train_dataloader, desc=f"Training Epoch {epoch+1}/{num_epochs}", colour="green"):
        optimizer.zero_grad()
        images = images.to(device)
        inputs = text_tokenizer(captions, return_tensors='pt', padding="max_length", max_length=32, truncation=True).to(device)
        text_embeddings = text_projector(text_model(**inputs).last_hidden_state[:, 0, :])
        image_embeddings = image_projector(vit_model(pixel_values=images)["last_hidden_state"][:, 0, :])
        loss = contrastive_loss(image_embeddings, text_embeddings)
        loss.backward()
        optimizer.step()
        total_train_loss += loss.item()

        # Calculate accuracy (threshold-based) for training
        cosine_sim = F.cosine_similarity(image_embeddings, text_embeddings)
        correct_train_predictions += (cosine_sim > 0.9).sum().item()
        total_train_predictions += images.size(0)

    avg_train_loss = total_train_loss / len(train_dataloader)
    train_losses.append(avg_train_loss)
    train_accuracy = correct_train_predictions / total_train_predictions
    train_accuracies.append(train_accuracy)

    vit_model.eval()
    text_model.eval()
    image_projector.eval()
    text_projector.eval()
    total_val_loss = 0.0
    correct_val_predictions = 0
    total_val_predictions = 0
    recall_at_1 = 0

    for images, captions in tqdm(val_dataloader, desc=f"Validation Epoch {epoch+1}/{num_epochs}", colour="blue"):
        with torch.no_grad():
            images = images.to(device)
            inputs = text_tokenizer(captions, return_tensors='pt', padding="max_length", max_length=32, truncation=True).to(device)
            text_embeddings = text_projector(text_model(**inputs).last_hidden_state[:, 0, :])
            image_embeddings = image_projector(vit_model(pixel_values=images)["last_hidden_state"][:, 0, :])
            val_loss = contrastive_loss(image_embeddings, text_embeddings)
            total_val_loss += val_loss.item()

            # Calculate Recall@1
            cosine_sim = F.cosine_similarity(image_embeddings.unsqueeze(1), text_embeddings.unsqueeze(0), dim=-1)
            recall_at_1 += (cosine_sim.argmax(dim=1) == torch.arange(cosine_sim.size(0), device=device)).sum().item()

            # Calculate accuracy (threshold-based) for validation
            correct_val_predictions += (cosine_sim > 0.9).sum().item()
            total_val_predictions += images.size(0)

    avg_val_loss = total_val_loss / len(val_dataloader)
    val_losses.append(avg_val_loss)
    recall_at_1_score = recall_at_1 / len(val_dataloader.dataset)
    recall_at_1_scores.append(recall_at_1_score)
    val_accuracy = correct_val_predictions / total_val_predictions
    val_accuracies.append(val_accuracy)

    print(f"Epoch [{epoch+1}/{num_epochs}] - Train Loss: {avg_train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}, Val Loss: {avg_val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}, Recall@1: {recall_at_1_score:.4f}")

    lr_scheduler.step(avg_val_loss)
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        early_stopping_counter = 0
    else:
        early_stopping_counter += 1
        if early_stopping_counter >= early_stopping_patience:
            print("Early stopping triggered.")
            break

# Plot Training and Validation Losses
plt.figure(figsize=(10, 5))
plt.plot(range(1, len(train_losses) + 1), train_losses, label="Train Loss", color="blue", marker="o")
plt.plot(range(1, len(val_losses) + 1), val_losses, label="Val Loss", color="green", marker="x")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training and Validation Losses")
plt.legend()
plt.grid()
plt.show()

# Plot Training and Validation Accuracies
plt.figure(figsize=(10, 5))
plt.plot(range(1, len(train_accuracies) + 1), train_accuracies, label="Train Accuracy", color="red", marker="o")
plt.plot(range(1, len(val_accuracies) + 1), val_accuracies, label="Val Accuracy", color="purple", marker="x")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.title("Training and Validation Accuracies")
plt.legend()
plt.grid()
plt.show()

# Plot Recall@1 Scores
plt.figure(figsize=(10, 5))
plt.plot(range(1, len(recall_at_1_scores) + 1), recall_at_1_scores, label="Recall@1", color="orange", marker="s")
plt.xlabel("Epoch")
plt.ylabel("Recall@1")
plt.title("Validation Recall@1 Scores")
plt.legend()
plt.grid()
plt.show()


In [None]:
# Specify the directory to save the models
save_directory = "/content/drive/MyDrive/Bangla Image dataset with caption"

# Save the models and other components
torch.save(vit_model.state_dict(), f"{save_directory}/vit_model.pth")
torch.save(text_model.state_dict(), f"{save_directory}/text_model.pth")
torch.save(image_projector.state_dict(), f"{save_directory}/image_projector.pth")
torch.save(text_projector.state_dict(), f"{save_directory}/text_projector.pth")
torch.save(optimizer.state_dict(), f"{save_directory}/optimizer.pth")
torch.save(lr_scheduler.state_dict(), f"{save_directory}/lr_scheduler.pth")


#Interface

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import torch.nn.functional as F
from tqdm import tqdm

# Ensure models are in evaluation mode
vit_model.eval()
image_projector.eval()
text_model.eval()
text_projector.eval()

# Function to create image embeddings with rescaling for ViT model compatibility
def create_image_embeddings(images, device):
    """
    Generates embeddings for the given image tensor using ViT and image projector.

    Args:
        images (torch.Tensor): Batch of images, normalized to [-1, 1].
        device (torch.device): Device to perform computation.

    Returns:
        torch.Tensor: Projected image embeddings.
    """
    with torch.no_grad():
        # Rescale pixel values to [0, 1]
        images = (images + 1) / 2
        features = vit_feature_extractor(images=images, return_tensors="pt").to(device)
        image_embeddings = vit_model(**features).last_hidden_state[:, 0, :]
        image_projection = image_projector(image_embeddings)
    return image_projection

# Precompute and store image embeddings for all training images
def precompute_image_embeddings(dataset, model, projector, device):
    """
    Precomputes and stores embeddings for all images in the dataset.

    Args:
        dataset (Dataset): PyTorch Dataset containing images.
        model (torch.nn.Module): Vision transformer model for feature extraction.
        projector (torch.nn.Module): Projection head for image embeddings.
        device (torch.device): Device to perform computation.

    Returns:
        list[torch.Tensor]: List of image embeddings.
    """
    embeddings_list = []
    for index in tqdm(range(len(dataset)), desc="Processing Training Images"):
        image_tensor = dataset[index][0].unsqueeze(0).to(device)  # Add batch dimension
        embedding = create_image_embeddings(image_tensor, device)
        embeddings_list.append(embedding[0].cpu())  # Move to CPU for storage
    return embeddings_list

# Image retrieval function
def image_retrieval_function(input_query, image_embeddings_list, dataset, n=5, display=False):
    """
    Retrieves the top N most similar images to the input text query.

    Args:
        input_query (str): Text query for image retrieval.
        image_embeddings_list (list[torch.Tensor]): List of precomputed image embeddings.
        dataset (Dataset): PyTorch Dataset containing images.
        n (int): Number of top images to retrieve.
        display (bool): Whether to display retrieved images.

    Returns:
        list[int]: Indices of the top N most similar images.
    """
    with torch.no_grad():
        # Generate text embeddings
        inputs = text_tokenizer(input_query, return_tensors='pt', padding="max_length", max_length=32, truncation=True).to(device)
        text_embeddings = text_model(**inputs).last_hidden_state[:, 0, :]
        text_projection = text_projector(text_embeddings)

    # Compute cosine similarity between text query and all image embeddings
    similarity_scores = [
        F.cosine_similarity(text_projection, embedding.unsqueeze(0).to(device)).item()
        for embedding in image_embeddings_list
    ]

    # Get indices of the top N most similar images
    top_indices = np.argsort(similarity_scores)[-n:][::-1]

    if display:
        # Display the top N images with their similarity scores
        for index in top_indices:
            image_tensor = dataset[index][0]  # Fetch image tensor
            plt.imshow(torch.moveaxis(image_tensor, 0, 2))
            plt.title(f"Similarity Score: {similarity_scores[index]:.4f}")
            plt.axis('off')
            plt.show()
    return top_indices

# Precompute embeddings for training dataset
image_embeddings_list_train = precompute_image_embeddings(train_dataset, vit_model, image_projector, device)

# Example usage
query = "A description of the image"
retrieved_indices = image_retrieval_function(query, image_embeddings_list_train, train_dataset, n=5, display=True)
