# üß© Baseline Training ‚Äî Full Dataset Model Performance

Before we study **dataset condensation**, it‚Äôs important to first train a model on the **full real dataset** and record its performance.  
This gives us a *baseline* ‚Äî a reference point against which we can later compare condensed (synthetic) datasets.

---

## üéØ Objective

Train a standard deep learning model (e.g., LSTM or RNN) on the complete time-series dataset and evaluate its predictive performance.  
The resulting accuracy and calibration metrics serve as a **ground truth** for how well a model can perform when it has access to all real samples.

---

## ‚öôÔ∏è What this notebook does

1. **Load and normalize** the real dataset  
   - Uses the data loader to create training, validation, and test splits.
   - Applies standardization or min‚Äìmax normalization.

2. **Define and train** a simple baseline model  
   - We‚Äôll use an LSTM (or RNN) suited for sequential / temporal data.
   - Trained using standard supervised learning on the full training data.

3. **Evaluate**  
   - Compute **AUC**, **APR**, and **loss** on validation and test sets.  
   - Store the best model (based on validation AUC) for later comparison.

4. **Visualize**  
   - Plot loss curves and validation metrics across epochs to understand convergence and overfitting behavior.

---

## üß† Why a baseline matters

When we later perform **dataset condensation**, we‚Äôll train new models on **synthetic data** generated through methods like *logit distribution matching*.  
By comparing the condensed model‚Äôs performance to this baseline, we can quantify:

- How much accuracy or AUC is lost when training on synthetic instead of real data.
- Whether condensation effectively captures the important information in the real dataset.

---

üëâ **In short:**  
This notebook establishes a **performance benchmark** using real data.  
All future condensation experiments will be evaluated relative to this baseline.


In [1]:
### import modules
import os
import random
import numpy as np
import torch
import matplotlib.pyplot as plt


from loaders import get_loaders_time_series   # replace with actual module name if different
from utils_1 import get_device, prediction_binary
from models import LSTMClassifier   # replace if your LSTM class is in a different module


def set_seed(seed):
    random.seed(seed)  # Python random module
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)  # Numpy random module
    torch.manual_seed(seed)  # PyTorch CPU
    torch.cuda.manual_seed(seed)  # PyTorch GPU (single GPU)
    torch.cuda.manual_seed_all(seed)  # PyTorch GPU (all GPUs)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

rand_seed = 42
set_seed(rand_seed)

In [None]:
device = get_device()
print(f"Using device: {device}")

In [None]:
!pwd

In [None]:
# -------------------------
# Load data (adjust `path` to where your pickles are)
# -------------------------
data_path = "../DATA/"   # <-- change to your data folder
train_loader, val_loader, test_loader = get_loaders_time_series(
    path=data_path,
    train_batch=128,
    val_batch=256,
    test_batch=256,
    sampler=True,
    pre_process="std",
    ds_half=0,
)

# Inspect one batch shape to infer model input dims
batch_x, batch_y = next(iter(train_loader))
print("One batch x shape:", batch_x.shape, "y shape:", batch_y.shape)
# typical shape: (batch, seq_len, n_features)

seq_len = batch_x.shape[1]
n_features = batch_x.shape[2]
print(f"seq_len={seq_len}, n_features={n_features}")

In [5]:
# -------------------------
# Create a simple LSTM model
# -------------------------
# LSTMClassifier should accept (input_dim, hidden_dim,device, output dim).
hidden_dim = 32
model = LSTMClassifier(input_dim=n_features, hidden_dim=hidden_dim, device=device, output_dim=1)
model.to(device)

# Loss and optimizer
loss_fn = torch.nn.BCEWithLogitsLoss()   # use logits from model, no sigmmoid activation in model
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

In [None]:
# -------------------------
# Training settings (simple)
# -------------------------
num_epochs = 20
train_losses = []
val_losses = []
val_aucs = []
val_aprs = []

# Helper: evaluate average training loss (quick)
def evaluate_train_loss(loader):
    model.eval()
    total = 0.0
    count = 0
    with torch.no_grad():
        for x, y in loader:
            x = x.to(torch.float32).to(device)
            y = y.to(torch.float32).to(device)
            out = model(x)[:, 0]
            loss = loss_fn(out, y)
            total += loss.item()
            count += 1
    return total / max(1, count)

# Initial evaluation (before training)
train_losses.append(evaluate_train_loss(train_loader))
val_loss, val_auc, val_apr = prediction_binary(model, val_loader, loss_fn, device)
val_losses.append(val_loss); val_aucs.append(val_auc); val_aprs.append(val_apr)
print(f"Init ‚Äî Train loss: {train_losses[-1]:.4f}, Val AUC: {val_auc:.4f}, Val APR: {val_apr:.4f}")

In [None]:
# -------------------------
# Choose which metric to use for saving best model:
# Options: "val_auc" (higher is better) or "train_loss" (lower is better)
# -------------------------
monitor_metric = "val_auc"   # set to "train_loss" if you prefer
best_metric = -float("inf") if monitor_metric == "val_auc" else float("inf")
best_epoch = -1
best_path = "./best_model_state.pt"   # saved as state_dict (recommended)

# -------------------------
# Training loop with checkpointing
# -------------------------
for epoch in range(1, num_epochs + 1):
    model.train()
    running_loss = 0.0

    for x, y in train_loader:
        x = x.to(torch.float32).to(device)
        y = y.to(torch.float32).to(device)

        optimizer.zero_grad()
        preds = model(x)[:, 0]         # assume model returns shape (B, 1) or (B,)
        loss = loss_fn(preds, y)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    avg_train_loss = running_loss / len(train_loader)
    train_losses.append(avg_train_loss)

    # ---------- validation ----------
    val_loss, val_auc, val_apr = prediction_binary(model, val_loader, loss_fn, device)
    val_losses.append(val_loss)
    val_aucs.append(val_auc)
    val_aprs.append(val_apr)

    # ---------- decide whether to save best model ----------
    save_now = False
    if monitor_metric == "val_auc":
        if val_auc > best_metric:
            best_metric = val_auc
            save_now = True
    elif monitor_metric == "train_loss":
        if avg_train_loss < best_metric:
            best_metric = avg_train_loss
            save_now = True
    else:
        raise ValueError("monitor_metric must be 'val_auc' or 'train_loss'")

    if save_now:
        best_epoch = epoch
        # save state_dict (more portable than saving the full model)
        torch.save(model.state_dict(), best_path)
        print(f"--> Saved new best model (epoch={epoch}) | {monitor_metric} = {best_metric:.4f}")

    print(
        f"Epoch {epoch}/{num_epochs} ‚Äî Train loss: {avg_train_loss:.4f} | "
        f"Val loss: {val_loss:.4f} | Val AUC: {val_auc:.4f} | Val AUPRC: {val_apr:.4f}"
    )



In [None]:
# -------------------------
# After training: load best model and evaluate on test set
# -------------------------
if best_epoch == -1:
    print("No checkpoint was saved during training. Evaluating final model.")
    best_model_state = None
else:
    print(f"\nLoading best model from epoch {best_epoch} (saved at '{best_path}').")
    best_model_state = torch.load(best_path, map_location=device)
    model.load_state_dict(best_model_state)

# Ensure model is in eval mode for testing
model.eval()
test_loss, test_auc, test_apr = prediction_binary(model, test_loader, loss_fn, device)

print("\n====== Best model test evaluation ======")
if monitor_metric == "val_auc":
    print(f"Best Val AUC (used for checkpoint) = {best_metric:.4f} (epoch {best_epoch})")
else:
    print(f"Best Train Loss (used for checkpoint) = {best_metric:.4f} (epoch {best_epoch})")
print(f"Test Loss: {test_loss:.4f} | Test AUC: {test_auc:.4f} | Test APR: {test_apr:.4f}")
print("========================================")

In [None]:
# -------------------------
# Simple plots
# -------------------------
plt.figure(figsize=(8,4))
plt.plot(train_losses, label="Train Loss")
plt.plot(val_losses, label="Val Loss")
plt.xlabel("Epoch (including init)")
plt.ylabel("Loss")
plt.legend()
plt.title("Loss curve")
plt.grid(True)
plt.show()

plt.figure(figsize=(8,4))
plt.plot(val_aucs, label="Val AUC", marker='x')
plt.plot(val_aprs, label="Val APR", marker='o')
plt.xlabel("Epoch (including init)")
plt.legend()
plt.title("Validation metrics")
plt.grid(True)
plt.show()