In [None]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np

class ModelNetDataset(Dataset):
    def __init__(self, root_dir, split='train', num_points=1024, single_view=True):

        self.root_dir = root_dir
        self.split = split
        self.num_points = num_points
        self.single_view = single_view
        self.categories = sorted([d for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))])
        self.cat_to_idx = {cat: idx for idx, cat in enumerate(self.categories)}
        self.data_paths = []
        self.labels = []
        
        # 遍历所有类别文件夹
        for cat in self.categories:
            cat_dir = os.path.join(root_dir, cat, split)
            if not os.path.exists(cat_dir):
                continue
            
            # 获取该类别下所有.xyz文件
            files = sorted([f for f in os.listdir(cat_dir) if f.endswith('.xyz')])
            
            if single_view:
                # 只保留每个物体的第一个视角
                unique_models = set('_'.join(f.split('_')[:-1]) for f in files)
                files = [next(f for f in files if f.startswith(model)) for model in unique_models]
            
            cat_idx = self.cat_to_idx[cat]
            self.data_paths.extend([os.path.join(cat_dir, f) for f in files])
            self.labels.extend([cat_idx] * len(files))

    def __len__(self):
        return len(self.data_paths)

    def __getitem__(self, idx):
        point_cloud = np.loadtxt(self.data_paths[idx]).astype(np.float32) 
        label = self.labels[idx]
        point_cloud = torch.from_numpy(point_cloud) 
        label = torch.tensor(label, dtype=torch.long)
        
        return point_cloud, label

def get_data_loaders(root_dir='modelnetdata', batch_size=32, num_workers=None):
    
    if num_workers is None:
        num_workers = 0 if os.name == 'nt' else 4  # Windows上设为0，Linux上设为4
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    pin_memory = device.type == 'cuda'  # 只在使用GPU时启用pin_memory
    
    train_dataset = ModelNetDataset(root_dir, split='train')
    test_dataset = ModelNetDataset(root_dir, split='test')
    
    print(f"训练集大小: {len(train_dataset)}")
    print(f"测试集大小: {len(test_dataset)}")
    print(f"类别数量: {len(train_dataset.categories)}")
    print(f"使用设备: {device}")
    print(f"数据加载进程数: {num_workers}")
    
    # 创建数据加载器
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=pin_memory,
        drop_last=True
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=pin_memory,
        drop_last=False
    )
    
    return train_loader, test_loader, device

if __name__ == "__main__":
    train_loader, test_loader, device = get_data_loaders(batch_size=32)
    for points, labels in train_loader:

        points = points.to(device)
        labels = labels.to(device)
        
        print(f"Batch points shape: {points.shape}")
        print(f"Batch labels shape: {labels.shape}")
        break

In [None]:
def train_epoch(model, train_loader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    for points, labels in train_loader:
        points, labels = points.to(device), labels.to(device)
        batch = torch.arange(points.size(0)).repeat_interleave(points.size(1)).to(device)
        points = points.view(-1, 3)  # [B*N, 3]
        
        optimizer.zero_grad()
        out = model(points, batch)
        loss = criterion(out, labels)
        
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        pred = out.argmax(dim=1)
        correct += pred.eq(labels).sum().item()
        total += labels.size(0)
    
    return total_loss / total, correct / total

def test(model, test_loader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for points, labels in test_loader:
            points, labels = points.to(device), labels.to(device)
            batch = torch.arange(points.size(0)).repeat_interleave(points.size(1)).to(device)
            points = points.view(-1, 3)
            
            out = model(points, batch)
            loss = criterion(out, labels)
            
            total_loss += loss.item()
            pred = out.argmax(dim=1)
            correct += pred.eq(labels).sum().item()
            total += labels.size(0)
    
    return total_loss / total, correct / total

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_scatter import scatter_max, scatter_mean, scatter_add
from torch_geometric.nn import radius_graph
from torch_geometric.utils import softmax
from torch_sparse import SparseTensor
import torch_geometric.nn as gnn


class EfficientLocalModule(nn.Module):
    def __init__(self):
        super().__init__()
        
        # 结构重要性评分网络
        self.importance_net = nn.Sequential(
            nn.Linear(4, 32),  # 3(位置) + 1(加权局部度)
            nn.LayerNorm(32),
            nn.ReLU(inplace=True),
            nn.Linear(32, 1),
            nn.Sigmoid()
        )
    
    def compute_weighted_local_degree(self, edge_index, pos, num_nodes):

        row, col = edge_index
        
        # 计算边的距离权重
        edge_dist = torch.norm(pos[row] - pos[col], dim=1)
        dist_weight = torch.exp(-edge_dist)  # 距离越近权重越大
        
        # 计算基础度
        degree = scatter_add(dist_weight, row, dim=0, dim_size=num_nodes)
        
        # 计算邻居的加权度
        neighbor_degree = scatter_add(
            degree[col] * dist_weight,  # 邻居的度 * 距离权重
            row,
            dim=0,
            dim_size=num_nodes
        )
        
        # 组合得分：结合自身度和邻居度
        weighted_local_degree = degree + 0.5 * neighbor_degree / (degree + 1e-6)
        
        # 归一化
        weighted_local_degree = weighted_local_degree / (weighted_local_degree.max() + 1e-6)
        
        return weighted_local_degree
    
    def forward(self, edge_index, pos, batch):
        num_nodes = pos.size(0)
        
        # 计算加权局部度中心性
        weighted_local_degree = self.compute_weighted_local_degree(edge_index, pos, num_nodes)
        
        # 组合特征
        node_features = torch.cat([
            pos,
            weighted_local_degree.unsqueeze(1)
        ], dim=1)
        
        # 计算结构重要性分数
        importance_scores = self.importance_net(node_features)
        
        return importance_scores

class StructuralEdgeConv(nn.Module):
    def __init__(self, in_channels, out_channels, radius=0.1, max_num_neighbors=16):
        super().__init__()
        self.radius = radius
        self.max_num_neighbors = max_num_neighbors
        
        # 边特征网络
        self.edge_nn = nn.Sequential(
            nn.Linear(in_channels * 2 + 1, out_channels),
            nn.BatchNorm1d(out_channels),
            nn.LeakyReLU(negative_slope=0.2)
        )
        
        # 使用高效的局部结构模块
        self.structural_module = EfficientLocalModule()
        
        # 特征转换
        self.transform = nn.Sequential(
            nn.Linear(out_channels, out_channels),
            nn.BatchNorm1d(out_channels),
            nn.LeakyReLU(negative_slope=0.2)
        )
        
        # 自适应聚合权重
        self.aggregate_weight = nn.Parameter(torch.ones(2))
    
    def forward(self, x, pos, batch):
        # 构建球形邻域图
        edge_index = radius_graph(
            pos,
            r=self.radius,
            batch=batch,
            max_num_neighbors=self.max_num_neighbors
        )
        
        # 计算结构重要性分数
        structural_scores = self.structural_module(edge_index, pos, batch)
        
        row, col = edge_index
        edge_dist = torch.norm(pos[row] - pos[col], dim=-1)
        
        # 边特征学习
        edge_feat = torch.cat([
            x[row],
            x[col] - x[row],
            edge_dist.unsqueeze(-1)
        ], dim=1)
        edge_feat = self.edge_nn(edge_feat)
        
        # 结构重要性加权
        edge_feat = edge_feat * structural_scores[row]
        
        # 自适应聚合
        weights = F.softmax(self.aggregate_weight, dim=0)
        out_max = scatter_max(edge_feat, row, dim=0, dim_size=x.size(0))[0]
        out_mean = scatter_mean(edge_feat, row, dim=0, dim_size=x.size(0))
        out = weights[0] * out_max + weights[1] * out_mean
        
        return self.transform(out)

class StructuralGNN(nn.Module):
    def __init__(self, in_channels=3, num_classes=40):
        super().__init__()
        
        # 特征提取层
        self.feat_extract = nn.Sequential(
            nn.Linear(in_channels, 64),
            nn.BatchNorm1d(64),
            nn.ReLU(inplace=True)
        )
        
        # 多尺度结构感知卷积层
        self.edge_conv1 = StructuralEdgeConv(64, 64, radius=0.1)
        self.edge_conv2 = StructuralEdgeConv(64, 96, radius=0.2)
        self.edge_conv3 = StructuralEdgeConv(96, 128, radius=0.4)
        
        # 残差连接
        self.res1 = nn.Linear(64, 96)
        self.res2 = nn.Linear(96, 128)
        
        # 特征稳定层
        self.stabilize = nn.ModuleList([
            nn.LayerNorm(64),
            nn.LayerNorm(96),
            nn.LayerNorm(128)
        ])
        
        # 分类头
        self.classifier = nn.Sequential(
            nn.Linear(288, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(512, num_classes)
        )
    
    def forward(self, x, batch):
        pos = x.clone()
        
        # 特征提取
        x = self.feat_extract(x)
        
        # 多尺度特征提取
        x1 = self.edge_conv1(x, pos, batch)
        x1 = self.stabilize[0](x1)
        
        x2 = self.edge_conv2(x1, pos, batch)
        x2 = self.stabilize[1](x2 + self.res1(x1))
        
        x3 = self.edge_conv3(x2, pos, batch)
        x3 = self.stabilize[2](x3 + self.res2(x2))
        
        # 特征融合
        x = torch.cat([x1, x2, x3], dim=1)
        
        # 全局池化
        x = scatter_max(x, batch, dim=0)[0]
        
        # 分类
        x = self.classifier(x)
        
        return x

In [None]:
import os
import torch
import torch_cluster
import torch.nn as nn
import torch.optim as optim
from tqdm.notebook import tqdm  
import numpy as np
from datetime import datetime

def train(config):
    
    # 数据增强
    def random_point_dropout(pos, max_dropout_ratio=0.2):
        device = pos.device
        batch_size, num_points, _ = pos.shape

        dropout_ratio = torch.rand(batch_size, device=device) * max_dropout_ratio
        point_masks = torch.rand(batch_size, num_points, device=device)

        for i, ratio in enumerate(dropout_ratio):
            point_masks[i] = (point_masks[i] <= (1 - ratio)).float()

        pos = pos * point_masks.unsqueeze(-1)
        return pos

    def random_scale_point_cloud(pos, scale_low=0.8, scale_high=1.2):
        device = pos.device
        batch_size = pos.shape[0]
        scales = torch.rand(batch_size, 1, 1, device=device) * (scale_high - scale_low) + scale_low
        pos = pos * scales
        return pos
        
    save_dir = os.path.join('runs', datetime.now().strftime('%Y%m%d_%H%M%S'))
    os.makedirs(save_dir, exist_ok=True)
    
    train_loader, test_loader, device = get_data_loaders(
        root_dir=config['data_dir'],
        batch_size=config['batch_size'],
        num_workers=config['num_workers']
    )
    
    # 初始化模型
    model = StructuralGNN(  
        in_channels=3,
        num_classes=40
    ).to(device)
    
    # 打印模型信息
    num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Model Parameters: {num_params/1e6:.2f}M")
    
    # 优化器设置
    optimizer = optim.AdamW(
        model.parameters(),
        lr=config['lr'],
        weight_decay=config['weight_decay'],
        betas=(0.9, 0.999)
    )

    # 学习率调度
    scheduler = optim.lr_scheduler.OneCycleLR(
        optimizer,
        max_lr=config['lr'],
        epochs=config['epochs'],
        steps_per_epoch=len(train_loader),
        pct_start=0.15,        # 预热阶段
        div_factor=10,         # 初始学习率 = max_lr/10
        final_div_factor=1e4,  # 最终学习率 = max_lr/10000
        anneal_strategy='cos'  # 余弦退火
    )
    
    warmup_epochs = 5
    warmup_scheduler = optim.lr_scheduler.LinearLR(
        optimizer, 
        start_factor=0.01,
        end_factor=1.0,
        total_iters=warmup_epochs * len(train_loader)
    )
    
    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
    
    # 训练记录
    history = {
        'train_loss': [], 'train_acc': [],
        'test_loss': [], 'test_acc': [],
        'lr': []
    }
    
    # 早停参数
    best_acc = 0.0
    patience = 25
    no_improve = 0
    
    print(f"\nTraining Configuration:")
    print(f"Epochs: {config['epochs']}")
    print(f"Batch Size: {config['batch_size']}")
    print(f"Learning Rate: {config['lr']}")
    print(f"Weight Decay: {config['weight_decay']}")
    print(f"Device: {device}\n")
    
    for epoch in range(config['epochs']):
        # 训练阶段
        model.train()
        
        if epoch < warmup_epochs:
            current_scheduler = warmup_scheduler
        else:
            current_scheduler = scheduler
            
        train_loss = 0
        correct = 0
        total = 0
        
        for batch_idx, (points, labels) in enumerate(train_loader):
            # 数据准备
            points = random_point_dropout(points)
            points = random_scale_point_cloud(points)
            points = points.to(device)
            labels = labels.to(device)
            batch = torch.arange(points.size(0)).repeat_interleave(points.size(1)).to(device)
            points = points.view(-1, 3)
            
            # 前向传播
            optimizer.zero_grad()
            outputs = model(points, batch)
            loss = criterion(outputs, labels)
            
            # 反向传播
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            current_scheduler.step()
            
            # 统计
            train_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            # 打印进度
            if (batch_idx + 1) % 50 == 0 or (batch_idx + 1) == len(train_loader):
                print(f'Epoch [{epoch+1}/{config["epochs"]}] '
                      f'Batch [{batch_idx+1}/{len(train_loader)}] '
                      f'Loss: {train_loss/(batch_idx+1):.4f} '
                      f'Acc: {100.*correct/total:.2f}% '
                      f'LR: {scheduler.get_last_lr()[0]:.6f}')
        
        # 计算训练指标
        train_loss = train_loss / len(train_loader)
        train_acc = correct / total
        
        # 测试
        test_loss, test_acc = test(model, test_loader, criterion, device)
        
        # 更新历史记录
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['test_loss'].append(test_loss)
        history['test_acc'].append(test_acc)
        history['lr'].append(scheduler.get_last_lr()[0])
        
        print(f'\nEpoch {epoch+1} Summary:')
        print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}')
        print(f'Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}')
        
        # 模型保存
        if test_acc > best_acc:
            best_acc = test_acc
            no_improve = 0
            print(f'Saving best model with accuracy: {best_acc:.4f}')
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_acc': best_acc,
                'config': config,
            }, os.path.join(save_dir, 'best_model.pth'))
        else:
            no_improve += 1
            if no_improve >= patience:
                print(f'\nEarly stopping at epoch {epoch+1} after {patience} epochs without improvement')
                break
    
    print(f'\nTraining completed! Best test accuracy: {best_acc:.4f}')
    return history, best_acc

def set_seed(seed=42):
    torch.manual_seed(seed)
    np.random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

config = {
    'data_dir': 'modelnetdata',
    'batch_size': 32,
    'num_workers': 4,
    'epochs': 300,  # 增加训练轮数
    'lr': 0.001,    # 提高基础学习率
    'weight_decay': 5e-4  # 增加权重衰减
}

set_seed()
history, best_acc = train(config)

In [None]:
# 绘制训练曲线
import matplotlib.pyplot as plt

plt.figure(figsize=(12, 4))

# 损失曲线
plt.subplot(1, 2, 1)
plt.plot(history['train_loss'], label='Train Loss')
plt.plot(history['test_loss'], label='Test Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Testing Loss')
plt.legend()

# 准确率曲线
plt.subplot(1, 2, 2)
plt.plot(history['train_acc'], label='Train Acc')
plt.plot(history['test_acc'], label='Test Acc')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Training and Testing Accuracy')
plt.legend()

plt.tight_layout()
plt.show()

In [None]:
import os

def evaluate_model(model_path, test_loader, device):
    """完整评估模型在测试集上的性能"""
    # 加载最佳模型
    checkpoint = torch.load(model_path)
    model = StructuralGNN(in_channels=3, num_classes=40).to(device)
    model.load_state_dict(checkpoint['model_state_dict'])
    
    # 设置评估模式
    model.eval()
    criterion = nn.CrossEntropyLoss()
    
    # 初始化指标
    total_loss = 0
    correct = 0
    total = 0
    all_preds = []
    all_labels = []
    
    print("开始在完整测试集上评估...")
    with torch.no_grad():
        for points, labels in test_loader:
            points, labels = points.to(device), labels.to(device)
            batch = torch.arange(points.size(0)).repeat_interleave(points.size(1)).to(device)
            points = points.view(-1, 3)
            
            # 前向传播
            outputs = model(points, batch)
            loss = criterion(outputs, labels)
            
            # 统计
            total_loss += loss.item()
            pred = outputs.argmax(dim=1)
            correct += pred.eq(labels).sum().item()
            total += labels.size(0)
            
            # 保存预测结果
            all_preds.extend(pred.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    # 计算整体指标
    avg_loss = total_loss / len(test_loader)
    accuracy = correct / total
    
    print("\n测试集评估结果:")
    print(f"Average Loss: {avg_loss:.4f}")
    print(f"Overall Accuracy: {accuracy:.4f}")
    print(f"Correct Predictions: {correct}/{total}")
    
    # 计算每类的准确率
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)
    
    for class_idx in range(40):
        class_mask = (all_labels == class_idx)
        if np.sum(class_mask) > 0:
            class_acc = np.mean(all_preds[class_mask] == all_labels[class_mask])
            print(f"Class {class_idx} Accuracy: {class_acc:.4f}")
    
    return {
        'loss': avg_loss,
        'accuracy': accuracy,
        'predictions': all_preds,
        'labels': all_labels
    }

# 运行评估
if __name__ == "__main__":
    # 获取最新的模型文件
    runs_dir = 'runs'
    latest_run = max([os.path.join(runs_dir, d) for d in os.listdir(runs_dir)], key=os.path.getmtime)
    model_path = "/home/featurize/work/GNN Pointcloud/runs/20250103_114612/best_model.pth"
    
    # 加载测试数据
    _, test_loader, device = get_data_loaders(
        root_dir='modelnetdata',
        batch_size=32,
        num_workers=4
    )
    
    # 评估模型
    results = evaluate_model(model_path, test_loader, device)
    
    # 绘制混淆矩阵
    from sklearn.metrics import confusion_matrix
    import seaborn as sns
    import matplotlib.pyplot as plt
    
    cm = confusion_matrix(results['labels'], results['predictions'])
    plt.figure(figsize=(15, 15))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
    plt.title('Confusion Matrix')
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.show()

In [None]:
import plotly.graph_objects as go
import numpy as np
import torch

class DynamicGraphVisualizer:
    def __init__(self, model, device):
        self.model = model
        self.device = device
        self.activation = {}
        
        # 获取中间层输出
        def get_activation(name):
            def hook(model, input, output):
                self.activation[name] = output
            return hook
            
        self.model.edge_conv1.structural_module.register_forward_hook(get_activation('layer1_importance'))
        self.model.edge_conv2.structural_module.register_forward_hook(get_activation('layer2_importance'))
        self.model.edge_conv3.structural_module.register_forward_hook(get_activation('layer3_importance'))

    def create_graph_figure(self, points, edge_index, centrality, title):

        points_np = points.cpu().numpy()
        edge_index_np = edge_index.cpu().numpy()
        centrality_np = centrality.cpu().numpy()

        fig = go.Figure()

        # 添加点云
        fig.add_trace(go.Scatter3d(
            x=points_np[:, 0],
            y=points_np[:, 1],
            z=points_np[:, 2],
            mode='markers',
            marker=dict(
                size=2,  
                color=centrality_np,
                colorscale='Viridis',
                opacity=0.7 
            ),
            name='Points'
        ))

        # 随机采样边
        num_edges = edge_index_np.shape[1]
        sample_size = min(800, num_edges)  # 减少边的数量
        sample_idx = np.random.choice(num_edges, sample_size, replace=False)

        start_points = points_np[edge_index_np[0][sample_idx]]
        end_points = points_np[edge_index_np[1][sample_idx]]

        x_lines = []
        y_lines = []
        z_lines = []

        for start, end in zip(start_points, end_points):
            x_lines.extend([start[0], end[0], None])
            y_lines.extend([start[1], end[1], None])
            z_lines.extend([start[2], end[2], None])

        fig.add_trace(go.Scatter3d(
            x=x_lines,
            y=y_lines,
            z=z_lines,
            mode='lines',
            line=dict(
                color='rgba(0,0,255,0.6)', 
                width=4 
            ),
            name='Edges'
        ))

        fig.update_layout(
            title=dict(text=title, x=0.5),
            scene=dict(
                xaxis_title='X',
                yaxis_title='Y',
                zaxis_title='Z',
                aspectmode='data'
            ),
            showlegend=True
        )

        return fig
    
    def visualize_sample(self, points, batch, class_id, sample_idx, save_dir):
        """可视化单个样本的每一层图结构"""
        self.model.eval()
        
        with torch.no_grad():
            pos = points.clone()
            x = self.model.feat_extract(points)
            
            # 第一层
            edge_index1 = radius_graph(
                pos, 
                r=self.model.edge_conv1.radius,
                batch=batch,
                max_num_neighbors=self.model.edge_conv1.max_num_neighbors
            )
            _ = self.model.edge_conv1(x, pos, batch)
            importance1 = self.activation['layer1_importance']
            
            fig1 = self.create_graph_figure(
                points,
                edge_index1,
                importance1,
                f'Class {class_id} - Sample {sample_idx} - Layer 1'
            )
            fig1.write_html(os.path.join(save_dir, f'layer1_graph.html'))
            
            # 第二层
            x1 = self.model.edge_conv1(x, pos, batch)
            edge_index2 = radius_graph(
                pos,
                r=self.model.edge_conv2.radius, 
                batch=batch,
                max_num_neighbors=self.model.edge_conv2.max_num_neighbors
            )
            _ = self.model.edge_conv2(x1, pos, batch)
            importance2 = self.activation['layer2_importance']
            
            fig2 = self.create_graph_figure(
                points,
                edge_index2,
                importance2, 
                f'Class {class_id} - Sample {sample_idx} - Layer 2'
            )
            fig2.write_html(os.path.join(save_dir, f'layer2_graph.html'))
            
            # 第三层 
            x2 = self.model.edge_conv2(x1, pos, batch)
            edge_index3 = radius_graph(
                pos,
                r=self.model.edge_conv3.radius,
                batch=batch,
                max_num_neighbors=self.model.edge_conv3.max_num_neighbors
            )
            _ = self.model.edge_conv3(x2, pos, batch)
            importance3 = self.activation['layer3_importance']
            
            fig3 = self.create_graph_figure(
                points,
                edge_index3,
                importance3,
                f'Class {class_id} - Sample {sample_idx} - Layer 3'
            )
            fig3.write_html(os.path.join(save_dir, f'layer3_graph.html'))
            
def visualize_random_samples(test_loader, model, device, num_samples=3):

    visualizer = DynamicGraphVisualizer(model, device)

    save_dir = "visualization_results"
    os.makedirs(save_dir, exist_ok=True)

    # 获取一个batch的数据
    for batch_data in test_loader:
        points, labels = batch_data
        batch_size = points.size(0)

        # 随机选择样本
        selected_indices = torch.randperm(batch_size)[:num_samples]

        for idx in selected_indices:
            sample_points = points[idx].to(device)  # [N, 3]
            sample_label = labels[idx].item()

            sample_batch = torch.zeros(sample_points.size(0), 
                                    dtype=torch.long, 
                                    device=device)

            sample_dir = os.path.join(save_dir, f'sample_{idx}_class_{sample_label}')
            os.makedirs(sample_dir, exist_ok=True)

            visualizer.visualize_sample(
                sample_points,
                sample_batch,
                sample_label,
                idx,
                sample_dir
            )

        break

    

In [None]:
checkpoint = torch.load(model_path)
model = StructuralGNN(in_channels=3, num_classes=40).to(device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

# 可视化随机样本
visualize_random_samples(test_loader, model, device, num_samples=3)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import random
import json

class PointCloudVisualizer:
    def __init__(self, model, device):
        self.model = model
        self.device = device
        self.activation = {}
        
        # 注册获取中间层输出
        def get_activation(name):
            def hook(model, input, output):
                self.activation[name] = output
            return hook
            
        self.model.edge_conv1.structural_module.register_forward_hook(get_activation('layer1_importance'))
        self.model.edge_conv2.structural_module.register_forward_hook(get_activation('layer2_importance'))
        self.model.edge_conv3.structural_module.register_forward_hook(get_activation('layer3_importance'))

    def visualize_connections(self, points, batch, class_id, sample_idx, save_dir):
        """可视化单个样本中中心性最大点的多层连接"""
        self.model.eval()
        
        with torch.no_grad():
            pos = points.clone()
            x = self.model.feat_extract(points)
            
            # 获取第一层的结构重要性分数
            _ = self.model.edge_conv1(x, pos, batch)
            importance1 = self.activation['layer1_importance'].squeeze()
            
            # 选择重要性分数最大的点作为中心点
            center_idx = importance1.argmax().item()
            
            # 获取三层的连接
            edge_index1 = radius_graph(
                pos, 
                r=self.model.edge_conv1.radius,
                batch=batch,
                max_num_neighbors=self.model.edge_conv1.max_num_neighbors
            )
            
            # 计算第二层
            x1 = self.model.edge_conv1(x, pos, batch)
            edge_index2 = radius_graph(
                pos,
                r=self.model.edge_conv2.radius,
                batch=batch,
                max_num_neighbors=self.model.edge_conv2.max_num_neighbors
            )
            
            # 计算第三层
            x2 = self.model.edge_conv2(x1, pos, batch)
            edge_index3 = radius_graph(
                pos,
                r=self.model.edge_conv3.radius,
                batch=batch,
                max_num_neighbors=self.model.edge_conv3.max_num_neighbors
            )
            
            # 获取与中心点相连的点索引
            neighbors1 = edge_index1[1][edge_index1[0] == center_idx].cpu().numpy()
            neighbors2 = edge_index2[1][edge_index2[0] == center_idx].cpu().numpy()
            neighbors3 = edge_index3[1][edge_index3[0] == center_idx].cpu().numpy()
            
            colors = np.ones((points.size(0), 3)) * 0.8 
            
            # 第三层连接点 - 最浅色
            colors[neighbors3] = np.array([0.2, 0.6, 1.0]) 
            
            # 第二层连接点 - 中等深度
            colors[neighbors2] = np.array([0.1, 0.4, 0.8])  
            
            # 第一层连接点 - 最深色
            colors[neighbors1] = np.array([0.0, 0.2, 0.6])  
            
            # 中心点标红
            colors[center_idx] = np.array([1.0, 0.0, 0.0]) 
            
            # 创建3D散点图
            fig = plt.figure(figsize=(10, 10))
            ax = fig.add_subplot(111, projection='3d')
            
            points_np = points.cpu().numpy()
            ax.scatter(points_np[:, 0], points_np[:, 1], points_np[:, 2], 
                      c=colors, s=20)
            
            ax.set_facecolor('white')
            fig.patch.set_facecolor('white')
            
            importance_score = importance1[center_idx].item()
            plt.title(f'Class {class_id} - Sample {sample_idx}\n' + 
                     f'Center Point Importance: {importance_score:.3f}\n' +
                     f'Connections: Layer1={len(neighbors1)}, ' +
                     f'Layer2={len(neighbors2)}, Layer3={len(neighbors3)}')
            
            plt.savefig(os.path.join(save_dir, 
                       f'connections_class{class_id}_sample{sample_idx}_imp{importance_score:.3f}.png'),
                       bbox_inches='tight', dpi=300)
            plt.close()

            connection_info = {
                'class_id': class_id,
                'sample_idx': sample_idx,
                'center_point_idx': center_idx,
                'importance_score': importance_score,
                'num_connections': {
                    'layer1': len(neighbors1),
                    'layer2': len(neighbors2),
                    'layer3': len(neighbors3)
                }
            }
            
            with open(os.path.join(save_dir, 
                     f'info_class{class_id}_sample{sample_idx}.json'), 'w') as f:
                json.dump(connection_info, f, indent=4)

In [None]:
def visualize_random_samples_from_different_classes(test_loader, model, device, num_samples=3):
    visualizer = PointCloudVisualizer(model, device)
    save_dir = "visualization_results_connections"
    os.makedirs(save_dir, exist_ok=True)
    
    all_data = []
    all_labels = []
    for batch_data in test_loader:
        points, labels = batch_data
        all_data.extend(points)
        all_labels.extend(labels)

    class_data = {}
    for data, label in zip(all_data, all_labels):
        label = label.item()
        if label not in class_data:
            class_data[label] = []
        class_data[label].append(data)
    
    # 随机选择不同的类别
    selected_classes = random.sample(list(class_data.keys()), num_samples)
    
    for i, class_id in enumerate(selected_classes):
        # 从该类别中随机选择一个样本
        sample_data = random.choice(class_data[class_id])
        sample_data = sample_data.to(device)
        
        # 创建batch向量
        sample_batch = torch.zeros(sample_data.size(0), dtype=torch.long, device=device)
        
        visualizer.visualize_connections(
            sample_data,
            sample_batch,
            class_id,
            i,
            save_dir
        )

In [None]:
checkpoint = torch.load(model_path)
model = StructuralGNN(in_channels=3, num_classes=40).to(device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
    
# 可视化来自不同类别的随机样本
visualize_random_samples_from_different_classes(test_loader, model, device, num_samples=3)