In [2]:
"""
import torch, numpy as np
from torch.utils.data import DataLoader, TensorDataset, random_split
from torch.nn import MSELoss
from model import FKAN 

psi = torch.load("psi_T1.pt") 
Y = np.load("Y_T1.npy")
Y_mean = np.load("Y_mean.npy")
Y_std  = np.load("Y_std.npy")
Yn = (Y - Y_mean) / (Y_std + 1e-6)
Yt = torch.tensor(Yn, dtype=torch.float32)

ds = TensorDataset(psi, Yt)
train_len = int(0.8 * len(ds))
val_len = len(ds) - train_len
train_ds, val_ds = random_split(ds, [train_len, val_len])

train_loader = DataLoader(train_ds, batch_size=1024, shuffle=True, pin_memory=True)
val_loader   = DataLoader(val_ds,   batch_size=1024, shuffle=False, pin_memory=True)

device= "cuda" if torch.cuda.is_available() else "cpu"
in_dim  = psi.shape[1]
out_dim = Yt.shape[1]
model = FKAN(in_dim, hid_feats=64, out_feats=out_dim, G=5).to(device)

opt  = torch.optim.AdamW(model.parameters(), lr=3e-3, weight_decay=1e-4)
crit = MSELoss()
best, wait, patience = 1e9, 0, 5

for epoch in range(1, 101):
    model.train(); tr_loss = 0
    for xb, yb in train_loader:
        xb, yb = xb.to(device), yb.to(device)
        opt.zero_grad()
        loss = crit(model(xb), yb)
        loss.backward()
        opt.step()
        tr_loss += loss.item() * xb.size(0)
    tr_loss /= len(train_loader.dataset)

    model.eval(); va_loss = 0
    with torch.no_grad():
        for xb, yb in val_loader:
            xb, yb = xb.to(device), yb.to(device)
            va_loss += crit(model(xb), yb).item() * xb.size(0)
    va_loss /= len(val_loader.dataset)

    print(f"[{epoch:03d}] train={tr_loss:.4f}  val={va_loss:.4f}")

    if va_loss < best - 1e-4:
        best, wait = va_loss, 0
        torch.save(model.state_dict(), "best_sugarnet.pt")
        print("  ✔️  saved best_sugarnet.pt")
    else:
        wait += 1
        if wait >= patience:
            print("⏹️  Early stop"); break

print("BEST val MSE:", best)
"""



'\nimport torch, numpy as np\nfrom torch.utils.data import DataLoader, TensorDataset, random_split\nfrom torch.nn import MSELoss\nfrom model import FKAN \n\npsi = torch.load("psi_T1.pt") \nY = np.load("Y_T1.npy")\nY_mean = np.load("Y_mean.npy")\nY_std  = np.load("Y_std.npy")\nYn = (Y - Y_mean) / (Y_std + 1e-6)\nYt = torch.tensor(Yn, dtype=torch.float32)\n\nds = TensorDataset(psi, Yt)\ntrain_len = int(0.8 * len(ds))\nval_len = len(ds) - train_len\ntrain_ds, val_ds = random_split(ds, [train_len, val_len])\n\ntrain_loader = DataLoader(train_ds, batch_size=1024, shuffle=True, pin_memory=True)\nval_loader   = DataLoader(val_ds,   batch_size=1024, shuffle=False, pin_memory=True)\n\ndevice= "cuda" if torch.cuda.is_available() else "cpu"\nin_dim  = psi.shape[1]\nout_dim = Yt.shape[1]\nmodel = FKAN(in_dim, hid_feats=64, out_feats=out_dim, G=5).to(device)\n\nopt  = torch.optim.AdamW(model.parameters(), lr=3e-3, weight_decay=1e-4)\ncrit = MSELoss()\nbest, wait, patience = 1e9, 0, 5\n\nfor epoch i

In [10]:
import torch
import numpy as np
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader, TensorDataset, random_split
from torch.nn import SmoothL1Loss
from torch.optim.lr_scheduler import ReduceLROnPlateau
from model import FKAN

T = 288
LOW_K = int(T * 0.10)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(2025)
np.random.seed(2025)

def vectorize(gamma: torch.Tensor) -> torch.Tensor:
    B, K = gamma.shape
    k = (K - 1) // 2
    R0 = gamma[:, :1]
    R = gamma[:, 1 : k+1]
    I = gamma[:, k+1 : ]
    half = torch.complex(R, I)
    conj_half = torch.conj(torch.flip(half[:, 1:], dims=[1]))
    full_spec = torch.cat([R0, half, conj_half], dim=1)
    return torch.fft.ifft(full_spec, n=T).real

psi = torch.load("psi_T1.pt")
Y = np.load("Y_T1.npy")
Y_mean = np.load("Y_mean.npy")
Y_std = np.load("Y_std.npy")
Yn = (Y - Y_mean) / (Y_std + 1e-6)
Yt = torch.tensor(Yn, dtype=torch.float32)
target_dim = Yt.shape[1]

ds = TensorDataset(psi, Yt)
n_tr = int(0.8 * len(ds))
train_ds, val_ds = random_split(ds, [n_tr, len(ds)-n_tr])
train_loader = DataLoader(train_ds, batch_size=256, shuffle=True, pin_memory=True, num_workers=2)
val_loader = DataLoader(val_ds, batch_size=256, shuffle=False, pin_memory=True, num_workers=2)

in_feats = psi.shape[1]
freq_feats = 1 + 2 * LOW_K
model = FKAN(in_feats, hid_feats=128, out_feats=freq_feats, G=5).to(DEVICE)

opt = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-5)
sched = ReduceLROnPlateau(opt, mode='min', factor=0.5, patience=3)
crit = SmoothL1Loss()

best_val = float('inf')
wait = 0
patience = 5

for epoch in range(1, 51):
    model.train()
    tr_loss = 0.0
    for x, y in train_loader:
        x = x.to(DEVICE)
        y = y.to(DEVICE)
        opt.zero_grad()
        gamma_pred = model(x)
        y_pred = vectorize(gamma_pred)[:, :target_dim]
        loss = crit(y_pred, y)
        loss.backward()
        opt.step()
        tr_loss += loss.item() * x.size(0)
    tr_loss /= len(train_ds)
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for x, y in val_loader:
            x = x.to(DEVICE)
            y = y.to(DEVICE)
            gamma_pred = model(x)
            y_pred = vectorize(gamma_pred)[:, :target_dim]
            val_loss += crit(y_pred, y).item() * x.size(0)
    val_loss /= len(val_ds)
    sched.step(val_loss)
    print(f"[{epoch:02d}] train_loss={tr_loss:.4f}  val_loss={val_loss:.4f}")
    if val_loss < best_val - 1e-4:
        best_val = val_loss
        wait = 0
        torch.save(model.state_dict(), "best_fkan.pt")
        print("Saved best_fkan.pt")
    else:
        wait += 1
        if wait >= patience:
            print(f"Early stopping at epoch {epoch}")
            break

print(f"Best validation loss: {best_val:.4f}")


[01] train_loss=0.3588  val_loss=0.3628
Saved best_fkan.pt
[02] train_loss=0.3585  val_loss=0.3624
Saved best_fkan.pt
[03] train_loss=0.3578  val_loss=0.3615
Saved best_fkan.pt
[04] train_loss=0.3565  val_loss=0.3595
Saved best_fkan.pt
[05] train_loss=0.3547  val_loss=0.3580
Saved best_fkan.pt
[06] train_loss=0.3530  val_loss=0.3563
Saved best_fkan.pt
[07] train_loss=0.3511  val_loss=0.3533
Saved best_fkan.pt
[08] train_loss=0.3493  val_loss=0.3529
Saved best_fkan.pt
[09] train_loss=0.3477  val_loss=0.3507
Saved best_fkan.pt
[10] train_loss=0.3455  val_loss=0.3476
Saved best_fkan.pt
[11] train_loss=0.3430  val_loss=0.3466
Saved best_fkan.pt
[12] train_loss=0.3410  val_loss=0.3461
Saved best_fkan.pt
[13] train_loss=0.3390  val_loss=0.3439
Saved best_fkan.pt
[14] train_loss=0.3367  val_loss=0.3407
Saved best_fkan.pt
[15] train_loss=0.3352  val_loss=0.3411
[16] train_loss=0.3328  val_loss=0.3388
Saved best_fkan.pt
[17] train_loss=0.3306  val_loss=0.3372
Saved best_fkan.pt
[18] train_loss=