In [19]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from datasets import LunarSeismicDataset, collate_fn, train_test_split_dataset
from models import SeismicEventPredictor, CNNAutoencoder

In [20]:
# Define model and training hyperparameters
num_epochs=50
batch_size=4
learning_rate=1e-3

In [None]:
# Load the prertained Autoencoder
autoencoder = CNNAutoencoder()
autoencoder.load_state_dict(torch.load('checkpoints/CNNAutoencoder/model_epoch_15.pth', weights_only=True))

In [22]:
# Load the dataset
supervised_dataset = LunarSeismicDataset(data_dir='data/lunar/training/data/S12_GradeA', catalog_file='data/lunar/training/catalogs/apollo12_catalog_GradeA_final.csv')
train_dataset, test_dataset = train_test_split_dataset(supervised_dataset)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

In [23]:
# Init the model
predictor = SeismicEventPredictor(autoencoder)

# Training set up
criterion = nn.MSELoss()
optimizer = optim.Adam(predictor.parameters(), lr=learning_rate)

In [None]:
# Training loop
for epoch in range(num_epochs):
    predictor.train()
    for inputs, targets in train_loader:
        if inputs is None:
            continue
        inputs = inputs.unsqueeze(1)  # Add channel dimension
        outputs = predictor(inputs)
        loss = criterion(outputs.squeeze(), targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f'Epoch [{epoch+1}/{num_epochs}], Training Loss: {loss.item():.4f}')

    # Evaluate on the test set
    predictor.eval()
    with torch.no_grad():
        test_loss = 0.0
        for inputs, targets in test_loader:
            if inputs is None:
                continue
            inputs = inputs.unsqueeze(1)  # Add channel dimension
            outputs = predictor(inputs)
            loss = criterion(outputs.squeeze(), targets)
            test_loss += loss.item()

        test_loss /= len(test_loader)
        print(f'Epoch [{epoch+1}/{num_epochs}], Test Loss: {test_loss:.4f}')
    
    # Save model checkpoint
    checkpoint_path = f'./checkpoints/SeismicEventPredictor/model_epoch_{epoch+1}.pth'
    torch.save(predictor.state_dict(), checkpoint_path)
    print(f'Model saved to {checkpoint_path}')