In [41]:
import torch
from torch_geometric.data import Dataset, Data
import numpy as np
import os
from Bio import PDB
import pandas as pd
import warnings
warnings.filterwarnings('ignore')

In [42]:
class SCOPDataset(Dataset):
    def __init__(self, root, transform=None, pre_transform=None, pre_filter=None):
        # Set up the root path first
        self.root = root

        # Now we can set up the class_info_path
        self.class_info_path = os.path.join(root, 'raw/class_info.csv')

        # Dictionary to map SCOP classes to indices
        self.class_mapping = {
            'a': 0,  # All-alpha
            'b': 1,  # All-beta
            'c': 2,  # Alpha/beta
            'd': 3,  # Alpha+beta
            'e': 4,  # Multi-domain
            'f': 5,  # Membrane
            'g': 6   # Small proteins
        }

        # Dictionary to map amino acids to indices
        self.amino_acids = {
            'ALA': 0, 'ARG': 1, 'ASN': 2, 'ASP': 3, 'CYS': 4,
            'GLN': 5, 'GLU': 6, 'GLY': 7, 'HIS': 8, 'ILE': 9,
            'LEU': 10, 'LYS': 11, 'MET': 12, 'PHE': 13, 'PRO': 14,
            'SER': 15, 'THR': 16, 'TRP': 17, 'TYR': 18, 'VAL': 19,
            'UNK': 20  # Unknown amino acid
        }

        self.pre_filter = pre_filter

        # Load class information
        if os.path.exists(self.class_info_path):
            self.class_info = pd.read_csv(self.class_info_path)
            print(f"Found class info file with {len(self.class_info)} entries")
        else:
            print(f"Warning: class_info.csv not found at {self.class_info_path}")
            self.class_info = None

        # Initialize the base class last
        super().__init__(root, transform, pre_transform)

    def _get_residue_features(self, residue):
        """
        Create feature vector for a residue.

        Args:
            residue: Biopython residue object

        Returns:
            numpy.ndarray: Feature vector containing one-hot encoded amino acid type
                          and 3D coordinates
        """
        # One-hot encode amino acid type
        aa_features = np.zeros(21)  # 20 standard amino acids + UNK
        aa_name = residue.get_resname()
        if aa_name in self.amino_acids:
            aa_features[self.amino_acids[aa_name]] = 1
        else:
            aa_features[self.amino_acids['UNK']] = 1

        # Get CA atom coordinates
        try:
            ca_atom = residue['CA']
            coords = ca_atom.get_coord()
        except:
            coords = np.zeros(3)

        # Combine features
        features = np.concatenate([
            aa_features,  # Amino acid identity (21)
            coords,      # 3D coordinates (3)
        ])

        return features

    def _create_edges(self, residues, threshold=5.0):
        """
        Create edges between residues based on spatial proximity.

        Args:
            residues: List of Biopython residue objects
            threshold (float): Distance threshold for creating edges

        Returns:
            list: List of edge pairs
        """
        edges = []
        residue_list = list(residues)

        for i in range(len(residue_list)):
            for j in range(i+1, len(residue_list)):
                res_i = residue_list[i]
                res_j = residue_list[j]

                # Check if both residues have CA atoms
                if 'CA' in res_i and 'CA' in res_j:
                    # Calculate distance between CA atoms
                    ca_i = res_i['CA'].get_coord()
                    ca_j = res_j['CA'].get_coord()
                    distance = np.linalg.norm(ca_i - ca_j)

                    # Add edge if within threshold
                    if distance < threshold:
                        edges.append([i, j])
                        edges.append([j, i])  # Add both directions

        return edges

    @property
    def raw_file_names(self):
        """List of raw file names in the dataset."""
        if not os.path.exists(os.path.join(self.root, 'raw')):
            return []
        return [f for f in os.listdir(os.path.join(self.root, 'raw'))
                if f.endswith('.pdb')]

    @property
    def processed_file_names(self):
        """List of processed file names in the dataset."""
        return ['data.pt']

    def download(self):
        """Download the dataset."""
        pass  # We already have the files

    def process(self):
        """Process the raw data into the internal format."""
        # Check if raw directory exists and contains PDB files
        if not os.path.exists(os.path.join(self.root, 'raw')):
            raise RuntimeError(f"Raw directory not found at {os.path.join(self.root, 'raw')}")

        raw_files = self.raw_file_names
        print(f"Found {len(raw_files)} PDB files in raw directory")

        if len(raw_files) == 0:
            raise RuntimeError("No PDB files found in raw directory")

        if not os.path.exists(self.class_info_path):
            raise RuntimeError(f"class_info.csv not found at {self.class_info_path}")

        # Create processed directory if it doesn't exist
        os.makedirs(self.processed_dir, exist_ok=True)

        # Load class information
        self.class_info = pd.read_csv(self.class_info_path)
        print(f"Processing {len(self.class_info)} entries from class_info.csv")

        data_list = []
        parser = PDB.PDBParser(QUIET=True)

        # Process each PDB file
        for idx, row in self.class_info.iterrows():
            pdb_id = str(row['scop_id'])
            class_label = self.class_mapping[row['class']]

            try:
                # Load structure
                pdb_file = f"{pdb_id}.pdb"
                pdb_path = os.path.join(self.root, 'raw', pdb_file)

                if not os.path.exists(pdb_path):
                    print(f"Warning: PDB file not found: {pdb_path}")
                    continue

                structure = parser.get_structure('protein', pdb_path)
                model = structure[0]  # Get first model

                # Get residues and create features
                residues = [res for res in model.get_residues()
                            if res.get_resname() in self.amino_acids]

                if len(residues) == 0:
                    print(f"Warning: No valid residues found in {pdb_file}")
                    continue

                # Create node features
                node_features = []
                for residue in residues:
                    features = self._get_residue_features(residue)
                    node_features.append(features)

                # Create edges
                edges = self._create_edges(residues)

                if len(edges) == 0:
                    print(f"Warning: No edges created for {pdb_file}")
                    continue

                # Create PyG Data object
                edge_index = torch.tensor(edges, dtype=torch.long).t()
                x = torch.tensor(node_features, dtype=torch.float)
                y = torch.tensor([class_label], dtype=torch.long)

                data = Data(
                    x=x,
                    edge_index=edge_index,
                    y=y,
                    num_nodes=len(residues)
                )

                if self.pre_filter is not None and not self.pre_filter(data):
                    print(f"Warning: {pdb_file} filtered out by pre_filter")
                    continue

                if self.pre_transform is not None:
                    data = self.pre_transform(data)

                data_list.append(data)

                if idx % 10 == 0:
                    print(f"Processed {idx+1}/{len(self.class_info)} proteins")

            except Exception as e:
                print(f"Error processing {pdb_id}: {str(e)}")
                continue

        if len(data_list) == 0:
            raise RuntimeError("No data was successfully processed!")

        print(f"Successfully processed {len(data_list)} proteins")

        # Save processed data
        torch.save(data_list, os.path.join(self.processed_dir, 'data.pt'))

    def get(self, idx):
        """Get a specific graph from the dataset."""
        if not hasattr(self, '_data_list'):
            processed_path = os.path.join(self.processed_dir, 'data.pt')
            if not os.path.exists(processed_path):
                print("Warning: Processed data file not found, running processing...")
                self.process()
            # Add weights_only=False to allow loading PyG Data objects
            self._data_list = torch.load(processed_path, weights_only=False)
        return self._data_list[idx]

    def len(self):
        """Return the number of graphs in the dataset."""
        processed_path = os.path.join(self.processed_dir, 'data.pt')
        if not os.path.exists(processed_path):
            print("Warning: Processed data file not found, running processing...")
            self.process()
        if not hasattr(self, '_data_list'):
            # Add weights_only=False here as well
            self._data_list = torch.load(processed_path, weights_only=False)
        return len(self._data_list)

    @property
    def num_classes(self):
        """Return the number of classes in the dataset."""
        return len(self.class_mapping)

    @property
    def num_features(self):
        """Return the number of node features."""
        return 24  # 21 for amino acids + 3 for coordinates

In [43]:
import os.path as osp
import time
from math import ceil

import torch
import torch.nn.functional as F

import torch_geometric.transforms as T
from torch_geometric.loader import DenseDataLoader
from torch_geometric.nn import DenseSAGEConv, dense_diff_pool

NUM_CLASSES = 7  # SCOP main classes
# First, let's check protein sizes
# First, analyze protein sizes
processed_path = os.path.join(data_dir, 'processed/data.pt')
data_list = torch.load(processed_path, weights_only=False)
sizes = [data.num_nodes for data in data_list]

print(f"Protein size statistics:")
print(f"Min size: {min(sizes)}")
print(f"Max size: {max(sizes)}")
print(f"Mean size: {sum(sizes)/len(sizes):.1f}")
print(f"Median size: {sorted(sizes)[len(sizes)//2]}")
print(f"Number of proteins > 150 residues: {sum(1 for s in sizes if s > 150)}")

# Set max_nodes to 500 since we have proteins up to 1381 residues
max_nodes = 1400

# Create dataset with proper pre-filtering
dataset = SCOPDataset(
    root=data_dir,
    transform=T.ToDense(max_nodes),
    pre_filter=lambda data: data.num_nodes <= max_nodes
)

print(f"\nDataset size: {len(dataset)}")
print(f"Number of features: {dataset.num_features}")
print(f"Number of classes: {dataset.num_classes}")

# Try accessing a sample
sample = dataset[0]
print(f"\nSample information:")
print(f"Node features shape: {sample.x.shape}")
print(f"Adjacency matrix shape: {sample.adj.shape}")
print(f"Mask shape: {sample.mask.shape}")
print(f"Label: {sample.y}")

Protein size statistics:
Min size: 20
Max size: 1395
Mean size: 202.8
Median size: 165
Number of proteins > 150 residues: 1859
Found class info file with 3500 entries

Dataset size: 3420
Number of features: 24
Number of classes: 7

Sample information:
Node features shape: torch.Size([1400, 24])
Adjacency matrix shape: torch.Size([1400, 1400])
Mask shape: torch.Size([1400])
Label: tensor([0])


In [44]:
# Prepare data loaders
dataset = dataset.shuffle()
n = (len(dataset) + 9) // 10
test_dataset = dataset[:n]
val_dataset = dataset[n:2 * n]
train_dataset = dataset[2 * n:]

test_loader = DenseDataLoader(test_dataset, batch_size=20)
val_loader = DenseDataLoader(val_dataset, batch_size=20)
train_loader = DenseDataLoader(train_dataset, batch_size=20)

print("\nDataset splits:")
print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"Test samples: {len(test_dataset)}")


Dataset splits:
Training samples: 2736
Validation samples: 342
Test samples: 342


In [45]:
class GNN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels,
                 normalize=False, lin=True):
        super().__init__()

        self.conv1 = DenseSAGEConv(in_channels, hidden_channels, normalize)
        self.bn1 = torch.nn.BatchNorm1d(hidden_channels)
        self.conv2 = DenseSAGEConv(hidden_channels, hidden_channels, normalize)
        self.bn2 = torch.nn.BatchNorm1d(hidden_channels)
        self.conv3 = DenseSAGEConv(hidden_channels, out_channels, normalize)
        self.bn3 = torch.nn.BatchNorm1d(out_channels)

        if lin is True:
            self.lin = torch.nn.Linear(2 * hidden_channels + out_channels,
                                       out_channels)
        else:
            self.lin = None

    def bn(self, i, x):
        batch_size, num_nodes, num_channels = x.size()

        x = x.view(-1, num_channels)
        x = getattr(self, f'bn{i}')(x)
        x = x.view(batch_size, num_nodes, num_channels)
        return x

    def forward(self, x, adj, mask=None):
        batch_size, num_nodes, in_channels = x.size()

        x0 = x
        x1 = self.bn(1, self.conv1(x0, adj, mask).relu())
        x2 = self.bn(2, self.conv2(x1, adj, mask).relu())
        x3 = self.bn(3, self.conv3(x2, adj, mask).relu())

        x = torch.cat([x1, x2, x3], dim=-1)

        if self.lin is not None:
            x = self.lin(x).relu()

        return x


class Net(torch.nn.Module):
    def __init__(self):
        super().__init__()

        num_nodes = ceil(0.25 * max_nodes)
        self.gnn1_pool = GNN(dataset.num_features, 64, num_nodes)
        self.gnn1_embed = GNN(dataset.num_features, 64, 64, lin=False)

        num_nodes = ceil(0.25 * num_nodes)
        self.gnn2_pool = GNN(3 * 64, 64, num_nodes)
        self.gnn2_embed = GNN(3 * 64, 64, 64, lin=False)

        self.gnn3_embed = GNN(3 * 64, 64, 64, lin=False)

        self.lin1 = torch.nn.Linear(3 * 64, 64)
        self.lin2 = torch.nn.Linear(64, NUM_CLASSES)

    def forward(self, x, adj, mask=None):
        s = self.gnn1_pool(x, adj, mask)
        x = self.gnn1_embed(x, adj, mask)

        x, adj, l1, e1 = dense_diff_pool(x, adj, s, mask)

        s = self.gnn2_pool(x, adj)
        x = self.gnn2_embed(x, adj)

        x, adj, l2, e2 = dense_diff_pool(x, adj, s)

        x = self.gnn3_embed(x, adj)

        x = x.mean(dim=1)
        x = self.lin1(x).relu()
        x = self.lin2(x)
        return F.log_softmax(x, dim=-1), l1 + l2, e1 + e2


if torch.cuda.is_available():
    device = torch.device('cuda')
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cpu')

model = Net().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)


def train(epoch):
    model.train()
    loss_all = 0

    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        output, _, _ = model(data.x, data.adj, data.mask)
        loss = F.nll_loss(output, data.y.view(-1))
        loss.backward()
        loss_all += data.y.size(0) * float(loss)
        optimizer.step()
    return loss_all / len(train_dataset)


@torch.no_grad()
def test(loader):
    model.eval()
    correct = 0

    for data in loader:
        data = data.to(device)
        pred = model(data.x, data.adj, data.mask)[0].max(dim=1)[1]
        correct += int(pred.eq(data.y.view(-1)).sum())
    return correct / len(loader.dataset)


best_val_acc = test_acc = 0
times = []
for epoch in range(1, 151):
    start = time.time()
    train_loss = train(epoch)
    val_acc = test(val_loader)
    if val_acc > best_val_acc:
        test_acc = test(test_loader)
        best_val_acc = val_acc
    print(f'Epoch: {epoch:03d}, Train Loss: {train_loss:.4f}, '
          f'Val Acc: {val_acc:.4f}, Test Acc: {test_acc:.4f}')
    times.append(time.time() - start)
print(f"Median time per epoch: {torch.tensor(times).median():.4f}s")


Epoch: 001, Train Loss: 1.8124, Val Acc: 0.3012, Test Acc: 0.2749
Epoch: 002, Train Loss: 1.7753, Val Acc: 0.3129, Test Acc: 0.3070
Epoch: 003, Train Loss: 1.7316, Val Acc: 0.3246, Test Acc: 0.3187
Epoch: 004, Train Loss: 1.7122, Val Acc: 0.3158, Test Acc: 0.3187
Epoch: 005, Train Loss: 1.6721, Val Acc: 0.3947, Test Acc: 0.3246
Epoch: 006, Train Loss: 1.6846, Val Acc: 0.4211, Test Acc: 0.3684
Epoch: 007, Train Loss: 1.6309, Val Acc: 0.3801, Test Acc: 0.3684
Epoch: 008, Train Loss: 1.6190, Val Acc: 0.4094, Test Acc: 0.3684
Epoch: 009, Train Loss: 1.5975, Val Acc: 0.3947, Test Acc: 0.3684
Epoch: 010, Train Loss: 1.5897, Val Acc: 0.4064, Test Acc: 0.3684
Epoch: 011, Train Loss: 1.5720, Val Acc: 0.3918, Test Acc: 0.3684
Epoch: 012, Train Loss: 1.5604, Val Acc: 0.4181, Test Acc: 0.3684
Epoch: 013, Train Loss: 1.5359, Val Acc: 0.4006, Test Acc: 0.3684
Epoch: 014, Train Loss: 1.5303, Val Acc: 0.3567, Test Acc: 0.3684
Epoch: 015, Train Loss: 1.5425, Val Acc: 0.4386, Test Acc: 0.3246
Epoch: 016

In [40]:
import torch
import numpy as np
from Bio import PDB

def analyze_protein_connectivity(pdb_path, cutoff=5.0):  # Changed to 5Å
    # Parse structure
    parser = PDB.PDBParser(QUIET=True)
    structure = parser.get_structure('protein', pdb_path)
    model = structure[0]

    # Get residues
    residues = list(model.get_residues())
    n = len(residues)

    # Count edges based on 5Å threshold
    edges = 0
    ca_ca_edges = 0
    for i in range(n):
        for j in range(i+1, n):
            if 'CA' in residues[i] and 'CA' in residues[j]:
                ca_i = residues[i]['CA'].get_coord()
                ca_j = residues[j]['CA'].get_coord()
                dist = np.linalg.norm(ca_i - ca_j)
                if dist < cutoff:
                    edges += 1
                    ca_ca_edges += 1

    # Calculate density
    max_possible_edges = (n * (n-1)) / 2
    density = edges / max_possible_edges if max_possible_edges > 0 else 0

    return {
        'num_residues': n,
        'num_edges': edges,
        'density': density,
        'ca_ca_edges': ca_ca_edges
    }

# Analyze with new cutoff
for pdb_file in pdb_files:
    pdb_path = os.path.join('data/SCOP/raw', pdb_file)
    stats = analyze_protein_connectivity(pdb_path, cutoff=5.0)
    print(f"\nProtein: {pdb_file}")
    print(f"Number of residues: {stats['num_residues']}")
    print(f"Number of edges: {stats['num_edges']}")
    print(f"Edge density: {stats['density']:.3f}")
    print(f"CA-CA edges: {stats['ca_ca_edges']}")




Protein: 400870.pdb
Number of residues: 128
Number of edges: 166
Edge density: 0.020
CA-CA edges: 166

Protein: 44527.pdb
Number of residues: 56
Number of edges: 64
Edge density: 0.042
CA-CA edges: 64

Protein: 181196.pdb
Number of residues: 106
Number of edges: 119
Edge density: 0.021
CA-CA edges: 119

Protein: 243737.pdb
Number of residues: 87
Number of edges: 101
Edge density: 0.027
CA-CA edges: 101

Protein: 271949.pdb
Number of residues: 405
Number of edges: 482
Edge density: 0.006
CA-CA edges: 482
