In [None]:
# 解压数据文件
!unzip -n RNA_design_public.zip

Archive:  RNA_design_public.zip
   creating: RNA_design_public/
  inflating: __MACOSX/._RNA_design_public  
   creating: RNA_design_public/RNAdesignv1/
  inflating: __MACOSX/RNA_design_public/._RNAdesignv1  
  inflating: RNA_design_public/public.ipynb  
  inflating: __MACOSX/RNA_design_public/._public.ipynb  
  inflating: RNA_design_public/.DS_Store  
  inflating: __MACOSX/RNA_design_public/._.DS_Store  
  inflating: RNA_design_public/RNAdesignv1/.DS_Store  
  inflating: __MACOSX/RNA_design_public/RNAdesignv1/._.DS_Store  
   creating: RNA_design_public/RNAdesignv1/train/
  inflating: __MACOSX/RNA_design_public/RNAdesignv1/._train  
   creating: RNA_design_public/RNAdesignv1/train/seqs/
  inflating: __MACOSX/RNA_design_public/RNAdesignv1/train/._seqs  
  inflating: RNA_design_public/RNAdesignv1/train/.DS_Store  
  inflating: __MACOSX/RNA_design_public/RNAdesignv1/train/._.DS_Store  
   creating: RNA_design_public/RNAdesignv1/train/coords/
  inflating: __MACOSX/RNA_design_public/RNAdesi

In [2]:
# 安装必要的包
!pip install numpy pandas biopython torch_geometric

Looking in indexes: https://mirrors.cloud.aliyuncs.com/pypi/simple
Collecting torch_geometric
  Downloading https://mirrors.cloud.aliyuncs.com/pypi/packages/03/9f/157e913626c1acfb3b19ce000b1a6e4e4fb177c0bc0ea0c67ca5bd714b5a/torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m21.2 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[33mDEPRECATION: pytorch-lightning 1.7.7 has a non-standard dependency specifier torch>=1.9.*. pip 24.0 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of pytorch-lightning or contact the author to suggest that they release a version with a conforming dependency specifiers. Discussion can be found at https://github.com/pypa/pip/issues/12063[0m[33m
[0mInstalling collected packages: torch_geometric
Successfully installed torch_geometric-2.6.1
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[3

In [3]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch_geometric.data import Data, Batch
from torch_geometric.nn import GCNConv, global_mean_pool
import torch_geometric
from Bio import SeqIO

In [None]:
# 训练模型
# 配置参数
class Config:
    seed = 42
    device = "cuda" if torch.cuda.is_available() else "cpu"
    batch_size = 32
    lr = 0.01
    epochs = 20
    seq_vocab = "AUCG"
    coord_dims = 7  # 7个骨架点
    hidden_dim = 128
    k_neighbors = 5  # 每个节点的近邻数

# 图结构数据生成器
class RNAGraphBuilder:
    @staticmethod
    def build_graph(coord, seq):
        """将坐标和序列转换为图结构"""
        num_nodes = coord.shape[0]
        
        # 节点特征：展平每个节点的7个骨架点坐标
        x = torch.tensor(coord.reshape(num_nodes, -1), dtype=torch.float32)  # [N, 7*3]
        
        # 边构建：基于序列顺序的k近邻连接
        edge_index = []
        for i in range(num_nodes):
            # 连接前k和后k个节点
            neighbors = list(range(max(0, i-Config.k_neighbors), i)) + \
                       list(range(i+1, min(num_nodes, i+1+Config.k_neighbors)))
            for j in neighbors:
                edge_index.append([i, j])
                edge_index.append([j, i])  # 双向连接
        
        edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
        
        # 节点标签
        y = torch.tensor([Config.seq_vocab.index(c) for c in seq], dtype=torch.long)
        
        return Data(x=x, edge_index=edge_index, y=y, num_nodes=num_nodes)

# 数据集类
class RNADataset(torch.utils.data.Dataset):
    def __init__(self, coords_dir, seqs_dir):
        self.samples = []
        
        # 读取所有数据并转换为图
        for fname in os.listdir(coords_dir):
            # 加载坐标数据
            coord = np.load(os.path.join(coords_dir, fname))  # [L, 7, 3]
            coord = np.nan_to_num(coord, nan=0.0)  # 新增行：将NaN替换为0
            # 加载对应序列
            seq_id = os.path.splitext(fname)[0]
            seq = next(SeqIO.parse(os.path.join(seqs_dir, f"{seq_id}.fasta"), "fasta")).seq
            
            # 转换为图结构
            graph = RNAGraphBuilder.build_graph(coord, str(seq))
            self.samples.append(graph)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]

# 简单GNN模型
class GNNModel(nn.Module):
    def __init__(self):
        super().__init__()
        # 特征编码
        self.encoder = nn.Sequential(
            nn.Linear(7*3, Config.hidden_dim),
            nn.ReLU()
        )
        
        # GNN层
        self.conv1 = GCNConv(Config.hidden_dim, Config.hidden_dim)
        self.conv2 = GCNConv(Config.hidden_dim, Config.hidden_dim)
        
        # 分类头
        self.cls_head = nn.Sequential(
            nn.Linear(Config.hidden_dim, len(Config.seq_vocab))
        )
        
    def forward(self, data):
        # 节点特征编码
        x = self.encoder(data.x)  # [N, hidden]
        
        # 图卷积
        x = self.conv1(x, data.edge_index)
        x = torch.relu(x)
        x = self.conv2(x, data.edge_index)
        x = torch.relu(x)
        
        # 节点分类
        logits = self.cls_head(x)  # [N, 4]
        return logits

# 训练函数
def train(model, loader, optimizer, criterion):
    model.train()
    total_loss = 0
    for batch in loader:
        batch = batch.to(Config.device)
        optimizer.zero_grad()
        
        # 前向传播
        logits = model(batch)
        
        # 计算损失
        loss = criterion(logits, batch.y)
        
        # 反向传播
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    return total_loss / len(loader)

# 评估函数
def evaluate(model, loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch in loader:
            batch = batch.to(Config.device)
            logits = model(batch)
            preds = logits.argmax(dim=1)
            correct += (preds == batch.y).sum().item()
            total += batch.y.size(0)
    return correct / total

# 主流程
if __name__ == "__main__":
    # 设置随机种子
    torch.manual_seed(Config.seed)
    
    # 加载数据集
    full_dataset = RNADataset("./RNA_design_public/RNAdesignv1/train/coords", "./RNA_design_public/RNAdesignv1/train/seqs")
    
    # 划分数据集
    train_size = int(0.8 * len(full_dataset))
    val_size = (len(full_dataset) - train_size) // 2
    test_size = len(full_dataset) - train_size - val_size
    train_set, val_set, test_set = torch.utils.data.random_split(
        full_dataset, [train_size, val_size, test_size])
    
    # 创建DataLoader
    train_loader = torch_geometric.loader.DataLoader(
        train_set, batch_size=Config.batch_size, shuffle=True)
    val_loader = torch_geometric.loader.DataLoader(val_set, batch_size=Config.batch_size)
    test_loader = torch_geometric.loader.DataLoader(test_set, batch_size=Config.batch_size)
    
    # 初始化模型
    model = GNNModel().to(Config.device)
    optimizer = optim.Adam(model.parameters(), lr=Config.lr)
    criterion = nn.CrossEntropyLoss()
    
    # 训练循环
    best_acc = 0
    for epoch in range(Config.epochs):
        train_loss = train(model, train_loader, optimizer, criterion)
        val_acc = evaluate(model, val_loader)
        
        print(f"Epoch {epoch+1}/{Config.epochs}")
        print(f"Train Loss: {train_loss:.4f} | Val Acc: {val_acc:.4f}")
        
        # 保存最佳模型
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(model.state_dict(), "best_gnn_model.pth")
    
    # 最终测试
    model.load_state_dict(torch.load("best_gnn_model.pth",weights_only=True))
    test_acc = evaluate(model, test_loader)
    print(f"\nTest Accuracy: {test_acc:.4f}")