In [33]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch_geometric.data import Data
from torch_geometric.nn import GATConv

import numpy as np
import scipy.io as sio
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.model_selection import train_test_split

In [34]:
# 加载数据
data_mat = sio.loadmat('Indian_pines_corrected.mat')
labels_mat = sio.loadmat('Indian_pines_gt.mat')
data = data_mat['indian_pines_corrected']
labels = labels_mat['indian_pines_gt']

# 数据形状
print('Data shape:', data.shape)
print('Labels shape:', labels.shape)

Data shape: (145, 145, 200)
Labels shape: (145, 145)


In [35]:
# 重塑数据
height, width, bands = data.shape
data_reshaped = data.reshape(-1, bands)
labels_reshaped = labels.reshape(-1)

# 数据标准化
scaler = StandardScaler()
data_reshaped = scaler.fit_transform(data_reshaped)

# 去除为0的标签（未标记的像素）
masked_data = data_reshaped[labels_reshaped > 0]
masked_labels = labels_reshaped[labels_reshaped > 0] - 1  # 标签从0开始

# 查看类别数量
num_classes = np.max(masked_labels) + 1
print('Number of classes:', num_classes)

Number of classes: 16


In [36]:
class HSI_Dataset(Dataset):
    def __init__(self, data, labels):
        self.data = torch.from_numpy(data).float()
        self.labels = torch.from_numpy(labels).long()
        self.indices = np.arange(len(self.labels))
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        # 将光谱数据视为“图像”的通道
        sample = self.data[idx]
        sample = sample.unsqueeze(1).unsqueeze(2)  # shape: [bands, 1, 1]
        label = self.labels[idx]
        index = self.indices[idx]  # 获取样本索引
        return sample, label, index

# 划分训练和测试集
train_data, test_data, train_labels, test_labels = train_test_split(
    masked_data, masked_labels, test_size=0.8, random_state=42, stratify=masked_labels)

# 创建数据集
train_dataset = HSI_Dataset(train_data, train_labels)
test_dataset = HSI_Dataset(test_data, test_labels)

# 创建数据加载器
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [37]:
# CNN编码器，用于提取光谱特征
class CNNEncoder(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(CNNEncoder, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(64, out_channels, kernel_size=3, padding=1)
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.relu = nn.ReLU()

    def forward(self, x):
        # x形状: [batch_size, in_channels, height, width]
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.pool(x)
        x = x.view(x.size(0), -1)  # 展平
        return x  # 返回形状: [batch_size, out_channels]

# GAT编码器，用于捕获空间关系
class GATEncoder(nn.Module):
    def __init__(self, in_channels, out_channels, heads=8):
        super(GATEncoder, self).__init__()
        self.gat1 = GATConv(in_channels, 64, heads=heads, concat=True)
        self.gat2 = GATConv(64 * heads, out_channels, heads=1, concat=False)
        self.relu = nn.ReLU()

    def forward(self, x, edge_index):
        # x形状: [num_nodes, in_channels]
        x = self.relu(self.gat1(x, edge_index))
        x = self.relu(self.gat2(x, edge_index))
        return x  # 返回形状: [num_nodes, out_channels]

# 加权特征融合模型
class WeightedFeatureFusion(nn.Module):
    def __init__(self, cnn_in_channels, gat_in_channels, num_classes):
        super(WeightedFeatureFusion, self).__init__()
        self.cnn_encoder = CNNEncoder(cnn_in_channels, 128)
        self.gat_encoder = GATEncoder(gat_in_channels, 128)
        self.cnn_weight = nn.Parameter(torch.tensor(0.5))
        self.gat_weight = nn.Parameter(torch.tensor(0.5))
        self.classifier = nn.Linear(128, num_classes)

    def forward(self, cnn_x, data):
        # cnn_x 形状：[num_nodes, in_channels, 1, 1]
        cnn_feat = self.cnn_encoder(cnn_x)  # 形状：[num_nodes, 128]
        gat_feat = self.gat_encoder(data.x, data.edge_index)  # 形状：[num_nodes, 128]

        # 融合特征
        fused_feat = self.cnn_weight * cnn_feat + self.gat_weight * gat_feat
        out = self.classifier(fused_feat)  # 形状：[num_nodes, num_classes]
        return out

In [38]:
device = torch.device('cpu')

model = WeightedFeatureFusion(cnn_in_channels=bands, gat_in_channels=bands, num_classes=num_classes)
model=model.to(device)

criterion = nn.CrossEntropyLoss(ignore_index=255)
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [39]:
# 创建图数据
from torch_geometric.utils import grid

# 获取图的边索引
edge_index = grid(height, width)[0]

# 所有节点的特征
node_features = torch.from_numpy(data_reshaped).float()  # 形状：[num_nodes, feature_dim]

# 节点标签（包括未标记的节点）
all_labels = torch.from_numpy(labels_reshaped - 1).long()  # 减 1 是让标签从 0 开始

# 创建图数据对象
graph_data = Data(x=node_features.to(device), edge_index=edge_index.to(device))

# 划分训练集和测试集
labeled_indices = torch.where(all_labels >= 0)[0]

train_indices, test_indices = train_test_split(
    labeled_indices.cpu().numpy(),
    test_size=0.2,
    random_state=42,
    stratify=all_labels[labeled_indices].cpu()
)

train_mask = torch.zeros(all_labels.size(0), dtype=torch.bool)
test_mask = torch.zeros(all_labels.size(0), dtype=torch.bool)
train_mask[train_indices] = True
test_mask[test_indices] = True

masked_labels = all_labels  # 包含未标记的节点，值为 -1

In [40]:
num_epochs = 50

# 将节点特征调整为 CNN 输入的形状
cnn_x = node_features.unsqueeze(2).unsqueeze(3).to(device)  # 形状：[num_nodes, in_channels, 1, 1]

for epoch in range(num_epochs):
    model.train()
    optimizer.zero_grad()

    outputs = model(cnn_x, graph_data)  # 输出形状：[num_nodes, num_classes]

    # 只计算训练集上的损失
    loss = criterion(outputs[train_mask], masked_labels[train_mask].to(device))
    loss.backward()
    optimizer.step()

    print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, loss.item()))

    # 验证
    model.eval()
    with torch.no_grad():
        outputs = model(cnn_x, graph_data)
        _, predicted = torch.max(outputs[test_mask], 1)
        total = test_mask.sum().item()
        correct = (predicted.cpu() == masked_labels[test_mask]).sum().item()
        print('Test Accuracy: {:.2f}%'.format(100 * correct / total))

Epoch [1/50], Loss: 2.7828
Test Accuracy: 13.94%
Epoch [2/50], Loss: 2.5378
Test Accuracy: 19.76%
Epoch [3/50], Loss: 2.3375
Test Accuracy: 19.17%
Epoch [4/50], Loss: 2.1514
Test Accuracy: 18.79%
Epoch [5/50], Loss: 1.9868
Test Accuracy: 18.60%
Epoch [6/50], Loss: 1.8577
Test Accuracy: 18.74%
Epoch [7/50], Loss: 1.7739
Test Accuracy: 19.64%
Epoch [8/50], Loss: 1.7236
Test Accuracy: 20.31%
Epoch [9/50], Loss: 1.6786
Test Accuracy: 20.62%
Epoch [10/50], Loss: 1.6302
Test Accuracy: 20.67%
Epoch [11/50], Loss: 1.5841
Test Accuracy: 21.17%
Epoch [12/50], Loss: 1.5458
Test Accuracy: 22.33%
Epoch [13/50], Loss: 1.5181
Test Accuracy: 23.31%
Epoch [14/50], Loss: 1.4977
Test Accuracy: 23.73%
Epoch [15/50], Loss: 1.4696
Test Accuracy: 24.54%
Epoch [16/50], Loss: 1.4302
Test Accuracy: 24.73%
Epoch [17/50], Loss: 1.3910
Test Accuracy: 24.64%
Epoch [18/50], Loss: 1.3592
Test Accuracy: 24.80%
Epoch [19/50], Loss: 1.3372
Test Accuracy: 25.07%
Epoch [20/50], Loss: 1.3205
Test Accuracy: 25.47%
Epoch [21

KeyboardInterrupt: 