In [None]:
import torch
from torch import nn

from _models import (
    CoordinateCNN,
    CoordinateDataset,
    create_data_loaders,
    evaluate_model,
    plot_training_history,
    train_and_validate,
)

# Check if GPU is available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Create dataset and dataloader
dataset: CoordinateDataset = CoordinateDataset("./train")

In [None]:
from torch.utils.data import DataLoader

# Create data loaders with train-test split
# train_loader, val_loader, test_loader = create_data_loaders(dataset)


testDataset: CoordinateDataset = CoordinateDataset("./testImages")
valDataset: CoordinateDataset = CoordinateDataset("./val")

train_loader: DataLoader = DataLoader(dataset, batch_size=32, shuffle=True)
val_loader: DataLoader = DataLoader(valDataset, batch_size=32, shuffle=True)
test_loader: DataLoader = DataLoader(testDataset, batch_size=32, shuffle=True)


# Initialize model
model = CoordinateCNN().to(device)

# Define loss function and optimizer
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters())

# Train and validate the model
history = train_and_validate(
    device,
    model,
    train_loader,
    val_loader,
    criterion,
    optimizer,
    num_epochs=200,
    patience=20,
)

In [None]:
# Plot training history
plot_training_history(history)

# Evaluate on test set
test_loss = evaluate_model(device, model, test_loader, criterion)

print("Training completed!")
print(f"Best validation loss: {min(history['val_loss']):.6f}")
print(f"Final test loss: {test_loss:.6f}")