# VisionVerse Training

This notebook demonstrates how to train the image captioning and classification models in the VisionVerse project.

In [None]:
"import sys",
    "sys.path.append('../')",
   
    "import torch",
    "import torch.nn as nn",
    "import torch.optim as optim",
    "from torch.utils.data import DataLoader",
    "from torchvision import transforms",
  
    "from src.captioning.model import CNNtoRNN",
    "from src.classification.model import CNNModel",
    "from src.utils.data_loader import ImageCaptionDataset, get_loader",
    "from src.captioning.utils import save_checkpoint",
    "from src.classification.utils import load_category_names",
    "from config import *"

## 1. Image Captioning Training

In [None]:
 "# Load vocabulary and prepare data",
    "vocab = load_vocabulary('data/vocab.json')",
    "transform = transforms.Compose([",
    "    transforms.Resize((224, 224)),",
    "    transforms.ToTensor()",
    "    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))",
    "])",
   
    "train_loader = get_loader(",
    "    root_folder='data/images'",
    "    annotation_file='data/captions.txt'",
    "    transform=transform,",
    "    batch_size=CAPTION_BATCH_SIZE",
    ")",

    "# Initialize model, loss, and optimizer",
    "model = CNNtoRNN(CAPTION_EMBED_SIZE, CAPTION_HIDDEN_SIZE, len(vocab), CAPTION_NUM_LAYERS).to(DEVICE)",
    "criterion = nn.CrossEntropyLoss(ignore_index=0)  # Assuming 0 is the pad token",
    "optimizer = optim.Adam(model.parameters(), lr=CAPTION_LEARNING_RATE)",
 
    "# Training loop",
    "for epoch in range(CAPTION_NUM_EPOCHS):",
    "    for idx, (imgs, captions) in enumerate(train_loader):",
    "        imgs = imgs.to(DEVICE)",
    "        captions = captions.to(DEVICE)",
 
    "        outputs = model(imgs, captions)",
    "        loss = criterion(outputs.reshape(-1, outputs.shape[2]), captions.reshape(-1))",

    "        optimizer.zero_grad()",
    "        loss.backward()",
    "        optimizer.step()",

    "        if idx % 100 == 0:",
    "            print(f\"Epoch [{epoch+1}/{CAPTION_NUM_EPOCHS}], Step [{idx+1}/{len(train_loader)}], Loss: {loss.item():.4f}\")",

    "# Save the trained model\n",
    "save_checkpoint(model, optimizer, 'checkpoints/caption_model.pth')"

## 2. Image Classification Training

In [None]:
"# Load categories and prepare data",
    "categories = load_category_names('data/flower_labels.json')",
    "transform = transforms.Compose([",
    "    transforms.RandomResizedCrop(224),",
    "    transforms.RandomHorizontalFlip(),",
    "    transforms.ToTensor(),",
    "    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])",
    "])",
  
    "train_dataset = datasets.ImageFolder(root='data/flowers/train', transform=transform)",
    "train_loader = DataLoader(train_dataset, batch_size=CLASSIFICATION_BATCH_SIZE, shuffle=True)",

    "# Initialize model, loss, and optimizer",
    "model = CNNModel(CLASSIFICATION_ARCH, CLASSIFICATION_HIDDEN_UNITS, len(categories)).to(DEVICE)",
    "criterion = nn.CrossEntropyLoss()",
    "optimizer = optim.Adam(model.parameters(), lr=CLASSIFICATION_LEARNING_RATE)",
 
    "# Training loop",
    "for epoch in range(CLASSIFICATION_NUM_EPOCHS):",
    "    for idx, (images, labels) in enumerate(train_loader):",
    "        images = images.to(DEVICE)",
    "        labels = labels.to(DEVICE)",

    "        outputs = model(images)",
    "        loss = criterion(outputs, labels)",

    "        optimizer.zero_grad()",
    "        loss.backward()",
    "        optimizer.step()",
 
    "        if idx % 100 == 0:",
    "            print(f\"Epoch [{epoch+1}/{CLASSIFICATION_NUM_EPOCHS}], Step [{idx+1}/{len(train_loader)}], Loss: {loss.item():.4f}\")",

    "# Save the trained model",
    "torch.save(model.state_dict(), 'checkpoints/classification_model.pth')"