In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torchvision.models as models
from torchvision.models.segmentation import deeplabv3_resnet50
import torch.nn.functional as F

import cv2
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import json
import os
from collections import defaultdict
import pickle
from tqdm import tqdm
import random

In [None]:
import nltk
from nltk.tokenize import word_tokenize
import re
nltk.download('punkt', quiet=True)
nltk.download('stopwords', quiet=True)

print("All libraries imported successfully!")

All libraries imported successfully!


In [None]:
import torch
from torch.utils.data import Dataset
from PIL import Image
from io import BytesIO
import requests
import re
from nltk.tokenize import word_tokenize
from datasets import load_dataset

class COCODataset(Dataset):
    def __init__(self, vocab, transform=None, max_caption_length=20, split="train[:1%]"):
        """
        COCO Dataset using HuggingFace for captioning
        """
        self.vocab = vocab
        self.transform = transform
        self.max_caption_length = max_caption_length

        # Load COCO via Hugging Face (lazy loading)
        self.dataset = load_dataset("HuggingFaceM4/COCO", split=split)

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        sample = self.dataset[idx]
        image_url = sample['image']['url']
        caption = self.preprocess_caption(sample['caption'])

        # Load image from URL
        try:
            response = requests.get(image_url)
            image = Image.open(BytesIO(response.content)).convert('RGB')
        except:
            image = Image.fromarray(np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8))

        if self.transform:
            image = self.transform(image)

        # Process caption
        caption_tokens = self.caption_to_tokens(caption)

        return image, caption_tokens, idx  # using idx as image_id here

    def preprocess_caption(self, caption):
        """Clean and preprocess caption text"""
        caption = caption.lower()
        caption = re.sub(r'[^\w\s]', '', caption)
        return caption.strip()

    def caption_to_tokens(self, caption):
        """Convert caption to token indices"""
        tokens = word_tokenize(caption)
        tokens = ['<start>'] + tokens + ['<end>']

        if len(tokens) > self.max_caption_length:
            tokens = tokens[:self.max_caption_length - 1] + ['<end>']
        else:
            tokens.append('<end>')
            tokens.extend(['<pad>'] * (self.max_caption_length - len(tokens)))

        token_indices = [self.vocab.get(token, self.vocab.get('<unk>', 0)) for token in tokens]
        return torch.tensor(token_indices)


In [None]:
class VocabularyBuilder:
    def __init__(self, min_freq=2):
        self.word2idx = {'<pad>': 0, '<start>': 1, '<end>': 2, '<unk>': 3}
        self.idx2word = {0: '<pad>', 1: '<start>', 2: '<end>', 3: '<unk>'}
        self.min_freq = min_freq
        self.word_freq = defaultdict(int)

    def build_vocab(self, captions):
        """Build vocabulary from captions"""
        # Count word frequencies
        for caption in captions:
            words = word_tokenize(caption.lower())
            for word in words:
                self.word_freq[word] += 1

        # Add words that meet minimum frequency
        idx = len(self.word2idx)
        for word, freq in self.word_freq.items():
            if freq >= self.min_freq and word not in self.word2idx:
                self.word2idx[word] = idx
                self.idx2word[idx] = word
                idx += 1

        print(f"Vocabulary size: {len(self.word2idx)}")
        return self.word2idx, self.idx2word

# Mock COCO data creation (for demonstration)
def create_mock_coco_data():
    """Create mock COCO data structure for demonstration"""
    mock_data = {
        'images': [
            {'id': 1, 'file_name': 'image1.jpg', 'height': 224, 'width': 224},
            {'id': 2, 'file_name': 'image2.jpg', 'height': 224, 'width': 224},
        ],
        'annotations': [
            {'id': 1, 'image_id': 1, 'caption': 'A cat sitting on a chair'},
            {'id': 2, 'image_id': 1, 'caption': 'Orange cat relaxing indoors'},
            {'id': 3, 'image_id': 2, 'caption': 'A dog running in the park'},
            {'id': 4, 'image_id': 2, 'caption': 'Happy golden retriever playing outside'},
        ]
    }

    # Save mock data
    os.makedirs('data', exist_ok=True)
    with open('data/mock_coco.json', 'w') as f:
        json.dump(mock_data, f)

    return mock_data

# Create mock data
mock_data = create_mock_coco_data()
print("Mock COCO data created!")


Mock COCO data created!


In [None]:
class CNNEncoder(nn.Module):
    def __init__(self, embed_size=256):
        super(CNNEncoder, self).__init__()
        # Use pre-trained ResNet
        resnet = models.resnet50(pretrained=True)
        # Remove last layer
        modules = list(resnet.children())[:-1]
        self.resnet = nn.Sequential(*modules)
        self.linear = nn.Linear(resnet.fc.in_features, embed_size)
        self.bn = nn.BatchNorm1d(embed_size, momentum=0.01)

    def forward(self, images):
        with torch.no_grad():
            features = self.resnet(images)
        features = features.reshape(features.size(0), -1)
        features = self.bn(self.linear(features))
        return features

In [None]:
class AttentionDecoder(nn.Module):
    def __init__(self, attention_dim, embed_dim, decoder_dim, vocab_size, dropout=0.5):
        super(AttentionDecoder, self).__init__()

        self.attention_dim = attention_dim
        self.embed_dim = embed_dim
        self.decoder_dim = decoder_dim
        self.vocab_size = vocab_size

        self.attention = nn.Linear(decoder_dim, attention_dim)
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.dropout = nn.Dropout(p=dropout)
        self.decode_step = nn.LSTMCell(embed_dim, decoder_dim, bias=True)
        self.init_h = nn.Linear(attention_dim, decoder_dim)
        self.init_c = nn.Linear(attention_dim, decoder_dim)
        self.f_beta = nn.Linear(decoder_dim, attention_dim)
        self.sigmoid = nn.Sigmoid()
        self.fc = nn.Linear(decoder_dim, vocab_size)

    def forward(self, encoder_out, encoded_captions, caption_lengths):
        batch_size = encoder_out.size(0)
        vocab_size = self.vocab_size

        # Sort input data by caption lengths
        caption_lengths, sort_ind = caption_lengths.squeeze(1).sort(dim=0, descending=True)
        encoder_out = encoder_out[sort_ind]
        encoded_captions = encoded_captions[sort_ind]

        # Embedding
        embeddings = self.embedding(encoded_captions)

        # Initialize LSTM state
        h = self.init_h(encoder_out)
        c = self.init_c(encoder_out)

        # We won't decode at the <end> position, since we've finished generating
        decode_lengths = (caption_lengths - 1).tolist()

        # Create tensors to hold word prediction scores
        predictions = torch.zeros(batch_size, max(decode_lengths), vocab_size)

        for t in range(max(decode_lengths)):
            batch_size_t = sum([l > t for l in decode_lengths])

            # Attention mechanism (simplified)
            attention_weighted_encoding = encoder_out[:batch_size_t]

            gate = self.sigmoid(self.f_beta(h[:batch_size_t]))
            attention_weighted_encoding = gate * attention_weighted_encoding

            h, c = self.decode_step(
                torch.cat([embeddings[:batch_size_t, t, :], attention_weighted_encoding], dim=1),
                (h[:batch_size_t], c[:batch_size_t])
            )

            preds = self.fc(self.dropout(h))
            predictions[:batch_size_t, t, :] = preds

        return predictions, encoded_captions, decode_lengths, sort_ind


In [None]:
class CaptioningModel(nn.Module):
    def __init__(self, vocab_size, embed_size=256, hidden_size=512, num_layers=1):
        super(CaptioningModel, self).__init__()

        self.embed_size = embed_size
        self.hidden_size = hidden_size
        self.vocab_size = vocab_size

        # CNN Encoder
        self.encoder = CNNEncoder(embed_size)

        # RNN Decoder
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_size, vocab_size)

    def forward(self, images, captions):
        # Encode images
        features = self.encoder(images)

        # Prepare captions (remove last token for input)
        captions_input = captions[:, :-1]
        embeddings = self.embedding(captions_input)

        # Concatenate image features and caption embeddings
        features = features.unsqueeze(1)
        inputs = torch.cat((features, embeddings), dim=1)

        # LSTM forward pass
        hiddens, _ = self.lstm(inputs)
        outputs = self.linear(hiddens)

        return outputs

    def generate_caption(self, image, vocab, max_length=20):
        """Generate caption for a single image"""
        self.eval()
        with torch.no_grad():
            # Encode image
            features = self.encoder(image.unsqueeze(0))

            # Start with <start> token
            inputs = torch.tensor([vocab['<start>']]).unsqueeze(0)
            hidden = None
            caption = []

            for _ in range(max_length):
                # Embed input
                embedded = self.embedding(inputs)

                if hidden is None:
                    # First step: use image features
                    lstm_input = torch.cat([features.unsqueeze(1), embedded], dim=1)
                else:
                    lstm_input = embedded

                # LSTM forward
                output, hidden = self.lstm(lstm_input, hidden)

                # Get prediction
                pred = self.linear(output[:, -1, :])
                predicted_id = pred.argmax(dim=1)

                # Add to caption
                caption.append(predicted_id.item())

                # Stop if <end> token
                if predicted_id.item() == vocab.get('<end>', 2):
                    break

                # Update input for next iteration
                inputs = predicted_id.unsqueeze(0)

            return caption

In [None]:
class UNet(nn.Module):
    def __init__(self, n_channels=3, n_classes=21):  # 21 classes for Pascal VOC
        super(UNet, self).__init__()

        self.n_channels = n_channels
        self.n_classes = n_classes

        # Encoder
        self.encoder1 = self.conv_block(n_channels, 64)
        self.encoder2 = self.conv_block(64, 128)
        self.encoder3 = self.conv_block(128, 256)
        self.encoder4 = self.conv_block(256, 512)

        # Bottleneck
        self.bottleneck = self.conv_block(512, 1024)

        # Decoder
        self.upconv4 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
        self.decoder4 = self.conv_block(1024, 512)

        self.upconv3 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.decoder3 = self.conv_block(512, 256)

        self.upconv2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.decoder2 = self.conv_block(256, 128)

        self.upconv1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.decoder1 = self.conv_block(128, 64)

        # Final layer
        self.final_conv = nn.Conv2d(64, n_classes, 1)

    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        # Encoder
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(F.max_pool2d(enc1, 2))
        enc3 = self.encoder3(F.max_pool2d(enc2, 2))
        enc4 = self.encoder4(F.max_pool2d(enc3, 2))

        # Bottleneck
        bottleneck = self.bottleneck(F.max_pool2d(enc4, 2))

        # Decoder
        dec4 = self.upconv4(bottleneck)
        dec4 = torch.cat((dec4, enc4), dim=1)
        dec4 = self.decoder4(dec4)

        dec3 = self.upconv3(dec4)
        dec3 = torch.cat((dec3, enc3), dim=1)
        dec3 = self.decoder3(dec3)

        dec2 = self.upconv2(dec3)
        dec2 = torch.cat((dec2, enc2), dim=1)
        dec2 = self.decoder2(dec2)

        dec1 = self.upconv1(dec2)
        dec1 = torch.cat((dec1, enc1), dim=1)
        dec1 = self.decoder1(dec1)

        return self.final_conv(dec1)

In [None]:
class MaskRCNNHead(nn.Module):
    def __init__(self, in_channels=256, num_classes=80):
        super(MaskRCNNHead, self).__init__()

        self.conv1 = nn.Conv2d(in_channels, 256, 3, padding=1)
        self.conv2 = nn.Conv2d(256, 256, 3, padding=1)
        self.conv3 = nn.Conv2d(256, 256, 3, padding=1)
        self.conv4 = nn.Conv2d(256, 256, 3, padding=1)
        self.deconv = nn.ConvTranspose2d(256, 256, 2, stride=2)
        self.predictor = nn.Conv2d(256, num_classes, 1)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = F.relu(self.deconv(x))
        return self.predictor(x)

In [None]:
class IntegratedModel(nn.Module):
    def __init__(self, vocab_size, num_seg_classes=21):
        super(IntegratedModel, self).__init__()

        # Shared CNN backbone
        resnet = models.resnet50(pretrained=True)
        self.backbone = nn.Sequential(*list(resnet.children())[:-2])  # Remove avgpool and fc

        # Caption branch
        self.caption_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.caption_fc = nn.Linear(2048, 512)
        self.caption_embedding = nn.Embedding(vocab_size, 256)
        self.caption_lstm = nn.LSTM(256, 512, batch_first=True)
        self.caption_output = nn.Linear(512, vocab_size)

        # Segmentation branch
        self.seg_decoder = nn.Sequential(
            nn.Conv2d(2048, 512, 3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            nn.Conv2d(512, 256, 3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            nn.Conv2d(256, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            nn.Conv2d(128, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            nn.Conv2d(64, num_seg_classes, 3, padding=1)
        )

    def forward(self, images, captions=None, mode='both'):
        # Extract shared features
        shared_features = self.backbone(images)

        outputs = {}

        if mode in ['caption', 'both'] and captions is not None:
            # Caption branch
            caption_features = self.caption_pool(shared_features).flatten(1)
            caption_features = self.caption_fc(caption_features)

            # Process captions
            caption_embeddings = self.caption_embedding(captions[:, :-1])

            # Add image features to beginning
            caption_features = caption_features.unsqueeze(1)
            lstm_input = torch.cat([caption_features, caption_embeddings], dim=1)

            lstm_out, _ = self.caption_lstm(lstm_input)
            caption_logits = self.caption_output(lstm_out)
            outputs['captions'] = caption_logits

        if mode in ['segment', 'both']:
            # Segmentation branch
            seg_logits = self.seg_decoder(shared_features)
            outputs['segmentation'] = seg_logits

        return outputs

    def generate_caption(self, image, vocab, max_length=20):
        """Generate caption for inference"""
        self.eval()
        with torch.no_grad():
            shared_features = self.backbone(image.unsqueeze(0))
            caption_features = self.caption_pool(shared_features).flatten(1)
            caption_features = self.caption_fc(caption_features)

            # Generate caption word by word
            caption = []
            hidden = None
            input_word = torch.tensor([vocab['<start>']]).unsqueeze(0)

            for _ in range(max_length):
                word_embed = self.caption_embedding(input_word)

                if hidden is None:
                    lstm_input = torch.cat([caption_features.unsqueeze(1), word_embed], dim=1)
                else:
                    lstm_input = word_embed

                output, hidden = self.caption_lstm(lstm_input, hidden)
                word_logits = self.caption_output(output[:, -1, :])
                predicted_word = word_logits.argmax(dim=1)

                caption.append(predicted_word.item())

                if predicted_word.item() == vocab.get('<end>', 2):
                    break

                input_word = predicted_word.unsqueeze(0)

            return caption


In [None]:
def train_integrated_model(model, train_loader, criterion_caption, criterion_seg, optimizer, device):
    model.train()
    total_loss = 0

    for batch_idx, (images, captions, seg_masks) in enumerate(train_loader):
        images = images.to(device)
        captions = captions.to(device)
        seg_masks = seg_masks.to(device) if seg_masks is not None else None

        optimizer.zero_grad()

        # Forward pass
        outputs = model(images, captions, mode='both')

        # Calculate losses
        caption_loss = 0
        seg_loss = 0

        if 'captions' in outputs:
            caption_targets = captions[:, 1:]  # Remove <start> token
            caption_loss = criterion_caption(
                outputs['captions'].reshape(-1, outputs['captions'].size(-1)),
                caption_targets.reshape(-1)
            )

        if 'segmentation' in outputs and seg_masks is not None:
            seg_loss = criterion_seg(outputs['segmentation'], seg_masks)

        # Combined loss
        total_batch_loss = caption_loss + seg_loss
        total_batch_loss.backward()
        optimizer.step()

        total_loss += total_batch_loss.item()

        if batch_idx % 10 == 0:
            print(f'Batch {batch_idx}, Caption Loss: {caption_loss:.4f}, Seg Loss: {seg_loss:.4f}')

    return total_loss / len(train_loader)

In [None]:
def evaluate_model(model, val_loader, vocab, device):
    model.eval()
    total_caption_loss = 0
    total_seg_loss = 0

    with torch.no_grad():
        for images, captions, seg_masks in val_loader:
            images = images.to(device)
            captions = captions.to(device)
            seg_masks = seg_masks.to(device) if seg_masks is not None else None

            outputs = model(images, captions, mode='both')

            # Calculate metrics (simplified)
            if 'captions' in outputs:
                caption_targets = captions[:, 1:]
                caption_loss = F.cross_entropy(
                    outputs['captions'].reshape(-1, outputs['captions'].size(-1)),
                    caption_targets.reshape(-1)
                )
                total_caption_loss += caption_loss.item()

            if 'segmentation' in outputs and seg_masks is not None:
                seg_loss = F.cross_entropy(outputs['segmentation'], seg_masks)
                total_seg_loss += seg_loss.item()

    return total_caption_loss / len(val_loader), total_seg_loss / len(val_loader)

In [None]:
import nltk
nltk.download('punkt_tab')

[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


True

In [None]:
def visualize_results(model, image, vocab, idx2word, device, seg_classes=None):
    """Visualize captioning and segmentation results"""
    model.eval()

    # Generate caption
    caption_ids = model.generate_caption(image, vocab)
    caption_words = [idx2word.get(idx, '<unk>') for idx in caption_ids]
    caption = ' '.join([word for word in caption_words if word not in ['<start>', '<end>', '<pad>']])

    # Generate segmentation
    with torch.no_grad():
        outputs = model(image.unsqueeze(0).to(device), mode='segment')
        seg_pred = outputs['segmentation'].squeeze(0).cpu()
        seg_mask = seg_pred.argmax(dim=0)

    # Plot results
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))

    # Original image
    img_np = image.permute(1, 2, 0).cpu().numpy()
    img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min())
    axes[0].imshow(img_np)
    axes[0].set_title('Original Image')
    axes[0].axis('off')

    # Segmentation mask
    axes[1].imshow(seg_mask, cmap='tab20')
    axes[1].set_title('Segmentation Mask')
    axes[1].axis('off')

    # Caption
    axes[2].text(0.5, 0.5, f'Caption: {caption}',
                horizontalalignment='center', verticalalignment='center',
                transform=axes[2].transAxes, fontsize=12, wrap=True)
    axes[2].set_title('Generated Caption')
    axes[2].axis('off')

    plt.tight_layout()
    plt.show()

    return caption, seg_mask

# Example usage and setup
def setup_training():
    """Setup training configuration"""
    # Device configuration
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Create vocabulary (mock)
    sample_captions = [
        "a cat sitting on a chair",
        "a dog running in the park",
        "people walking on the street",
        "cars driving on the road"
    ]

    vocab_builder = VocabularyBuilder()
    word2idx, idx2word = vocab_builder.build_vocab(sample_captions)

    # Model configuration
    vocab_size = len(word2idx)
    embed_size = 256
    hidden_size = 512
    num_seg_classes = 21

    # Initialize models
    captioning_model = CaptioningModel(vocab_size, embed_size, hidden_size)
    segmentation_model = UNet(n_channels=3, n_classes=num_seg_classes)
    integrated_model = IntegratedModel(vocab_size, num_seg_classes)

    # Move to device
    captioning_model.to(device)
    segmentation_model.to(device)
    integrated_model.to(device)

    print("Models initialized and moved to device!")

    return {
        'device': device,
        'vocab': word2idx,
        'idx2word': idx2word,
        'captioning_model': captioning_model,
        'segmentation_model': segmentation_model,
        'integrated_model': integrated_model
    }
setup_dict = setup_training()
print("\nSetup complete! You can now:")
print("1. Load your COCO dataset")
print("2. Create data loaders")
print("3. Start training with the integrated model")
print("4. Evaluate on validation data")
print("5. Visualize results")




Using device: cpu
Vocabulary size: 7
Models initialized and moved to device!

Setup complete! You can now:
1. Load your COCO dataset
2. Create data loaders
3. Start training with the integrated model
4. Evaluate on validation data
5. Visualize results


In [None]:
def example_training_setup():
    """Example of how to set up the training loop"""

    # Hyperparameters
    learning_rate = 0.001
    batch_size = 32
    num_epochs = 10

    device = setup_dict['device']
    model = setup_dict['integrated_model']
    vocab = setup_dict['vocab']

    # Loss functions
    criterion_caption = nn.CrossEntropyLoss(ignore_index=vocab['<pad>'])
    criterion_seg = nn.CrossEntropyLoss()

    # Optimizer
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    print(f"Training setup complete!")
    print(f"Learning rate: {learning_rate}")
    print(f"Batch size: {batch_size}")
    print(f"Number of epochs: {num_epochs}")

    return {
        'optimizer': optimizer,
        'criterion_caption': criterion_caption,
        'criterion_seg': criterion_seg,
        'batch_size': batch_size,
        'num_epochs': num_epochs
    }

training_setup = example_training_setup()

print("\n" + "="*50)
print("IMPLEMENTATION COMPLETE!")
print("="*50)
print("Next steps:")
print("1. Replace mock data with real COCO dataset")
print("2. Implement proper data loaders")
print("3. Run training loop")
print("4. Add evaluation metrics (BLEU, IoU, etc.)")
print("5. Implement model checkpointing")
print("6. Add tensorboard logging")

Training setup complete!
Learning rate: 0.001
Batch size: 32
Number of epochs: 10

IMPLEMENTATION COMPLETE!
Next steps:
1. Replace mock data with real COCO dataset
2. Implement proper data loaders
3. Run training loop
4. Add evaluation metrics (BLEU, IoU, etc.)
5. Implement model checkpointing
6. Add tensorboard logging


In [None]:
class RealCOCODataset(torch.utils.data.Dataset):
    def __init__(self, root_dir, ann_file, vocab, transform=None, max_caption_length=20):
        """
        Real COCO Dataset implementation
        """
        self.root_dir = root_dir
        self.transform = transform
        self.max_caption_length = max_caption_length
        self.vocab = vocab

        # Initialize COCO API
        self.coco = COCO(ann_file)
        self.img_ids = list(self.coco.imgs.keys())

        # Filter images that have both captions and segmentation annotations
        self.valid_img_ids = []
        for img_id in self.img_ids:
            ann_ids = self.coco.getAnnIds(imgIds=[img_id])
            anns = self.coco.loadAnns(ann_ids)

            # Check if image has both captions and segmentation
            has_caption = any('caption' in ann for ann in anns)
            has_segmentation = any('segmentation' in ann for ann in anns)

            if has_caption and has_segmentation:
                self.valid_img_ids.append(img_id)

        print(f"Found {len(self.valid_img_ids)} images with both captions and segmentation")

    def __len__(self):
        return len(self.valid_img_ids)

    def __getitem__(self, idx):
        img_id = self.valid_img_ids[idx]

        # Load image
        img_info = self.coco.loadImgs(img_id)[0]
        image_path = os.path.join(self.root_dir, img_info['file_name'])
        image = Image.open(image_path).convert('RGB')

        # Get annotations
        ann_ids = self.coco.getAnnIds(imgIds=[img_id])
        anns = self.coco.loadAnns(ann_ids)

        # Get caption
        captions = [ann['caption'] for ann in anns if 'caption' in ann]
        if captions:
            caption = random.choice(captions)  # Random caption if multiple
        else:
            caption = "no caption available"

        # Create segmentation mask
        seg_mask = self.create_segmentation_mask(anns, img_info['height'], img_info['width'])

        # Apply transforms
        if self.transform:
            image = self.transform(image)
            # Also transform segmentation mask
            seg_mask = torch.from_numpy(seg_mask).long()

        # Process caption
        caption_tokens = self.caption_to_tokens(caption)

        return image, caption_tokens, seg_mask

    def create_segmentation_mask(self, anns, height, width):
        """Create segmentation mask from COCO annotations"""
        mask = np.zeros((height, width), dtype=np.uint8)

        for ann in anns:
            if 'segmentation' in ann:
                category_id = ann['category_id']
                if isinstance(ann['segmentation'], list):
                    # Polygon format
                    for seg in ann['segmentation']:
                        poly = np.array(seg).reshape(-1, 2)
                        cv2.fillPoly(mask, [poly.astype(np.int32)], category_id)
                else:
                    # RLE format
                    rle = coco_mask.frPyObjects(ann['segmentation'], height, width)
                    m = coco_mask.decode(rle)
                    mask[m > 0] = category_id

        return mask

    def caption_to_tokens(self, caption):
        """Convert caption to token indices"""
        caption = caption.lower().strip()
        tokens = caption.split()
        tokens = ['<start>'] + tokens + ['<end>']

        # Pad or truncate
        if len(tokens) > self.max_caption_length:
            tokens = tokens[:self.max_caption_length]
        else:
            tokens.extend(['<pad>'] * (self.max_caption_length - len(tokens)))

        # Convert to indices
        token_indices = [self.vocab.get(token, self.vocab.get('<unk>', 0)) for token in tokens]
        return torch.tensor(token_indices)


# Step 2: Data Loading Setup


In [None]:
def create_data_loaders(vocab, batch_size=32):
    """Create train and validation data loaders"""

    # Define transforms
    train_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(0.5),
        transforms.ColorJitter(brightness=0.2, contrast=0.2),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    val_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    # Create datasets (you'll need to download COCO dataset first)
    # train_dataset = RealCOCODataset(
    #     root_dir='path/to/coco/train2017',
    #     ann_file='path/to/coco/annotations/instances_train2017.json',
    #     vocab=vocab,
    #     transform=train_transform
    # )

    # val_dataset = RealCOCODataset(
    #     root_dir='path/to/coco/val2017',
    #     ann_file='path/to/coco/annotations/instances_val2017.json',
    #     vocab=vocab,
    #     transform=val_transform
    # )

    # For now, create dummy datasets for demonstration
    class DummyDataset(torch.utils.data.Dataset):
        def __init__(self, size=1000, vocab_size=len(vocab)):
            self.size = size
            self.vocab_size = vocab_size

        def __len__(self):
            return self.size

        def __getitem__(self, idx):
            # Dummy data
            image = torch.randn(3, 224, 224)
            caption = torch.randint(0, self.vocab_size, (20,))
            seg_mask = torch.randint(0, 21, (224, 224))
            return image, caption, seg_mask

    train_dataset = DummyDataset(1000)
    val_dataset = DummyDataset(200)

    # Create data loaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=4,
        pin_memory=True
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=4,
        pin_memory=True
    )

    return train_loader, val_loader

# Step 3: Enhanced Training Loop with Logging


In [None]:
def train_model_with_logging(model, train_loader, val_loader, vocab, idx2word,
                            num_epochs=10, device='cuda'):
    """Complete training loop with logging and checkpointing"""

    # Setup
    model.to(device)
    criterion_caption = torch.nn.CrossEntropyLoss(ignore_index=vocab['<pad>'])
    criterion_seg = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

    # Training history
    train_losses = []
    val_losses = []
    best_val_loss = float('inf')

    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        print("-" * 50)

        # Training phase
        model.train()
        train_loss = 0.0
        train_caption_loss = 0.0
        train_seg_loss = 0.0

        for batch_idx, (images, captions, seg_masks) in enumerate(train_loader):
            images = images.to(device)
            captions = captions.to(device)
            seg_masks = seg_masks.to(device)

            optimizer.zero_grad()

            # Forward pass
            outputs = model(images, captions, mode='both')

            # Calculate losses
            caption_targets = captions[:, 1:]  # Remove <start> token
            caption_loss = criterion_caption(
                outputs['captions'].reshape(-1, outputs['captions'].size(-1)),
                caption_targets.reshape(-1)
            )

            seg_loss = criterion_seg(outputs['segmentation'], seg_masks)

            # Combined loss (you can adjust weights)
            total_loss = caption_loss + 0.5 * seg_loss

            total_loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            # Accumulate losses
            train_loss += total_loss.item()
            train_caption_loss += caption_loss.item()
            train_seg_loss += seg_loss.item()

            # Print progress
            if batch_idx % 50 == 0:
                print(f'Batch {batch_idx}/{len(train_loader)}, '
                      f'Loss: {total_loss.item():.4f}, '
                      f'Caption: {caption_loss.item():.4f}, '
                      f'Seg: {seg_loss.item():.4f}')

        # Validation phase
        model.eval()
        val_loss = 0.0
        val_caption_loss = 0.0
        val_seg_loss = 0.0

        with torch.no_grad():
            for images, captions, seg_masks in val_loader:
                images = images.to(device)
                captions = captions.to(device)
                seg_masks = seg_masks.to(device)

                outputs = model(images, captions, mode='both')

                caption_targets = captions[:, 1:]
                caption_loss = criterion_caption(
                    outputs['captions'].reshape(-1, outputs['captions'].size(-1)),
                    caption_targets.reshape(-1)
                )

                seg_loss = criterion_seg(outputs['segmentation'], seg_masks)
                total_loss = caption_loss + 0.5 * seg_loss

                val_loss += total_loss.item()
                val_caption_loss += caption_loss.item()
                val_seg_loss += seg_loss.item()

        # Calculate average losses
        avg_train_loss = train_loss / len(train_loader)
        avg_val_loss = val_loss / len(val_loader)
        avg_train_caption = train_caption_loss / len(train_loader)
        avg_val_caption = val_caption_loss / len(val_loader)
        avg_train_seg = train_seg_loss / len(train_loader)
        avg_val_seg = val_seg_loss / len(val_loader)

        train_losses.append(avg_train_loss)
        val_losses.append(avg_val_loss)

        print(f'\nEpoch {epoch+1} Results:')
        print(f'Train Loss: {avg_train_loss:.4f} (Caption: {avg_train_caption:.4f}, Seg: {avg_train_seg:.4f})')
        print(f'Val Loss: {avg_val_loss:.4f} (Caption: {avg_val_caption:.4f}, Seg: {avg_val_seg:.4f})')

        # Save best model
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': best_val_loss,
                'vocab': vocab,
                'idx2word': idx2word
            }, 'best_model.pth')
            print(f'Best model saved with validation loss: {best_val_loss:.4f}')

        # Step scheduler
        scheduler.step()

        # Generate sample results every 2 epochs
        if (epoch + 1) % 2 == 0:
            generate_sample_results(model, val_loader, vocab, idx2word, device)

    return train_losses, val_losses

# Step 4: Evaluation Metrics


In [None]:
def calculate_metrics(model, test_loader, vocab, idx2word, device):
    """Calculate comprehensive evaluation metrics"""
    model.eval()

    # For caption evaluation (simplified BLEU)
    from nltk.translate.bleu_score import sentence_bleu
    bleu_scores = []

    # For segmentation evaluation
    intersection = 0
    union = 0

    with torch.no_grad():
        for images, captions, seg_masks in test_loader:
            images = images.to(device)
            seg_masks = seg_masks.to(device)

            # Generate captions
            for i in range(images.size(0)):
                generated_caption_ids = model.generate_caption(images[i], vocab)
                generated_words = [idx2word[idx] for idx in generated_caption_ids
                                 if idx in idx2word and idx2word[idx] not in ['<start>', '<end>', '<pad>']]

                # Get ground truth caption
                gt_caption_ids = captions[i].cpu().numpy()
                gt_words = [idx2word[idx] for idx in gt_caption_ids
                           if idx in idx2word and idx2word[idx] not in ['<start>', '<end>', '<pad>']]

                # Calculate BLEU score
                if gt_words:
                    bleu = sentence_bleu([gt_words], generated_words)
                    bleu_scores.append(bleu)

            # Segmentation evaluation
            outputs = model(images, mode='segment')
            pred_masks = outputs['segmentation'].argmax(dim=1)

            # Calculate IoU
            for i in range(pred_masks.size(0)):
                pred = pred_masks[i].cpu()
                gt = seg_masks[i].cpu()

                intersection += (pred * gt).sum().item()
                union += (pred + gt).clamp(0, 1).sum().item()

    # Calculate final metrics
    avg_bleu = np.mean(bleu_scores) if bleu_scores else 0
    iou = intersection / union if union > 0 else 0

    print(f"Evaluation Results:")
    print(f"Average BLEU Score: {avg_bleu:.4f}")
    print(f"IoU Score: {iou:.4f}")

    return avg_bleu, iou

# Step 5: Sample Generation Function


In [None]:
def generate_sample_results(model, data_loader, vocab, idx2word, device, num_samples=3):
    """Generate and display sample results"""
    model.eval()

    with torch.no_grad():
        for i, (images, captions, seg_masks) in enumerate(data_loader):
            if i >= num_samples:
                break

            image = images[0].to(device)

            # Generate caption
            caption_ids = model.generate_caption(image, vocab)
            caption_words = [idx2word.get(idx, '<unk>') for idx in caption_ids]
            generated_caption = ' '.join([w for w in caption_words if w not in ['<start>', '<end>', '<pad>']])

            # Generate segmentation
            outputs = model(image.unsqueeze(0), mode='segment')
            seg_pred = outputs['segmentation'].squeeze(0).argmax(0).cpu()

            print(f"\nSample {i+1}:")
            print(f"Generated Caption: {generated_caption}")
            print(f"Segmentation shape: {seg_pred.shape}")


# Step 6: Main Execution Function


In [6]:
def main():
    """Main execution function"""
    print("Starting complete training pipeline...")

    # Setup from your existing code
    setup_dict = setup_training()  # Your existing function
    vocab = setup_dict['vocab']
    idx2word = setup_dict['idx2word']
    model = setup_dict['integrated_model']
    device = setup_dict['device']

    # Create data loaders
    print("Creating data loaders...")
    train_loader, val_loader = create_data_loaders(vocab, batch_size=16)

    # Start training
    print("Starting training...")
    train_losses, val_losses = train_model_with_logging(
        model, train_loader, val_loader, vocab, idx2word,
        num_epochs=5, device=device
    )

    # Evaluate model
    print("Evaluating model...")
    bleu, iou = calculate_metrics(model, val_loader, vocab, idx2word, device)

    print("\nTraining completed successfully!")
    return model, train_losses, val_losses


    """Instructions for downloading COCO dataset"""
    instructions = """
    To use real COCO dataset, follow these steps:
    
    1. Create a directory structure:
       coco/
       ├── train2017/
       ├── val2017/
       └── annotations/
           ├── instances_train2017.json
           ├── instances_val2017.json
           ├── captions_train2017.json
           └── captions_val2017.json
    
    2. Download from: https://cocodataset.org/#download
       - 2017 Train images [118K/18GB]
       - 2017 Val images [5K/1GB]
       - 2017 Train/Val annotations [241MB]
    
    3. Install pycocotools:
       pip install pycocotools
    
    4. Update the dataset paths in create_data_loaders()
    """