In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import math
from tqdm import tqdm
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm

# 设置设备（如果有GPU则使用）
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 加载 ray 数据的函数
def load_ray_data(filename, sizes):
    """
    Loads ray data from a binary file.
    
    Parameters:
    - filename (str): Path to the binary file containing ray data.
    
    Returns:
    - torch.Tensor: Ray data tensor.
    """
    stdata = np.fromfile(filename, dtype=np.float32)
    stdata = stdata.reshape(sizes)
    stdata_pos = np.sum(stdata, axis=(-2, -1))
    stdata_pos = stdata_pos / np.sum(stdata_pos) * sizes[0] * sizes[1] / math.pi / 4
    stdata_pos = torch.tensor(stdata_pos, dtype=torch.float, device=device)
    return stdata_pos.to(device)



# 定义残差块
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.conv3 = nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, bias=False)  # 添加 1x1 卷积层
        self.bn3 = nn.BatchNorm2d(out_channels)
        self.downsample = downsample

        # 可学习参数 alpha，用于增强残差连接
        self.alpha = nn.Parameter(torch.ones(1))

    def forward(self, x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)

        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out = self.alpha * out + identity  # 残差连接带有可学习权重
        out = F.relu(out)
        return out

# 定义 EncoderResNet 模型，添加 Attention 机制
class EncoderResNet(nn.Module):
    def __init__(self, embedding_dim=64):
        super(EncoderResNet, self).__init__()
        self.in_channels = 64

        # 初始卷积层，使用较小的卷积核，并添加残差块
        self.initial_conv = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=1, stride=1, bias=False),  # 添加 1x1 卷积层来增强信息
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )

        # 使用 MaxPooling 来提取高频信息
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        # 使用卷积代替池化层
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(64)

        # 残差块层定义
        self.layer1 = self._make_layer(64, 2, stride=1)
        self.layer2 = self._make_layer(128, 2, stride=2)
        self.layer3 = self._make_layer(256, 2, stride=2)
        self.layer4 = self._make_layer(512, 2, stride=2)

        # Attention 机制
        self.attention = nn.MultiheadAttention(embed_dim=512, num_heads=8, batch_first=True)
        

        # Attention Pooling 
        self.attention_pooling = nn.Sequential(
            nn.Conv2d(512, 512, kernel_size=1, bias=False),
            nn.Softmax(dim=-1)
        )
        # Fully connected layer
        self.fc = nn.Linear(512, embedding_dim)

    def _make_layer(self, out_channels, blocks, stride):
        downsample = None
        if stride != 1 or self.in_channels != out_channels:
            downsample = nn.Sequential(
                nn.Conv2d(self.in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

        layers = []
        layers.append(ResidualBlock(self.in_channels, out_channels, stride, downsample))
        self.in_channels = out_channels
        for _ in range(1, blocks):
            layers.append(ResidualBlock(out_channels, out_channels))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.initial_conv(x)

        x = self.maxpool(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = F.relu(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        # Attention 机制
        b, c, h, w = x.size()
        x = x.view(b, c, -1).permute(0, 2, 1)  # (batch_size, num_patches, channels)
        x, _ = self.attention(x, x, x)
        x = x.permute(0, 2, 1).view(b, c, h, w)  # (batch_size, channels, height, width)

        # Attention Pooling
        x = self.attention_pooling(x)
        x = torch.sum(x, dim=(-2, -1))  # 将特征图加权求和，得到全局特征

        x = self.fc(x)

        return x

# SphereGaussianMixture 解码器模块定义
class SphereGaussianMixture(nn.Module):
    def __init__(self, embedding_dim=64, hidden_dim=128, num_spheres=32, dropout_rate=0.2):
        super(SphereGaussianMixture, self).__init__()
        self.num_spheres = num_spheres

        self.fc = nn.Sequential(
            nn.Linear(embedding_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_dim, num_spheres * 4)  # 4 parameters per sphere
        )

    def forward(self, embedding):
        # Mixture Weights
        out = self.fc(embedding)  # Shape: (batch_size, num_spheres * 4)
        weights = F.softmax(out[:, :self.num_spheres], dim=-1)  # Shape: (batch_size, num_spheres)

        # Convert theta and phi to angles in radians
        theta_phi = torch.sigmoid(out[:, self.num_spheres: 3 * self.num_spheres])
        theta = theta_phi[:, :self.num_spheres] * math.pi
        phi = theta_phi[:, self.num_spheres:] * 2 * math.pi

        # Calculate axes using spherical coordinates
        cos_theta = torch.cos(theta)
        sin_theta = torch.sin(theta)
        cos_phi = torch.cos(phi)
        sin_phi = torch.sin(phi)
        axes = torch.stack((cos_theta, sin_theta * cos_phi, sin_theta * sin_phi), dim=-1)  # Shape: (batch_size, num_spheres, 3)

        # Regularize kappa with clamping to avoid numerical instability
        kappas = torch.exp(out[:, 3 * self.num_spheres:])  # Shape: (batch_size, num_spheres)
        kappas = torch.clamp(kappas, max=100)  # Clamp kappa for stability

        return weights.squeeze(), axes.squeeze(), kappas.squeeze()

def vmf_pdf(x, axes, kappas):
    # Ensure kappas are non-negative for stability
    # kappas = torch.clamp(kappas, min=1e-10, max=1e5)

    # Define thresholds for approximations
    large_kappa_threshold = 1e5  # Threshold for considering kappa as "large"
    small_kappa_threshold = 1e-3  # Threshold for considering kappa as "small"

    # Approximate normalization constant for large and small kappa values


    norm_const = torch.where(
        kappas > large_kappa_threshold,
        kappas / (2 * math.pi),  # Approximation for large kappa
        kappas / (2 * math.pi * (1-torch.exp(-2*kappas)))
    )
    # norm_const = kappas / (4 * math.pi * (1-torch.exp(-2*kappas)))

    # Compute dot products between input w and the axes of the spheres (unit vectors)
    dot_products = torch.matmul(x, axes.transpose(0, 1))-1  # Shape: (data_sizes, num_spheres)

    # # Compute von Mises-Fisher pdf values
    return norm_const * torch.exp(kappas * dot_products)

# 多重 von Mises-Fisher 分布函数
def multi_vmf(weights, axes, kappas, w):
    # Ensure kappas are non-negative for stability
    kappas = torch.clamp(kappas, min=1e-10, max=1e5)

    # Define thresholds for approximations
    large_kappa_threshold = 1e5  # Threshold for considering kappa as "large"
    small_kappa_threshold = 1e-3  # Threshold for considering kappa as "small"

    # Approximate normalization constant for large and small kappa values


    norm_const = torch.where(
        kappas > large_kappa_threshold,
        kappas / (2 * math.pi),  # Approximation for large kappa
        kappas / (2 * math.pi * (1-torch.exp(-2*kappas)))
    )
    # norm_const = kappas / (4 * math.pi * (1-torch.exp(-2*kappas)))

    # Compute dot products between input w and the axes of the spheres (unit vectors)
    dot_products = torch.matmul(w, axes.transpose(0, 1))-1  # Shape: (data_sizes, num_spheres)

    # Compute the weighted von Mises-Fisher pdf values
    weighted_exps = weights * norm_const * torch.exp(kappas * dot_products)  # Shape: (data_sizes, num_spheres)
    q = torch.sum(weighted_exps, dim=-1)  # Shape: (data_sizes,)
    q = torch.clamp(q, min=1e-10, max=1e10)  # Further clamping to avoid extreme values
    return q

# KL 散度损失函数
def kl_divergence_loss(weights, axes, kappas, raw_data, w, p, **kwargs):

    kl_lambda = kwargs['kl_lambda']
    l1_lambda = kwargs['l1_lambda']
    l2_lambda = kwargs['l2_lambda']
    epoch_step = kwargs['epoch_step']
    # if epoch_step < 500:
    #     kl_lambda = 0.0

    total_prob = multi_vmf(weights, axes, kappas, raw_data)
    kl_loss = -torch.log(total_prob + 1e-10).mean() if kl_lambda > 0 else torch.tensor(0.0)

    q = multi_vmf(weights, axes, kappas, w)
    # nonzero = p > 0
    # area = 4 * math.pi / (weights.shape[0])
    # kl_loss = -torch.sum(p[nonzero] * torch.log(q/p[nonzero])) * area if kl_lambda > 0 else torch.tensor(0.0)
    rec_loss = torch.abs(q - p).mean() if l1_lambda > 0 else torch.tensor(0.0)
    l2_loss = torch.norm(q - p, p=2).mean() if l2_lambda > 0 else torch.tensor(0.0)
    loss = kl_lambda * kl_loss + l1_lambda * rec_loss + l2_lambda * l2_loss
    loss_dict = {'KL': kl_loss.item(), 'Rec': rec_loss.item(), 'L2': l2_loss.item()}
    return loss, loss_dict

def smooth_curve(values, smoothing_factor=0.9):
    smoothed_values = []
    last = values[0]
    for value in values:
        smoothed_value = last * smoothing_factor + (1 - smoothing_factor) * value
        smoothed_values.append(smoothed_value)
        last = smoothed_value
    return smoothed_values
def plot_losses(train_losses, val_losses = None):
    train_losses_smoothed = smooth_curve(train_losses)
    if val_losses is not None:
        val_losses_smoothed = smooth_curve(val_losses)
    
    plt.figure(figsize=(10, 6))
    plt.plot(train_losses_smoothed, label="Training Loss (Smoothed)", color="blue")
    if val_losses is not None:
        plt.plot(val_losses_smoothed, label="Validation Loss (Smoothed)", color="red")
    plt.yscale("log")  # Log scale for the y-axis
    plt.xlabel("Epoch")
    plt.ylabel("Loss (Log Scale)")
    plt.title("Loss (Log Scale with Smoothing)")
    plt.legend()
    plt.show()

def plot_outputs_ed(encoder,decoder, all_ray_data, X, sizes, device):
    """
    Plot the output X for each graph data in the dataset as a heatmap, and plot all ray data in the same figure.
    
    Parameters:
    - model: The trained model.
    - all_graph_data (list): List of graph data to be used for generating outputs.
    - all_ray_data (list): List of ray data corresponding to each graph data.
    - X (torch.Tensor): Prepared input data on the unit sphere.
    - device (torch.device): Device to run the model on.
    """
    encoder.eval()  # Set the model to evaluation mode
    decoder.eval()  # Set the model to evaluation mode

    # Define the ranges for theta and phi
    z_min, z_max = -1, 1
    phi_min, phi_max = -np.pi, np.pi

    # Define tick positions in radians
    z_ticks_pos = np.linspace(z_min, z_max, 5)        # -1, -0.5, 0, 0.5, 1
    phi_ticks = np.linspace(phi_min, phi_max, 5)      # -pi, -pi/2, 0, pi/2, pi

    # Convert tick positions to degrees for labeling
    z_tick_pos_labels = [f"{z:.2f}" for z in z_ticks_pos]
    phi_tick_labels = [f"{int(np.degrees(p))}°" for p in phi_ticks]

    with torch.no_grad():
        for i, ray_data in zip(file_idx, all_ray_data):
            ray_data = ray_data.unsqueeze(0).unsqueeze(0) 
            ray_data = ray_data.to(device)

            target_img = ray_data.cpu().numpy().reshape(sizes[0], sizes[1])

            # 编码器：提取 embedding
            embedding = encoder(ray_data)

            # 解码器：生成权重、轴、kappa 参数
            weights, axes, kappas = decoder(embedding)


            # Compute the multi-vMF output for visualization
            predict_img = multi_vmf(weights, axes, kappas, X).cpu().numpy()  # Convert to numpy for plotting

            # Reshape X_output to match heatmap dimensions
            try:
                predict_img = predict_img.reshape((sizes[0], sizes[1]))
            except ValueError:
                print(f"Error: Cannot reshape output for graph {i}. Check dimensions of X_output.")
                continue

            # Set up subplots
            fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))  # Two subplots in one row

            # Plot heatmap for the multi-vMF output on the first axis
            im1 = ax1.imshow(
                predict_img,
                extent=[phi_min, phi_max, z_min, z_max],
                aspect='auto',
                origin='lower',
                cmap='viridis'  # Adding color map to make visualization better
            )
            ax1.set_title(f'Prediction for Graph {i}')
            ax1.set_xlabel('Phi')
            ax1.set_ylabel('Z')

            # Set x and y ticks for the first subplot
            ax1.set_xticks(phi_ticks)
            ax1.set_yticks(z_ticks_pos)

            # Set tick labels in degrees for the first subplot
            ax1.set_xticklabels(phi_tick_labels)
            ax1.set_yticklabels(z_tick_pos_labels)
            
            # Enable ticks on top and right for the first subplot
            ax1.tick_params(top=True, right=True, labeltop=True, labelright=True)
            
            # Add colorbar to the heatmap
            fig.colorbar(im1, ax=ax1, orientation='vertical')

            im2 = ax2.imshow(
                target_img,
                extent=[phi_min, phi_max, z_min, z_max],
                aspect='auto',
                origin='lower',
                cmap='viridis'  # Different color map for better differentiation
            )
            ax2.set_title(f'Refernece for Graph {i}')
            ax2.set_xlabel('Phi')
            ax2.set_ylabel('Z')

            # Set x and y ticks for the second subplot
            ax2.set_xticks(phi_ticks)
            ax2.set_yticks(z_ticks_pos)

            # Set tick labels in degrees for the second subplot
            ax2.set_xticklabels(phi_tick_labels)
            ax2.set_yticklabels(z_tick_pos_labels)
            
            # Enable ticks on top and right for the second subplot
            ax2.tick_params(top=True, right=True, labeltop=True, labelright=True)
            
            # Add colorbar to the ray data heatmap
            fig.colorbar(im2, ax=ax2, orientation='vertical')

            # Save figure
            # plt.savefig(f'data/cnn/output_{i}.png')
            # plt.close(fig)  # Close figure to free memory
            plt.show()

    encoder.train()  # Set the model back to training mode
    decoder.train()  # Set the model back to training mode

def plot_outputs_3d_ed(encoder, decoder, all_ray_data, X, sizes, device):
    """
    Plot the output X for each graph data in the dataset as a 3D surface plot, and plot all ray data in the same figure.
    
    Parameters:
    - encoder: The trained encoder model.
    - decoder: The trained decoder model.
    - all_ray_data (list): List of ray data corresponding to each graph data.
    - X (torch.Tensor): Prepared input data on the unit sphere.
    - device (torch.device): Device to run the model on.
    """
    encoder.eval()  # Set the model to evaluation mode
    decoder.eval()  # Set the model to evaluation mode

    # Define the ranges for theta and phi
    z_min, z_max = -1, 1
    phi_min, phi_max = -np.pi, np.pi

    # Create meshgrid for 3D plotting
    z_in = np.linspace(-1, 1, sizes[0])
    theta_in=np.linspace(-np.pi, np.pi, sizes[1])

    Z, Phi = np.meshgrid(z_in, theta_in, indexing='ij')

    with torch.no_grad():
        for i, ray_data in enumerate(all_ray_data):
            ray_data = ray_data.unsqueeze(0).unsqueeze(0) 
            ray_data = ray_data.to(device)

            target_img = ray_data.cpu().numpy().reshape(sizes[0], sizes[1])

            # Encoder: extract embedding
            embedding = encoder(ray_data)

            # Decoder: generate weights, axes, kappa parameters
            weights, axes, kappas = decoder(embedding)

            # Compute the multi-vMF output for visualization
            predict_img = multi_vmf(weights, axes, kappas, X).cpu().numpy()  # Convert to numpy for plotting

            # Reshape X_output to match heatmap dimensions
            try:
                predict_img = predict_img.reshape((sizes[0], sizes[1]))
            except ValueError:
                print(f"Error: Cannot reshape output for graph {i}. Check dimensions of X_output.")
                continue

            # Set up subplots for 3D visualization
            fig = plt.figure(figsize=(14, 6))
            ax1 = fig.add_subplot(121, projection='3d')
            ax2 = fig.add_subplot(122, projection='3d')

            # Plot 3D surface for the multi-vMF output
            ax1.plot_surface(Z, Phi, predict_img, rstride=1, cstride=1, cmap='rainbow')
            ax1.set_title(f'Prediction for Graph {i}')
            ax1.set_xlabel('Z')
            ax1.set_ylabel('Phi')
            ax1.set_zlabel('Value')

            # Plot 3D surface for the target ray data
            ax2.plot_surface(Z, Phi, target_img, rstride=1, cstride=1, cmap='rainbow')
            ax2.set_title(f'Reference for Graph {i}')
            ax2.set_xlabel('Z')
            ax2.set_ylabel('Phi')
            ax2.set_zlabel('Value')

            # Show the figure
            plt.show()

    encoder.train()  # Set the model back to training mode
    decoder.train()  # Set the model back to training mode



In [None]:

# 数据加载
data_dir = 'D:/Downloads/foamGNN/raw'
sizes = np.array([32, 64, 32, 64])
all_ray_data = []

file_idx = [1]

# for i in file_idx:
#     rayfile = os.path.join(data_dir, f'{i}/stdata.bin')
#     ray_data = load_ray_data(rayfile, sizes)
#     # ray_data = torch.log(ray_data + 1)
#     all_ray_data.append(ray_data)

datafile = "D:/Github/datasets/raw_data/foam0/2/stdataNonSpe.bin"
ray_data = load_ray_data(datafile, sizes)
all_ray_data = [ray_data]

# 准备输入数据 X
i_idx = torch.arange(sizes[0], dtype=torch.float, device=device) / sizes[0]
j_idx = torch.arange(sizes[1], dtype=torch.float, device=device) / sizes[1]
i_grid, j_grid = torch.meshgrid(i_idx, j_idx)
pos_x = i_grid * 2 - 1
pos_phi = (j_grid * 2 - 1) * np.pi
pos_r = torch.sqrt(1 - pos_x**2)
pos_y = pos_r * torch.cos(pos_phi)
pos_z = pos_r * torch.sin(pos_phi)
X = torch.stack((pos_x, pos_y, pos_z), dim=-1).reshape(-1, 3).to(device)


filename = "D:/Github/datasets/raw_data/foam0/2/rawdataNonSpe.bin"
rawdata = np.fromfile(filename, dtype=np.float32)
rawdata = rawdata.reshape(-1, 4)
print(rawdata.shape)
x = rawdata[:,0]
# print(np.max(data[:,1]))
phi = rawdata[:,1]-np.pi
r = np.sqrt(1 - x**2)
y = r * np.cos(phi)
z = r * np.sin(phi)


raw_X = np.column_stack((x, y, z))
raw_num = min(4096, raw_X.shape[0])
raw_X = raw_X[:raw_num, :]
print(raw_X.shape)

# 将数据转换为张量
raw_data = torch.tensor(raw_X, dtype=torch.float32, device=device)
all_raw_data = [raw_data]

(1010853, 4)
(4096, 3)


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [None]:
# 训练参数
num_epochs = 5000
kl_lambda = 1
l1_lambda = 10
l2_lambda = 1

# 模型定义
embedding_dim = 256
hidden_dim = 512
num_spheres = 64

init_lr = 5e-4
decay_rate = 0.9999

In [None]:
# 初始化模型（编码器和解码器）和优化器
encoder = EncoderResNet(embedding_dim=embedding_dim).to(device)
decoder = SphereGaussianMixture(embedding_dim=embedding_dim, hidden_dim=hidden_dim, num_spheres=num_spheres).to(device)

optimizer = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=1e-4)
# scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=decay_rate)


# 记录损失值
loss_history = []
l1_loss_history = []
w_data =  X.clone().detach().to(device)
X_data = raw_data.clone().detach().to(device)
# 开始训练并使用 tqdm 显示进度
with tqdm(total=num_epochs, desc="Training Progress") as pbar:
    for epoch in range(num_epochs):
        encoder.train()
        decoder.train()
        epoch_loss = 0.0
        kl_loss = 0
        rec_loss = 0
        l2_loss = 0
        file_count = 0
        for idx, ray_data, raw_data in zip(file_idx, all_ray_data, all_raw_data):
            # 对 ray_data 进行处理以适应编码器输入
            ray_data = ray_data.unsqueeze(0).unsqueeze(0)  # 添加 batch 和 channel 维度 (1, 1, sizes[0], sizes[1])
            ray_data = ray_data.to(device)

            target_dist = ray_data.reshape(-1).to(device)

            # 编码器：提取 embedding
            embedding = encoder(ray_data)

            # 解码器：生成权重、轴、kappa 参数
            weights, axes, kappas = decoder(embedding)

            # 计算损失
            loss, loss_dict = kl_divergence_loss(weights, axes, kappas, raw_data, w_data, target_dist,
                                                    kl_lambda=kl_lambda, l1_lambda=l1_lambda, l2_lambda=l2_lambda,
                                                    epoch_step=epoch)

            # 反向传播和优化
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            # scheduler.step()

            # 累计损失
            epoch_loss += loss.item()
            kl_loss += loss_dict['KL']
            rec_loss += loss_dict['Rec']
            l2_loss += loss_dict['L2']
            file_count += 1

        # Calculate and record the average loss for the epoch
        avg_loss = epoch_loss / file_count
        avg_kl_loss = kl_loss / file_count
        avg_rec_loss = rec_loss / file_count
        avg_l2_loss = l2_loss / file_count
        loss_history.append(avg_loss)
        l1_loss_history.append(avg_rec_loss)
        pbar.set_postfix({
            'Loss': f"{avg_loss:.6f}", 
            'KL': f"{avg_kl_loss:.6f}", 
            'Rec': f"{avg_rec_loss:.6f}", 
            'L2': f"{avg_l2_loss:.6f}"
        })
        pbar.update(1)
# 绘制损失率的衰减图
plot_losses(l1_loss_history) 


In [None]:
# Plot the output heatmaps for each graph data
plot_outputs_ed(encoder, decoder, all_ray_data, X, sizes,device)

In [None]:
# Plot the output 3D surfaces for each graph data
plot_outputs_3d_ed(encoder, decoder, all_ray_data, X, sizes,device)