In [None]:
import ml_collections
import numpy as np
import yaml
import os
import torch
import scipy.io
import random
from src2.utilities import *
from src2.utilities_NSPDE import dataloader_nspde_2d, train_nspde
from src2.fusion_model import Fusion_NSPDE
from torch.utils.data import TensorDataset, DataLoader

The history saving thread hit an unexpected error (OperationalError('attempt to write a readonly database')).History will not be written to the database.


# Set Up

In [None]:
config_dir = 'configs/example.yaml'
with open(config_dir) as file:
    config = ml_collections.ConfigDict(yaml.safe_load(file))

# Set random seed
seed = config.seed
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

if (config.device ==
        "cuda" and torch.cuda.is_available()):
    config.update({"device": "cuda:0"}, allow_val_change=True)
else:
    config.update({"device": "cpu"}, allow_val_change=True)
DEVICE = torch.device(config.device)

# Load and Processing Data


In [3]:
BATCH_SIZE = 3
T_TRAIN_STEPS = 201
N_TOTAL_SAMPLES = 1500
# Check training data
data_training = scipy.io.loadmat(r"data/public_data.mat")
W_raw = torch.tensor(data_training['W']).float()
Sol_raw = torch.tensor(data_training['sol']).float()
print('Raw Noise W shape: ', W_raw.shape)  
print('Raw Solution Sol shape: ', Sol_raw.shape)

Raw Noise W shape:  torch.Size([1500, 32, 32, 201])
Raw Solution Sol shape:  torch.Size([1500, 32, 32, 201])


In [None]:
N_TOTAL_SAMPLES = 1500
N_TRAIN = 1300
N_VAL = 100
N_TEST = 100 
SEED_SPLIT = 42 
torch.manual_seed(SEED_SPLIT)

indices = torch.randperm(N_TOTAL_SAMPLES)
idx_train = indices[:N_TRAIN]
idx_val = indices[N_TRAIN:N_TRAIN + N_VAL]
idx_test = indices[N_TRAIN + N_VAL:]

W_train_raw = W_raw[idx_train]
Sol_train_raw = Sol_raw[idx_train]

W_val_raw = W_raw[idx_val]
Sol_val_raw = Sol_raw[idx_val]

W_test_raw = W_raw[idx_test]
Sol_test_raw = Sol_raw[idx_test]

In [None]:
def create_dataloader(W, Sol, batch_size, shuffle):
    
    u0 = Sol[..., 0] # [N, 32, 32]
    u_label = Sol       # [N, 32, 32, 201]
    xi_data = W         # [N, 32, 32, 201]
    
    return DataLoader(TensorDataset(u0, xi_data, u_label), batch_size=batch_size, shuffle=shuffle, num_workers = 0)

In [None]:
train_loader = create_dataloader(W_train_raw, Sol_train_raw, BATCH_SIZE, shuffle=True)
val_loader = create_dataloader(W_val_raw, Sol_val_raw, BATCH_SIZE, shuffle=False)
test_loader = create_dataloader(W_test_raw, Sol_test_raw, BATCH_SIZE, shuffle=False)

print(f"Train/Val/Test Dataloaders created: {N_TRAIN}/{N_VAL}/{N_TEST} samples.")

Train/Val/Test Dataloaders created: 1300/100/100 samples.


# Initialize Model

In [None]:
# === Model HYBRID (DLR + NSPDE) ===

HIDDEN_CHANNELS = 32  
MODES_X = 16          
MODES_Y = 16         
N_ITER_SOLVER = 1     
SOLVER_MODE = 'diffeq' 

T_points = torch.linspace(0, 0.020, 201).to(DEVICE) 
X_points = torch.linspace(0, 1, 32).to(DEVICE)
Y_points = torch.linspace(0, 1, 32).to(DEVICE)

# 2. Initialize the Model with the DLR Encoder
model = Fusion_NSPDE(
    dim=2, 
    in_channels=1, 
    noise_channels=1, 
    hidden_channels=HIDDEN_CHANNELS,
    n_iter=N_ITER_SOLVER, 
    modes1=MODES_X, 
    modes2=MODES_Y, 
    solver=SOLVER_MODE,
  
    T_points=T_points,
    X_points=X_points,
    Y_points=Y_points,
    device=DEVICE
).to(DEVICE)

print(model)

print(f"Hybrid Model initialized using {SOLVER_MODE}.")
if hasattr(model, 'use_dlr') and model.use_dlr:
    print(f"DLR Encoder is ACTIVE. Context channels: {model.context_channels}")
else:
    print("WARNING: DLR Encoder is NOT active.")

Pure U0 Indices: [1, 3, 4]
Xi/Mixed Indices: [0, 2, 5, 6, 7, 8, 9]
DLR Split Mode ACTIVE. U_feat: 6, Xi_feat: 15
Fusion_NSPDE(
  (dlr_encoder): LearnableDLREncoder(
    (physics_engine): ParabolicIntegrate_2d()
    (mlp): Sequential(
      (0): Linear(in_features=10, out_features=32, bias=True)
      (1): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
      (2): GELU(approximate='none')
      (3): Linear(in_features=32, out_features=1, bias=True)
      (4): Tanh()
    )
  )
  (norm_u): InstanceNorm3d(6, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
  (norm_xi): InstanceNorm3d(14, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
  (lift): Linear(in_features=1, out_features=32, bias=True)
  (spde_func): SPDEFunc1d(
    (net_F): Sequential(
      (0): Conv2d(38, 128, kernel_size=(1, 1), stride=(1, 1))
      (1): GroupNorm(4, 128, eps=1e-05, affine=True)
      (2): GLU(dim=1)
      (3): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))
    )
    (ne

# Load Check Point

In [None]:
CHECKPOINT_PATH = r'src/best_fusion_model_trained.pth' 

if os.path.exists(CHECKPOINT_PATH):
    print(f"Found checkpoint: {CHECKPOINT_PATH}")
    state_dict = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
    model_dict = model.state_dict() 
    
    pretrained_dict = {}
    keys_loaded = []
    keys_skipped = []

    for k, v in state_dict.items():
        if k in model_dict:
            
            if v.shape == model_dict[k].shape:
                pretrained_dict[k] = v
                keys_loaded.append(k)
            else:
                keys_skipped.append(f"{k} (Shape mismatch: ckpt {v.shape} vs model {model_dict[k].shape})")
        else:
            keys_skipped.append(f"{k} (Not found in current model)")

    # Update model state
    model_dict.update(pretrained_dict)
    
    model.load_state_dict(model_dict, strict=False)
    
    print(f"\nLoading Status:")
    print(f"✅ Loaded {len(keys_loaded)} layers successfully.")
    if len(keys_skipped) > 0:
        print(f"⚠️ Skipped {len(keys_skipped)} layers:")
        for msg in keys_skipped[:5]:
            print(f"   - {msg}")
            
    print("\nThe model is ready for further training (fine-tuning).")
else:
    print(f"No checkpoint found at {CHECKPOINT_PATH}. Training will start from scratch.")

Found checkpoint: src2/best_fusion_model_trained.pth

Loading Status:
✅ Loaded 58 layers successfully.

Mô hình đã sẵn sàng để train tiếp (Fine-tuning).


# Training

In [None]:
LEARNING_RATE = 1e-4
EPOCHS = 20
PRINT_EVERY = 1

myloss = LpLoss(size_average=False) 

print("\nStarting model training…")

model_trained, losses_train, losses_val = train_nspde(
    model=model, 
    train_loader=train_loader, 
    test_loader=val_loader, 
    device=DEVICE, 
    myloss=myloss, 
    batch_size=BATCH_SIZE, 
    epochs=EPOCHS,
    learning_rate=LEARNING_RATE,
    plateau_patience=2,      
    plateau_terminate=20,     
    scheduler_gamma=0.5,      
    print_every=PRINT_EVERY,
    checkpoint_file=CHECKPOINT_FILE
)

print("Training completed. Loading the best model")


Bắt đầu huấn luyện mô hình...
Training interrupted explicitly.
Huấn luyện hoàn tất. Tải mô hình tốt nhất.
