In [None]:
import torch 
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
from dataset import FlickrDataset, MyCollate
from vocab import Vocabulary
from model import EncoderCNN, DecoderRNN
import os 
import pickle

In [None]:
def train_one_epoch(encoder, decoder, loader, criterion, optimizer, device, pad_idx):
    encoder.train()
    decoder.train()
    epoch_loss = 0

    for imgs, captions in loader:
        imgs, captions = imgs.to(device), captions.to(device)

        features = encoder(imgs)
        outputs = decoder(features, captions)

        loss = criterion(outputs.reshape(-1, outputs.shape[2]), captions[:, 1:].reshape(-1))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    return epoch_loss / len(loader)

In [None]:
def validate(encoder, decoder, loader, criterion, device, pad_idx):
    encoder.eval()
    decoder.eval()
    val_loss = 0

    with torch.no_grad():
        for imgs, captions in loader:
            imgs, captions = imgs.to(device), captions.to(device)

            features = encoder(features, captions)
            outputs = decoder(features, captions)

            loss = criterion(outputs.reshape(-1, outputs.shape[2]), captions[:, 1:].reshape(-1))
            val_loss += loss.item()

    return val_loss / len(loader)


In [None]:
def debug_dataset_and_files(images_path, captions_file):
    """Debug function to check dataset integrity"""
    print("🔍 Debugging dataset...")
    
    # Check if paths exist
    if not os.path.exists(images_path):
        print(f"❌ Images directory does not exist: {images_path}")
        return False
    
    if not os.path.exists(captions_file):
        print(f"❌ Captions file does not exist: {captions_file}")
        return False
    
    # Count images in directory
    image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff'}
    image_files = []
    for file in os.listdir(images_path):
        if any(file.lower().endswith(ext) for ext in image_extensions):
            image_files.append(file)
    
    print(f"📁 Found {len(image_files)} image files in {images_path}")
    if len(image_files) > 0:
        print(f"   Sample files: {image_files[:3]}")
    
    # Check captions file format
    captions_count = 0
    valid_captions = 0
    sample_lines = []
    
    try:
        with open(captions_file, "r", encoding='utf-8') as f:
            for i, line in enumerate(f):
                if i == 0 and line.strip().lower().startswith('image,caption'):
                    print("   📄 Detected CSV format with header")
                    continue  # Skip header
                
                captions_count += 1
                if i < 6:  # Store first 5 lines as samples (accounting for header)
                    sample_lines.append(line.strip())
                
                # Try both comma and tab separation
                if ',' in line and line.count(',') >= 1:
                    parts = line.strip().split(',', 1)  # Split only on first comma
                elif '\t' in line:
                    parts = line.strip().split('\t')
                else:
                    parts = []
                
                if len(parts) == 2:
                    image_name = parts[0]
                    caption = parts[1]
                    
                    # Check if corresponding image exists
                    if image_name in image_files:
                        valid_captions += 1
                    elif i < 6:  # Show mismatches for first few
                        print(f"   ⚠️  Image not found for: {image_name}")
        
        print(f"📝 Found {captions_count} caption entries")
        print(f"✅ Valid image-caption pairs: {valid_captions}")
        print(f"📋 Sample caption file format:")
        for line in sample_lines:
            print(f"   {line}")
            
    except Exception as e:
        print(f"❌ Error reading captions file: {e}")
        return False
    
    return valid_captions > 0

In [None]:
def main():
    # Paths
    images_path = "/home/sahil_duwal/Projects/ImageCap/flickr8k/images"
    captions_file = "/home/sahil_duwal/Projects/ImageCap/flickr8k/captions.txt"
    save_dir = "checkpoints"
    os.makedirs(save_dir, exist_ok=True)

    # Debug dataset first
    if not debug_dataset_and_files(images_path, captions_file):
        print("❌ Dataset debugging failed. Please fix the issues above.")
        return

    # Transform
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ])

    # Build vocab
    print("\n🏗️  Building vocabulary...")
    all_captions = []
    with open(captions_file, "r", encoding='utf-8') as f:
        for i, line in enumerate(f):
            # Skip header if present
            if i == 0 and line.strip().lower().startswith('image,caption'):
                continue
            
            # Handle both CSV and TSV formats
            if ',' in line and line.count(',') >= 1:
                parts = line.strip().split(',', 1)  # Split only on first comma
            elif '\t' in line:
                parts = line.strip().split('\t')
            else:
                continue
                
            if len(parts) == 2:
                all_captions.append(parts[1].lower())

    print(f"📚 Collected {len(all_captions)} captions for vocabulary")
    
    if len(all_captions) == 0:
        print("❌ No captions found! Check your captions file format.")
        return

    vocab = Vocabulary(freq_threshold=5)
    vocab.build_vocab(all_captions)
    print(f"📖 Built vocabulary with {len(vocab)} words")

    # Save vocab for inference
    with open(os.path.join(save_dir, "vocab.pkl"), "wb") as f:
        pickle.dump(vocab, f)

    # Create dataset
    print("\n📦 Creating dataset...")
    full_dataset = FlickrDataset(images_path, captions_file, vocab, transform=transform)
    
    if len(full_dataset) == 0:
        print("❌ Dataset is empty! This means:")
        print("   1. No matching image-caption pairs found")
        print("   2. Check that image filenames in captions.txt match actual files")
        print("   3. Verify the captions file format (should be: filename\\tcaption)")
        return

    print(f"✅ Successfully loaded {len(full_dataset)} samples")

    # Dataset split
    if len(full_dataset) < 10:
        print("⚠️  Very small dataset! Using 80-20 split instead of 90-10")
        val_size = max(1, int(0.2 * len(full_dataset)))
    else:
        val_size = max(1, int(0.1 * len(full_dataset)))  # At least 1 validation sample
    
    train_size = len(full_dataset) - val_size
    print(f"📊 Train size: {train_size}, Validation size: {val_size}")

    train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

    pad_idx = vocab.stoi["<PAD>"]

    # Create data loaders
    print("\n🔄 Creating data loaders...")
    train_loader = DataLoader(
        dataset=train_dataset,
        batch_size=min(32, len(train_dataset)),  # Adjust batch size for small datasets
        shuffle=True,
        collate_fn=MyCollate(pad_idx=pad_idx)
    )

    val_loader = DataLoader(
        dataset=val_dataset,
        batch_size=min(32, len(val_dataset)),  # Adjust batch size for small datasets
        shuffle=False,
        collate_fn=MyCollate(pad_idx=pad_idx)
    )

    print(f"🔄 Train loader: {len(train_loader)} batches")
    print(f"🔄 Val loader: {len(val_loader)} batches")

    # Model setup
    print("\n🤖 Setting up model...")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"🖥️  Using device: {device}")
    
    encoder = EncoderCNN(embed_size=256).to(device)
    decoder = DecoderRNN(embed_size=256, hidden_size=512, vocab_size=len(vocab)).to(device)

    criterion = nn.CrossEntropyLoss(ignore_index=pad_idx)
    optimizer = optim.Adam(list(decoder.parameters()) + list(encoder.parameters()), lr=3e-4)

    # Training loop
    print(f"\n🚀 Starting training...")
    num_epochs = 5
    best_val_loss = float("inf")

    for epoch in range(num_epochs):
        print(f"\n📈 Epoch [{epoch+1}/{num_epochs}]")
        
        train_loss = train_one_epoch(encoder, decoder, train_loader, criterion, optimizer, device, pad_idx)
        val_loss = validate(encoder, decoder, val_loader, criterion, device, pad_idx)

        print(f"   Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")

        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(encoder.state_dict(), os.path.join(save_dir, "best_encoder.pth"))
            torch.save(decoder.state_dict(), os.path.join(save_dir, "best_decoder.pth"))
            print("   ✅ Saved new best model!")

    print("\n🎉 Training finished!")

In [None]:
if __name__ == "__main__":
    main()