In [19]:
import os
import cv2
import gc
import json
import numpy as np
import pandas as pd
import itertools
from tqdm.autonotebook import tqdm
import albumentations as A
import matplotlib.pyplot as plt

import torch
from torch import nn
import torch.nn.functional as F
import timm
from transformers import DistilBertModel, DistilBertConfig, DistilBertTokenizer


In [20]:
# Load model directly
from transformers import AutoTokenizer, AutoModelForPreTraining

In [21]:
# Step 1: Mount Google Drive to access the dataset.
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [22]:
import os
import json
import cv2
import torch
import numpy as np
import pandas as pd
import albumentations as A
import torch.nn as nn
import torch.nn.functional as F
import timm
import itertools
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModel, AutoConfig
import matplotlib.pyplot as plt

# Add get_lr function
def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

# First, update the CFG class with optimized parameters
class CFG:
    debug = False
    # Dataset paths remain the same
    image_path1 = '/content/drive/MyDrive/Bangla Image dataset with caption/Flickr8k_Dataset/Flicker8k_Dataset'
    captions_path1 = '/content/drive/MyDrive/Bangla Image dataset with caption/Flickr8k_Dataset'
    image_path2 = '/content/drive/MyDrive/Bangla Image dataset with caption/BNATURE/Pictures'
    captions_path2 = '/content/drive/MyDrive/Bangla Image dataset with caption/BNATURE/caption/captions.json'
    image_path3 = '/content/drive/MyDrive/Bangla Image dataset with caption/Bangla Lekha 2.0/images'
    captions_path3 = '/content/drive/MyDrive/Bangla Image dataset with caption/Bangla Lekha 2.0/captions.json'

    # Optimized training parameters
    batch_size = 64  # Increased batch size
    gradient_accumulation_steps = 2  # Reduced accumulation steps
    num_workers = 4  # Increased workers
    pin_memory = True
    mixed_precision = True

    # Optimized learning rates
    image_encoder_lr = 2e-4  # Increased learning rate
    text_encoder_lr = 2e-4   # Increased learning rate
    head_lr = 5e-4          # Increased learning rate
    weight_decay = 0.01

    # Early stopping and scheduler settings
    patience = 2
    factor = 0.7
    epochs = 5
    warmup_ratio = 0.05

    # Model parameters
    model_name = 'resnet50'
    image_embedding = 2048
    text_encoder_model = "csebuetnlp/banglabert"
    text_embedding = 768
    text_tokenizer = "csebuetnlp/banglabert"
    max_length = 128  # Reduced max length
    pretrained = True
    trainable = True
    temperature = 0.07  # Adjusted temperature
    size = 224
    projection_dim = 256
    dropout = 0.1

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # Model parameters
    model_name = 'resnet50'
    image_embedding = 2048
    text_encoder_model = "csebuetnlp/banglabert"
    text_embedding = 768
    text_tokenizer = "csebuetnlp/banglabert"
    max_length = 200
    pretrained = True
    trainable = True
    temperature = 1.0
    size = 224
    num_projection_layers = 1
    projection_dim = 256
    dropout = 0.1

In [23]:
def load_captions():
    """Load captions with improved validation"""
    captions_list = []

    # Load Flickr8k dataset
    try:
        with open(os.path.join(CFG.captions_path1, 'BAN-Cap_captiondata.json'), 'r', encoding='utf-8') as f:
            captions_data1 = json.load(f)

        for entry in captions_data1:
            if isinstance(entry, dict) and 'caption_id' in entry and 'bengali_caption' in entry:
                filename = str(entry['caption_id']).split('#')[0]
                caption = str(entry['bengali_caption'])
                if caption and filename:
                    captions_list.append({
                        "image": filename.strip(),
                        "caption": caption.strip()
                    })
    except Exception as e:
        print(f"Error loading Flickr8k dataset: {str(e)}")

    # Load BNATURE dataset
    try:
        with open(CFG.captions_path2, 'r', encoding='utf-8') as f:
            captions_data2 = json.load(f)

        for entry in captions_data2:
            if isinstance(entry, dict) and 'caption_id' in entry and 'bengali_caption' in entry:
                filename = str(entry['caption_id'])
                caption = str(entry['bengali_caption'])
                if caption and filename:
                    captions_list.append({
                        "image": filename.strip(),
                        "caption": caption.strip()
                    })
    except Exception as e:
        print(f"Error loading BNATURE dataset: {str(e)}")

    # Load Bangla Lekha dataset with improved handling
    try:
        with open(CFG.captions_path3, 'r', encoding='utf-8') as f:
            captions_data3 = json.load(f)

        if isinstance(captions_data3, list):
            for entry in captions_data3:
                if isinstance(entry, dict) and 'filename' in entry and 'caption' in entry:
                    filename = str(entry['filename'])
                    caption = str(entry['caption'])
                    if caption and filename:
                        captions_list.append({
                            "image": filename.strip(),
                            "caption": caption.strip()
                        })
    except Exception as e:
        print(f"Error loading Bangla Lekha dataset: {str(e)}")

    df = pd.DataFrame(captions_list)
    df = df.dropna()
    df = df.drop_duplicates()
    df['id'] = df.index // 5

    print(f"Loaded {len(df)} valid caption entries")
    return df

In [24]:
class CLIPDataset(torch.utils.data.Dataset):
    def __init__(self, image_filenames, captions, tokenizer, transforms):
        self.image_filenames = image_filenames
        self.captions = list(captions)
        self.encoded_captions = tokenizer(
            list(captions),
            padding='max_length',
            truncation=True,
            max_length=CFG.max_length,
            return_tensors='pt'
        )
        self.transforms = transforms

        self.valid_indices = []
        for idx in range(len(self.image_filenames)):
            try:
                image_found = False
                for path in [CFG.image_path1, CFG.image_path2, CFG.image_path3]:
                    if os.path.exists(os.path.join(path, self.image_filenames[idx])):
                        image_found = True
                        break

                if image_found:
                    self.valid_indices.append(idx)
            except Exception as e:
                continue

        print(f"Found {len(self.valid_indices)} valid images out of {len(image_filenames)}")

    def __getitem__(self, idx):
        try:
            actual_idx = self.valid_indices[idx]

            item = {
                'input_ids': self.encoded_captions['input_ids'][actual_idx],
                'attention_mask': self.encoded_captions['attention_mask'][actual_idx],
            }

            image_path = None
            for path in [CFG.image_path1, CFG.image_path2, CFG.image_path3]:
                if os.path.exists(os.path.join(path, self.image_filenames[actual_idx])):
                    image_path = path
                    break

            if image_path is None:
                raise FileNotFoundError(f"Image {self.image_filenames[actual_idx]} not found in any path")

            image = cv2.imread(os.path.join(image_path, self.image_filenames[actual_idx]))
            if image is None:
                raise ValueError(f"Failed to load image: {self.image_filenames[actual_idx]}")

            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            image = self.transforms(image=image)['image']
            item['image'] = torch.tensor(image).permute(2, 0, 1).float()
            item['caption'] = self.captions[actual_idx]

            return item
        except Exception as e:
            print(f"Error processing item {idx}: {str(e)}")
            raise e

    def __len__(self):
        return len(self.valid_indices)

In [25]:


def build_loaders(dataframe, tokenizer, mode):
    """
    Build data loaders with error handling
    """
    transforms = get_transforms(mode=mode)

    try:
        dataset = CLIPDataset(
            image_filenames=dataframe["image"].values,
            captions=dataframe["caption"].values,
            tokenizer=tokenizer,
            transforms=transforms
        )

        # Custom collate function to handle potential None values
        def collate_fn(batch):
            # Filter out None values
            batch = [item for item in batch if item is not None]
            if len(batch) == 0:
                raise RuntimeError("Empty batch after filtering")

            return {
                'image': torch.stack([item['image'] for item in batch]),
                'input_ids': torch.stack([item['input_ids'] for item in batch]),
                'attention_mask': torch.stack([item['attention_mask'] for item in batch]),
                'caption': [item['caption'] for item in batch]
            }

        dataloader = torch.utils.data.DataLoader(
            dataset,
            batch_size=CFG.batch_size,
            num_workers=CFG.num_workers,
            shuffle=True if mode == "train" else False,
            collate_fn=collate_fn,
            drop_last=True  # Drop incomplete batches
        )

        return dataloader

    except Exception as e:
        print(f"Error building dataloader: {str(e)}")
        raise e

In [26]:

class ImageEncoder(nn.Module):
    def __init__(
        self, model_name=CFG.model_name, pretrained=CFG.pretrained, trainable=CFG.trainable
    ):
        super().__init__()
        self.model = timm.create_model(
            model_name, pretrained, num_classes=0, global_pool="avg"
        )
        for p in self.model.parameters():
            p.requires_grad = trainable

    def forward(self, x):
        return self.model(x)

from transformers import AutoTokenizer, AutoModel, AutoConfig

class TextEncoder(nn.Module):
    def __init__(self, model_name=CFG.text_encoder_model, pretrained=CFG.pretrained, trainable=CFG.trainable):
        super().__init__()
        if pretrained:
            self.model = AutoModel.from_pretrained(model_name)
        else:
            self.model = AutoModel(config=AutoConfig.from_pretrained(model_name))

        for p in self.model.parameters():
            p.requires_grad = trainable
        self.target_token_idx = 0

    def forward(self, input_ids, attention_mask):
        output = self.model(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden_state = output.last_hidden_state
        return last_hidden_state[:, self.target_token_idx, :]

class ProjectionHead(nn.Module):
    def __init__(
        self,
        embedding_dim,
        projection_dim=CFG.projection_dim,
        dropout=CFG.dropout
    ):
        super().__init__()
        self.projection = nn.Linear(embedding_dim, projection_dim)
        self.gelu = nn.GELU()
        self.fc = nn.Linear(projection_dim, projection_dim)
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(projection_dim)

    def forward(self, x):
        projected = self.projection(x)
        x = self.gelu(projected)
        x = self.fc(x)
        x = self.dropout(x)
        x = x + projected
        x = self.layer_norm(x)
        return x


In [27]:
class CLIPModel(nn.Module):
    def __init__(
        self,
        temperature=CFG.temperature,
        image_embedding=CFG.image_embedding,
        text_embedding=CFG.text_embedding,
    ):
        super().__init__()
        self.image_encoder = ImageEncoder()
        self.text_encoder = TextEncoder()
        self.image_projection = ProjectionHead(embedding_dim=image_embedding)
        self.text_projection = ProjectionHead(embedding_dim=text_embedding)
        self.temperature = temperature

    def forward(self, batch):
        image_features = self.image_encoder(batch["image"])
        text_features = self.text_encoder(
            input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]
        )
        image_embeddings = self.image_projection(image_features)
        text_embeddings = self.text_projection(text_features)

        logits = (text_embeddings @ image_embeddings.T) / self.temperature
        images_similarity = image_embeddings @ image_embeddings.T
        texts_similarity = text_embeddings @ text_embeddings.T
        targets = F.softmax(
            (images_similarity + texts_similarity) / 2 * self.temperature, dim=-1
        )
        texts_loss = F.cross_entropy(logits, targets, reduction='none')
        images_loss = F.cross_entropy(logits.T, targets.T, reduction='none')
        loss = (images_loss + texts_loss) / 2.0
        return loss.mean()


In [28]:

def cross_entropy(preds, targets, reduction='none'):
    log_softmax = nn.LogSoftmax(dim=-1)
    loss = (-targets * log_softmax(preds)).sum(1)
    if reduction == "none":
        return loss
    elif reduction == "mean":
        return loss.mean()

In [29]:

def get_transforms(mode="train"):
    if mode == "train":
        return A.Compose([
            A.Resize(CFG.size, CFG.size, always_apply=True),
            A.HorizontalFlip(p=0.5),
            A.RandomBrightnessContrast(p=0.5),
            A.Normalize(max_pixel_value=255.0, always_apply=True),
        ])
    else:
        return A.Compose([
            A.Resize(CFG.size, CFG.size, always_apply=True),
            A.Normalize(max_pixel_value=255.0, always_apply=True),
        ])


In [30]:

def make_train_valid_dfs():
    dataframe = load_captions()
    max_id = dataframe["id"].max() + 1 if not CFG.debug else 100
    image_ids = np.arange(0, max_id)
    np.random.seed(42)
    valid_ids = np.random.choice(
        image_ids, size=int(0.2 * len(image_ids)), replace=False
    )
    train_ids = [id_ for id_ in image_ids if id_ not in valid_ids]
    train_dataframe = dataframe[dataframe["id"].isin(train_ids)].reset_index(drop=True)
    valid_dataframe = dataframe[dataframe["id"].isin(valid_ids)].reset_index(drop=True)
    return train_dataframe, valid_dataframe

In [31]:

def build_loaders(dataframe, tokenizer, mode):
    transforms = get_transforms(mode=mode)
    dataset = CLIPDataset(
        dataframe["image"].values,
        dataframe["caption"].values,
        tokenizer=tokenizer,
        transforms=transforms,
    )
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=CFG.batch_size,
        num_workers=CFG.num_workers,
        shuffle=True if mode == "train" else False,
    )
    return dataloader



In [32]:
class AvgMeter:
    """
    Computes and stores the average and current value
    """

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [33]:
import torch
import itertools
from tqdm import tqdm
from transformers import AutoTokenizer
import matplotlib.pyplot as plt
import numpy as np

In [34]:
import torch
import itertools
from tqdm import tqdm
from transformers import AutoTokenizer, get_linear_schedule_with_warmup
import matplotlib.pyplot as plt
import numpy as np
import gc

# Memory optimizations
torch.cuda.empty_cache()
torch.backends.cudnn.benchmark = True
if torch.cuda.is_available():
    torch.backends.cuda.enable_mem_efficient_sdp(True)

# Configure PyTorch memory allocator
import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128,garbage_collection_threshold:0.8,expandable_segments:True'


In [35]:
import torch
import itertools
from tqdm import tqdm
from transformers import AutoTokenizer, get_linear_schedule_with_warmup
import matplotlib.pyplot as plt
import numpy as np

# Set memory and speed optimizations
torch.cuda.empty_cache()
torch.backends.cudnn.benchmark = True
if torch.cuda.is_available():
    torch.backends.cuda.enable_mem_efficient_sdp(True)

# Configure PyTorch memory allocator
import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512,expandable_segments:True'

def compute_cosine_similarity(embeddings_a, embeddings_b, temperature=0.07):
    # Process in chunks to save memory
    chunk_size = 256
    num_chunks = (embeddings_a.size(0) + chunk_size - 1) // chunk_size
    similarity_chunks = []

    for i in range(num_chunks):
        start_idx = i * chunk_size
        end_idx = min((i + 1) * chunk_size, embeddings_a.size(0))
        chunk_a = embeddings_a[start_idx:end_idx]

        # Normalize chunks
        chunk_a = torch.nn.functional.normalize(chunk_a, p=2, dim=-1)
        chunk_b = torch.nn.functional.normalize(embeddings_b, p=2, dim=-1)

        # Compute similarity for chunk
        chunk_sim = torch.mm(chunk_a, chunk_b.t()) / temperature
        similarity_chunks.append(chunk_sim)

    return torch.cat(similarity_chunks, dim=0)

def compute_recall_at_k(similarities, k):
    argsort = torch.argsort(similarities, dim=-1, descending=True)
    diagonal = torch.arange(similarities.size(0), device=similarities.device)
    topk_indices = argsort[:, :k]
    recall_at_k = (topk_indices == diagonal.view(-1, 1)).any(dim=-1).float().mean()
    return recall_at_k.item()

In [None]:
def train(model, train_loader, optimizer, lr_scheduler, scaler, epoch):
    model.train()
    loss_meter = AvgMeter()
    similarity_meter = AvgMeter()

    # Dynamic batch processing based on epoch
    max_steps = len(train_loader) if epoch > 0 else len(train_loader) // 2

    tqdm_object = tqdm(train_loader, total=max_steps)

    for batch_idx, batch in enumerate(tqdm_object):
        if batch_idx >= max_steps:
            break

        batch = {k: v.to(CFG.device, non_blocking=True) for k, v in batch.items() if k != "caption"}

        # Gradient accumulation for larger effective batch size
        accumulation_steps = 4 if epoch == 0 else 2

        with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
            loss = model(batch)
            # Calculate similarities for monitoring
            with torch.no_grad():
                image_features = model.image_encoder(batch["image"])
                text_features = model.text_encoder(
                    input_ids=batch["input_ids"],
                    attention_mask=batch["attention_mask"]
                )
                image_embeddings = model.image_projection(image_features)
                text_embeddings = model.text_projection(text_features)

                similarity = F.cosine_similarity(
                    image_embeddings.unsqueeze(1),
                    text_embeddings.unsqueeze(0),
                    dim=-1
                ).mean()

            loss = loss / accumulation_steps

        scaler.scale(loss).backward()

        if (batch_idx + 1) % accumulation_steps == 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad(set_to_none=True)
            lr_scheduler.step()

        loss_meter.update(loss.item() * accumulation_steps, batch["image"].size(0))
        similarity_meter.update(similarity.item(), batch["image"].size(0))

        if batch_idx % 50 == 0:
            torch.cuda.empty_cache()

        tqdm_object.set_postfix(
            loss=f"{loss_meter.avg:.4f}",
            similarity=f"{similarity_meter.avg:.4f}",
            lr=f"{get_lr(optimizer):.6f}"
        )

    return loss_meter.avg, similarity_meter.avg

def validate(model, valid_loader):
    model.eval()
    loss_meter = AvgMeter()
    similarity_meter = AvgMeter()

    max_val_steps = len(valid_loader) // 2  # Evaluate on half the validation set

    image_embeddings_list = []
    text_embeddings_list = []

    with torch.no_grad():
        for batch_idx, batch in enumerate(tqdm(valid_loader)):
            if batch_idx >= max_val_steps:
                break

            batch = {k: v.to(CFG.device, non_blocking=True) for k, v in batch.items() if k != "caption"}

            with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
                loss = model(batch)

                image_features = model.image_encoder(batch["image"])
                text_features = model.text_encoder(
                    input_ids=batch["input_ids"],
                    attention_mask=batch["attention_mask"]
                )
                image_embeddings = model.image_projection(image_features)
                text_embeddings = model.text_projection(text_features)

                similarity = F.cosine_similarity(
                    image_embeddings.unsqueeze(1),
                    text_embeddings.unsqueeze(0),
                    dim=-1
                ).mean()

                image_embeddings_list.append(image_embeddings.cpu())
                text_embeddings_list.append(text_embeddings.cpu())

            loss_meter.update(loss.item(), batch["image"].size(0))
            similarity_meter.update(similarity.item(), batch["image"].size(0))

            if batch_idx % 20 == 0:
                torch.cuda.empty_cache()

    image_embeddings = torch.cat(image_embeddings_list)
    text_embeddings = torch.cat(text_embeddings_list)

    i2t_similarities = compute_cosine_similarity(image_embeddings, text_embeddings)
    t2i_similarities = i2t_similarities.t()

    metrics = {
        'val_similarity': similarity_meter.avg,
        'val_loss': loss_meter.avg
    }

    for k in [1, 5, 10]:
        metrics[f'image_to_text_recall@{k}'] = compute_recall_at_k(i2t_similarities, k)
        metrics[f'text_to_image_recall@{k}'] = compute_recall_at_k(t2i_similarities, k)

    return metrics

def plot_training_progress(metrics_history, epoch):
    plt.figure(figsize=(20, 10))

    # Plot 1: Losses
    plt.subplot(2, 2, 1)
    plt.plot(metrics_history['train_loss'], label='Train Loss', marker='o')
    plt.plot(metrics_history['val_loss'], label='Val Loss', marker='x')
    plt.title('Training and Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)

    # Plot 2: Recalls
    plt.subplot(2, 2, 2)
    for k in [1, 5, 10]:
        plt.plot(
            [m[f'image_to_text_recall@{k}'] for m in metrics_history['val_metrics']],
            label=f'R@{k}',
            marker='o'
        )
    plt.title('Image-to-Text Recall')
    plt.xlabel('Epoch')
    plt.ylabel('Recall')
    plt.legend()
    plt.grid(True)

    # Plot 3: Similarities
    plt.subplot(2, 2, 3)
    plt.plot(metrics_history['train_similarity'], label='Train Similarity', marker='o')
    plt.plot([m['val_similarity'] for m in metrics_history['val_metrics']],
             label='Val Similarity', marker='x')
    plt.title('Cosine Similarities')
    plt.xlabel('Epoch')
    plt.ylabel('Similarity')
    plt.legend()
    plt.grid(True)

    # Plot 4: Learning Rate
    plt.subplot(2, 2, 4)
    plt.plot(metrics_history['learning_rates'], label='Learning Rate', marker='o')
    plt.title('Learning Rate Progress')
    plt.xlabel('Epoch')
    plt.ylabel('Learning Rate')
    plt.yscale('log')
    plt.grid(True)

    plt.tight_layout()
    plt.savefig(f'training_progress_epoch_{epoch+1}.png')
    plt.close()

def main():
    print("Starting optimized training...")

    train_df, valid_df = make_train_valid_dfs()
    tokenizer = AutoTokenizer.from_pretrained(CFG.text_tokenizer)

    train_loader = build_loaders(train_df, tokenizer, mode="train")
    valid_loader = build_loaders(valid_df, tokenizer, mode="valid")

    model = CLIPModel(
        temperature=CFG.temperature,
        image_embedding=CFG.image_embedding,
        text_embedding=CFG.text_embedding
    ).to(CFG.device)

    # Optimizer setup with warmup and cosine schedule
    params = [
        {"params": model.image_encoder.parameters(), "lr": CFG.image_encoder_lr},
        {"params": model.text_encoder.parameters(), "lr": CFG.text_encoder_lr},
        {"params": itertools.chain(
            model.image_projection.parameters(),
            model.text_projection.parameters()
        ), "lr": CFG.head_lr}
    ]

    optimizer = torch.optim.AdamW(params, weight_decay=CFG.weight_decay)
    scaler = torch.amp.GradScaler()

    num_training_steps = len(train_loader) * CFG.epochs
    num_warmup_steps = int(num_training_steps * 0.1)  # 10% warmup

    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=num_warmup_steps,
        num_training_steps=num_training_steps
    )

    # Initialize tracking
    best_val_loss = float('inf')
    best_recall = 0
    early_stopping_counter = 0

    metrics_history = {
        'train_loss': [],
        'val_loss': [],
        'train_similarity': [],
        'val_metrics': [],
        'learning_rates': []
    }

    for epoch in range(CFG.epochs):
        print(f"\nEpoch {epoch + 1}")

        # Training
        train_loss, train_similarity = train(
            model, train_loader, optimizer, scheduler, scaler, epoch
        )

        # Validation
        val_metrics = validate(model, valid_loader)

        # Update metrics history
        metrics_history['train_loss'].append(train_loss)
        metrics_history['val_loss'].append(val_metrics['val_loss'])
        metrics_history['train_similarity'].append(train_similarity)
        metrics_history['val_metrics'].append(val_metrics)
        metrics_history['learning_rates'].append(get_lr(optimizer))

        # Plot progress
        plot_training_progress(metrics_history, epoch)

        # Print metrics
        print(f"Train Loss: {train_loss:.4f}")
        print(f"Val Loss: {val_metrics['val_loss']:.4f}")
        print(f"Train Similarity: {train_similarity:.4f}")
        print(f"Val Similarity: {val_metrics['val_similarity']:.4f}")
        for k in [1, 5, 10]:
            print(f"R@{k}: {val_metrics[f'image_to_text_recall@{k}']:.4f}")

        # Save best model
        current_recall = val_metrics['image_to_text_recall@5']
        if val_metrics['val_loss'] < best_val_loss or current_recall > best_recall:
            best_val_loss = min(val_metrics['val_loss'], best_val_loss)
            best_recall = max(current_recall, best_recall)

            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'best_val_loss': best_val_loss,
                'best_recall': best_recall,
                'metrics_history': metrics_history
            }, f'best_model_epoch_{epoch+1}.pth')

            print(f"Saved best model - Val Loss: {best_val_loss:.4f}, R@5: {best_recall:.4f}")
            early_stopping_counter = 0
        else:
            early_stopping_counter += 1
            if early_stopping_counter >= CFG.patience:
                print("Early stopping triggered")
                break

        # Memory cleanup
        torch.cuda.empty_cache()
        gc.collect()

if __name__ == "__main__":
    main()

Starting optimized training...
Loaded 88641 valid caption entries


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/119 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/586 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/528k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

Found 70919 valid images out of 70919




Found 17722 valid images out of 17722


model.safetensors:   0%|          | 0.00/102M [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/443M [00:00<?, ?B/s]


Epoch 1


100%|██████████| 554/554 [35:15<00:00,  3.82s/it, loss=10.1857, lr=0.000050, similarity=0.1569]
 50%|████▉     | 138/277 [04:05<04:07,  1.78s/it]


Train Loss: 10.1857
Val Loss: 3.8020
Train Similarity: 0.1569
Val Similarity: 0.2278
R@1: 0.0016
R@5: 0.0066
R@10: 0.0112
Saved best model - Val Loss: 3.8020, R@5: 0.0066

Epoch 2


 98%|█████████▊| 1084/1109 [25:00<00:35,  1.42s/it, loss=2.4700, lr=0.000195, similarity=0.2785]

In [None]:
import torch

# Define the Google Drive path for saving the model
save_path = "/content/drive/MyDrive/Bangla Image dataset with caption/banglaclipcombinedfinalnowok.pt"

# Define your model class (ensure this matches the architecture used during training)
model = CLIPModel().to(CFG.device)  # Replace CLIPModel with your actual model class

# Load the trained model from the saved state dictionary
checkpoint_path = "/content/best_model.pth"
model.load_state_dict(torch.load(checkpoint_path))

# Save the model's state dictionary to the specified Google Drive path
torch.save(model.state_dict(), save_path)

print(f"Model saved successfully at: {save_path}")


#Interface

In [None]:
def get_image_embeddings(dataframe, model_path):
    """
    Loads a model and generates image embeddings for the provided dataset.

    Parameters:
    - dataframe: DataFrame containing image file names.
    - model_path: Path to the saved model file.

    Returns:
    - model: Loaded CLIPModel instance.
    - image_embeddings: Tensor of image embeddings for the dataset.
    """
    # Load tokenizer and model
    tokenizer = AutoTokenizer.from_pretrained(CFG.text_tokenizer)
    model = CLIPModel().to(CFG.device)
    model.load_state_dict(torch.load(model_path, map_location=CFG.device))
    model.eval()

    # Prepare data loader
    transforms = get_transforms(mode="valid")
    dataset = CLIPDataset(
        dataframe["image"].values,
        dataframe["caption"].values,
        tokenizer=tokenizer,
        transforms=transforms,
    )
    dataloader = torch.utils.data.DataLoader(
        dataset, batch_size=CFG.batch_size, num_workers=CFG.num_workers, shuffle=False
    )

    # Generate image embeddings
    image_embeddings_list = []
    with torch.no_grad():
        for batch in tqdm(dataloader, colour="yellow", desc="Generating image embeddings"):
            images = batch["image"].to(CFG.device)
            image_features = model.image_encoder(images)
            image_embeddings = model.image_projection(image_features)
            image_embeddings_list.append(image_embeddings)

    # Concatenate all embeddings
    image_embeddings = torch.cat(image_embeddings_list, dim=0)
    return model, image_embeddings


# Interface Function to Use find_matches
if __name__ == "__main__":
    # Ensure the validation DataFrame and model are ready
    _, valid_df = make_train_valid_dfs()

    # Generate the model and image embeddings
    model, image_embeddings = get_image_embeddings(valid_df, "best_model.pth")

    # Define the Bangla text query and display matches
    query = "কক্সবাজারের সমুদ্র সৈকত"  # Bangla query for "cycle"
    find_matches(
        model=model,
        image_embeddings=image_embeddings,
        query=query,
        image_filenames=valid_df["image"].values,
        n=9,
    )
