In [2]:
import os
import torch
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import numpy as np
from PIL import Image
from random import randint, choice
from src.models.base_network import BaseNetwork
from src.models.siamese_network import SiameseNetwork
from src.models.loss_functions import ContrastiveLoss
from src.training.trainer import Trainer
from src.evaluation.metrics import compute_accuracy
from src.training.data_loader import SiameseDataset

def create_dummy_data(num_classes=3, num_images_per_class=5, image_size=(224, 224)):
    """
    Create a small dataset of random images grouped by classes.
    """
    dataset_dir = "dummy_dataset"
    os.makedirs(dataset_dir, exist_ok=True)
    image_paths = []
    labels = []

    for class_id in range(num_classes):
        class_dir = os.path.join(dataset_dir, f"class_{class_id}")
        os.makedirs(class_dir, exist_ok=True)

        for img_id in range(num_images_per_class):
            img_path = os.path.join(class_dir, f"img_{img_id}.jpg")
            image = np.random.randint(0, 255, (image_size[0], image_size[1], 3), dtype=np.uint8)
            Image.fromarray(image).save(img_path)
            image_paths.append(img_path)
            labels.append(class_id)

    return image_paths, labels

image_paths, labels = create_dummy_data()

transform = transforms.Compose([
    transforms.ToTensor()
])

dummy_dataset = SiameseDataset(image_paths, labels, transform=transform)
dummy_loader = DataLoader(dummy_dataset, batch_size=2, shuffle=True)

device = "cuda" if torch.cuda.is_available() else "cpu"
base_network = BaseNetwork()
siamese_model = SiameseNetwork(base_network).to(device)
criterion = ContrastiveLoss(margin=1.0)
optimizer = torch.optim.Adam(siamese_model.parameters(), lr=0.001)

print("Training on dummy data...")
for epoch in range(2):  # Minimal training: 2 epochs
    siamese_model.train()
    epoch_loss = 0

    for img1, img2, label in dummy_loader:
        img1, img2, label = img1.to(device), img2.to(device), label.to(device)

        # Forward pass
        embedding1, embedding2 = siamese_model(img1, img2)
        distances = siamese_model.compute_distance(embedding1, embedding2)
        loss = criterion(distances, label)

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    print(f"Epoch {epoch + 1}: Loss = {epoch_loss / len(dummy_loader):.4f}")

print("\nTesting predictions on dummy data...")
siamese_model.eval()
with torch.no_grad():
    for img1, img2, label in dummy_loader:
        img1, img2 = img1.to(device), img2.to(device)
        embedding1, embedding2 = siamese_model(img1, img2)
        distances = siamese_model.compute_distance(embedding1, embedding2)

        print(f"Distances: {distances.cpu().numpy()}, Ground Truth: {label.numpy()}")
        break  # Predict on a single batch


Training on dummy data...
Epoch 1: Loss = 3.8385
Epoch 2: Loss = 0.2934

Testing predictions on dummy data...
Distances: [0.9118988 0.5342101], Ground Truth: [[1.]
 [0.]]
