In [16]:
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.utils.class_weight import compute_class_weight
import numpy as np
from collections import Counter

# ---------- 设置 ----------
GRAPH_PATH = "/home/zihend1/Genesis/KNOT/data/gene_graph.pt"
LABEL_TYPE = "druggability_tier"
NUM_EPOCHS = 100
USE_GPU = torch.cuda.is_available()
DEVICE = torch.device("cuda:0" if USE_GPU else "cpu")

# ---------- 模型定义 ----------
class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.bn1 = nn.BatchNorm1d(hidden_channels)
        self.dropout = nn.Dropout(0.5)
        self.conv2 = GCNConv(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = self.bn1(x)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.conv2(x, edge_index)
        return x

# ---------- 加载图 ----------
data = torch.load(GRAPH_PATH)
print(f"✅ Loaded graph with shape: x={data.x.shape}, edge_index={data.edge_index.shape}")

# ---------- 清理特征 ----------
nan_mask = torch.isnan(data.x)
if nan_mask.any():
    print(f"⚠️ NaNs in features: {nan_mask.sum().item()}")
    data.x[nan_mask] = 0

# ---------- 标签编码 ----------
if not isinstance(data.y, torch.Tensor):
    y_np = np.array(data.y)
else:
    y_np = data.y.cpu().numpy()

# 如果是字符串标签
if y_np.dtype.type is np.str_ or y_np.dtype.type is np.object_:
    known_mask = y_np != "-1"
    le = LabelEncoder()
    le.fit(y_np[known_mask])
    y_encoded = np.full_like(y_np, fill_value=-1, dtype=int)
    y_encoded[known_mask] = le.transform(y_np[known_mask])
    data.y = torch.tensor(y_encoded, dtype=torch.long)
    print(f"🔢 Encoded classes: {list(le.classes_)}")
else:
    data.y = torch.tensor(y_np, dtype=torch.long)
    print("✅ Labels already encoded as integers")

num_classes = int(data.y[data.y != -1].max().item()) + 1
print(f"🎯 Number of classes: {num_classes}")

# ---------- 数据划分 ----------
def split_masks(y, val_ratio=0.1, test_ratio=0.2, seed=42):
    np.random.seed(seed)
    idx = np.arange(len(y))
    known = idx[y != -1]
    y_known = y[known].numpy()

    train_idx, temp_idx = train_test_split(known, test_size=val_ratio + test_ratio, stratify=y_known)
    val_idx, test_idx = train_test_split(temp_idx, test_size=test_ratio / (val_ratio + test_ratio), stratify=y[temp_idx].numpy())

    mask = lambda ids: torch.tensor([i in ids for i in range(len(y))])
    return mask(train_idx), mask(val_idx), mask(test_idx)

data.train_mask, data.val_mask, data.test_mask = split_masks(data.y)
print(f"📊 Train: {data.train_mask.sum().item()}, Val: {data.val_mask.sum().item()}, Test: {data.test_mask.sum().item()}")

# ---------- 类别权重 ----------
train_labels = data.y[data.train_mask].cpu().numpy()
class_weights = compute_class_weight(class_weight='balanced', classes=np.unique(train_labels), y=train_labels)
class_weights = torch.tensor(class_weights, dtype=torch.float, device=DEVICE)

# ---------- 模型初始化 ----------
model = GCN(in_channels=data.x.shape[1], hidden_channels=64, out_channels=num_classes).to(DEVICE)
data = data.to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4)

# ---------- 训练 ----------
def train():
    model.train()
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)
    loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask], weight=class_weights)
    loss.backward()
    optimizer.step()
    return loss.item()

@torch.no_grad()
def evaluate(print_details=False):
    model.eval()
    out = model(data.x, data.edge_index)
    pred = out.argmax(dim=1)

    if print_details:
        print(f"📤 Output logits (first 5):\n{out[:5]}")
        print(f"🎯 Predicted classes (first 20): {pred[:20].tolist()}")
        print(f"🧾 Ground truth labels (first 20): {data.y[:20].tolist()}")
        pred_counts = Counter(pred[data.y != -1].tolist())
        print("📊 Class distribution of predictions:")
        for c in range(num_classes):
            print(f"  Class {c}: {pred_counts.get(c, 0)} nodes")

    results = {}
    for split in ['train_mask', 'val_mask', 'test_mask']:
        mask = getattr(data, split)
        correct = (pred[mask] == data.y[mask]).sum().item()
        acc = correct / mask.sum().item()
        results[split] = acc
    return results

# ---------- 主循环 ----------
for epoch in range(1, NUM_EPOCHS + 1):
    loss = train()
    results = evaluate(print_details=(epoch == NUM_EPOCHS))
    if epoch % 10 == 0 or epoch == NUM_EPOCHS:
        print(f"Epoch {epoch:03d} | Loss: {loss:.4f} | Train: {results['train_mask']:.3f} | Val: {results['val_mask']:.3f} | Test: {results['test_mask']:.3f}")


  data = torch.load(GRAPH_PATH)


✅ Loaded graph with shape: x=torch.Size([19032, 504]), edge_index=torch.Size([2, 29514972])
🔢 Encoded classes: ['Tier 1', 'Tier 2', 'Tier 3A', 'Tier 3B']
🎯 Number of classes: 4
📊 Train: 3017, Val: 431, Test: 863
Epoch 010 | Loss: 1.3264 | Train: 0.257 | Val: 0.278 | Test: 0.260
Epoch 020 | Loss: 1.3158 | Train: 0.337 | Val: 0.348 | Test: 0.335
Epoch 030 | Loss: 1.3141 | Train: 0.358 | Val: 0.390 | Test: 0.355
Epoch 040 | Loss: 1.3047 | Train: 0.423 | Val: 0.425 | Test: 0.401
Epoch 050 | Loss: 1.2992 | Train: 0.371 | Val: 0.381 | Test: 0.357
Epoch 060 | Loss: 1.2971 | Train: 0.389 | Val: 0.392 | Test: 0.359
Epoch 070 | Loss: 1.2880 | Train: 0.291 | Val: 0.304 | Test: 0.263
Epoch 080 | Loss: 1.2814 | Train: 0.349 | Val: 0.385 | Test: 0.348
Epoch 090 | Loss: 1.2772 | Train: 0.381 | Val: 0.392 | Test: 0.377
📤 Output logits (first 5):
tensor([[-0.5885, -0.7711, -0.2612, -0.1269],
        [-0.3180, -0.2921, -0.3181, -0.1037],
        [-0.1031, -0.6414, -0.1241, -0.4198],
        [-0.7087, -0