# Imports

In [3]:
import torch
from torch import optim
from torchvision import transforms
from torch.utils.data import DataLoader

from sklearn.model_selection import train_test_split
import pandas as pd
import cv2 as cv
import os
import sys

In [4]:
# Add the root directory to the system path
sys.path.append(os.path.abspath(".."))

# Activate autoreload
%load_ext autoreload
%autoreload 2

from src.dataset import Vocabulary, FlickrDataset, build_weights_matrix, preprocess_data
from src.trainer import ImageCaptionTrainer
from src.model import ImageCaptionModel, PreTrainedMobileNetV3, ScratchRNN
from src.tester import ImageCaptionTester

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Preprocess Data

## Data Splitting

In [5]:
# Split by unique image ids
df = pd.read_csv("../data/cleaned/flickr8k_cleaned_data.csv")
unique_images = df['image'].unique()

train_ids, temp_ids = train_test_split(unique_images, test_size=0.15, random_state=42)
val_ids, test_ids = train_test_split(temp_ids, test_size=0.075, random_state=42)

train_df = df[df['image'].isin(train_ids)].reset_index(drop=True)
val_df = df[df['image'].isin(val_ids)].reset_index(drop=True)
test_df = df[df['image'].isin(test_ids)].reset_index(drop=True)

test_grouped = test_df.groupby('image')['caption_clean'].apply(list).reset_index()

## Vocabulary and Dataloaders

In [6]:
train_captions = train_df['caption_clean'].tolist()

min_freq=1

vocab = Vocabulary(min_freq)
vocab.build_vocabulary(train_captions)

In [7]:
transform = transforms.Compose([
    transforms.Resize((224, 224)), 
    transforms.ToTensor(),         
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 
])

batch_size = 128
max_tokens = 100
image_path = "../data/raw/Images"

train_dataset = FlickrDataset(image_path, train_df, vocab, transform=transform, max_tokens=max_tokens)
val_dataset = FlickrDataset(image_path, val_df, vocab, transform=transform, max_tokens=max_tokens)
test_dataset = FlickrDataset(image_path, test_grouped, vocab, transform=transform, max_tokens=max_tokens, is_eval=True)

train_loader, val_loader, test_loader = preprocess_data(train_dataset, val_dataset, test_dataset, batch_size)

test_loader = DataLoader(
    test_dataset, 
    batch_size=1, 
    shuffle=False, 
    num_workers=0
)

# Model Setup

In [8]:
# Model Configuration
EMBED_SIZE = 300
HIDDEN_SIZE = 256
NUM_LAYERS = 1
DROPOUT = 0.5
EPOCHS = 30
PATIENCE = 5
LEARNING_RATE = 3e-4
CLIP_NORM=5.0



In [None]:
weights, _ = build_weights_matrix(vocab=vocab, pretrained_embeddings_path="../embeddings/wiki-news-300d-1M.vec", embedding_dim=300)

In [None]:
# CNN Encoder (MobileNetV3)
encoder = PreTrainedMobileNetV3(
    dropout_rate=DROPOUT, 
    embed_size=EMBED_SIZE, 
    fine_tune=True
)

# RNN Decoder (GRU)
decoder = ScratchRNN(
    embed_size=EMBED_SIZE,
    num_layers=NUM_LAYERS,
    hidden_size=HIDDEN_SIZE,
    dropout_rate=DROPOUT,
    vocab=vocab,
    is_gru=True,
    pretrained_embeddings=weights
)

model = ImageCaptionModel(cnn=encoder, rnn=decoder)

# Training

In [14]:
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
loss_fn = torch.nn.CrossEntropyLoss(ignore_index=vocab.stoi["<PAD>"], reduction="mean")

trainer = ImageCaptionTrainer(
    model=model, 
    optimizer=optimizer, 
    device=device,
    loss_function=loss_fn, 
    clip_norm=CLIP_NORM
)

In [None]:
# Start training with Early Stopping
trainer.fit(
    train_loader=train_loader, 
    val_loader=val_loader, 
    epochs=EPOCHS, 
    patience=PATIENCE,
    epsilon=1e-3,
    checkpoint_dir='../checkpoints'
)

# Evaluation

In [None]:
checkpoint = torch.load('../checkpoints/best_model.pth', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])

# Moves the model to GPU or CPU
model.to(device) 
model.eval()

tester = ImageCaptionTester(model, device, vocab)
resultados = tester.test(test_loader)

print("\n--- Final Metrics ---")
for metric, score in resultados.items():
    print(f"{metric}: {score:.4f}")

hyperparameters = {
    "EMBED_SIZE": EMBED_SIZE,
    "HIDDEN_SIZE": HIDDEN_SIZE,
    "NUM_LAYERS": NUM_LAYERS,
    "DROPOUT": DROPOUT,
    "EPOCHS": EPOCHS,
    "PATIENCE": PATIENCE,
    "LEARNING_RATE": LEARNING_RATE,
    "CLIP_NORM": CLIP_NORM
}

metrics = resultados

tester.write_log_txt(hyperparameters, resultados, "../checkpoints/")

In [17]:
image_files = test_df["image_path"]
image_idx = 10
image_path = image_files[image_idx].split("/")
image_literal_path = "../data/raw/Images/" + image_path[-1]


In [None]:
tester.show_example(transform, image_literal_path)