In [6]:
import random
import copy
import time
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import os
from sklearn.metrics import mean_squared_error, mean_absolute_error
from sklearn.utils import shuffle


In [2]:
data = np.load("../data/sunspot_sequences.npz")

X_train = torch.tensor(data["X_train"])
y_train = torch.tensor(data["y_train"])
X_val = torch.tensor(data["X_val"])
y_val = torch.tensor(data["y_val"])
X_test = torch.tensor(data["X_test"])
y_test = torch.tensor(data["y_test"])



In [7]:
class Trainer:
    def __init__(self, model, lr=0.01, batch_size=32):
        self.model = model
        self.lr = lr
        self.batch_size = batch_size
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)
        self.criterion = nn.MSELoss()
    
    def train_one_epoch(self, X, y):
        self.model.train()
        X, y = shuffle(X, y)
        total_loss = 0
        for i in range(0, len(X), self.batch_size):
            batch_X = X[i:i+self.batch_size]
            batch_y = y[i:i+self.batch_size]
            # Training step
            self.optimizer.zero_grad()
            pred = self.model(batch_X)
            loss = self.criterion(pred, batch_y)
            loss.backward()
            self.optimizer.stepp()

            total_lost += loss.item()
        return total_loss / (len(X) // self.batch_size)
    
    def evaluate(self, X, y):
        self.model.eval()
        with torch.no_grad():
            pred = self.model(X)
            loss = self.criterion(pred, y).item()
        return loss
    
    def fit(self, X_train, y_train, X_val, y_val, max_epochs=1000, patience=50):
        best_val_loss = float('inf')
        best_model = None
        no_improve = 0

        train_losses = []
        val_losses = []

        print(f"Starting training for model: {self.model._get_name()}")
        for epoch in range(max_epochs):
            train_loss = self.train_one_epoch(X_train, y_train)
            val_loss = self.evaluate(X_val, y_val)

            train_losses.append(train_loss)
            val_losses.append(val_loss)

            if val_loss < best_val_loss:
                best_val_loss = val_loss
                best_model = copy.deepcopy(self.model.state_dict())
                no_improve = 0
                print(f"Epoch {epoch}: Train Loss = {train_loss:.6f}, Val Loss = {val_loss:.6f} (improved)")
            else:
                no_improve += 1
                print(f"Epoch {epoch}: Train Loss = {train_loss:.6f}, Val Loss = {val_loss:.6f}")

            if no_improve >= patience:
                print("Early stopping triggered.")
                break

        self.model.load_state_dict(best_model)
        return train_losses, val_losses

In [8]:
def analyze_and_save(model, trainer, X_train, y_train, X_val, y_val, X_test, y_test, save_dir="results", model_name="QNN"):
    os.makedirs(save_dir, exist_ok=True)

    model.eval()
    with torch.no_grad():
        pred_train = model(X_train).numpy()
        pred_val = model(X_val).numpy()
        pred_test = model(X_test).numpy()

    y_train, y_val, y_test = y_train.numpy(), y_val.numpy(), y_test.numpy()

    metrics = {
        "MSE": [mean_squared_error(y_train, pred_train),
                mean_squared_error(y_val, pred_val),
                mean_squared_error(y_test, pred_test)],
        "MAE": [mean_absolute_error(y_train, pred_train),
                mean_absolute_error(y_val, pred_val),
                mean_absolute_error(y_test, pred_test)],
    }

    print("\n" + "="*50)
    print(f"KẾT QUẢ {model_name.upper()}")
    print("="*50)
    print(f"{'':<10} {'Train':<10} {'Val':<10} {'Test':<10}")
    print("-"*50)
    for name in ["MSE", "MAE"]:
        print(f"{name:<10} {metrics[name][0]:.6f}    {metrics[name][1]:.6f}    {metrics[name][2]:.6f}")
    print("="*50)

    plt.figure(figsize=(12, 4))

    # Loss curve
    plt.subplot(1, 2, 1)
    plt.plot(trainer.train_losses, label="Train Loss", alpha=0.8)
    plt.plot(trainer.val_losses, label="Val Loss", alpha=0.8)
    plt.yscale('log')
    plt.xlabel("Epoch")
    plt.ylabel("MSE Loss")
    plt.legend()
    plt.title("Training Curve")

    # Prediction vs True (Test)
    plt.subplot(1, 2, 2)
    plt.plot(y_test[:100], 'b-', label="True", linewidth=2)
    plt.plot(pred_test[:100], 'r--', label="Predict", linewidth=2)
    plt.xlabel("Time")
    plt.ylabel("Sunspot")
    plt.legend()
    plt.title("Prediction vs True (Test)")

    plt.tight_layout()
    plt.savefig(f"{save_dir}/{model_name}_results.pdf")
    plt.close()

    torch.save(model.state_dict(), f"{save_dir}/{model_name}_best.pth")

    print(f"Saved: {save_dir}/{model_name}_results.pdf")
    print(f"Saved model: {save_dir}/{model_name}_best.pth")

