# VisionVerse Exploration

This notebook demonstrates the basic functionality of the VisionVerse project, including image captioning, classification, and generation.


In [None]:
import sys",
    "sys.path.append('../')

"import torch",
    "from PIL import Image",
    "import matplotlib.pyplot as plt",
    "from src.captioning.model import CNNtoRNN",
    "from src.classification.model import CNNModel",
    "from src.generation.generator import generate_image",
    "from src.captioning.utils import load_vocabulary",
    "from src.classification.utils import load_category_names, process_image"

## 1. Image Captioning

In [None]:
"# Load the captioning model",
    "vocab = load_vocabulary('data/vocab.json')",
    "caption_model = CNNtoRNN(CAPTION_EMBED_SIZE, CAPTION_HIDDEN_SIZE, len(vocab), CAPTION_NUM_LAYERS).to(DEVICE)"
    "caption_model.load_state_dict(torch.load('checkpoints/caption_model.pth')
    
    "caption_model.eval()",

    "# Function to generate caption",
    "def generate_caption(image_path):",
    "    image = process_image(image_path).to(DEVICE)",
    "    caption = caption_model.caption_image(image, vocab)",
    "    return ' '.join(caption)",
  
    "# Test the captioning model",
    "test_image_path = 'data/test_image.jpg'",
    "caption = generate_caption(test_image_path)",
   
    "plt.imshow(Image.open(test_image_path))",
    "plt.axis('off')",
    "plt.title(f\"Caption: {caption}\")",
    "plt.show()"

## 2. Image Classification

In [None]:
"# Load the classification model",
    "categories = load_category_names('data/flower_labels.json')",
    "classify_model = CNNModel(CLASSIFICATION_ARCH, CLASSIFICATION_HIDDEN_UNITS, len(categories)).to(DEVICE)",
    "classify_model.load_state_dict(torch.load('checkpoints/classification_model.pth'))",
    "classify_model.eval()",
  
    "# Function to classify image",
    "def classify_image(image_path):",
    "    image = process_image(image_path).to(DEVICE)",
    "    with torch.no_grad():",
    "        output = classify_model(image)",
    "        _, predicted = torch.max(output, 1)",
    "    return categories[str(predicted.item() + 1)]",
 
    "# Test the classification model",
    "test_image_path = 'data/test_flower.jpg'",
    "classification = classify_image(test_image_path)",
    
    "plt.imshow(Image.open(test_image_path))",
    "plt.axis('off')",
    "plt.title(f\"Classification: {classification}\")",
    "plt.show()"

## 3. Image Generation

In [None]:
"# Generate an image from text",
    "text_prompt = \"A serene lake surrounded by mountains at sunset\"",
    "generated_image = generate_image(text_prompt, iterations=GENERATION_ITERATIONS, lr=GENERATION_LEARNING_RATE)",
    
    "plt.imshow(generated_image)",
    "plt.axis('off')",
    "plt.title(f\"Generated from: {text_prompt}\")",
    "plt.show()"