In [None]:
# ========================
# Standard library imports
# ========================
from datetime import datetime
import os
import pickle
import tempfile
from typing import Dict, List, Optional, Tuple
import warnings
import yaml

# =========================
# Third-party library imports
# =========================
import numpy as np
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from torch.optim import Adam

import flwr as fl
from flwr.common import (
    Context,
    Metrics,
    NDArrays,
    Parameters,
    Scalar,
    ndarrays_to_parameters,
    parameters_to_ndarrays,
)
from flwr.server.strategy import FedAvg, FedAvgM, FedOpt, FedProx

# =========================
# Local/project imports
# =========================
from configs.config_utils import AppConfig, load_config

from scripts.data_utils.data_trans import (
    s_denormalize,
    s_normalize,
    s_normalize_none,
    v_denormalize,
    v_normalize,
)
import scripts.data_utils.pytorch_ssim  # module has side effects / functions

from scripts.diffusion_models.diffusion_model import *  # TODO: make explicit
from scripts.flwr.flwr_client import *                 # TODO: make explicit
from scripts.flwr.flwr_evaluation import get_evaluate_fn
from scripts.flwr.flwr_utils import *                  # TODO: make explicit
from scripts.pde_solvers.solver import FWIForward


In [None]:
### Load configuration & setup diffusion model ###
config = load_config('configs/config_2clients.yml')
mp.set_start_method('spawn', force=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## setup forward solver ##
ctx = {
    'n_grid': config.forward.n_grid, 'nt': config.forward.nt, 
    'dx': config.forward.dx, 'nbc': config.forward.nbc, 
    'dt': config.forward.dt, 'f': config.forward.f,
    'sz': config.forward.sz, 'gz': config.forward.gz,
    'ng': config.forward.ng
}

fwi_forward = FWIForward(ctx, device, normalize=True, 
                         v_denorm_func=v_denormalize, 
                         s_norm_func=s_normalize_none)

## setup diffusion model ##
diffusion_args = {
    'dim': config.diffusion.dim, 'dim_mults': config.diffusion.dim_mults, 
    'flash_attn': config.diffusion.flash_attn, 'channels': config.diffusion.channels,
    'image_size': config.diffusion.image_size, 'timesteps': config.diffusion.timesteps, 
    'sampling_timesteps': config.diffusion.sampling_timesteps, 
    'objective': config.diffusion.objective
}

unet_model = Unet(
    dim=diffusion_args.get('dim'),
    dim_mults=diffusion_args.get('dim_mults'),
    flash_attn=diffusion_args.get('flash_attn'),
    channels=diffusion_args.get('channels')
)

diffusion = GaussianDiffusion(
    unet_model,
    image_size=diffusion_args.get('image_size'),
    timesteps=diffusion_args.get('timesteps'),
    sampling_timesteps=diffusion_args.get('sampling_timesteps'),
    objective=diffusion_args.get('objective')
).to(device)

In [None]:
### hyper-parameters and configurations ###
client_data_list = []
final_parameters_store = {} 
ssim_loss = pytorch_ssim.SSIM(window_size=11)  
num_clients = config.experiment.num_clients
regularization = config.experiment.regularization
fed_rounds = config.federated.num_rounds
local_epochs = config.federated.local_epochs
local_lr = config.federated.local_lr

### Load & prepare data ###
velocity_data_path = config.path.velocity_data_path
client_seismic_data_path = config.path.client_seismic_data_path
gt_seismic_data_path = config.path.gt_seismic_data_path
model_path = config.path.model_path
output_path = config.path.output_path

checkpoint = torch.load(model_path, weights_only=True)
state_dict = checkpoint.get('model', checkpoint)
diffusion.load_state_dict(state_dict)
diffusion.eval()
unwrapped_diffusion_model = diffusion
unwrapped_diffusion_model.eval()
diffusion_state_dict = unwrapped_diffusion_model.state_dict()

for i in range(num_clients):
    client_data = np.load(client_seismic_data_path + f"client{i+1}/CF.npy")[0:1,:]
    client_data_list.append(torch.tensor(client_data).float().to(device))
for i, data in enumerate(client_data_list):
    print(f"Data shape for Client {i}: {data.shape}")
    
gt_seismic_data = torch.tensor(np.load(gt_seismic_data_path + "CF.npy")[0:1,:]).float().to(device)
vm_data = torch.tensor(np.load(velocity_data_path + "CF.npy")[0:1,:]).float()

initial_model = data_trans.prepare_initial_model(vm_data, initial_type='smoothed', 
                                                 sigma=config.forward.initial_sigma)
initial_model = F.pad(initial_model, (1, 1, 1, 1), "constant", 0)


In [None]:
### setup server functions and strategy ###
if diffusion_state_dict is not None and diffusion_args is not None:
    server_diffusion_model = diffusion
    
evaluate_fn = get_evaluate_fn(
    model_shape=initial_model.shape,
    seismic_data=gt_seismic_data,
    mu_true=vm_data,
    fwi_forward=fwi_forward,
    data_trans=data_trans,
    ssim_loss=ssim_loss,
    device=device,
    diffusion_model=server_diffusion_model,
    total_rounds=config.federated.num_rounds,
    final_params_store=final_parameters_store,
    config = config
)

### setup strategy ###
strategy_classes = {
    "FedAvg": FedAvg,
    "FedProx": FedProx, 
    "FedAvgM": FedAvgM,
    "FedOpt": FedOpt
}

strategy_class = strategy_classes[config.experiment.strategy]
strategy_params = {"server_momentum": config.experiment.server_momentum}

strategy = strategy_class(
    fraction_fit=1.0, 
    min_fit_clients=config.experiment.num_clients,
    min_available_clients=config.experiment.num_clients,
    evaluate_fn=evaluate_fn, 
    fraction_evaluate=0.0, 
    on_fit_config_fn=fit_config_fn(config = config),
    initial_parameters=ndarrays_to_parameters(tensor_to_ndarrays(initial_model.to(device))),
    **strategy_params
)

### setup client function ###
client_fn_instance = client_fn_factory(client_data_list)
client_resources = {"num_cpus": 1, "num_gpus": 0.5} if device.type == 'cuda' else {"num_cpus": 2}
ray_init_args = {"include_dashboard": False}

### start federated learning simulation ###
history = fl.simulation.start_simulation(
    client_fn=client_fn_instance, 
    num_clients=num_clients,
    config=fl.server.ServerConfig(num_rounds=fed_rounds),
    strategy=strategy,
    client_resources=client_resources,
    ray_init_args=ray_init_args
)

saved_ndarrays = final_parameters_store["final_model"]
final_model = ndarrays_to_tensor(saved_ndarrays, device)

timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
strategy_results = {
    'final_model': final_model.cpu().detach().numpy(),
    'metrics_history': history.metrics_centralized,
    'losses_history': history.losses_centralized,
    'config': config
}

filename = f"results/{config.experiment.strategy}_{config.experiment.regularization}_{config.experiment.scenario_flag}_{timestamp}.pkl"
with open(filename, 'wb') as f:
    pickle.dump(strategy_results, f)

[33m(raylet)[0m [2025-08-05 15:58:02,025 E 41741 41775] (raylet) file_system_monitor.cc:111: /tmp/ray/session_2025-08-05_13-23-30_316948_41493 is over 95% full, available space: 26610892800; capacity: 982821515264. Object creation will fail if spilling is required.
