In [None]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [None]:
import torch
import wandb
from neuralop import H1Loss, LpLoss, Trainer
from neuralop.models import FNO
from neuralop.training import AdamW
from src import dataset

In [None]:
# ================================================================
# ‚öôÔ∏è 0. Setup
# ================================================================
device = "cuda" if torch.cuda.is_available() else "cpu"
device

In [None]:
# ================================================================
# üîë 1. W&B Setup
# ================================================================
os.environ["WANDB_API_KEY"] = "REMOVED_WANDB_KEY"
os.environ["WANDB_PROJECT"] = "grainlegumes_pino"
os.environ["WANDB_ENTITY"] = "Rinovative-Hub"

In [None]:
# ================================================================
# üì¶ 2. Dataset
# ================================================================
dataloader_cfg = {
    "batch_size": 32,
    "num_workers": 8,
    "pin_memory": True,
    "persistent_workers": True,
}

train_loader, test_loaders, data_processor = dataset.dataset_base.create_dataloaders(
    dataset_cls=dataset.dataset_simulation.PermeabilityFlowDataset,
    path_train="../data/raw/lhs_var10_plog100_seed9/lhs_var10_plog100_seed9.pt",
    path_test_ood="../data/raw/lhs_var10_plog100_seed9/lhs_var10_plog100_seed9.pt",
    train_ratio=0.8,
    ood_fraction=0.2,
    **dataloader_cfg,
)

# ================================================================
# üîç Debug Info
# ================================================================
print("\n=== Dataset Debug Info ===")
print(f"Train loader size: {len(train_loader.dataset)} samples")
print(f"Eval loader size:  {len(test_loaders['eval'].dataset)} samples")
print(f"OOD loader size:   {len(test_loaders['ood'].dataset)} samples")

batch = next(iter(train_loader))
x, y = batch["x"], batch["y"]

print("\n--- First Batch ---")
print(f"x shape: {x.shape}, y shape: {y.shape}")
print(f"Global x mean/std: {x.mean():.4e} / {x.std():.4e}")
print(f"Global y mean/std: {y.mean():.4e} / {y.std():.4e}")

# ================================================================
# üîç Kanalweise Statistiken
# ================================================================
print("\n--- Channel-wise Stats (x) ---")
for i in range(x.shape[1]):
    print(f"x[{i}]: mean={x[:, i].mean():.4e}, std={x[:, i].std():.4e}")

print("\n--- Channel-wise Stats (y) ---")
for i in range(y.shape[1]):
    print(f"y[{i}]: mean={y[:, i].mean():.4e}, std={y[:, i].std():.4e}")

# ================================================================
# üîç Normalizer Werte
# ================================================================
print("\n--- Normalizer means/stds ---")
print("Input means:", data_processor.in_normalizer.mean.flatten())
print("Input stds: ", data_processor.in_normalizer.std.flatten())
print("Output means:", data_processor.out_normalizer.mean.flatten())
print("Output stds: ", data_processor.out_normalizer.std.flatten())

# ================================================================
# üîç Zusatzcheck 1: Kanalbereich (Min/Max) zur Verifikation der Kanalzuordnung
# ================================================================
print("\n--- Field ranges per channel (x) ---")
for i in range(x.shape[1]):
    cmin, cmax = x[:, i].min().item(), x[:, i].max().item()
    print(f"x[{i}] range: min={cmin:.4f}, max={cmax:.4f}")

# ================================================================
# üîç Zusatzcheck 2: Normalisierung invertierbar?
# ================================================================
x_cpu = x.cpu()
x_norm = data_processor.in_normalizer(x_cpu)

mean = data_processor.in_normalizer.mean
std = data_processor.in_normalizer.std

# manuelles inverse-normalizing
x_back = x_norm * std + mean

err = (x_cpu - x_back).abs().mean().item()
print(f"\nInvertibility check error: {err:.3e}")

# ================================================================
# üîç Zusatzcheck 3: Wird im Loader wirklich normalisiert?
# ================================================================
raw_sample = next(iter(train_loader.dataset))["x"]  # unnormalised
norm_sample = x[0]  # first normalized sample

print("\n--- Raw vs Normalized sample check ---")
print(f"Raw mean/std:    {raw_sample.mean():.4f} / {raw_sample.std():.4f}")
print(f"Normed mean/std: {norm_sample.mean():.4f} / {norm_sample.std():.4f}")

print("=========================================================\n")

data_processor = data_processor.to(device)

In [None]:
# ================================================================
# üß† 3. Model
# ================================================================
model = FNO(
    n_modes=(32, 32),
    hidden_channels=64,
    in_channels=4,
    out_channels=4,
)
model = model.to(device)

In [None]:
# ================================================================
# ‚öôÔ∏è 4. Optimizer, Scheduler, Loss
# ================================================================
optimizer = AdamW(model.parameters(), lr=1e-2, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=30)

l2loss = LpLoss(d=2, p=2)
h1loss = H1Loss(d=2)
train_loss = h1loss
eval_losses = {"h1": h1loss, "l2": l2loss}

In [None]:
# ================================================================
# ü™Ñ 5. W&B Config & Init
# ================================================================
N_EPOCHS = 500

config = {
    "model": "FNO",
    "dataset": "PermeabilityFlow",
    "batch_size": dataloader_cfg["batch_size"],
    "num_workers": dataloader_cfg["num_workers"],
    "lr": optimizer.param_groups[0]["lr"],
    "weight_decay": optimizer.param_groups[0]["weight_decay"],
    "n_epochs": N_EPOCHS,
}

wandb.init(
    project="grainlegumes_pino",
    entity="Rinovative-Hub",
    config=config,
)

In [None]:
# ================================================================
# üöÄ 6. Trainer
# ================================================================
trainer = Trainer(
    model=model,
    n_epochs=N_EPOCHS,
    wandb_log=True,
    device=device,
    mixed_precision=False,
    data_processor=data_processor,
    eval_interval=5,
    verbose=True,
)

In [None]:
# ================================================================
# üèãÔ∏è‚Äç‚ôÇÔ∏è 7. Training
# ================================================================
trainer.train(
    train_loader=train_loader,
    test_loaders=test_loaders,
    optimizer=optimizer,
    scheduler=scheduler,
    training_loss=train_loss,
    eval_losses=eval_losses,
    save_best="eval_l2",  # type: ignore[arg-type]
    save_dir="../data/processed/model/test",
)

wandb.finish()