In [None]:
# Import required libraries
import os
import torch
import pickle
import matplotlib.pyplot as plt
from PIL import Image
from torchvision import transforms
import ipywidgets as widgets
from IPython.display import display, clear_output
import io

# Import custom modules
from encoder import EncoderCNN
from decoder import DecoderRNN
from vocab import Vocabulary


In [None]:
# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load model parameters
model_path = './models/final_model.pth'
if os.path.exists(model_path):
    checkpoint = torch.load(model_path, map_location=device)
    embed_size = checkpoint['embed_size']
    hidden_size = checkpoint['hidden_size']
    vocab_size = checkpoint['vocab_size']
    
    print(f"Model parameters loaded:")
    print(f"Embed size: {embed_size}")
    print(f"Hidden size: {hidden_size}")
    print(f"Vocabulary size: {vocab_size}")
else:
    # Fallback to individual model files
    embed_size = 256
    hidden_size = 512
    print("Using default parameters and loading individual model files...")


In [None]:
# Load vocabulary
vocab_file = "vocab.pkl"
with open(vocab_file, 'rb') as f:
    vocab = pickle.load(f)

if 'vocab_size' not in locals():
    vocab_size = len(vocab)

print(f"Vocabulary loaded with {vocab_size} words")


In [None]:
# Initialize and load models
encoder = EncoderCNN(embed_size)
decoder = DecoderRNN(embed_size, hidden_size, vocab_size)

encoder.to(device)
decoder.to(device)

# Load model weights
if os.path.exists(model_path):
    encoder.load_state_dict(checkpoint['encoder_state_dict'])
    decoder.load_state_dict(checkpoint['decoder_state_dict'])
else:
    # Load from individual files
    encoder.load_state_dict(torch.load('./models/encoder-3.pkl', map_location=device))
    decoder.load_state_dict(torch.load('./models/decoder-3.pkl', map_location=device))

# Set to evaluation mode
encoder.eval()
decoder.eval()

print("Models loaded and set to evaluation mode!")


In [None]:
# Define image transformations
transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

def clean_sentence(output, idx2word):
    """Convert word indices to clean sentence."""
    sentence = ""
    for i in output:
        if i == 0:  # <pad> token
            continue
        if i == 1:  # <end> token
            break
        word = idx2word[i]
        if i == 18:  # Handle punctuation
            sentence = sentence + word
        else:
            sentence = sentence + " " + word
    return sentence.strip()

def generate_caption(image_path):
    """Generate caption for an image."""
    # Load and preprocess image
    image = Image.open(image_path).convert('RGB')
    image_tensor = transform(image).unsqueeze(0).to(device)
    
    # Generate caption
    with torch.no_grad():
        features = encoder(image_tensor).unsqueeze(1)
        sampled_ids = decoder.sample(features)
        caption = clean_sentence(sampled_ids, vocab.idx2word)
    
    return image, caption


In [None]:
# Create file upload widget
upload_widget = widgets.FileUpload(
    accept='image/*',  # Accept only image files
    multiple=False,    # Single file upload
    description='Upload Image'
)

output_widget = widgets.Output()

def on_upload_change(change):
    """Handle file upload."""
    with output_widget:
        clear_output(wait=True)
        
        if upload_widget.value:
            # Get uploaded file
            uploaded_file = list(upload_widget.value.values())[0]
            
            # Save temporarily
            temp_path = 'temp_image.jpg'
            with open(temp_path, 'wb') as f:
                f.write(uploaded_file['content'])
            
            try:
                # Generate caption
                image, caption = generate_caption(temp_path)
                
                # Display results
                plt.figure(figsize=(10, 8))
                plt.imshow(image)
                plt.axis('off')
                plt.title(f'Generated Caption: {caption}', fontsize=14, pad=20)
                plt.tight_layout()
                plt.show()
                
                print(f"\n📸 Caption: {caption}")
                
            except Exception as e:
                print(f"Error processing image: {e}")
            
            finally:
                # Clean up temporary file
                if os.path.exists(temp_path):
                    os.remove(temp_path)

# Set up the upload widget
upload_widget.observe(on_upload_change, names='value')

# Display widgets
display(upload_widget, output_widget)


In [None]:
# Test with a sample image if available
sample_images = ['test.jpg', 'sample.jpg', 'example.jpg']

for img_path in sample_images:
    if os.path.exists(img_path):
        print(f"Testing with {img_path}...")
        try:
            image, caption = generate_caption(img_path)
            
            plt.figure(figsize=(10, 6))
            plt.imshow(image)
            plt.axis('off')
            plt.title(f'Generated Caption: {caption}', fontsize=12, pad=15)
            plt.show()
            
            print(f"Caption: {caption}\n")
            break
            
        except Exception as e:
            print(f"Error with {img_path}: {e}")
else:
    print("No sample images found. Please upload an image using the widget above.")


In [None]:
# Process multiple images from a directory
def process_image_directory(image_dir, max_images=5):
    """Process multiple images from a directory."""
    if not os.path.exists(image_dir):
        print(f"Directory {image_dir} not found.")
        return
    
    image_extensions = ['.jpg', '.jpeg', '.png', '.bmp']
    image_files = [f for f in os.listdir(image_dir) 
                   if any(f.lower().endswith(ext) for ext in image_extensions)]
    
    if not image_files:
        print(f"No image files found in {image_dir}")
        return
    
    print(f"Processing {min(len(image_files), max_images)} images...")
    
    for i, img_file in enumerate(image_files[:max_images]):
        img_path = os.path.join(image_dir, img_file)
        try:
            image, caption = generate_caption(img_path)
            
            plt.figure(figsize=(8, 6))
            plt.imshow(image)
            plt.axis('off')
            plt.title(f'{img_file}: {caption}', fontsize=10)
            plt.tight_layout()
            plt.show()
            
        except Exception as e:
            print(f"Error processing {img_file}: {e}")

# Example usage (uncomment and modify path as needed)
# process_image_directory('./test_images/', max_images=3)


In [None]:
# Advanced caption generation with confidence scores
def generate_caption_with_confidence(image_path, max_len=20):
    """Generate caption with word confidence scores."""
    image = Image.open(image_path).convert('RGB')
    image_tensor = transform(image).unsqueeze(0).to(device)
    
    with torch.no_grad():
        features = encoder(image_tensor).unsqueeze(1)
        
        # Manual sampling with confidence
        sampled_ids = []
        confidences = []
        inputs = features
        states = None
        
        for i in range(max_len):
            hiddens, states = decoder.lstm(inputs, states)
            outputs = decoder.linear(hiddens.squeeze(1))
            
            # Get probabilities
            probs = torch.nn.functional.softmax(outputs, dim=1)
            predicted_prob, predicted_idx = probs.max(1)
            
            sampled_ids.append(predicted_idx.item())
            confidences.append(predicted_prob.item())
            
            if predicted_idx == 1:  # <end> token
                break
                
            inputs = decoder.embed(predicted_idx).unsqueeze(1)
        
        caption = clean_sentence(sampled_ids, vocab.idx2word)
        avg_confidence = sum(confidences) / len(confidences) if confidences else 0
    
    return image, caption, avg_confidence

# Example usage function
def test_with_confidence(image_path):
    """Test caption generation with confidence score."""
    if os.path.exists(image_path):
        try:
            image, caption, confidence = generate_caption_with_confidence(image_path)
            
            plt.figure(figsize=(10, 6))
            plt.imshow(image)
            plt.axis('off')
            plt.title(f'Caption: {caption}\nConfidence: {confidence:.2f}', fontsize=12, pad=15)
            plt.show()
            
            print(f"Caption: {caption}")
            print(f"Average Confidence: {confidence:.4f}")
            
        except Exception as e:
            print(f"Error: {e}")
    else:
        print(f"Image {image_path} not found")

# Uncomment to test with a sample image
# test_with_confidence('test.jpg')


In [None]:
# Display model information and statistics
print("=== Model Information ===")
print(f"Encoder: CNN (ResNet-50 based)")
print(f"Decoder: LSTM")
print(f"Embedding Size: {embed_size}")
print(f"Hidden Size: {hidden_size}")
print(f"Vocabulary Size: {vocab_size}")
print(f"Device: {device}")

# Count model parameters
encoder_params = sum(p.numel() for p in encoder.parameters() if p.requires_grad)
decoder_params = sum(p.numel() for p in decoder.parameters() if p.requires_grad)
total_params = encoder_params + decoder_params

print(f"\n=== Model Parameters ===")
print(f"Trainable Encoder Parameters: {encoder_params:,}")
print(f"Decoder Parameters: {decoder_params:,}")
print(f"Total Trainable Parameters: {total_params:,}")
