In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
%matplotlib inline
import os

torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_math_sdp(True)

In [None]:
from model import diffusion as ModelTrain

### Define model parameter

In [None]:
data_file = {"file_path": "@graph/@STgraph",
             "batch_size": 10,
             "max_data":2000,
             "time_step": "any"}
model_parameter = {"hid_features":8,
                   "num_node_features":6,
                   "num_output_features":2,
                   "num_heads":8,
                   "num_layers":6,
                   "dropout":0.1,
                   "kernel_size":5}
train_parameter = {"epoch":100, 
                   "lr":0.001,
                   "train_rate":0.8}

In [None]:
DeformTrainer = ModelTrain.network_diffusion_trainer(data_file = data_file,
                                                     model_parameter = model_parameter,
                                                     train_parameter = train_parameter,
                                                     model_type= "SymmetricalResidualGAT")

Copy the '...lin.(weight|bias)' of the old version of GATConv to '...lin_src.(weight|bias)' and '...lin_dst.(weight|bias)' required by the new version. Keep the rest of the parameters unchanged.

import torch, re, os


save_model_path = "model/@model/20240510st01_SymmetricalResidualGAT.pth"


def convert_gat_state_dict(old_sd):

    new_sd = {}
    pattern = re.compile(r'\.lin\.(weight|bias)$')  

    for k, v in old_sd.items():
        m = pattern.search(k)
        if m:                             
            base = k[: -len('.lin.' + m.group(1))]  
            new_sd[f'{base}.lin_src.{m.group(1)}'] = v.clone()
            new_sd[f'{base}.lin_dst.{m.group(1)}'] = v.clone()

        else:
            new_sd[k] = v
    return new_sd

if os.path.exists(save_model_path):

    raw_sd = torch.load(save_model_path, map_location='cpu')

    if next(iter(raw_sd)).startswith('module.'):
        raw_sd = {k.replace('module.', ''): v for k, v in raw_sd.items()}
    new_sd = convert_gat_state_dict(raw_sd)

    DeformTrainer.model.load_state_dict(new_sd, strict=True)
    torch.save(new_sd, save_model_path.replace('.pth', '_v2.pth'))


In [None]:
save_model_path = f"model/@model/20240510st01_SymmetricalResidualGAT_v2.pth"
#save_model_path = f"model/@model/20240510st01_{DeformTrainer.model_type}.pth"
if os.path.exists(save_model_path):
    print(f"Already found weights")
    DeformTrainer.model.load_state_dict(torch.load(save_model_path))
else:
    DeformTrainer.write_model_parameter_into_txt(data_file,model_parameter,train_parameter,
                                                 file_path=save_model_path+".txt")
    print(f"Begin training")
    DeformTrainer.train()
    torch.save(DeformTrainer.model.state_dict(), save_model_path)

In [None]:
try:
    plt.plot(np.arange(train_parameter["epoch"]),DeformTrainer.history["train"],
            label = "train")
    plt.plot(np.arange(train_parameter["epoch"]),DeformTrainer.history["test"],
            label = "test")
    plt.title(f"Loss: {DeformTrainer.history['test'][-1]:.4e}")
    plt.xlabel("Epoch", fontsize=18)
    plt.ylabel("ds MSE loss", fontsize=18)
    plt.yscale("log")
    plt.legend()
except AttributeError:
    print("Model load from local, not training")

In [None]:
model = DeformTrainer.model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
val_data = torch.load("@graph/@STgraph/0.pt")
dxy_target = (val_data.y - val_data.x[:,:2]).numpy()
dxy_xy_target = val_data.y.numpy()
dxy_xy_start = val_data.x[:,:2].numpy()
val_data = val_data.to(device)
model.eval()
with torch.no_grad():
    dxy_predict = model(val_data)
    dxy_predict = dxy_predict.cpu().numpy()
dxy_xy_predict = dxy_predict + dxy_xy_start

In [None]:
predict_ds = dxy_xy_predict - dxy_xy_start
simulat_ds = val_data.y.cpu().numpy() - val_data.x[:,:2].cpu().numpy()
plt.scatter(predict_ds[:,0], predict_ds[:,1], label = "Prediction",s = 3)
plt.scatter(simulat_ds[:,0], simulat_ds[:,1], label = "Simulation",s = 3)
plt.axis("equal")
plt.legend()

In [None]:
fig, axs = plt.subplots(1, 3, figsize=(15, 5.2))

# Target
axs[1].scatter(dxy_xy_target[:,0], dxy_xy_target[:,1], label="Target", color="orange", s=2)
axs[1].axis("equal")
axs[1].legend()

# Predict
axs[2].scatter(dxy_xy_predict[:,0], dxy_xy_predict[:,1], label="Predict", color="purple", s=2)
axs[2].axis("equal")
axs[2].legend()

# Start
axs[0].scatter(dxy_xy_start[:,0], dxy_xy_start[:,1], label="Start", color="steelblue", s=2)
axs[0].axis("equal")
axs[0].legend()

plt.show()
