In [1]:
pip install numpy<2

zsh:1: no such file or directory: 2
Note: you may need to restart the kernel to use updated packages.


In [2]:
pip install --upgrade matplotlib

Note: you may need to restart the kernel to use updated packages.


In [3]:
import os
import re
import random
import numpy as np
import matplotlib.pyplot as plt
import io
import logging
import math
from pathlib import Path
from PIL import Image
from tqdm import tqdm
from collections import defaultdict
import gradio as gr

# PyTorch imports
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torchvision import models, transforms
from sklearn.model_selection import train_test_split

# Hugging Face imports
from transformers import GPT2Tokenizer, GPT2LMHeadModel

# Metrics imports
try:
    from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
    from nltk.translate.meteor_score import meteor_score
    import nltk
    nltk.download('punkt', quiet=True)
    nltk.download('wordnet', quiet=True)
    METRICS_AVAILABLE = True
except ImportError:
    print("NLTK not installed. BLEU and METEOR metrics will not be available.")
    METRICS_AVAILABLE = False

try:
    from skimage.metrics import peak_signal_noise_ratio as psnr
    from skimage.metrics import structural_similarity as ssim
    IMAGE_METRICS_AVAILABLE = True
except ImportError:
    print("scikit-image not installed. PSNR and SSIM metrics will not be available.")
    IMAGE_METRICS_AVAILABLE = False

# Configure 
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Set seed 
def seed_everything(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed) if torch.cuda.is_available() else None
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
seed_everything()

# Global tokenizer
tokenizer = None

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cpu


In [28]:
DATASET_PATH = "/Users/nehaeshwaragari/Documents/Deep Learning/Project 3/Flickr8K"
print(f"Using dataset path: {DATASET_PATH}")

IMAGE_DIR = os.path.join(DATASET_PATH, "Images")
if not os.path.exists(IMAGE_DIR):
    IMAGE_DIR = os.path.join(DATASET_PATH, "Flickr8k_Dataset")
    if not os.path.exists(IMAGE_DIR):
        print(f"Warning: Could not find image directory in {DATASET_PATH}")

CAPTION_FILE = os.path.join(DATASET_PATH, "captions.txt")
if not os.path.exists(CAPTION_FILE):
    alt_caption_files = [
        os.path.join(DATASET_PATH, "Flickr8k.token.txt"),
        os.path.join(DATASET_PATH, "Flickr_8k.lemma.token.txt"),
        os.path.join(DATASET_PATH, "Flickr8k_text", "Flickr8k.token.txt")
    ]
    
    for alt_file in alt_caption_files:
        if os.path.exists(alt_file):
            CAPTION_FILE = alt_file
            print(f"Found caption file at: {CAPTION_FILE}")
            break
    else:
        print(f"Warning: Could not find caption file in {DATASET_PATH}")

WORK_DIR = os.path.join(DATASET_PATH, "output")
FEATURE_DIR = os.path.join(WORK_DIR, "features")
CHECKPOINT_DIR = os.path.join(WORK_DIR, "checkpoints")
os.makedirs(WORK_DIR, exist_ok=True)
os.makedirs(FEATURE_DIR, exist_ok=True)
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

# Define image transforms for ResNet
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

# Define image transforms for diffusion model
diffusion_transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),  # Normalize to [-1, 1]
])

Using dataset path: /Users/nehaeshwaragari/Documents/Deep Learning/Project 3/Flickr8K


In [30]:
# Initialize tokenizer function
def initialize_tokenizer():
    global tokenizer
    print("Initializing GPT-2 tokenizer...")
    tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
    tokenizer.pad_token = tokenizer.eos_token
    print(f"Tokenizer vocabulary size: {len(tokenizer)}")
    return tokenizer

# Load captions
def load_captions(caption_file):
    """Load captions from Flickr8k captions file"""
    captions_dict = defaultdict(list)
    
    if not os.path.exists(caption_file):
        print(f"Caption file not found: {caption_file}")
        return captions_dict
        
    try:
        with open(caption_file, 'r', encoding='utf-8') as f:
            first_line = f.readline().strip()
            f.seek(0)  # Go back to start of file
            
            if '#' in first_line and '\t' in first_line:  
                for line in f:
                    line = line.strip()
                    if not line:
                        continue
                    parts = line.split('\t')
                    if len(parts) == 2:
                        img_name = parts[0].split('#')[0]  
                        caption = parts[1].strip()
                        captions_dict[img_name].append(caption)
            else:  
                next(f, None)  
                for line in f:
                    line = line.strip()
                    if not line:
                        continue
                    parts = line.split(',', 1)
                    if len(parts) == 2:
                        img_name = parts[0].strip()
                        caption = parts[1].strip()
                        captions_dict[img_name].append(caption)
                        
        print(f"Loaded captions for {len(captions_dict)} images")
    except Exception as e:
        print(f"Error loading captions: {e}")
    
    return captions_dict

initialize_tokenizer()

# Sample captions
test_captions = load_captions(CAPTION_FILE)
sample_keys = list(test_captions.keys())[:3]
print(f"Sample of captions:")
for key in sample_keys:
    print(f"Image: {key}, Caption: {test_captions[key][0]}")

Initializing GPT-2 tokenizer...
Tokenizer vocabulary size: 50257
Loaded captions for 8091 images
Sample of captions:
Image: 1000268201_693b08cb0e.jpg, Caption: A child in a pink dress is climbing up a set of stairs in an entry way .
Image: 1001773457_577c3a7d70.jpg, Caption: A black dog and a spotted dog are fighting
Image: 1002674143_1b742ab4b8.jpg, Caption: A little girl covered in paint sits in front of a painted rainbow with her hands in a bowl .


In [32]:
class FlickrDataset(Dataset):
    def __init__(self, image_dir, captions_dict, transform=None):
        self.image_dir = image_dir
        self.captions_dict = captions_dict
        self.transform = transform
        self.image_files = list(captions_dict.keys())
        
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        img_name = self.image_files[idx]
        img_path = os.path.join(self.image_dir, img_name)
        
        # Check if image exists
        if not os.path.exists(img_path):
            for ext in ['.jpg', '.jpeg', '.png']:
                alt_path = os.path.join(self.image_dir, os.path.splitext(img_name)[0] + ext)
                if os.path.exists(alt_path):
                    img_path = alt_path
                    break
            else:
                return self.__getitem__((idx + 1) % len(self))
        
        # Load and transform image
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
            
        caption = random.choice(self.captions_dict[img_name])
        
        return image, caption, img_name

class FeatureCaptionDataset(Dataset):
    def __init__(self, feature_dir, captions_dict, tokenizer, max_len=50):
        self.feature_dir = feature_dir
        self.captions_dict = captions_dict
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.image_files = list(captions_dict.keys())
        
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        img_name = self.image_files[idx]
        feature_path = os.path.join(self.feature_dir, img_name.replace('.jpg', '.npy'))
        
        if not os.path.exists(feature_path):
            base_name = os.path.splitext(img_name)[0]
            for ext in ['.npy']:
                alt_path = os.path.join(self.feature_dir, base_name + ext)
                if os.path.exists(alt_path):
                    feature_path = alt_path
                    break
            else:
                return self.__getitem__((idx + 1) % len(self))
        
        # Load features
        feature = torch.tensor(np.load(feature_path), dtype=torch.float32)
        
        # Get random caption
        caption = random.choice(self.captions_dict[img_name])
        
        # Tokenize caption
        tokens = self.tokenizer(
            caption,
            padding="max_length",
            max_length=self.max_len,
            truncation=True,
            return_tensors="pt"
        )
        
        return feature, tokens.input_ids.squeeze(0), tokens.attention_mask.squeeze(0), caption

print("Dataset classes created.")

Dataset classes created.


In [34]:
# Extract features using ResNet50
def extract_features(dataset, feature_dir, batch_size=32):
    """Extract features from images using ResNet50"""
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    
    # Create ResNet model
    resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
    feature_extractor = nn.Sequential(*list(resnet.children())[:-1]).to(device)
    feature_extractor.eval()
    
    print(f"Extracting features for {len(dataset)} images...")
    for images, _, img_names in tqdm(dataloader):
        images = images.to(device)
        
        # Extract features
        with torch.no_grad():
            features = feature_extractor(images).squeeze(-1).squeeze(-1)
        
        # Save features
        for i, img_name in enumerate(img_names):
            feature_path = os.path.join(feature_dir, os.path.splitext(img_name)[0] + '.npy')
            np.save(feature_path, features[i].cpu().numpy())
    
    print("Feature extraction complete")
if len(test_captions) > 0:
    mini_dataset = FlickrDataset(IMAGE_DIR, {k: test_captions[k] for k in sample_keys}, transform=transform)

    need_extraction = False
    for key in sample_keys:
        feature_path = os.path.join(FEATURE_DIR, os.path.splitext(key)[0] + '.npy')
        if not os.path.exists(feature_path):
            need_extraction = True
            break
    
    if need_extraction and len(mini_dataset) > 0:
        extract_features(mini_dataset, FEATURE_DIR, batch_size=2)
        print("Sample features extracted")
    else:
        print("Features already exist for sample images")

Features already exist for sample images


In [36]:
class GPT2CaptionModel(nn.Module):
    def __init__(self, feature_dim=2048, embed_dim=768):
        super().__init__()
        self.gpt2 = GPT2LMHeadModel.from_pretrained("gpt2")
        self.project = nn.Linear(feature_dim, embed_dim)
        self.gpt2.resize_token_embeddings(len(tokenizer))
    
    def forward(self, features, input_ids, attention_mask):
        prefix = self.project(features).unsqueeze(1)
        token_embed = self.gpt2.transformer.wte(input_ids)
        embed = torch.cat([prefix, token_embed], dim=1)
        extended_mask = torch.cat([
            torch.ones((input_ids.shape[0], 1), dtype=attention_mask.dtype, device=features.device),
            attention_mask
        ], dim=1)
        labels = F.pad(input_ids, (1, 0), value=-100)
        outputs = self.gpt2(
            inputs_embeds=embed,
            attention_mask=extended_mask,
            labels=labels
        )
        
        return outputs.loss, outputs.logits
    
    def generate_caption(self, feature_tensor, max_len=30, temperature=0.9, top_k=40):
        """Generate a caption for an image feature tensor with fixed decoding"""
        self.eval()
        
        feature_tensor = feature_tensor.to(next(self.parameters()).device)
        if feature_tensor.dim() == 1:
            feature_tensor = feature_tensor.unsqueeze(0)
            
        if feature_tensor.size(0) > 1:
            captions = []
            for i in range(feature_tensor.size(0)):
                single_caption = self.generate_caption(
                    feature_tensor[i].unsqueeze(0), 
                    max_len=max_len, 
                    temperature=temperature
                )
                captions.append(single_caption)
            return captions[0] if len(captions) == 1 else captions
        
        with torch.no_grad():
            prefix = self.project(feature_tensor).unsqueeze(1)
            
            # Start with a special token
            bos_token_id = tokenizer.bos_token_id if hasattr(tokenizer, 'bos_token_id') else tokenizer.eos_token_id
            if bos_token_id is None:
                bos_token_id = 50256  
                
            input_ids = torch.tensor([[bos_token_id]]).to(feature_tensor.device)
            
            generated_text = ""
            
            for _ in range(max_len):
                token_embed = self.gpt2.transformer.wte(input_ids)
                if input_ids.size(1) == 1:
                    inputs_embeds = torch.cat([prefix, token_embed], dim=1)
                    attention_mask = torch.ones((1, 2), device=feature_tensor.device)
                else:
                    inputs_embeds = token_embed
                    attention_mask = torch.ones((1, input_ids.size(1)), device=feature_tensor.device)
                outputs = self.gpt2(inputs_embeds=inputs_embeds, attention_mask=attention_mask)
                next_token_logits = outputs.logits[:, -1, :] / temperature
                top_k_logits, top_k_indices = torch.topk(next_token_logits, top_k, dim=-1)
                next_token_logits.fill_(-float('inf'))
                next_token_logits.scatter_(1, top_k_indices, top_k_logits)

                probs = F.softmax(next_token_logits, dim=-1)
                next_token = torch.multinomial(probs, num_samples=1)
                
                input_ids = torch.cat([input_ids, next_token], dim=1)
                
                token_str = tokenizer.decode([next_token.item()], skip_special_tokens=True)
                generated_text += token_str
                
                if next_token.item() == tokenizer.eos_token_id or token_str.strip() == ".":
                    break
  
            if not generated_text.strip():
                return "A photo."
                
            return generated_text

mini_caption_model = GPT2CaptionModel().to(device)
print(f"Created caption model with {sum(p.numel() for p in mini_caption_model.parameters())} parameters")

# Load a sample feature to test caption generation
if os.path.exists(os.path.join(FEATURE_DIR, os.path.splitext(sample_keys[0])[0] + '.npy')):
    sample_feature = torch.tensor(
        np.load(os.path.join(FEATURE_DIR, os.path.splitext(sample_keys[0])[0] + '.npy')), 
        dtype=torch.float32
    ).to(device)
    
    # Generate a sample caption
    sample_caption = mini_caption_model.generate_caption(sample_feature, temperature=0.8)
    print(f"Sample image: {sample_keys[0]}")
    print(f"Actual caption: {test_captions[sample_keys[0]][0]}")
    print(f"Generated caption: {sample_caption}")

Created caption model with 126013440 parameters
Sample image: 1000268201_693b08cb0e.jpg
Actual caption: A child in a pink dress is climbing up a set of stairs in an entry way .
Generated caption: In this section, we provide a simple way to generate a string that will fit in your current state.


In [38]:
# Training function for caption model
def train_caption_model(model, train_loader, val_loader=None, epochs=5, lr=1e-4):
    """Train the caption model"""
    optimizer = AdamW(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2)
    best_val_loss = float('inf')
    
    # Training loop
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        num_batches = 0
        
        for batch_idx, (feature, input_ids, attention_mask, _) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")):
            feature = feature.to(device)
            input_ids = input_ids.to(device)
            attention_mask = attention_mask.to(device)
            
            # Zero gradients
            optimizer.zero_grad()
            
            # Forward pass
            loss, _ = model(feature, input_ids, attention_mask)
            
            # Backward pass
            loss.backward()
            
            # Clip gradients
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            # Update weights
            optimizer.step()
            
            total_loss += loss.item()
            num_batches += 1
            
            # Print progress
            if batch_idx % 50 == 0:
                print(f"Batch {batch_idx}/{len(train_loader)}, Loss: {loss.item():.4f}")
                
                # Generate a sample caption to check progress
                if batch_idx % 200 == 0:
                    with torch.no_grad():
                        sample_caption = model.generate_caption(feature[0].unsqueeze(0))
                        print(f"Sample caption: {sample_caption}")
        
        avg_loss = total_loss / num_batches if num_batches > 0 else 0
        print(f"Epoch {epoch+1}/{epochs}, Average Loss: {avg_loss:.4f}")
        
        # Validation
        if val_loader is not None:
            model.eval()
            val_loss = 0
            val_batches = 0
            
            with torch.no_grad():
                for feature, input_ids, attention_mask, _ in tqdm(val_loader, desc="Validation"):
                    feature = feature.to(device)
                    input_ids = input_ids.to(device)
                    attention_mask = attention_mask.to(device)
                    
                    loss, _ = model(feature, input_ids, attention_mask)
                    val_loss += loss.item()
                    val_batches += 1
            
            avg_val_loss = val_loss / val_batches if val_batches > 0 else 0
            print(f"Validation Loss: {avg_val_loss:.4f}")
            
            print("Sample validation captions:")
            for i in range(min(3, len(feature))):
                with torch.no_grad():
                    caption = model.generate_caption(feature[i].unsqueeze(0))
                    print(f"Generated: {caption}")
            
            # Update learning rate
            scheduler.step(avg_val_loss)
            
            # Save best model
            if avg_val_loss < best_val_loss:
                best_val_loss = avg_val_loss
                best_model_path = os.path.join(CHECKPOINT_DIR, "caption_model_best.pt")
                torch.save(model.state_dict(), best_model_path)
                print(f"New best model saved to {best_model_path}")
        
        # Save checkpoint
        checkpoint_path = os.path.join(CHECKPOINT_DIR, f"caption_model_epoch_{epoch+1}.pt")
        torch.save(model.state_dict(), checkpoint_path)
        print(f"Checkpoint saved to {checkpoint_path}")
    
    # Save final model
    final_path = os.path.join(CHECKPOINT_DIR, "caption_model_final.pt")
    torch.save(model.state_dict(), final_path)
    print(f"Final model saved to {final_path}")
    
    if val_loader is not None and os.path.exists(os.path.join(CHECKPOINT_DIR, "caption_model_best.pt")):
        model.load_state_dict(torch.load(os.path.join(CHECKPOINT_DIR, "caption_model_best.pt")))
        print("Loaded best model based on validation loss")
    
    return model

def evaluate_caption_model(model, dataloader, num_samples=10):
    """Evaluate the caption model and print example outputs"""
    model.eval()
    
    samples = []
    metrics = {'bleu1': [], 'bleu4': []}
    
    if not METRICS_AVAILABLE:
        print("NLTK metrics not available for evaluation.")
    
    with torch.no_grad():
        for batch in dataloader:
            feature, _, _, true_caption = batch
            
            for i in range(feature.size(0)):
                single_feature = feature[i].unsqueeze(0).to(device)
                single_caption = true_caption[i]
                
                generated_caption = model.generate_caption(
                    single_feature, 
                    temperature=0.8,
                    top_k=40
                )
                
                # Print debugging info
                print(f"Feature shape: {single_feature.shape}")
                print(f"Generated caption: '{generated_caption}'")
                
                # Store sample
                samples.append({
                    'true': single_caption,
                    'generated': generated_caption
                })
                
                if METRICS_AVAILABLE and generated_caption.strip():
                    # BLEU score
                    reference = [single_caption.split()]
                    candidate = generated_caption.split()
                    
                    # Skip empty candidates
                    if not candidate:
                        continue
                    
                    # BLEU-1
                    bleu1 = sentence_bleu(reference, candidate, 
                                         weights=(1, 0, 0, 0),
                                         smoothing_function=SmoothingFunction().method1)
                    metrics['bleu1'].append(bleu1)
                    
                    # BLEU-4
                    bleu4 = sentence_bleu(reference, candidate,
                                         weights=(0.25, 0.25, 0.25, 0.25),
                                         smoothing_function=SmoothingFunction().method1)
                    metrics['bleu4'].append(bleu4)
                
                if len(samples) >= num_samples:
                    break
            
            if len(samples) >= num_samples:
                break
    
    # Print results
    print("\nCaption Generation Examples:")
    for i, sample in enumerate(samples):
        print(f"Example {i+1}:")
        print(f"True: {sample['true']}")
        print(f"Generated: {sample['generated']}")
        print()
    
    # Print metrics if available
    if METRICS_AVAILABLE and len(metrics['bleu1']) > 0:
        print("Metrics:")
        print(f"BLEU-1: {sum(metrics['bleu1']) / len(metrics['bleu1']):.4f}")
        print(f"BLEU-4: {sum(metrics['bleu4']) / len(metrics['bleu4']):.4f}")
    else:
        print("No metrics available or all generated captions were empty")
    
    return samples, metrics

print("Improved training and evaluation functions defined")

Improved training and evaluation functions defined


In [40]:
# Visualization functions
def visualize_caption_results(samples):
    """Create a visualization grid of caption results"""
    plt.figure(figsize=(12, len(samples) * 2))
    
    for i, sample in enumerate(samples):
        plt.subplot(len(samples), 1, i + 1)
        plt.text(0.5, 0.5, f"True: {sample['true']}\nGenerated: {sample['generated']}", 
                 ha='center', va='center', fontsize=12)
        plt.axis('off')
    
    plt.tight_layout()
    return plt

def visualize_denoising_results(samples):
    """Create a visualization grid of denoising results"""
    plt.figure(figsize=(12, len(samples) * 4))
    
    for i, sample in enumerate(samples):
        original = sample['original'].permute(1, 2, 0).numpy()
        original = (original + 1) / 2  # Convert from [-1,1] to [0,1]
        
        noisy = sample['noisy'].permute(1, 2, 0).numpy()
        noisy = (noisy + 1) / 2  # Convert from [-1,1] to [0,1]
        
        denoised = sample['denoised'].permute(1, 2, 0).numpy()
        denoised = (denoised + 1) / 2  # Convert from [-1,1] to [0,1]
        
        # Original image
        plt.subplot(len(samples), 3, i * 3 + 1)
        plt.imshow(np.clip(original, 0, 1))
        plt.title('Original')
        plt.axis('off')
        
        # Noisy image
        plt.subplot(len(samples), 3, i * 3 + 2)
        plt.imshow(np.clip(noisy, 0, 1))
        plt.title('Noisy')
        plt.axis('off')
        
        # Denoised image
        plt.subplot(len(samples), 3, i * 3 + 3)
        plt.imshow(np.clip(denoised, 0, 1))
        plt.title('Denoised')
        plt.axis('off')
    
    plt.tight_layout()
    return plt

def tensor_to_pil(tensor):
    """Convert a tensor to PIL image for display"""
    tensor = tensor.cpu().clone()
    tensor = tensor.squeeze(0)
    
    if tensor.min() < 0:
        tensor = (tensor + 1) / 2
    
    # Convert to PIL
    tensor = tensor.clamp(0, 1)
    tensor = tensor.permute(1, 2, 0).numpy() * 255
    return Image.fromarray(tensor.astype('uint8'))

In [42]:
def main():
    print("Starting image captioning and denoising pipeline...")
    
    print("Loading captions...")
    captions_dict = load_captions(CAPTION_FILE)
    if len(captions_dict) == 0:
        print("No captions found. Exiting.")
        return
    

    print("Creating dataset for feature extraction...")
    dataset = FlickrDataset(IMAGE_DIR, captions_dict, transform=transform)
    
    if len(os.listdir(FEATURE_DIR)) < len(dataset) * 0.8:
        print("Extracting features...")
        extract_features(dataset, FEATURE_DIR)
    else:
        print("Features already exist, skipping extraction.")
    
    # Initialize GPT-2 tokenizer (globally)
    print("Initializing tokenizer...")
    initialize_tokenizer()
    
    print("Creating feature-caption dataset...")

    print("Creating feature-caption dataset...")
    feature_dataset = FeatureCaptionDataset(FEATURE_DIR, captions_dict, tokenizer)
    
    # Split data
    train_indices, val_indices = train_test_split(
        range(len(feature_dataset)), test_size=0.1, random_state=42
    )
    
    # Create DataLoader
    train_loader = DataLoader(
        [feature_dataset[i] for i in train_indices],
        batch_size=32,
        shuffle=True,
        num_workers=0 
    )
    
    val_loader = DataLoader(
        [feature_dataset[i] for i in val_indices],
        batch_size=16,  
        shuffle=False,
        num_workers=0
    )
    
    print("Initializing caption model...")
    caption_model = GPT2CaptionModel().to(device)
   
    pretrained_path = os.path.join(CHECKPOINT_DIR, "caption_model_final.pt")
    if os.path.exists(pretrained_path):
        print(f"Loading pretrained model from {pretrained_path}")
        caption_model.load_state_dict(torch.load(pretrained_path))
    else:
        print("Training caption model...")
        caption_model = train_caption_model(
            caption_model,
            train_loader,
            val_loader,
            epochs=3,
            lr=5e-5
        )
    
    print("Evaluating caption model...")
    caption_samples, caption_metrics = evaluate_caption_model(
        caption_model,
        val_loader,
        num_samples=5
    )
    
    print("Creating dataset for denoising model...")
    image_dataset = FlickrDataset(IMAGE_DIR, captions_dict, transform=diffusion_transform)
    
    # Split data
    train_indices, val_indices = train_test_split(
        range(len(image_dataset)), test_size=0.1, random_state=42
    )
    
    denoising_train_loader = DataLoader(
        [image_dataset[i] for i in train_indices],
        batch_size=16,
        shuffle=True,
        num_workers=0  
    )
    
    denoising_val_loader = DataLoader(
        [image_dataset[i] for i in val_indices],
        batch_size=16,
        shuffle=False,
        num_workers=0
    )
    
    # Train denoiser model
    print("Initializing denoiser model...")
    denoiser_model = SimpleDenoiser().to(device)
    
    # Check if pretrained model exists
    pretrained_denoiser_path = os.path.join(CHECKPOINT_DIR, "denoiser_model_final.pt")
    if os.path.exists(pretrained_denoiser_path):
        print(f"Loading pretrained denoiser model from {pretrained_denoiser_path}")
        denoiser_model.load_state_dict(torch.load(pretrained_denoiser_path))
    else:
        print("Training denoiser model...")
        denoiser_model = train_denoiser(
            denoiser_model,
            denoising_train_loader,
            denoising_val_loader,
            epochs=3,
            lr=1e-4
        )
    
    # Evaluate denoiser model
    print("Evaluating denoiser model...")
    denoising_samples, denoising_metrics = evaluate_denoiser(
        denoiser_model,
        denoising_val_loader,
        num_samples=5
    )
    
    print("Pipeline complete!")
    
    # Return models and evaluation results
    return {
        'caption_model': caption_model,
        'denoiser_model': denoiser_model,
        'caption_samples': caption_samples,
        'caption_metrics': caption_metrics,
        'denoising_samples': denoising_samples,
        'denoising_metrics': denoising_metrics
    }

In [44]:
# SimpleDenoiser model
class SimpleDenoiser(nn.Module):
    def __init__(self, in_channels=3, hidden_channels=64):
        super().__init__()
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(hidden_channels, hidden_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2)
        )
        
        # Middle
        self.middle = nn.Sequential(
            nn.Conv2d(hidden_channels, hidden_channels*2, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(hidden_channels*2, hidden_channels*2, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
        )
        
        # Decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(hidden_channels*2, hidden_channels, kernel_size=2, stride=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(hidden_channels, hidden_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(hidden_channels, in_channels, kernel_size=3, padding=1),
            nn.Tanh()  # Output in [-1, 1] range
        )
        
    def forward(self, x):
        x = self.encoder(x)
        x = self.middle(x)
        x = self.decoder(x)
        return x

def add_noise_to_image(image_tensor, noise_level=0.1):
    """Add random noise to image tensor"""
    noise = torch.randn_like(image_tensor) * noise_level
    noisy_image = image_tensor + noise
    return torch.clamp(noisy_image, -1, 1)

def train_denoiser(model, train_loader, val_loader=None, epochs=5, lr=1e-4):
    """Train the denoiser model"""
    optimizer = AdamW(model.parameters(), lr=lr)
    criterion = nn.MSELoss()
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2)
    best_val_loss = float('inf')
    
    # Training loop
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        num_batches = 0
        
        for batch_idx, (images, _, _) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")):
            images = images.to(device)
            
            # Add noise to images
            noisy_images = add_noise_to_image(images, noise_level=0.1)
            
            # Zero gradients
            optimizer.zero_grad()
            
            # Forward pass
            denoised_images = model(noisy_images)
            
            # Calculate loss
            loss = criterion(denoised_images, images)
            
            # Backward pass
            loss.backward()
            
            # Update weights
            optimizer.step()
            
            total_loss += loss.item()
            num_batches += 1
            
     
            if batch_idx % 50 == 0:
                print(f"Batch {batch_idx}/{len(train_loader)}, Loss: {loss.item():.4f}")
        
        avg_loss = total_loss / num_batches if num_batches > 0 else 0
        print(f"Epoch {epoch+1}/{epochs}, Average Loss: {avg_loss:.4f}")
        
        # Validation
        if val_loader is not None:
            model.eval()
            val_loss = 0
            val_batches = 0
            
            with torch.no_grad():
                for images, _, _ in tqdm(val_loader, desc="Validation"):
                    images = images.to(device)
                    
                    noisy_images = add_noise_to_image(images, noise_level=0.1)
                    
                    # Forward pass
                    denoised_images = model(noisy_images)
                    
                    # Calculate loss
                    loss = criterion(denoised_images, images)
                    val_loss += loss.item()
                    val_batches += 1
            
            avg_val_loss = val_loss / val_batches if val_batches > 0 else 0
            print(f"Validation Loss: {avg_val_loss:.4f}")
            
            # Update learning rate
            scheduler.step(avg_val_loss)
            
            # Save best model
            if avg_val_loss < best_val_loss:
                best_val_loss = avg_val_loss
                best_model_path = os.path.join(CHECKPOINT_DIR, "denoiser_model_best.pt")
                torch.save(model.state_dict(), best_model_path)
                print(f"New best model saved to {best_model_path}")
        
        # Save checkpoint
        checkpoint_path = os.path.join(CHECKPOINT_DIR, f"denoiser_model_epoch_{epoch+1}.pt")
        torch.save(model.state_dict(), checkpoint_path)
        print(f"Checkpoint saved to {checkpoint_path}")
    
    # Save final model
    final_path = os.path.join(CHECKPOINT_DIR, "denoiser_model_final.pt")
    torch.save(model.state_dict(), final_path)
    print(f"Final model saved to {final_path}")
    
    # Load best model if available
    if val_loader is not None and os.path.exists(os.path.join(CHECKPOINT_DIR, "denoiser_model_best.pt")):
        model.load_state_dict(torch.load(os.path.join(CHECKPOINT_DIR, "denoiser_model_best.pt")))
        print("Loaded best model based on validation loss")
    
    return model

# Evaluation function for denoiser model
def evaluate_denoiser(model, dataloader, num_samples=5):
    """Evaluate the denoiser model and display example outputs"""
    model.eval()
    
    samples = []
    metrics = {'psnr': [], 'ssim': []}
    
    if not IMAGE_METRICS_AVAILABLE:
        print("Image metrics not available for evaluation.")
    
    with torch.no_grad():
        for images, _, _ in dataloader:
            images = images.to(device)
            
            # Add noise to images
            noisy_images = add_noise_to_image(images, noise_level=0.1)
            
            # Denoise
            denoised_images = model(noisy_images)
            
            # Calculate metrics if available
            if IMAGE_METRICS_AVAILABLE:
                for i in range(images.size(0)):
                    # Convert tensors to numpy arrays for metric calculation
                    clean_np = images[i].cpu().permute(1, 2, 0).numpy()
                    clean_np = (clean_np + 1) / 2  # Convert from [-1,1] to [0,1]
                    
                    denoised_np = denoised_images[i].cpu().permute(1, 2, 0).numpy()
                    denoised_np = (denoised_np + 1) / 2  # Convert from [-1,1] to [0,1]
                    
                    # Calculate PSNR
                    psnr_val = psnr(clean_np, denoised_np, data_range=1.0)
                    metrics['psnr'].append(psnr_val)
                    
                    # Calculate SSIM
                    ssim_val = ssim(clean_np, denoised_np, data_range=1.0, multichannel=True)
                    metrics['ssim'].append(ssim_val)
            
            # Store sample images
            for i in range(min(images.size(0), num_samples - len(samples))):
                samples.append({
                    'original': images[i].cpu(),
                    'noisy': noisy_images[i].cpu(),
                    'denoised': denoised_images[i].cpu()
                })
            
            if len(samples) >= num_samples:
                break
    
    # Print metrics 
    if IMAGE_METRICS_AVAILABLE and len(metrics['psnr']) > 0:
        print("Image Denoising Metrics:")
        print(f"PSNR: {sum(metrics['psnr']) / len(metrics['psnr']):.4f} dB")
        print(f"SSIM: {sum(metrics['ssim']) / len(metrics['ssim']):.4f}")
    
    return samples, metrics

In [46]:
# Setup Gradio interface
def setup_gradio_interface(caption_model, denoiser_model):
    def process_image(input_image):
        input_image_pil = Image.fromarray(input_image).convert('RGB')
        
        input_tensor = transform(input_image_pil).unsqueeze(0).to(device)
        
        # Extract features
        with torch.no_grad():
            resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
            feature_extractor = nn.Sequential(*list(resnet.children())[:-1]).to(device)
            feature_extractor.eval()
            features = feature_extractor(input_tensor).squeeze(-1).squeeze(-1)
        
        # Generate caption
        caption = caption_model.generate_caption(features, temperature=0.7)
        
        # Process for denoising
        noisy_tensor = diffusion_transform(input_image_pil).unsqueeze(0).to(device)
        noisy_tensor = add_noise_to_image(noisy_tensor, noise_level=0.2)
        
        # Denoise
        with torch.no_grad():
            denoised_tensor = denoiser_model(noisy_tensor)
        
        # Convert tensors to images
        noisy_image = tensor_to_pil(noisy_tensor)
        denoised_image = tensor_to_pil(denoised_tensor)
        
        return caption, noisy_image, denoised_image
    
    # Create Gradio interface
    interface = gr.Interface(
        fn=process_image,
        inputs=gr.Image(),
        outputs=[
            gr.Textbox(label="Generated Caption"),
            gr.Image(label="Noisy Image"),
            gr.Image(label="Denoised Image")
        ],
        title="Image Captioning and Denoising Demo",
        description="Upload an image to generate a caption and see denoising in action."
    )
    
    return interface

In [48]:
if __name__ == "__main__":
    USE_SMALL_DATASET = True
    
    if USE_SMALL_DATASET:
        print("Testing with a small dataset...")
        test_captions = load_captions(CAPTION_FILE)
        sample_keys = list(test_captions.keys())[:100]  
        small_captions = {k: test_captions[k] for k in sample_keys}
        
        dataset = FlickrDataset(IMAGE_DIR, small_captions, transform=transform)
        
        for key in sample_keys:
            feature_path = os.path.join(FEATURE_DIR, os.path.splitext(key)[0] + '.npy')
            if not os.path.exists(feature_path):
                extract_features(dataset, FEATURE_DIR, batch_size=8)
                break
        
        # Initialize tokenizer
        initialize_tokenizer()
        
        # Create dataset with features and captions
        feature_dataset = FeatureCaptionDataset(FEATURE_DIR, small_captions, tokenizer)
        
        # Split data
        train_indices, val_indices = train_test_split(
            range(len(feature_dataset)), test_size=0.2, random_state=42
        )
        
        # Create DataLoader
        train_loader = DataLoader(
            [feature_dataset[i] for i in train_indices],
            batch_size=8,
            shuffle=True
        )
        
        val_loader = DataLoader(
            [feature_dataset[i] for i in val_indices],
            batch_size=4
        )
        
        caption_model = GPT2CaptionModel().to(device)
        
        # Training will take time even with small dataset
        print("Training caption model (this may take some time even with small dataset)...")
        caption_model = train_caption_model(
            caption_model,
            train_loader,
            val_loader,
            epochs=5,  
            lr=5e-5
        )
        
        # Evaluate caption model
        print("Evaluating caption model...")
        caption_samples, caption_metrics = evaluate_caption_model(
            caption_model,
            val_loader,
            num_samples=5
        )
        
        # Display results
        if caption_samples:
            print("\nCaption Results:")
            for i, sample in enumerate(caption_samples):
                print(f"Example {i+1}:")
                print(f"True: {sample['true']}")
                print(f"Generated: {sample['generated']}")
                print()
        
        # Skip denoiser training for this quick test
        print("Skipping denoiser training for this quick test")
    else:
        results = main()
        if 'caption_model' in results and 'denoiser_model' in results:
            interface = setup_gradio_interface(results['caption_model'], results['denoiser_model'])
            interface.launch(share=True)
        else:
            print("Model training failed.")

Testing with a small dataset...
Loaded captions for 8091 images
Initializing GPT-2 tokenizer...
Tokenizer vocabulary size: 50257
Training caption model (this may take some time even with small dataset)...


Epoch 1/5:   0%|                                         | 0/10 [00:00<?, ?it/s]`loss_type=None` was set in the config but it is unrecognised.Using the default loss: `ForCausalLMLoss`.
Epoch 1/5:  10%|███▎                             | 1/10 [00:01<00:10,  1.17s/it]

Batch 0/10, Loss: 7.7736
Sample caption: (Editor's note: This story was originally published on Nov.


Epoch 1/5: 100%|████████████████████████████████| 10/10 [00:06<00:00,  1.55it/s]


Epoch 1/5, Average Loss: 2.4130


Validation: 100%|█████████████████████████████████| 5/5 [00:00<00:00, 10.80it/s]


Validation Loss: 0.9414
Sample validation captions:
Generated: A photo.
Generated: A photo.
Generated: A photo.
New best model saved to /Users/nehaeshwaragari/Documents/Deep Learning/Project 3/Flickr8K/output/checkpoints/caption_model_best.pt
Checkpoint saved to /Users/nehaeshwaragari/Documents/Deep Learning/Project 3/Flickr8K/output/checkpoints/caption_model_epoch_1.pt


Epoch 2/5:  10%|███▎                             | 1/10 [00:00<00:05,  1.57it/s]

Batch 0/10, Loss: 0.9679
Sample caption: A photo.


Epoch 2/5: 100%|████████████████████████████████| 10/10 [00:05<00:00,  1.78it/s]


Epoch 2/5, Average Loss: 0.8113


Validation: 100%|█████████████████████████████████| 5/5 [00:00<00:00, 12.43it/s]


Validation Loss: 0.8487
Sample validation captions:
Generated: A photo.
Generated: A photo.
Generated: A photo.
New best model saved to /Users/nehaeshwaragari/Documents/Deep Learning/Project 3/Flickr8K/output/checkpoints/caption_model_best.pt
Checkpoint saved to /Users/nehaeshwaragari/Documents/Deep Learning/Project 3/Flickr8K/output/checkpoints/caption_model_epoch_2.pt


Epoch 3/5:  10%|███▎                             | 1/10 [00:00<00:05,  1.60it/s]

Batch 0/10, Loss: 0.6994
Sample caption: A photo.


Epoch 3/5: 100%|████████████████████████████████| 10/10 [00:05<00:00,  1.78it/s]


Epoch 3/5, Average Loss: 0.6358


Validation: 100%|█████████████████████████████████| 5/5 [00:00<00:00, 12.92it/s]


Validation Loss: 0.8222
Sample validation captions:
Generated: A photo.
Generated: A photo.
Generated: A photo.
New best model saved to /Users/nehaeshwaragari/Documents/Deep Learning/Project 3/Flickr8K/output/checkpoints/caption_model_best.pt
Checkpoint saved to /Users/nehaeshwaragari/Documents/Deep Learning/Project 3/Flickr8K/output/checkpoints/caption_model_epoch_3.pt


Epoch 4/5:  10%|███▎                             | 1/10 [00:00<00:05,  1.58it/s]

Batch 0/10, Loss: 0.6532
Sample caption: A photo.


Epoch 4/5: 100%|████████████████████████████████| 10/10 [00:05<00:00,  1.74it/s]


Epoch 4/5, Average Loss: 0.4920


Validation: 100%|█████████████████████████████████| 5/5 [00:00<00:00, 11.58it/s]


Validation Loss: 0.8840
Sample validation captions:
Generated: A photo.
Generated: A photo.
Generated: A photo.
Checkpoint saved to /Users/nehaeshwaragari/Documents/Deep Learning/Project 3/Flickr8K/output/checkpoints/caption_model_epoch_4.pt


Epoch 5/5:  10%|███▎                             | 1/10 [00:00<00:05,  1.58it/s]

Batch 0/10, Loss: 0.4393
Sample caption: A photo.


Epoch 5/5: 100%|████████████████████████████████| 10/10 [00:05<00:00,  1.76it/s]


Epoch 5/5, Average Loss: 0.3509


Validation: 100%|█████████████████████████████████| 5/5 [00:00<00:00, 12.17it/s]


Validation Loss: 0.9505
Sample validation captions:
Generated: A photo.
Generated: A photo.
Generated: A photo.
Checkpoint saved to /Users/nehaeshwaragari/Documents/Deep Learning/Project 3/Flickr8K/output/checkpoints/caption_model_epoch_5.pt
Final model saved to /Users/nehaeshwaragari/Documents/Deep Learning/Project 3/Flickr8K/output/checkpoints/caption_model_final.pt
Loaded best model based on validation loss
Evaluating caption model...
Feature shape: torch.Size([1, 2048])
Generated caption: 'A photo.'
Feature shape: torch.Size([1, 2048])
Generated caption: 'A photo.'
Feature shape: torch.Size([1, 2048])
Generated caption: 'A photo.'
Feature shape: torch.Size([1, 2048])
Generated caption: 'A photo.'
Feature shape: torch.Size([1, 2048])
Generated caption: 'A photo.'

Caption Generation Examples:
Example 1:
True: two dogs running around
Generated: A photo.

Example 2:
True: A dog and a tennis ball .
Generated: A photo.

Example 3:
True: Two people wearing yellow jackets cross-country sk

In [50]:
if __name__ == "__main__":
    # For testing with a small dataset
    USE_SMALL_DATASET = True
    
    if USE_SMALL_DATASET:
        print("Testing with a small dataset...")
        test_captions = load_captions(CAPTION_FILE)
        sample_keys = list(test_captions.keys())[:100]  # Use only 100 images
        small_captions = {k: test_captions[k] for k in sample_keys}
        
        # Create dataset for feature extraction
        dataset = FlickrDataset(IMAGE_DIR, small_captions, transform=transform)
        
        # Extract features if needed
        for key in sample_keys:
            feature_path = os.path.join(FEATURE_DIR, os.path.splitext(key)[0] + '.npy')
            if not os.path.exists(feature_path):
                extract_features(dataset, FEATURE_DIR, batch_size=8)
                break
        
        # Initialize tokenizer
        initialize_tokenizer()
        
        # Create dataset with features and captions
        feature_dataset = FeatureCaptionDataset(FEATURE_DIR, small_captions, tokenizer)
        
        # Split data
        train_indices, val_indices = train_test_split(
            range(len(feature_dataset)), test_size=0.2, random_state=42
        )
        
        # Create DataLoader
        train_loader = DataLoader(
            [feature_dataset[i] for i in train_indices],
            batch_size=8,
            shuffle=True
        )
        
        val_loader = DataLoader(
            [feature_dataset[i] for i in val_indices],
            batch_size=4
        )
        
        # Train caption model (with limited epochs)
        caption_model = GPT2CaptionModel().to(device)
        

        print("Training caption model (this may take some time even with small dataset)...")
        caption_model = train_caption_model(
            caption_model,
            train_loader,
            val_loader,
            epochs=5,  
            lr=5e-5
        )
        
        # Evaluate caption model
        print("Evaluating caption model...")
        caption_samples, caption_metrics = evaluate_caption_model(
            caption_model,
            val_loader,
            num_samples=5
        )
        
        # Display results
        if caption_samples:
            print("\nCaption Results:")
            for i, sample in enumerate(caption_samples):
                print(f"Example {i+1}:")
                print(f"True: {sample['true']}")
                print(f"Generated: {sample['generated']}")
                print()
        
        print("Skipping denoiser training for this quick test")
    else:
        # Run full training
        results = main()
        
        # Check if models were created successfully
        if 'caption_model' in results and 'denoiser_model' in results:
            # Set up and launch Gradio interface
            interface = setup_gradio_interface(results['caption_model'], results['denoiser_model'])
            interface.launch(share=True)
        else:
            print("Model training failed.")

Testing with a small dataset...
Loaded captions for 8091 images
Initializing GPT-2 tokenizer...
Tokenizer vocabulary size: 50257
Training caption model (this may take some time even with small dataset)...


Epoch 1/5:  10%|███▎                             | 1/10 [00:00<00:07,  1.15it/s]

Batch 0/10, Loss: 7.3777
Sample caption: The following was written by Richard A.


Epoch 1/5: 100%|████████████████████████████████| 10/10 [00:05<00:00,  1.70it/s]


Epoch 1/5, Average Loss: 2.2931


Validation: 100%|█████████████████████████████████| 5/5 [00:00<00:00, 12.44it/s]


Validation Loss: 0.8904
Sample validation captions:
Generated: A photo.
Generated: A photo.
Generated: A photo.
New best model saved to /Users/nehaeshwaragari/Documents/Deep Learning/Project 3/Flickr8K/output/checkpoints/caption_model_best.pt
Checkpoint saved to /Users/nehaeshwaragari/Documents/Deep Learning/Project 3/Flickr8K/output/checkpoints/caption_model_epoch_1.pt


Epoch 2/5:  10%|███▎                             | 1/10 [00:00<00:05,  1.55it/s]

Batch 0/10, Loss: 0.9253
Sample caption: A photo.


Epoch 2/5: 100%|████████████████████████████████| 10/10 [00:05<00:00,  1.67it/s]


Epoch 2/5, Average Loss: 0.8196


Validation: 100%|█████████████████████████████████| 5/5 [00:00<00:00, 11.97it/s]


Validation Loss: 0.8100
Sample validation captions:
Generated: A photo.
Generated: A photo.
Generated: A photo.
New best model saved to /Users/nehaeshwaragari/Documents/Deep Learning/Project 3/Flickr8K/output/checkpoints/caption_model_best.pt
Checkpoint saved to /Users/nehaeshwaragari/Documents/Deep Learning/Project 3/Flickr8K/output/checkpoints/caption_model_epoch_2.pt


Epoch 3/5:  10%|███▎                             | 1/10 [00:00<00:05,  1.61it/s]

Batch 0/10, Loss: 0.7606
Sample caption: A photo.


Epoch 3/5: 100%|████████████████████████████████| 10/10 [00:05<00:00,  1.76it/s]


Epoch 3/5, Average Loss: 0.6370


Validation: 100%|█████████████████████████████████| 5/5 [00:00<00:00, 13.02it/s]


Validation Loss: 0.7837
Sample validation captions:
Generated: A photo.
Generated: A photo.
Generated: A photo.
New best model saved to /Users/nehaeshwaragari/Documents/Deep Learning/Project 3/Flickr8K/output/checkpoints/caption_model_best.pt
Checkpoint saved to /Users/nehaeshwaragari/Documents/Deep Learning/Project 3/Flickr8K/output/checkpoints/caption_model_epoch_3.pt


Epoch 4/5:  10%|███▎                             | 1/10 [00:00<00:05,  1.64it/s]

Batch 0/10, Loss: 0.6139
Sample caption: A photo.


Epoch 4/5: 100%|████████████████████████████████| 10/10 [00:05<00:00,  1.73it/s]


Epoch 4/5, Average Loss: 0.4797


Validation: 100%|█████████████████████████████████| 5/5 [00:00<00:00, 11.54it/s]


Validation Loss: 0.8464
Sample validation captions:
Generated: A photo.
Generated: A photo.
Generated: A photo.
Checkpoint saved to /Users/nehaeshwaragari/Documents/Deep Learning/Project 3/Flickr8K/output/checkpoints/caption_model_epoch_4.pt


Epoch 5/5:  10%|███▎                             | 1/10 [00:00<00:05,  1.62it/s]

Batch 0/10, Loss: 0.5601
Sample caption: A photo.


Epoch 5/5: 100%|████████████████████████████████| 10/10 [00:05<00:00,  1.78it/s]


Epoch 5/5, Average Loss: 0.3309


Validation: 100%|█████████████████████████████████| 5/5 [00:00<00:00, 12.30it/s]


Validation Loss: 0.9566
Sample validation captions:
Generated: A photo.
Generated: A photo.
Generated: A photo.
Checkpoint saved to /Users/nehaeshwaragari/Documents/Deep Learning/Project 3/Flickr8K/output/checkpoints/caption_model_epoch_5.pt
Final model saved to /Users/nehaeshwaragari/Documents/Deep Learning/Project 3/Flickr8K/output/checkpoints/caption_model_final.pt
Loaded best model based on validation loss
Evaluating caption model...
Feature shape: torch.Size([1, 2048])
Generated caption: 'A photo.'
Feature shape: torch.Size([1, 2048])
Generated caption: 'A photo.'
Feature shape: torch.Size([1, 2048])
Generated caption: 'A photo.'
Feature shape: torch.Size([1, 2048])
Generated caption: 'A photo.'
Feature shape: torch.Size([1, 2048])
Generated caption: 'A photo.'

Caption Generation Examples:
Example 1:
True: two dogs running around
Generated: A photo.

Example 2:
True: A dog and a tennis ball .
Generated: A photo.

Example 3:
True: Two people wearing yellow jackets cross-country sk

In [51]:
def setup_gradio_interface(caption_model, denoiser_model):
    def process_image(input_image):
        input_image_pil = Image.fromarray(input_image).convert('RGB')
        
        input_tensor = transform(input_image_pil).unsqueeze(0).to(device)
        
        with torch.no_grad():
            resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
            feature_extractor = nn.Sequential(*list(resnet.children())[:-1]).to(device)
            feature_extractor.eval()
            features = feature_extractor(input_tensor).squeeze(-1).squeeze(-1)
        
        # Generate caption
        caption = caption_model.generate_caption(features, temperature=0.7)
        
        # Process for denoising
        noisy_tensor = diffusion_transform(input_image_pil).unsqueeze(0).to(device)
        noisy_tensor = add_noise_to_image(noisy_tensor, noise_level=0.2)
        
        # Denoise
        with torch.no_grad():
            denoised_tensor = denoiser_model(noisy_tensor)
        
        # Convert tensors to images
        noisy_image = tensor_to_pil(noisy_tensor)
        denoised_image = tensor_to_pil(denoised_tensor)
        
        return caption, noisy_image, denoised_image
    
    # Create Gradio interface
    interface = gr.Interface(
        fn=process_image,
        inputs=gr.Image(),
        outputs=[
            gr.Textbox(label="Generated Caption"),
            gr.Image(label="Noisy Image"),
            gr.Image(label="Denoised Image")
        ],
        title="Image Captioning and Denoising Demo",
        description="Upload an image to generate a caption and see denoising in action."
    )
    
    return interface

if __name__ == "__main__":
    USE_SMALL_DATASET = True
    
    if USE_SMALL_DATASET:
        print("Testing with a small dataset...")
        test_captions = load_captions(CAPTION_FILE)
        sample_keys = list(test_captions.keys())[:100] 
        small_captions = {k: test_captions[k] for k in sample_keys}
        
        # Create dataset for feature extraction
        dataset = FlickrDataset(IMAGE_DIR, small_captions, transform=transform)
        
        # Extract features if needed
        for key in sample_keys:
            feature_path = os.path.join(FEATURE_DIR, os.path.splitext(key)[0] + '.npy')
            if not os.path.exists(feature_path):
                extract_features(dataset, FEATURE_DIR, batch_size=8)
                break
        
        # Initialize tokenizer
        initialize_tokenizer()
        
        # Create dataset with features and captions
        feature_dataset = FeatureCaptionDataset(FEATURE_DIR, small_captions, tokenizer)
        
        # Split data
        train_indices, val_indices = train_test_split(
            range(len(feature_dataset)), test_size=0.2, random_state=42
        )
        
        # Create DataLoader
        train_loader = DataLoader(
            [feature_dataset[i] for i in train_indices],
            batch_size=8,
            shuffle=True
        )
        
        val_loader = DataLoader(
            [feature_dataset[i] for i in val_indices],
            batch_size=4
        )
        
        # Train caption model (with limited epochs)
        caption_model = GPT2CaptionModel().to(device)
        
        # Training will take time even with small dataset
        print("Training caption model (this may take some time even with small dataset)...")
        caption_model = train_caption_model(
            caption_model,
            train_loader,
            val_loader,
            epochs=5,  
            lr=5e-5
        )
        
        # Evaluate caption model
        print("Evaluating caption model...")
        caption_samples, caption_metrics = evaluate_caption_model(
            caption_model,
            val_loader,
            num_samples=5
        )
        
        # Display results
        if caption_samples:
            print("\nCaption Results:")
            for i, sample in enumerate(caption_samples):
                print(f"Example {i+1}:")
                print(f"True: {sample['true']}")
                print(f"Generated: {sample['generated']}")
                print()
        
        # Skip denoiser training for this quick test
        print("Skipping denoiser training for this quick test")
    else:
        # Run full training
        results = main()
        
        # Check if models were created successfully
        if 'caption_model' in results and 'denoiser_model' in results:
            # Set up and launch Gradio interface
            interface = setup_gradio_interface(results['caption_model'], results['denoiser_model'])
            interface.launch(share=True)
        else:
            print("Model training failed.")

Testing with a small dataset...
Loaded captions for 8091 images
Initializing GPT-2 tokenizer...
Tokenizer vocabulary size: 50257
Training caption model (this may take some time even with small dataset)...


Epoch 1/5:   0%|                                         | 0/10 [00:00<?, ?it/s]

Batch 0/10, Loss: 7.4206


Epoch 1/5:  10%|███▎                             | 1/10 [00:01<00:11,  1.29s/it]

Sample caption: 
A couple weeks ago I wrote a blog post about the very high energy that comes with having a strong personal connection with an organisation.


Epoch 1/5: 100%|████████████████████████████████| 10/10 [00:06<00:00,  1.51it/s]


Epoch 1/5, Average Loss: 2.1546


Validation: 100%|█████████████████████████████████| 5/5 [00:00<00:00, 10.26it/s]


Validation Loss: 0.9640
Sample validation captions:
Generated: A photo.
Generated: A photo.
Generated: A photo.
New best model saved to /Users/nehaeshwaragari/Documents/Deep Learning/Project 3/Flickr8K/output/checkpoints/caption_model_best.pt
Checkpoint saved to /Users/nehaeshwaragari/Documents/Deep Learning/Project 3/Flickr8K/output/checkpoints/caption_model_epoch_1.pt


Epoch 2/5:  10%|███▎                             | 1/10 [00:00<00:05,  1.62it/s]

Batch 0/10, Loss: 1.0241
Sample caption: A photo.


Epoch 2/5: 100%|████████████████████████████████| 10/10 [00:05<00:00,  1.71it/s]


Epoch 2/5, Average Loss: 0.7573


Validation: 100%|█████████████████████████████████| 5/5 [00:00<00:00, 12.04it/s]


Validation Loss: 0.8717
Sample validation captions:
Generated: A photo.
Generated: A photo.
Generated: A photo.
New best model saved to /Users/nehaeshwaragari/Documents/Deep Learning/Project 3/Flickr8K/output/checkpoints/caption_model_best.pt
Checkpoint saved to /Users/nehaeshwaragari/Documents/Deep Learning/Project 3/Flickr8K/output/checkpoints/caption_model_epoch_2.pt


Epoch 3/5:  10%|███▎                             | 1/10 [00:00<00:05,  1.57it/s]

Batch 0/10, Loss: 0.8254
Sample caption: A photo.


Epoch 3/5: 100%|████████████████████████████████| 10/10 [00:05<00:00,  1.73it/s]


Epoch 3/5, Average Loss: 0.5858


Validation: 100%|█████████████████████████████████| 5/5 [00:00<00:00, 12.62it/s]


Validation Loss: 0.8715
Sample validation captions:
Generated: A photo.
Generated: A photo.
Generated: A photo.
New best model saved to /Users/nehaeshwaragari/Documents/Deep Learning/Project 3/Flickr8K/output/checkpoints/caption_model_best.pt
Checkpoint saved to /Users/nehaeshwaragari/Documents/Deep Learning/Project 3/Flickr8K/output/checkpoints/caption_model_epoch_3.pt


Epoch 4/5:  10%|███▎                             | 1/10 [00:00<00:05,  1.58it/s]

Batch 0/10, Loss: 0.6071
Sample caption: A photo.


Epoch 4/5: 100%|████████████████████████████████| 10/10 [00:05<00:00,  1.76it/s]


Epoch 4/5, Average Loss: 0.4349


Validation: 100%|█████████████████████████████████| 5/5 [00:00<00:00, 12.60it/s]


Validation Loss: 0.9420
Sample validation captions:
Generated: A photo.
Generated: A photo.
Generated: A photo.
Checkpoint saved to /Users/nehaeshwaragari/Documents/Deep Learning/Project 3/Flickr8K/output/checkpoints/caption_model_epoch_4.pt


Epoch 5/5:  10%|███▎                             | 1/10 [00:00<00:05,  1.60it/s]

Batch 0/10, Loss: 0.5112
Sample caption: A photo.


Epoch 5/5: 100%|████████████████████████████████| 10/10 [00:05<00:00,  1.72it/s]


Epoch 5/5, Average Loss: 0.2921


Validation: 100%|█████████████████████████████████| 5/5 [00:00<00:00, 13.10it/s]


Validation Loss: 1.1021
Sample validation captions:
Generated: A photo.
Generated: A photo.
Generated: A photo.
Checkpoint saved to /Users/nehaeshwaragari/Documents/Deep Learning/Project 3/Flickr8K/output/checkpoints/caption_model_epoch_5.pt
Final model saved to /Users/nehaeshwaragari/Documents/Deep Learning/Project 3/Flickr8K/output/checkpoints/caption_model_final.pt
Loaded best model based on validation loss
Evaluating caption model...
Feature shape: torch.Size([1, 2048])
Generated caption: 'A photo.'
Feature shape: torch.Size([1, 2048])
Generated caption: 'A photo.'
Feature shape: torch.Size([1, 2048])
Generated caption: 'A photo.'
Feature shape: torch.Size([1, 2048])
Generated caption: 'A photo.'
Feature shape: torch.Size([1, 2048])
Generated caption: 'A photo.'

Caption Generation Examples:
Example 1:
True: A blond dog and a black and white dog run in a dirt field .
Generated: A photo.

Example 2:
True: A black and white dog catches a toy in midair .
Generated: A photo.

Example 3

In [52]:
def setup_gradio_interface(caption_model, denoiser_model):
    def process_image(input_image):
        input_image_pil = Image.fromarray(input_image).convert('RGB')
        
        input_tensor = transform(input_image_pil).unsqueeze(0).to(device)
        
        with torch.no_grad():
            resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
            feature_extractor = nn.Sequential(*list(resnet.children())[:-1]).to(device)
            feature_extractor.eval()
            features = feature_extractor(input_tensor).squeeze(-1).squeeze(-1)
        
        # Generate caption
        caption = caption_model.generate_caption(features, temperature=0.7)
        
        # Process for denoising
        noisy_tensor = diffusion_transform(input_image_pil).unsqueeze(0).to(device)
        noisy_tensor = add_noise_to_image(noisy_tensor, noise_level=0.2)
        
        # Denoise
        with torch.no_grad():
            denoised_tensor = denoiser_model(noisy_tensor)
        
        # Convert tensors to images
        noisy_image = tensor_to_pil(noisy_tensor)
        denoised_image = tensor_to_pil(denoised_tensor)
        
        return caption, noisy_image, denoised_image
    
    # Create Gradio interface
    interface = gr.Interface(
        fn=process_image,
        inputs=gr.Image(),
        outputs=[
            gr.Textbox(label="Generated Caption"),
            gr.Image(label="Noisy Image"),
            gr.Image(label="Denoised Image")
        ],
        title="Image Captioning and Denoising Demo",
        description="Upload an image to generate a caption and see denoising in action."
    )
    
    return interface

In [53]:
if __name__ == "__main__":
    caption_model = GPT2CaptionModel().to(device)
    denoiser_model = SimpleDenoiser().to(device)
    
    # Launch the interface
    interface = setup_gradio_interface(caption_model, denoiser_model)
    interface.launch(share=True)

2025-04-20 15:52:38,649 - INFO - HTTP Request: GET http://127.0.0.1:7860/gradio_api/startup-events "HTTP/1.1 200 OK"
2025-04-20 15:52:38,658 - INFO - HTTP Request: HEAD http://127.0.0.1:7860/ "HTTP/1.1 200 OK"


* Running on local URL:  http://127.0.0.1:7860


2025-04-20 15:52:39,217 - INFO - HTTP Request: GET https://api.gradio.app/pkg-version "HTTP/1.1 200 OK"
2025-04-20 15:52:39,353 - INFO - HTTP Request: GET https://api.gradio.app/v3/tunnel-request "HTTP/1.1 200 OK"


* Running on public URL: https://e94784f02a6e0e894b.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


2025-04-20 15:52:40,532 - INFO - HTTP Request: HEAD https://e94784f02a6e0e894b.gradio.live "HTTP/1.1 200 OK"
