In [None]:
# Install necessary packages
!pip install transformers sentencepiece torch torchvision albumentations timm




In [None]:

!pip install datasets



In [None]:
from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader
import torch
import albumentations as A
from transformers import AutoTokenizer
from PIL import Image
import timm
from tqdm import tqdm

In [None]:
import os
import json
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import albumentations as A
from torch.utils.data import Dataset as TorchDataset, DataLoader
from transformers import AutoTokenizer, AutoModel
from tqdm import tqdm
import timm
import cv2
import itertools


In [None]:

# Step 1: Mount Google Drive to access the dataset.
from google.colab import drive
drive.mount('/content/drive')


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
# Imports
import os
import json
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import albumentations as A
from torch.utils.data import Dataset as TorchDataset, DataLoader
from transformers import AutoTokenizer, AutoModel
from tqdm import tqdm
import timm
import cv2
import itertools
from torch.nn.utils.rnn import pad_sequence

In [None]:
# Configuration Class
class CFG:
    model_name = "resnet50"
    text_encoder_model = "csebuetnlp/banglabert"
    pretrained = True
    trainable = True
    batch_size = 16
    size = 224
    image_embedding = 2048
    text_embedding = 768
    projection_dim = 512
    max_length = 2
    temperature = 0.07
    image_encoder_lr = 1e-4
    text_encoder_lr = 1e-5
    head_lr = 1e-3
    weight_decay = 1e-4
    patience = 2
    factor = 0.5
    device = "cuda" if torch.cuda.is_available() else "cpu"
    epochs = 7

In [None]:
# Define paths to the image folders and caption files
image_folder1 = '/content/drive/MyDrive/Bangla Image dataset with caption/BNATURE/Pictures'
caption_file1 = '/content/drive/MyDrive/Bangla Image dataset with caption/BNATURE/caption/captions.json'
image_files_key1 = 'caption_id'
caption_key1 = 'bengali_caption'

image_folder2 = '/content/drive/MyDrive/Bangla Image dataset with caption/Bangla Lekha 2.0/images'
caption_file2 = '/content/drive/MyDrive/Bangla Image dataset with caption/Bangla Lekha 2.0/captions.json'
image_files_key2 = 'filename'
caption_key2 = 'caption'

image_folder3 = '/content/drive/MyDrive/Bangla Image dataset with caption/Flickr8k_Dataset/Flicker8k_Dataset'
caption_file3 = '/content/drive/MyDrive/Bangla Image dataset with caption/Flickr8k_Dataset/BAN-Cap_captiondata.json'
image_files_key3 = 'caption_id'
caption_key3 = 'bengali_caption'

In [None]:
# Function to load a single dataset
def load_dataset(image_folder, caption_file, image_files_key, caption_key):
    with open(caption_file, 'r', encoding='utf-8') as f:
        captions_data = json.load(f)

    images, captions = [], []
    for item in captions_data:
        image_file = os.path.join(image_folder, item[image_files_key])
        caption = item[caption_key]
        if os.path.exists(image_file):  # Ensure the image file exists
            images.append(image_file)
            captions.append(caption)

    return {"image": images, "caption": captions}

# Load all datasets
data1 = load_dataset(image_folder1, caption_file1, image_files_key1, caption_key1)
data2 = load_dataset(image_folder2, caption_file2, image_files_key2, caption_key2)
data3 = load_dataset(image_folder3, caption_file3, image_files_key3, caption_key3)

# Combine all datasets
all_images = data1["image"] + data2["image"] + data3["image"]
all_captions = data1["caption"] + data2["caption"] + data3["caption"]


In [None]:
class CLIPDataset(TorchDataset):
    def __init__(self, image_filenames, captions, tokenizer, transforms):
        self.image_filenames = image_filenames
        self.captions = captions
        self.tokenizer = tokenizer
        self.transforms = transforms

    def __getitem__(self, idx):
        # Tokenize captions with consistent output
        encoded = self.tokenizer(
            self.captions[idx],
            padding="max_length",  # Ensure padding to max_length
            truncation=True,       # Ensure truncation to max_length
            max_length=CFG.max_length,
            return_tensors=None,   # Return plain lists, not tensors
        )

        # Convert tokenized outputs to tensors
        input_ids = torch.tensor(encoded["input_ids"], dtype=torch.long)
        attention_mask = torch.tensor(encoded["attention_mask"], dtype=torch.long)

        # Validate shapes
        assert input_ids.ndim == 1, f"Input IDs have incorrect shape {input_ids.shape}"
        assert attention_mask.ndim == 1, f"Attention Mask has incorrect shape {attention_mask.shape}"

        # Load and process the image
        image = cv2.imread(self.image_filenames[idx])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = self.transforms(image=image)["image"]

        return {
            "image": torch.tensor(image).permute(2, 0, 1).float(),
            "input_ids": input_ids,  # Correct 1D tensor
            "attention_mask": attention_mask,  # Correct 1D tensor
            "caption": self.captions[idx],
        }

    def __len__(self):
        return len(self.captions)


In [None]:
def custom_collate_fn(batch):
    try:
        # Stack images
        images = torch.stack([item["image"] for item in batch])

        # Validate shapes of input tensors
        for i, item in enumerate(batch):
            assert item["input_ids"].ndim == 1, f"Input IDs of item {i} have incorrect shape {item['input_ids'].shape}"
            assert item["attention_mask"].ndim == 1, f"Attention Mask of item {i} has incorrect shape {item['attention_mask'].shape}"

        # Pad sequences (input_ids and attention_mask)
        input_ids = pad_sequence(
            [item["input_ids"] for item in batch], batch_first=True, padding_value=0
        )  # Shape: [batch_size, seq_length]

        attention_mask = pad_sequence(
            [item["attention_mask"] for item in batch], batch_first=True, padding_value=0
        )  # Shape: [batch_size, seq_length]

        captions = [item["caption"] for item in batch]

    except AssertionError as e:
        print(f"Error in collate function: {e}")
        for i, item in enumerate(batch):
            print(f"Item {i} Input IDs Shape: {item['input_ids'].shape}")
            print(f"Item {i} Attention Mask Shape: {item['attention_mask'].shape}")
        raise e

    return {
        "image": images,
        "input_ids": input_ids,  # Shape: [batch_size, seq_length]
        "attention_mask": attention_mask,  # Shape: [batch_size, seq_length]
        "caption": captions,
    }


In [None]:

# Transforms
def get_transforms():
    return A.Compose([
        A.Resize(CFG.size, CFG.size, always_apply=True),
        A.Normalize(max_pixel_value=255.0, always_apply=True),
    ])

In [None]:
# Image Encoder
class ImageEncoder(nn.Module):
    def __init__(self, model_name=CFG.model_name, pretrained=CFG.pretrained, trainable=CFG.trainable):
        super().__init__()
        self.model = timm.create_model(model_name, pretrained, num_classes=0, global_pool="avg")
        for p in self.model.parameters():
            p.requires_grad = trainable

    def forward(self, x):
        return self.model(x)

In [None]:
class TextEncoder(nn.Module):
    def __init__(self, model_name=CFG.text_encoder_model, pretrained=CFG.pretrained, trainable=CFG.trainable):
        super().__init__()
        self.model = AutoModel.from_pretrained(model_name)
        for p in self.model.parameters():
            p.requires_grad = trainable
        self.target_token_idx = 0

    def forward(self, input_ids, attention_mask):
        # Log input shapes for debugging
        print(f"TextEncoder Input IDs Shape: {input_ids.shape}")  # Expected: [batch_size, seq_length]
        print(f"TextEncoder Attention Mask Shape: {attention_mask.shape}")  # Expected: [batch_size, seq_length]

        # Pass inputs through the transformer model
        output = self.model(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden_state = output.last_hidden_state  # Shape: [batch_size, seq_length, hidden_dim]

        # Return features for the target token (e.g., CLS token)
        return last_hidden_state[:, self.target_token_idx, :]  # Shape: [batch_size, hidden_dim]


In [None]:
# Projection Head
class ProjectionHead(nn.Module):
    def __init__(self, embedding_dim, projection_dim=CFG.projection_dim, dropout=0.1):
        super().__init__()
        self.projection = nn.Linear(embedding_dim, projection_dim)
        self.gelu = nn.GELU()
        self.fc = nn.Linear(projection_dim, projection_dim)
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(projection_dim)

    def forward(self, x):
        projected = self.projection(x)
        x = self.gelu(projected)
        x = self.fc(x)
        x = self.dropout(x)
        x = x + projected
        x = self.layer_norm(x)
        return x


In [None]:
class CLIPModel(nn.Module):
    def __init__(self, temperature=CFG.temperature):
        super().__init__()
        self.image_encoder = ImageEncoder()
        self.text_encoder = TextEncoder()
        self.image_projection = ProjectionHead(embedding_dim=CFG.image_embedding)
        self.text_projection = ProjectionHead(embedding_dim=CFG.text_embedding)
        self.temperature = temperature

    def forward(self, batch):
        # Log input shapes for debugging
        print(f"Image Shape: {batch['image'].shape}")
        print(f"Input IDs Shape: {batch['input_ids'].shape}")
        print(f"Attention Mask Shape: {batch['attention_mask'].shape}")

        # Encode image and text
        image_features = self.image_encoder(batch["image"])
        text_features = self.text_encoder(
            input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]
        )

        # Project embeddings
        image_embeddings = self.image_projection(image_features)
        text_embeddings = self.text_projection(text_features)

        # Compute logits and loss
        logits = (text_embeddings @ image_embeddings.T) / self.temperature
        images_similarity = image_embeddings @ image_embeddings.T
        texts_similarity = text_embeddings @ text_embeddings.T

        targets = F.softmax((images_similarity + texts_similarity) / 2 * self.temperature, dim=-1)
        texts_loss = cross_entropy(logits, targets, reduction='none')
        images_loss = cross_entropy(logits.T, targets.T, reduction='none')
        loss = (images_loss + texts_loss) / 2.0
        return loss.mean()


In [None]:
# Cross Entropy
def cross_entropy(preds, targets, reduction='none'):
    log_softmax = nn.LogSoftmax(dim=-1)
    loss = (-targets * log_softmax(preds)).sum(1)
    return loss.mean() if reduction == "mean" else loss


In [None]:
def get_transforms():
    return A.Compose([
        A.Resize(CFG.size, CFG.size, always_apply=True),  # Resize all images to CFG.size
        A.Normalize(max_pixel_value=255.0, always_apply=True),
    ])


In [None]:
# Inspect Dataset Output
dataset = CLIPDataset(all_images, all_captions, tokenizer, transforms)
for i in range(5):  # Check the first 5 samples
    sample = dataset[i]
    print(f"Sample {i} - Input IDs shape: {sample['input_ids'].shape}, Attention Mask shape: {sample['attention_mask'].shape}")


Sample 0 - Input IDs shape: torch.Size([2]), Attention Mask shape: torch.Size([2])
Sample 1 - Input IDs shape: torch.Size([2]), Attention Mask shape: torch.Size([2])
Sample 2 - Input IDs shape: torch.Size([2]), Attention Mask shape: torch.Size([2])
Sample 3 - Input IDs shape: torch.Size([2]), Attention Mask shape: torch.Size([2])
Sample 4 - Input IDs shape: torch.Size([2]), Attention Mask shape: torch.Size([2])


In [None]:
# Test Dataset
for i in range(5):
    sample = dataset[i]
    print(f"Sample {i} Input IDs Shape: {sample['input_ids'].shape}")  # Expected: [CFG.max_length]
    print(f"Sample {i} Attention Mask Shape: {sample['attention_mask'].shape}")  # Expected: [CFG.max_length]


Sample 0 Input IDs Shape: torch.Size([2])
Sample 0 Attention Mask Shape: torch.Size([2])
Sample 1 Input IDs Shape: torch.Size([2])
Sample 1 Attention Mask Shape: torch.Size([2])
Sample 2 Input IDs Shape: torch.Size([2])
Sample 2 Attention Mask Shape: torch.Size([2])
Sample 3 Input IDs Shape: torch.Size([2])
Sample 3 Attention Mask Shape: torch.Size([2])
Sample 4 Input IDs Shape: torch.Size([2])
Sample 4 Attention Mask Shape: torch.Size([2])


In [None]:

# Test DataLoader
for batch in train_loader:
    print(f"Batch Input IDs Shape: {batch['input_ids'].shape}")  # Expected: [batch_size, CFG.max_length]
    print(f"Batch Attention Mask Shape: {batch['attention_mask'].shape}")  # Expected: [batch_size, CFG.max_length]
    break


Batch Index 0: Image Shape torch.Size([3, 224, 224]), Input IDs Length 70
Batch Index 1: Image Shape torch.Size([3, 224, 224]), Input IDs Length 70
Batch Index 2: Image Shape torch.Size([3, 224, 224]), Input IDs Length 70
Batch Index 3: Image Shape torch.Size([3, 224, 224]), Input IDs Length 70
Batch Input IDs Shape: torch.Size([4, 70])
Batch Attention Mask Shape: torch.Size([4, 70])


In [None]:
def train_epoch(model, train_loader, optimizer, lr_scheduler, step):
    model.train()
    scaler = torch.amp.GradScaler()
    loss_meter = 0

    for batch in tqdm(train_loader, total=len(train_loader)):
        # Log input shapes
        print(f"Batch Image Shape: {batch['image'].shape}")  # Expected: [batch_size, 3, CFG.size, CFG.size]
        print(f"Batch Input IDs Shape: {batch['input_ids'].shape}")  # Expected: [batch_size, CFG.max_length]
        print(f"Batch Attention Mask Shape: {batch['attention_mask'].shape}")  # Expected: [batch_size, CFG.max_length]

        batch = {k: v.to(CFG.device) for k, v in batch.items() if k != "caption"}

        with torch.amp.autocast(device_type='cuda'):
            loss = model(batch)

        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        if step == "batch":
            lr_scheduler.step()

        loss_meter += loss.item()

    return loss_meter / len(train_loader)




def valid_epoch(model, valid_loader):
    model.eval()
    loss_meter = 0
    with torch.no_grad():
        for batch in tqdm(valid_loader, total=len(valid_loader)):
            batch = {k: v.to(CFG.device) for k, v in batch.items() if k != "caption"}

            loss = model(batch)
            loss_meter += loss.item()
    return loss_meter / len(valid_loader)

# Main Function
def main():
    tokenizer = AutoTokenizer.from_pretrained(CFG.text_encoder_model)
    transforms = get_transforms()
    dataset = CLIPDataset(all_images, all_captions, tokenizer, transforms)

    # Split dataset
    train_size = int(0.8 * len(dataset))
    valid_size = len(dataset) - train_size
    train_dataset, valid_dataset = torch.utils.data.random_split(dataset, [train_size, valid_size])

    # DataLoader with custom collate_fn
    train_loader = DataLoader(train_dataset, batch_size=CFG.batch_size, shuffle=True, num_workers=4, collate_fn=custom_collate_fn)
    valid_loader = DataLoader(valid_dataset, batch_size=CFG.batch_size, shuffle=False, num_workers=4, collate_fn=custom_collate_fn)

    model = CLIPModel().to(CFG.device)
    params = [
        {"params": model.image_encoder.parameters(), "lr": CFG.image_encoder_lr},
        {"params": model.text_encoder.parameters(), "lr": CFG.text_encoder_lr},
        {"params": itertools.chain(model.image_projection.parameters(), model.text_projection.parameters()),
         "lr": CFG.head_lr, "weight_decay": CFG.weight_decay}
    ]
    optimizer = torch.optim.AdamW(params, weight_decay=0.0)
    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", patience=CFG.patience, factor=CFG.factor)
    step = "epoch"

    best_loss = float('inf')
    for epoch in range(CFG.epochs):
        print(f"Epoch {epoch + 1}/{CFG.epochs}")
        train_loss = train_epoch(model, train_loader, optimizer, lr_scheduler, step)
        print(f"Train Loss: {train_loss:.4f}")
        valid_loss = valid_epoch(model, valid_loader)
        print(f"Validation Loss: {valid_loss:.4f}")

        if valid_loss < best_loss:
            best_loss = valid_loss
            torch.save(model.state_dict(), "best_clip_model_bangla.pt")
            print("Best model saved!")

        if step == "epoch":
            lr_scheduler.step(valid_loss)

# Execute the main function
if __name__ == "__main__":
    main()

Epoch 1/7


  0%|          | 0/2413 [00:00<?, ?it/s]


AssertionError: Caught AssertionError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/worker.py", line 351, in _worker_loop
    data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/fetch.py", line 50, in fetch
    data = self.dataset.__getitems__(possibly_batched_index)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataset.py", line 420, in __getitems__
    return [self.dataset[self.indices[idx]] for idx in indices]
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataset.py", line 420, in <listcomp>
    return [self.dataset[self.indices[idx]] for idx in indices]
  File "<ipython-input-281-f3c98c09fde2>", line 23, in __getitem__
    assert input_ids.ndim == 1, f"Input IDs have incorrect shape {input_ids.shape}"
AssertionError: Input IDs have incorrect shape torch.Size([2, 2])


#Interface

In [None]:
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
from tqdm import tqdm

# Assuming CLIPDataset and CFG have already been defined
# Ensure the paths to the dataset are properly set

# Load and prepare the dataset
with open(caption_file, 'r', encoding='utf-8') as f:
    captions_data = json.load(f)

# Prepare image files and captions
image_files = [item['caption_id'].split('#')[0] for item in captions_data]
captions = [item['bengali_caption'] for item in captions_data]

# Initialize the tokenizer
tokenizer = AutoTokenizer.from_pretrained(CFG.text_encoder_model)
transforms = get_transforms()

# Create the dataset
dataset = CLIPDataset(image_files, captions, tokenizer, transforms)

# Split dataset into train and validation sets
train_size = int(0.8 * len(dataset))
valid_size = len(dataset) - train_size
_, valid_dataset = torch.utils.data.random_split(dataset, [train_size, valid_size])

# Define the function to get image embeddings
def get_image_embeddings(valid_dataset, model_path):
    valid_loader = DataLoader(valid_dataset, batch_size=CFG.batch_size, shuffle=False, num_workers=4)

    # Load the trained model
    model = CLIPModel().to(CFG.device)
    model.load_state_dict(torch.load(model_path, map_location=CFG.device))
    model.eval()

    valid_image_embeddings = []
    with torch.no_grad():
        for batch in tqdm(valid_loader):
            image_features = model.image_encoder(batch["image"].to(CFG.device))
            image_embeddings = model.image_projection(image_features)
            valid_image_embeddings.append(image_embeddings)

    return model, torch.cat(valid_image_embeddings)

# Perform inference to get image embeddings from the validation set
model, image_embeddings = get_image_embeddings(valid_dataset, "/content/best_clip_model_bangla.pt")


In [None]:
import matplotlib.pyplot as plt  # Add this import


In [None]:
def find_matches(model, image_embeddings, query, image_files, n=9):
    tokenizer = AutoTokenizer.from_pretrained(CFG.text_encoder_model)

    # Encode the query and print to ensure uniqueness
    encoded_query = tokenizer([query], return_tensors="pt", padding=True, truncation=True, max_length=CFG.max_length)
    print(f"Encoded Query: {encoded_query}")  # Check if different prompts yield different encodings

    batch = {
        key: torch.tensor(values).to(CFG.device)
        for key, values in encoded_query.items()
    }

    with torch.no_grad():
        text_features = model.text_encoder(
            input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]
        )
        text_embeddings = model.text_projection(text_features)

    # Normalize embeddings and print for debugging
    text_embeddings_n = F.normalize(text_embeddings, p=2, dim=-1)
    image_embeddings_n = F.normalize(image_embeddings, p=2, dim=-1)

    print(f"Text Embeddings: {text_embeddings_n}")  # Check if text embeddings change with different prompts

    # Calculate similarity and retrieve top matches
    dot_similarity = text_embeddings_n @ image_embeddings_n.T
    values, indices = torch.topk(dot_similarity.squeeze(0), n)

    matches = [image_files[idx] for idx in indices]
    print(f"Top match values: {values}")  # To see if similarity scores vary

    # Display the matched images
    _, axes = plt.subplots(3, 3, figsize=(10, 10))
    for match, ax in zip(matches, axes.flatten()):
        image = cv2.imread(f"{image_folder}/{match}")
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        ax.imshow(image)
        ax.axis("off")

    plt.show()

In [None]:
prompt = "কুকুর "  # Example prompt for "A dog is playing"
find_matches(model, image_embeddings, prompt, image_files)


In [None]:
from google.colab import files

# Path to your saved model file
model_path = "best_clip_model_bangla.pt"  # Replace with your actual model file name if different

# Download the file
files.download(model_path)
