In [None]:
# Install required packages.
import os
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam

from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from tqdm import tqdm

os.environ['TORCH'] = torch.__version__
print(torch.__version__)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Helper function for visualization.
import matplotlib.pyplot as plt
import numpy as np


project_root = os.path.abspath(os.path.join(os.path.dirname("__file__"), ".."))
sys.path.append(project_root)

from models.utils import *
from models.data_utils.transform import *
from models.data_utils.dataset import FoamDataset
from models.model import GraphTransformer


In [None]:
# 示例使用
base_path = "D:\\Github\\datasets\\"
foam_names = ["foam0", "foam1", "foam2", "foam3", "foam4", "foam5"]
file_num = 14
sizes = [64, 64]

pre_transform = None

rawdata = FoamDataset(root=base_path, foam_names=foam_names, file_num=file_num, sizes=sizes, pre_transform=pre_transform)
print(rawdata[0].x.device)

In [None]:
hyperparameters = {
    'num_features': rawdata[0].x.size(1),
    'gnn_dim': 128,
    'gnn_layers': 3,
    'transformer_width': 256,
    'transformer_layers': 4,
    'transformer_heads': 8,
    'hidden_dim': 512,
    'output_dim': 256,
    'embedding_dim': 32,
    'pos_dim': rawdata[0].pos.size(1),
    'dropout': 0.1,
    'gnn_type': 'GCNConv',
    'pool' : 'cls',
    'patch_rw_dim': 16,

    'num_epochs': 3000,
    'lr': 1e-4,
    'gamma': 0.9,
    'step_size': 500,
    'batch_size': 16,
    'num_hops': 1,
    'n_patches': 128
}

In [None]:
transform = GraphPartitionTransform(
    n_patches = hyperparameters['n_patches'],
    metis = True,
    drop_rate = 0.0,
    num_hops = hyperparameters['num_hops'],
    is_directed = False,
    patch_rw_dim = hyperparameters['patch_rw_dim'],
    patch_num_diff = -1
)

In [None]:
rawdata.transform = transform

datasets = [x for x in rawdata]

In [None]:
for key in datasets[0].keys():
    print(key, datasets[0][key].shape)

In [None]:
train_loader = DataLoader(datasets, batch_size=hyperparameters['batch_size'], shuffle=True)

In [None]:
first = next(iter(train_loader))
for key in first.keys():
    print(key, first[key].shape)

In [None]:
graph_transformer = GraphTransformer(
    num_features=hyperparameters['num_features'],
    gnn_dim=hyperparameters['gnn_dim'],
    gnn_layers=hyperparameters['gnn_layers'],
    transformer_width=hyperparameters['transformer_width'],
    transformer_layers=hyperparameters['transformer_layers'],
    transformer_heads=hyperparameters['transformer_heads'],
    hidden_dim=hyperparameters['hidden_dim'],
    output_dim=hyperparameters['output_dim'],
    embedding_dim=hyperparameters['embedding_dim'],
    pos_dim=hyperparameters['pos_dim'],
    dropout=hyperparameters['dropout'],
    gnn_type=hyperparameters['gnn_type'],
    pool=hyperparameters['pool'],
    patch_rw_dim=hyperparameters['patch_rw_dim']
).to(device)

In [None]:
def count_parameters(model):
    total_params = sum(p.numel() for p in model.parameters())  # 总参数数量
    return total_params

total_params = count_parameters(graph_transformer)
print(f"Total params: {total_params}, {total_params / 1024 / 1024} M")

In [None]:
optimizer = Adam(graph_transformer.parameters(), lr=hyperparameters['lr'])

In [None]:
loss_history = []
import tqdm
num_epochs = 1000
with tqdm.tqdm(total=num_epochs) as pbar:
    for epoch in range(num_epochs):
        graph_transformer.train()
        total_loss = 0
        for data in train_loader:
            data = data.to(device)
            optimizer.zero_grad()
            out = graph_transformer(data).reshape(-1)
            loss = F.mse_loss(out, data.y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        epoch_loss = total_loss / len(train_loader)
        loss_history.append(epoch_loss)
        pbar.set_postfix({'loss': f"{epoch_loss:.6f}"})
        pbar.update(1)

In [None]:
plot_losses(loss_history)

In [None]:
sample = rawdata[1]
X = get_gridX(sizes, device=device)
with torch.no_grad():
    graph_transformer.eval()
    weights, mus, kappas = graph_transformer.vmf_param(sample.to(device))
    img_predict = multi_vmf(weights.squeeze(), mus.squeeze(), kappas.squeeze(), X).cpu().numpy()
    img_predict = img_predict.reshape(sizes)

    tgt_w, tgt_m, tgt_k = extract_param(sample.y)
    img_reference = multi_vmf(tgt_w, tgt_m, tgt_k, X).cpu().numpy()
    img_reference = img_reference.reshape(sizes)
    plot_outputs_3d(img_reference, img_predict, sizes)