In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pip install torchtext sentence-transformers transformers

Collecting torchtext
  Downloading torchtext-0.18.0-cp310-cp310-manylinux1_x86_64.whl.metadata (7.9 kB)
Collecting sentence-transformers
  Downloading sentence_transformers-3.2.1-py3-none-any.whl.metadata (10 kB)
Downloading torchtext-0.18.0-cp310-cp310-manylinux1_x86_64.whl (2.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m46.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading sentence_transformers-3.2.1-py3-none-any.whl (255 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m255.8/255.8 kB[0m [31m23.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torchtext, sentence-transformers
Successfully installed sentence-transformers-3.2.1 torchtext-0.18.0


In [None]:
!pip install --upgrade tensorflow

Collecting tensorflow
  Downloading tensorflow-2.18.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.1 kB)
Collecting tensorboard<2.19,>=2.18 (from tensorflow)
  Downloading tensorboard-2.18.0-py3-none-any.whl.metadata (1.6 kB)
Collecting keras>=3.5.0 (from tensorflow)
  Downloading keras-3.6.0-py3-none-any.whl.metadata (5.8 kB)
Downloading tensorflow-2.18.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (615.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m615.3/615.3 MB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading keras-3.6.0-py3-none-any.whl (1.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.2/1.2 MB[0m [31m39.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading tensorboard-2.18.0-py3-none-any.whl (5.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.5/5.5 MB[0m [31m43.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: tensorboard, keras, tensorflow
  At

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import ViTFeatureExtractor, ViTModel
from sentence_transformers import SentenceTransformer
from PIL import Image
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
import os
import nltk
from nltk.translate.bleu_score import sentence_bleu
import os
import nltk
from nltk.translate.bleu_score import sentence_bleu

import pickle
from tqdm import tqdm
import h5py

In [None]:
EMBED_SIZE = 768  # ViT has 12 layers
HIDDEN_SIZE = 512
BATCH_SIZE = 32
EPOCHS = 10
LEARNING_RATE = 0.001

    # Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")


Using device: cuda


In [None]:
def extract_vit_features(image, feature_extractor, vit_model):
    """Extract features from ViT"""
    with torch.no_grad():
        inputs = feature_extractor(images=image, return_tensors="pt")
        outputs = vit_model(**inputs)
        features = outputs.last_hidden_state[:, 0, :]
        print(f"ViT feature shape: {features.shape}")
        return features

In [None]:
class DataPreprocessor:
    def __init__(self, image_dir, captions_file, feature_extractor, max_len=50, cache_dir='cached_data'):
        self.image_dir = image_dir
        self.captions_file = captions_file  # Added this line
        self.max_len = max_len
        self.cache_dir = cache_dir
        self.feature_extractor = feature_extractor
        self.vit_model = None

        # Create cache directory
        os.makedirs(cache_dir, exist_ok=True)

        # Cache file paths
        self.vocab_cache = os.path.join(cache_dir, 'vocabulary.pkl')
        self.train_cache = os.path.join(cache_dir, 'train_data.pkl')
        self.test_cache = os.path.join(cache_dir, 'test_data.pkl')

        # Initialize ViT model (do it once)
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        print(f"Using device: {self.device}")

        print("Loading ViT model...")
        self.vit_model = ViTModel.from_pretrained('google/vit-base-patch16-224').to(self.device)
        self.vit_model.eval()

    def extract_features(self, image_path):
        try:
            image = Image.open(image_path).convert('RGB')

            with torch.no_grad():
                inputs = self.feature_extractor(images=image, return_tensors="pt")
                inputs = {k: v.to(self.device) for k, v in inputs.items()}
                outputs = self.vit_model(**inputs)
                features = outputs.last_hidden_state[:, 0, :].cpu()
                return features.squeeze(0)
        except Exception as e:
            print(f"Error processing {image_path}: {str(e)}")
            return torch.zeros(768)

    def process_data(self):
        if (os.path.exists(self.train_cache) and
            os.path.exists(self.test_cache) and
            os.path.exists(self.vocab_cache)):
            print("Loading cached data...")
            return self.load_cached_data()

        print("Processing data from scratch...")
        return self.create_and_cache_data()

    def create_and_cache_data(self):
        print("Reading captions file...")
        img_captions = {}
        all_captions = []

        # Read and process captions
        with open(self.captions_file, 'r', encoding='utf-8') as f:
            lines = f.readlines()[1:]  # Skip header
            for line in tqdm(lines, desc="Reading captions"):
                parts = line.strip().split(',', 1)
                if len(parts) == 2:
                    img_name = parts[0].strip()
                    caption = parts[1].strip().strip('"\'')

                    if caption:
                        if img_name not in img_captions:
                            img_captions[img_name] = []
                        img_captions[img_name].append(caption)
                        all_captions.append(caption)

        # Build vocabulary
        word2idx, idx2word = self.build_vocabulary(all_captions)

        # Process all images and captions
        print("\nProcessing images and creating batches...")
        features_list = []
        captions_list = []

        for img_name in tqdm(img_captions.keys(), desc="Processing images"):
            image_path = os.path.join(self.image_dir, img_name)
            if os.path.exists(image_path):
                features = self.extract_features(image_path)

                for caption in img_captions[img_name]:
                    features_list.append(features)

                    # Process caption
                    words = caption.lower().split()
                    caption_indices = [word2idx.get(word, word2idx['<UNK>']) for word in words]
                    caption_indices = [word2idx['<START>']] + caption_indices + [word2idx['<END>']]

                    # Pad sequence
                    if len(caption_indices) < self.max_len:
                        caption_indices += [word2idx['<PAD>']] * (self.max_len - len(caption_indices))
                    else:
                        caption_indices = caption_indices[:self.max_len]

                    captions_list.append(caption_indices)
            else:
                print(f"Warning: Image not found: {image_path}")

        # Convert to tensors
        print("\nConverting to tensors...")
        features_tensor = torch.stack(features_list)
        captions_tensor = torch.tensor(captions_list)

        # Split into train and test
        print("Splitting into train and test sets...")
        indices = torch.randperm(len(features_tensor))
        train_size = int(0.8 * len(indices))

        train_indices = indices[:train_size]
        test_indices = indices[train_size:]

        train_data = (features_tensor[train_indices], captions_tensor[train_indices])
        test_data = (features_tensor[test_indices], captions_tensor[test_indices])

        # Cache the processed data
        print("Caching processed data...")
        with open(self.train_cache, 'wb') as f:
            pickle.dump(train_data, f)
        with open(self.test_cache, 'wb') as f:
            pickle.dump(test_data, f)

        print(f"\nProcessing completed!")
        print(f"Train set size: {len(train_indices)}")
        print(f"Test set size: {len(test_indices)}")

        return (word2idx, idx2word), train_data, test_data
    def build_vocabulary(self, captions):
        print("Building vocabulary...")
        word_freq = {}

        for caption in captions:
            words = caption.lower().split()
            for word in words:
                word_freq[word] = word_freq.get(word, 0) + 1

        word2idx = {'<PAD>': 0, '<START>': 1, '<END>': 2, '<UNK>': 3}
        for word, freq in sorted(word_freq.items(), key=lambda x: x[1], reverse=True):
            if len(word2idx) < 10000:
                word2idx[word] = len(word2idx)

        idx2word = {v: k for k, v in word2idx.items()}

        # Cache vocabulary
        with open(self.vocab_cache, 'wb') as f:
            pickle.dump((word2idx, idx2word), f)

        print(f"Vocabulary size: {len(word2idx)}")
        return word2idx, idx2word
    def load_cached_data(self):
        print("Loading vocabulary...")
        with open(self.vocab_cache, 'rb') as f:
            vocab = pickle.load(f)

        print("Loading train data...")
        with open(self.train_cache, 'rb') as f:
            train_data = pickle.load(f)

        print("Loading test data...")
        with open(self.test_cache, 'rb') as f:
            test_data = pickle.load(f)

        print(f"Train set size: {len(train_data[0])}")
        print(f"Test set size: {len(test_data[0])}")

        return vocab, train_data, test_data

In [None]:

from torch.utils.data import Dataset, DataLoader, TensorDataset

    # Initialize ViT feature extractor
print("Initializing ViT feature extractor...")
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')

    # Create dataset
preprocessor = DataPreprocessor(
        image_dir="/content/drive/MyDrive/Tech India/Preprocessed-Dataset/Rams-approach-preprocess/flickr30k/Images",
        captions_file="/content/drive/MyDrive/Tech India/Preprocessed-Dataset/Rams-approach-preprocess/flickr30k/captions.txt",
        feature_extractor=feature_extractor
    )
(word2idx, idx2word), (train_features, train_captions), (test_features, test_captions) = preprocessor.process_data()
# print(f"Dataset size: {len(dataset)}")
    # Split dataset
train_dataset = TensorDataset(train_features, train_captions)
test_dataset = TensorDataset(test_features, test_captions)
    # Create data loaders
train_loader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=16,
        pin_memory=True
    )

test_loader = DataLoader(
        test_dataset,
        batch_size=BATCH_SIZE,
        num_workers=16,
        pin_memory=True
    )

Initializing ViT feature extractor...
Using device: cuda
Loading ViT model...


Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Processing data from scratch...
Reading captions file...


Reading captions: 100%|██████████| 158915/158915 [00:00<00:00, 751378.19it/s]


Building vocabulary...
Vocabulary size: 10000

Processing images and creating batches...


Processing images:   3%|▎         | 918/31783 [03:53<2:11:00,  3.93it/s]


KeyboardInterrupt: 

In [None]:
(word2idx, idx2word), (train_features, train_captions), (test_features, test_captions) = preprocessor.process_data()
# print(f"Dataset size: {len(dataset)}")


Processing data from scratch...
Reading captions file...


Reading captions: 100%|██████████| 158915/158915 [00:00<00:00, 844690.69it/s]


Building vocabulary...
Vocabulary size: 10000

Processing images and creating batches...


Processing images:   4%|▍         | 1339/31783 [04:29<1:42:10,  4.97it/s]


KeyboardInterrupt: 

In [None]:
import pickle

def load_dataset(cache_dir='cached_data', batch_size=32):
    """
    Load preprocessed data from pickle files and create DataLoaders
    Returns vocabulary and data loaders for train and test sets
    """
    vocab_path = os.path.join(cache_dir, 'vocabulary.pkl')
    train_path = os.path.join(cache_dir, 'train_data.pkl')
    test_path = os.path.join(cache_dir, 'test_data.pkl')

    # Check if pickle files exist
    if not all(os.path.exists(p) for p in [vocab_path, train_path, test_path]):
        raise FileNotFoundError("Required pickle files not found. Run preprocessing first.")

    # Load vocabulary
    print("Loading vocabulary...")
    with open(vocab_path, 'rb') as f:
        word2idx, idx2word = pickle.load(f)

    # Load train features and captions
    print("Loading train data...")
    with open(train_path, 'rb') as f:
        train_features, train_captions = pickle.load(f)

    # Load test features and captions
    print("Loading test data...")
    with open(test_path, 'rb') as f:
        test_features, test_captions = pickle.load(f)

    # Create datasets
    train_dataset = TensorDataset(train_features, train_captions)
    test_dataset = TensorDataset(test_features, test_captions)

    # Create data loaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=16,
        pin_memory=True
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        num_workers=16,
        pin_memory=True
    )

    print(f"Vocabulary size: {len(word2idx)}")
    print(f"Train set size: {len(train_features)}")
    print(f"Test set size: {len(test_features)}")

    # Return both raw data and loaders
    raw_data = {
        'vocab': (word2idx, idx2word),
        'train_data': (train_features, train_captions),
        'test_data': (test_features, test_captions)
    }

    loaders = {
        'train': train_loader,
        'test': test_loader
    }

    return raw_data, loaders
# Load both raw data and DataLoaders
raw_data, loaders = load_dataset(cache_dir='/content/drive/MyDrive/Tech India/Preprocessed-Dataset/Rams-approach-preprocess/flickr30k/cached_data', batch_size=32)

# Access vocabulary and raw data if needed
word2idx, idx2word = raw_data['vocab']
train_features, train_captions = raw_data['train_data']
test_features, test_captions = raw_data['test_data']

# Access DataLoaders
train_loader = loaders['train']
test_loader = loaders['test']

Loading vocabulary...
Loading train data...
Loading test data...
Vocabulary size: 10000
Train set size: 127131
Test set size: 31783




# Model

In [None]:
import torch.nn.functional as F
class AttentionLayer(nn.Module):
    def __init__(self, hidden_size):
        super(AttentionLayer, self).__init__()
        self.attention = nn.Linear(hidden_size * 2, hidden_size)
        self.v = nn.Parameter(torch.rand(hidden_size))
        self.v.data.normal_(mean=0, std=0.1)

    def forward(self, hidden, encoder_outputs):
        """
        hidden: (batch_size, 1, hidden_size)
        encoder_outputs: (batch_size, seq_len, hidden_size)
        """
        batch_size, seq_len, hidden_size = encoder_outputs.size()

        # Ensure hidden has correct shape
        if hidden.dim() == 2:
            hidden = hidden.unsqueeze(1)

        # Repeat hidden state for each encoder output
        hidden = hidden.repeat(1, seq_len, 1)

        # Calculate attention scores
        energy = torch.tanh(self.attention(torch.cat((hidden, encoder_outputs), dim=2)))

        # Reshape v for batch processing
        v = self.v.repeat(batch_size, 1).unsqueeze(1)

        # Calculate attention weights
        attention_weights = torch.bmm(v, energy.transpose(1, 2)).squeeze(1)
        attention_weights = F.softmax(attention_weights, dim=1)

        # Apply attention to encoder outputs
        context = torch.bmm(attention_weights.unsqueeze(1), encoder_outputs)

        return context, attention_weights


In [None]:
class ImageCaptioningModel(nn.Module):
    def __init__(self, input_size, hidden_size, vocab_size, embed_size=256, num_layers=2, dropout_p=0.3):
        super(ImageCaptioningModel, self).__init__()

        self.hidden_size = hidden_size
        self.vocab_size = vocab_size
        self.num_layers = num_layers
        self.embed_size = embed_size

        # Image feature processing
        self.feature_encoder = nn.Sequential(
            nn.Linear(input_size, hidden_size * 2),
            nn.ReLU(),
            nn.Dropout(dropout_p),
            nn.Linear(hidden_size * 2, hidden_size),
            nn.ReLU(),
            nn.Dropout(dropout_p)
        )

        # Word embeddings
        self.embedding = nn.Embedding(vocab_size, embed_size, padding_idx=0)
        self.embed_dropout = nn.Dropout(dropout_p)

        # Additional embedding processing
        self.embed_process = nn.Linear(embed_size, hidden_size)

        # Attention
        self.attention = AttentionLayer(hidden_size)

        # Decoder LSTM
        self.decoder_rnn = nn.LSTM(
            input_size=hidden_size * 2,  # Concatenated context and processed embedding
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout_p if num_layers > 1 else 0
        )

        # Output projection
        self.output_layer = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Dropout(dropout_p),
            nn.Linear(hidden_size, vocab_size)
        )

        # Layer normalization
        self.layer_norm = nn.LayerNorm(hidden_size)

    def forward(self, images, captions, teacher_forcing_ratio=0.5):
        batch_size = images.size(0)
        max_length = captions.size(1) - 1  # -1 because we don't predict for last token
        device = images.device

        # Encode images
        image_features = self.feature_encoder(images)
        image_features = image_features.unsqueeze(1)  # (batch_size, 1, hidden_size)

        # Initialize outputs tensor
        outputs = torch.zeros(batch_size, max_length, self.vocab_size).to(device)

        # Initialize decoder input
        decoder_input = captions[:, 0]  # Start tokens

        # Initialize hidden states
        h = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(device)
        c = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(device)
        hidden = (h, c)

        for t in range(max_length):
            # Embed input tokens
            embedded = self.embedding(decoder_input)  # (batch_size, embed_size)
            embedded = self.embed_dropout(embedded)
            embedded = self.embed_process(embedded)  # (batch_size, hidden_size)

            # Add sequence dimension
            embedded = embedded.unsqueeze(1)  # (batch_size, 1, hidden_size)

            # Calculate attention
            context, _ = self.attention(embedded, image_features)

            # Combine embedding and context
            decoder_input_combined = torch.cat((embedded, context), dim=2)

            # RNN forward pass
            output, hidden = self.decoder_rnn(decoder_input_combined, hidden)

            # Process output
            output = self.layer_norm(output.squeeze(1))
            output = self.output_layer(output)

            # Store output
            outputs[:, t] = output

            # Teacher forcing
            if random.random() < teacher_forcing_ratio and t < max_length - 1:
                decoder_input = captions[:, t + 1]
            else:
                decoder_input = output.argmax(dim=1)

        return outputs

In [None]:

def evaluate_model(model, test_loader, criterion, device, pad_idx):
    model.eval()
    total_loss = 0
    total_word_accuracy = 0
    total_sentence_accuracy = 0
    num_batches = 0

    with torch.no_grad():
        progress_bar = tqdm(test_loader, desc="Evaluating")

        for images, captions in progress_bar:
            # Move to device
            images = images.to(device)
            captions = captions.to(device)

            # Get input and target sequences
            input_captions = captions[:, :-1]
            target_captions = captions[:, 1:]

            # Forward pass
            outputs = model(images, input_captions)
            outputs = outputs[:, :-1, :]

            # Reshape for loss calculation
            outputs_flat = outputs.reshape(-1, outputs.size(-1))
            targets_flat = target_captions.reshape(-1)

            # Calculate metrics
            loss = criterion(outputs_flat, targets_flat)
            word_acc, sent_acc = calculate_accuracy(outputs_flat, targets_flat, pad_idx)

            # Update metrics
            total_loss += loss.item()
            total_word_accuracy += word_acc
            total_sentence_accuracy += sent_acc
            num_batches += 1

            # Update progress bar
            progress_bar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'word_acc': f'{word_acc:.4f}',
                'sent_acc': f'{sent_acc:.4f}'
            })

    # Calculate averages
    avg_loss = total_loss / num_batches
    avg_word_acc = total_word_accuracy / num_batches
    avg_sent_acc = total_sentence_accuracy / num_batches

    return {
        'loss': avg_loss,
        'word_accuracy': avg_word_acc,
        'sentence_accuracy': avg_sent_acc
    }

In [None]:
def calculate_accuracy(outputs, targets, pad_idx):
    """
    Calculate word-level
    outputs: (batch_size * seq_len, vocab_size)
    targets: (batch_size * seq_len)
    """
    # Get predictions
    predictions = outputs.argmax(dim=1)  # (batch_size * seq_len)

    # Create mask to ignore padding tokens
    mask = (targets != pad_idx)

    # Word-level accuracy
    correct_words = ((predictions == targets) & mask).sum().item()
    total_words = mask.sum().item()
    word_accuracy = correct_words / total_words if total_words > 0 else 0

    # Reshape for sentence-level accuracy
    batch_size = len(targets) // targets.shape[0]
    predictions = predictions.view(-1, batch_size)
    targets = targets.view(-1, batch_size)
    mask = mask.view(-1, batch_size)

    return word_accuracy

In [None]:
from torch.nn.parallel import DataParallel
from tqdm import tqdm
import threading
from queue import Queue
import random
from torch.cuda.amp import autocast, GradScaler

In [None]:
import random
def train_model(model, train_loader, criterion, optimizer, device, epoch, total_epochs, teacher_forcing_ratio=0.5):
    model.train()
    total_loss = 0
    total_words = 0
    correct_words = 0

    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch}/{total_epochs}")

    for i, (images, captions) in enumerate(progress_bar):
        try:
            # Move to device
            images = images.to(device)
            captions = captions.to(device)

            # Forward pass
            outputs = model(images, captions)

            # Calculate loss
            loss = criterion(
                outputs.reshape(-1, outputs.size(-1)),
                captions[:, 1:].reshape(-1)
            )

            # Calculate accuracy
            predictions = outputs.argmax(dim=2)
            mask = captions[:, 1:] != 0  # Ignore padding
            correct = (predictions == captions[:, 1:]) & mask
            total_words += mask.sum().item()
            correct_words += correct.sum().item()

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            # Update metrics
            total_loss += loss.item()
            current_accuracy = correct_words / total_words if total_words > 0 else 0

            # Update progress bar
            progress_bar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'acc': f'{current_accuracy:.4f}'
            })

            # Print batch statistics
            if (i + 1) % 100 == 0:
                print(f"\nBatch {i+1}/{len(train_loader)}")
                print(f"Loss: {loss.item():.4f}")

        except Exception as e:
            print(f"\nError in batch {i}:")
            print(f"Exception: {str(e)}")
            continue

    avg_loss = total_loss / len(train_loader)
    avg_accuracy = correct_words / total_words if total_words > 0 else 0

    return avg_loss, avg_accuracy



In [None]:
import random
import torch
import threading
from queue import Queue
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor

def train_model(model, train_loader, criterion, optimizer, device, epoch, total_epochs, teacher_forcing_ratio=0.5, num_threads=2):
    model.train()
    total_loss = 0
    total_words = 0
    correct_words = 0

    # Create queues for batch processing and results
    batch_queue = Queue(maxsize=num_threads * 2)
    result_queue = Queue()

    # Lock for synchronizing updates
    update_lock = threading.Lock()

    def process_batch(batch_data):
        try:
            images, captions = batch_data
            batch_results = {}

            # Move to device
            images = images.to(device)
            captions = captions.to(device)

            # Forward pass
            outputs = model(images, captions)

            # Calculate loss
            loss = criterion(
                outputs.reshape(-1, outputs.size(-1)),
                captions[:, 1:].reshape(-1)
            )

            # Calculate accuracy
            predictions = outputs.argmax(dim=2)
            mask = captions[:, 1:] != 0  # Ignore padding
            correct = (predictions == captions[:, 1:]) & mask

            batch_results['loss'] = loss
            batch_results['correct'] = correct.sum().item()
            batch_results['total'] = mask.sum().item()

            return batch_results

        except Exception as e:
            print(f"\nError processing batch:")
            print(f"Exception: {str(e)}")
            return None

    def update_metrics(results):
        nonlocal total_loss, total_words, correct_words

        with update_lock:
            if results:
                total_loss += results['loss'].item()
                total_words += results['total']
                correct_words += results['correct']

                # Backward pass (needs to be done in main thread for thread safety)
                optimizer.zero_grad()
                results['loss'].backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                optimizer.step()

    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch}/{total_epochs}")

    # Create thread pool
    with ThreadPoolExecutor(max_workers=num_threads) as executor:
        futures = []

        for i, batch in enumerate(progress_bar):
            # Submit batch processing to thread pool
            future = executor.submit(process_batch, batch)
            futures.append(future)

            # Process completed batches
            while futures:
                done_futures = [f for f in futures if f.done()]
                for future in done_futures:
                    results = future.result()
                    update_metrics(results)
                    futures.remove(future)

                    # Update progress bar
                    current_accuracy = correct_words / total_words if total_words > 0 else 0
                    progress_bar.set_postfix({
                        'loss': f'{total_loss/(i+1):.4f}',
                        'acc': f'{current_accuracy:.4f}'
                    })

                    # Print batch statistics
                    if (i + 1) % 100 == 0:
                        print(f"\nBatch {i+1}/{len(train_loader)}")
                        print(f"Loss: {total_loss/(i+1):.4f}")
                        print(f"Accuracy: {current_accuracy:.4f}")
                        print(f"Total words: {total_words}")
                        print(f"Correct words: {correct_words}")

        # Process any remaining futures
        for future in futures:
            results = future.result()
            update_metrics(results)

    avg_loss = total_loss / len(train_loader)
    avg_accuracy = correct_words / total_words if total_words > 0 else 0

    return avg_loss, avg_accuracy


'\nmodel = YourModel().to(device)\ncriterion = torch.nn.CrossEntropyLoss(ignore_index=0)\noptimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n\nfor epoch in range(total_epochs):\n    avg_loss, avg_accuracy = train_model(\n        model=model,\n        train_loader=train_loader,\n        criterion=criterion,\n        optimizer=optimizer,\n        device=device,\n        epoch=epoch+1,\n        total_epochs=total_epochs,\n        num_threads=2  # Adjust based on your system\n    )\n'

In [None]:
def main():
    LEARNING_RATE = 0.001
    # Initialize model
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    model = ImageCaptioningModel(
        input_size=768,  # ViT base size
        hidden_size=HIDDEN_SIZE,
        vocab_size=len(word2idx),
        embed_size=EMBED_SIZE,
        num_layers=2,
        dropout_p=0.3
    ).to(device)

    criterion = nn.CrossEntropyLoss(ignore_index=0)  # Ignore padding index
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    # Training loop
    print("Starting training...")
    for epoch in range(EPOCHS):
        loss, accuracy = train_model(
        model=model,
        train_loader=train_loader,
        criterion=criterion,
        optimizer=optimizer,
        device=device,
        epoch=epoch+1,
        total_epochs=EPOCHS,
        num_threads=16 # Adjust based on your system
    )

        print(f"\nEpoch {epoch + 1}/{EPOCHS}")
        print(f"Average Loss: {loss:.4f}")

        # Save checkpoint
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            'accuracy': accuracy,
            'word2idx': word2idx,
            'idx2word': idx2word
        }, f'improved_model_epoch_{epoch+1}.pth')

if __name__ == "__main__":
    try:
        main()
    except Exception as e:
        print(f"Main error: {str(e)}")
        raise

Using device: cuda
Starting training...


Epoch 1/10:   3%|▎         | 100/3973 [01:22<36:32,  1.77it/s, loss=5.4345, acc=0.1859]


Batch 100/3973
Loss: 5.4345
Accuracy: 0.1859
Total words: 46286
Correct words: 8605


Epoch 1/10:   5%|▌         | 200/3973 [02:40<2:18:31,  2.20s/it, loss=5.1775, acc=0.2087]


Batch 200/3973
Loss: 5.1775
Accuracy: 0.2087
Total words: 92146
Correct words: 19230


Epoch 1/10:   8%|▊         | 300/3973 [03:56<33:05,  1.85it/s, loss=5.0580, acc=0.2202]


Batch 300/3973
Loss: 5.0580
Accuracy: 0.2202
Total words: 138242
Correct words: 30447


Epoch 1/10:  10%|█         | 400/3973 [05:11<30:54,  1.93it/s, loss=4.9816, acc=0.2277]


Batch 400/3973
Loss: 4.9816
Accuracy: 0.2277
Total words: 184089
Correct words: 41918


Epoch 1/10:  13%|█▎        | 500/3973 [06:35<36:39,  1.58it/s, loss=4.9351, acc=0.2317]


Batch 500/3973
Loss: 4.9351
Accuracy: 0.2317
Total words: 229644
Correct words: 53217


Epoch 1/10:  15%|█▌        | 600/3973 [07:44<27:23,  2.05it/s, loss=4.8986, acc=0.2354]


Batch 600/3973
Loss: 4.8986
Accuracy: 0.2354
Total words: 275910
Correct words: 64941


Epoch 1/10:  18%|█▊        | 700/3973 [08:59<41:49,  1.30it/s, loss=4.8675, acc=0.2380]


Batch 700/3973
Loss: 4.8675
Accuracy: 0.2380
Total words: 322015
Correct words: 76648


Epoch 1/10:  20%|██        | 800/3973 [10:15<27:34,  1.92it/s, loss=4.8451, acc=0.2395]


Batch 800/3973
Loss: 4.8451
Accuracy: 0.2395
Total words: 367967
Correct words: 88141


Epoch 1/10:  23%|██▎       | 900/3973 [11:32<26:23,  1.94it/s, loss=4.8219, acc=0.2413]


Batch 900/3973
Loss: 4.8219
Accuracy: 0.2413
Total words: 413454
Correct words: 99767


Epoch 1/10:  25%|██▌       | 1000/3973 [12:53<28:43,  1.72it/s, loss=4.8011, acc=0.2429]


Batch 1000/3973
Loss: 4.8011
Accuracy: 0.2429
Total words: 459524
Correct words: 111624


Epoch 1/10:  28%|██▊       | 1100/3973 [14:12<1:29:17,  1.86s/it, loss=4.7848, acc=0.2441]


Batch 1100/3973
Loss: 4.7848
Accuracy: 0.2441
Total words: 505835
Correct words: 123472


Epoch 1/10:  30%|███       | 1200/3973 [15:28<22:35,  2.05it/s, loss=4.7659, acc=0.2455]


Batch 1200/3973
Loss: 4.7659
Accuracy: 0.2455
Total words: 551839
Correct words: 135495


Epoch 1/10:  33%|███▎      | 1300/3973 [16:53<1:12:05,  1.62s/it, loss=4.7510, acc=0.2464]


Batch 1300/3973
Loss: 4.7510
Accuracy: 0.2464
Total words: 598169
Correct words: 147375


Epoch 1/10:  35%|███▌      | 1400/3973 [18:10<25:15,  1.70it/s, loss=4.7366, acc=0.2474]


Batch 1400/3973
Loss: 4.7366
Accuracy: 0.2474
Total words: 643759
Correct words: 159274


Epoch 1/10:  38%|███▊      | 1500/3973 [19:23<1:12:51,  1.77s/it, loss=4.7250, acc=0.2483]


Batch 1500/3973
Loss: 4.7250
Accuracy: 0.2483
Total words: 689945
Correct words: 171298


Epoch 1/10:  40%|████      | 1600/3973 [20:38<22:01,  1.80it/s, loss=4.7140, acc=0.2489]


Batch 1600/3973
Loss: 4.7140
Accuracy: 0.2489
Total words: 736444
Correct words: 183288


Epoch 1/10:  43%|████▎     | 1700/3973 [21:54<19:45,  1.92it/s, loss=4.7012, acc=0.2500]


Batch 1700/3973
Loss: 4.7012
Accuracy: 0.2500
Total words: 782462
Correct words: 195617


Epoch 1/10:  45%|████▌     | 1800/3973 [23:10<18:02,  2.01it/s, loss=4.6880, acc=0.2511]


Batch 1800/3973
Loss: 4.6880
Accuracy: 0.2511
Total words: 829021
Correct words: 208129


Epoch 1/10:  48%|████▊     | 1900/3973 [24:26<17:10,  2.01it/s, loss=4.6759, acc=0.2519]


Batch 1900/3973
Loss: 4.6759
Accuracy: 0.2519
Total words: 875398
Correct words: 220523


Epoch 1/10:  50%|█████     | 2000/3973 [25:41<15:32,  2.12it/s, loss=4.6636, acc=0.2529]


Batch 2000/3973
Loss: 4.6636
Accuracy: 0.2529
Total words: 921113
Correct words: 232967


Epoch 1/10:  53%|█████▎    | 2100/3973 [27:06<36:40,  1.17s/it, loss=4.6534, acc=0.2537]


Batch 2100/3973
Loss: 4.6534
Accuracy: 0.2537
Total words: 967521
Correct words: 245446


Epoch 1/10:  55%|█████▌    | 2200/3973 [28:22<16:07,  1.83it/s, loss=4.6437, acc=0.2544]


Batch 2200/3973
Loss: 4.6437
Accuracy: 0.2544
Total words: 1013235
Correct words: 257773


Epoch 1/10:  58%|█████▊    | 2300/3973 [29:41<16:05,  1.73it/s, loss=4.6342, acc=0.2551]


Batch 2300/3973
Loss: 4.6342
Accuracy: 0.2551
Total words: 1059015
Correct words: 270164


Epoch 1/10:  60%|██████    | 2400/3973 [30:57<13:27,  1.95it/s, loss=4.6252, acc=0.2559]


Batch 2400/3973
Loss: 4.6252
Accuracy: 0.2559
Total words: 1104580
Correct words: 282694


Epoch 1/10:  63%|██████▎   | 2500/3973 [32:23<31:25,  1.28s/it, loss=4.6178, acc=0.2563]


Batch 2500/3973
Loss: 4.6178
Accuracy: 0.2563
Total words: 1150538
Correct words: 294888


Epoch 1/10:  65%|██████▌   | 2600/3973 [33:40<11:40,  1.96it/s, loss=4.6090, acc=0.2569]


Batch 2600/3973
Loss: 4.6090
Accuracy: 0.2569
Total words: 1196235
Correct words: 307263


Epoch 1/10:  68%|██████▊   | 2700/3973 [34:58<11:35,  1.83it/s, loss=4.6024, acc=0.2573]


Batch 2700/3973
Loss: 4.6024
Accuracy: 0.2573
Total words: 1242729
Correct words: 319711


Epoch 1/10:  70%|███████   | 2800/3973 [36:16<10:03,  1.94it/s, loss=4.5957, acc=0.2576]


Batch 2800/3973
Loss: 4.5957
Accuracy: 0.2576
Total words: 1289132
Correct words: 332077


Epoch 1/10:  73%|███████▎  | 2900/3973 [37:41<11:30,  1.55it/s, loss=4.5878, acc=0.2581]


Batch 2900/3973
Loss: 4.5878
Accuracy: 0.2581
Total words: 1334860
Correct words: 344554


Epoch 1/10:  76%|███████▌  | 3000/3973 [38:58<07:53,  2.06it/s, loss=4.5802, acc=0.2586]


Batch 3000/3973
Loss: 4.5802
Accuracy: 0.2586
Total words: 1381100
Correct words: 357160


Epoch 1/10:  78%|███████▊  | 3100/3973 [40:13<11:28,  1.27it/s, loss=4.5734, acc=0.2590]


Batch 3100/3973
Loss: 4.5734
Accuracy: 0.2590
Total words: 1427224
Correct words: 369721


Epoch 1/10:  81%|████████  | 3200/3973 [41:25<07:31,  1.71it/s, loss=4.5670, acc=0.2595]


Batch 3200/3973
Loss: 4.5670
Accuracy: 0.2595
Total words: 1473316
Correct words: 382299


Epoch 1/10:  83%|████████▎ | 3300/3973 [42:35<05:57,  1.88it/s, loss=4.5603, acc=0.2600]


Batch 3300/3973
Loss: 4.5603
Accuracy: 0.2600
Total words: 1519228
Correct words: 395018


Epoch 1/10:  86%|████████▌ | 3400/3973 [44:01<06:00,  1.59it/s, loss=4.5547, acc=0.2603]


Batch 3400/3973
Loss: 4.5547
Accuracy: 0.2603
Total words: 1565633
Correct words: 407568


Epoch 1/10:  88%|████████▊ | 3500/3973 [45:18<07:25,  1.06it/s, loss=4.5506, acc=0.2604]


Batch 3500/3973
Loss: 4.5506
Accuracy: 0.2604
Total words: 1612167
Correct words: 419852


Epoch 1/10:  91%|█████████ | 3600/3973 [46:37<03:12,  1.94it/s, loss=4.5450, acc=0.2607]


Batch 3600/3973
Loss: 4.5450
Accuracy: 0.2607
Total words: 1657903
Correct words: 432262


Epoch 1/10:  93%|█████████▎| 3700/3973 [47:55<02:25,  1.88it/s, loss=4.5389, acc=0.2611]


Batch 3700/3973
Loss: 4.5389
Accuracy: 0.2611
Total words: 1703370
Correct words: 444669


Epoch 1/10:  93%|█████████▎| 3703/3973 [48:02<09:10,  2.04s/it, loss=4.5388, acc=0.2610]

In [None]:
def generate_caption(image_path, model_path="best_model.pth", max_length=50):
    """
    Generate a caption for a single image using the saved model
    """
    # Load model checkpoint
    checkpoint = torch.load(model_path, map_location='cpu')
    word2idx = checkpoint['word2idx']
    idx2word = checkpoint['idx2word']

    # Initialize model and load weights
    model = ImageCaptioningModel(
        input_size=768,  # ViT base size
        hidden_size=512,
        vocab_size=len(word2idx)
    )
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()

    # Load and process image
    feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
    vit_model = ViTModel.from_pretrained('google/vit-base-patch16-224')

    # Extract features
    image = Image.open(image_path).convert('RGB')

    with torch.no_grad():
        # Get ViT features
        inputs = feature_extractor(images=image, return_tensors="pt")
        outputs = vit_model(**inputs)
        image_features = outputs.last_hidden_state[:, 0, :]  # [1, 768]
        print(f"Image Embedding shape:",{image_features.shape})
    # Generate caption
    with torch.no_grad():
        # Initialize with start token
        current_token = torch.tensor([[word2idx['<START>']]])
        caption = []

        # Generate words until max length or end token
        for _ in range(max_length):
            # Generate next word
            output = model(image_features, current_token)
            next_word_idx = output[0, -1].argmax().item()

            # Convert to word
            word = idx2word[next_word_idx]

            # Stop if end token or pad
            if word in ['<END>', '<PAD>']:
                break

            caption.append(word)

            # Update current token
            current_token = torch.cat([current_token, torch.tensor([[next_word_idx]])], dim=1)

    return ' '.join(caption)

# Example usage:
if __name__ == "__main__":
    # Test with a sample image
    image_path = "/content/8192398089.jpg"  # Replace with your image path
    caption = generate_caption(image_path)
    print(f"\nGenerated caption: {caption}")

    # # Test with multiple images
    # test_images = [
    #     "flickr30k/images/image1.jpg",
    #     "flickr30k/images/image2.jpg",
    #     "flickr30k/images/image3.jpg"
    # ]

    # print("\nGenerating captions for multiple images:")
    # for img_path in test_images:
    #     try:
    #         caption = generate_caption(img_path)
    #         print(f"\nImage: {img_path}")
    #         print(f"Caption: {caption}")
    #     except Exception as e:
    #         print(f"Error processing {img_path}: {str(e)}")

  checkpoint = torch.load(model_path, map_location='cpu')
Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Image Embedding shape: {torch.Size([1, 768])}

Generated caption: women black in and outfits in dance


In [None]:
from nltk.translate.bleu_score import sentence_bleu, corpus_bleu, SmoothingFunction
from nltk.tokenize import word_tokenize
def load_model_and_tokenizers(model_path):
    """Load the trained model and tokenizers"""
    checkpoint = torch.load(model_path, map_location='cpu')
    word2idx = checkpoint['word2idx']
    idx2word = checkpoint['idx2word']

    # Initialize model
    model = ImageCaptioningModel(
        input_size=768,  # ViT base size
        hidden_size=512,
        vocab_size=len(word2idx)
    )

    model.load_state_dict(checkpoint['model_state_dict'])
    return model, word2idx, idx2word

def generate_caption(model, image_features, word2idx, idx2word, max_length=50):
    """Generate a caption for given image features"""
    model.eval()
    with torch.no_grad():
        # Initialize with start token
        decoder_input = torch.tensor([[word2idx['<START>']]])

        # Initialize hidden states
        h = torch.zeros(model.num_layers, 1, model.hidden_size)
        c = torch.zeros(model.num_layers, 1, model.hidden_size)
        hidden = (h, c)

        caption = []

        for _ in range(max_length):
            # Forward pass through model
            output, hidden = model(image_features.unsqueeze(0), decoder_input, hidden)

            # Get predicted word
            pred_idx = output[0, -1].argmax().item()
            pred_word = idx2word[pred_idx]

            # Break if end token or padding
            if pred_word in ['<END>', '<PAD>']:
                break

            caption.append(pred_word)
            decoder_input = torch.tensor([[pred_idx]])

    return caption

def calculate_bleu_scores(references, hypothesis):
    """Calculate BLEU-1,2,3,4 scores"""
    smoothing = SmoothingFunction().method1

    # Ensure references is a list of lists
    if not isinstance(references[0], list):
        references = [references]

    # Calculate individual BLEU scores
    bleu1 = sentence_bleu(references, hypothesis, weights=(1, 0, 0, 0), smoothing_function=smoothing)
    bleu2 = sentence_bleu(references, hypothesis, weights=(0.5, 0.5, 0, 0), smoothing_function=smoothing)
    bleu3 = sentence_bleu(references, hypothesis, weights=(0.33, 0.33, 0.33, 0), smoothing_function=smoothing)
    bleu4 = sentence_bleu(references, hypothesis, weights=(0.25, 0.25, 0.25, 0.25), smoothing_function=smoothing)

    return {
        'bleu1': bleu1,
        'bleu2': bleu2,
        'bleu3': bleu3,
        'bleu4': bleu4
    }

def evaluate_model(model_path, test_dataloader, device):
    """Evaluate model on test set"""
    # Load model and tokenizers
    model, word2idx, idx2word = load_model_and_tokenizers(model_path)
    model = model.to(device)
    model.eval()

    # Initialize feature extractor
    feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
    vit_model = ViTModel.from_pretrained('google/vit-base-patch16-224').to(device)

    # Initialize metrics
    all_bleu_scores = []
    all_references = []
    all_hypotheses = []

    print("Generating captions and calculating BLEU scores...")
    for images, captions in tqdm(test_dataloader):
        try:
            images = images.to(device)

            # Generate captions
            generated_caption = generate_caption(model, images, word2idx, idx2word)

            # Get reference captions (remove special tokens)
            reference_captions = []
            for cap in captions:
                tokens = [idx2word[idx.item()] for idx in cap]
                # Remove special tokens
                tokens = [token for token in tokens if token not in ['<START>', '<END>', '<PAD>', '<UNK>']]
                reference_captions.append(tokens)

            # Calculate BLEU scores
            scores = calculate_bleu_scores(reference_captions, generated_caption)
            all_bleu_scores.append(scores)

            # Store for corpus BLEU calculation
            all_references.append(reference_captions)
            all_hypotheses.append(generated_caption)

        except Exception as e:
            print(f"Error processing batch: {str(e)}")
            continue

    # Calculate average scores
    avg_scores = {
        'bleu1': np.mean([s['bleu1'] for s in all_bleu_scores]),
        'bleu2': np.mean([s['bleu2'] for s in all_bleu_scores]),
        'bleu3': np.mean([s['bleu3'] for s in all_bleu_scores]),
        'bleu4': np.mean([s['bleu4'] for s in all_bleu_scores])
    }

    # Calculate corpus BLEU
    corpus_bleu1 = corpus_bleu(all_references, all_hypotheses, weights=(1, 0, 0, 0))
    corpus_bleu4 = corpus_bleu(all_references, all_hypotheses, weights=(0.25, 0.25, 0.25, 0.25))

    # Print results
    print("\nEvaluation Results:")
    print("\nAverage BLEU Scores:")
    print(f"BLEU-1: {avg_scores['bleu1']:.4f}")
    print(f"BLEU-2: {avg_scores['bleu2']:.4f}")
    print(f"BLEU-3: {avg_scores['bleu3']:.4f}")
    print(f"BLEU-4: {avg_scores['bleu4']:.4f}")

    print("\nCorpus BLEU Scores:")
    print(f"Corpus BLEU-1: {corpus_bleu1:.4f}")
    print(f"Corpus BLEU-4: {corpus_bleu4:.4f}")

    # Save some example predictions
    print("\nSample Predictions:")
    for i in range(min(5, len(all_hypotheses))):
        print(f"\nImage {i+1}:")
        print(f"Generated: {' '.join(all_hypotheses[i])}")
        print(f"Reference: {' '.join(all_references[i][0])}")

    return avg_scores, {'corpus_bleu1': corpus_bleu1, 'corpus_bleu4': corpus_bleu4}

def evaluate_single_image(model, image_path, word2idx, idx2word, feature_extractor, vit_model, device):
    """Generate caption for a single image"""
    model.eval()
    with torch.no_grad():
        # Load and process image
        image = Image.open(image_path).convert('RGB')
        inputs = feature_extractor(images=image, return_tensors="pt")
        inputs = {k: v.to(device) for k, v in inputs.items()}

        # Extract features
        features = vit_model(**inputs).last_hidden_state[:, 0, :]

        # Generate caption
        caption = generate_caption(model, features, word2idx, idx2word)

        return ' '.join(caption)

if __name__ == "__main__":
    # Parameters
    MODEL_PATH = 'best_model.pth'  # Path to your saved model
    BATCH_SIZE = 32
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Load test dataset
    test_loader = DataLoader(
        test_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=4
    )

    # Evaluate model
    avg_scores, corpus_scores = evaluate_model(MODEL_PATH, test_loader, device)

    # Test on single images
    model, word2idx, idx2word = load_model_and_tokenizers(MODEL_PATH)
    model = model.to(device)
    feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
    vit_model = ViTModel.from_pretrained('google/vit-base-patch16-224').to(device)

    # Test on some sample images
    test_images = [
        "/content/8192398089.jpg",
    ]

    print("\nTesting individual images:")
    for img_path in test_images:
        try:
            caption = evaluate_single_image(
                model, img_path, word2idx, idx2word,
                feature_extractor, vit_model, device
            )
            print(f"\nImage: {img_path}")
            print(f"Generated Caption: {caption}")
        except Exception as e:
            print(f"Error processing {img_path}: {str(e)}")

  checkpoint = torch.load(model_path, map_location='cpu')


NameError: name 'EncoderDecoder' is not defined