In [7]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv  # Changed from DenseSAGEConv
from torch_geometric.nn import global_mean_pool  # For sparse pooling
from torch_geometric.data import Dataset, Data
import numpy as np
import os
from Bio import PDB
import warnings
import torch
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv
from torch_geometric.nn import global_mean_pool
from torch_geometric.data import Dataset, Data
import numpy as np
import os
from Bio import PDB
import pandas as pd
import warnings
warnings.filterwarnings('ignore')
warnings.filterwarnings('ignore')

# Define data directory
data_dir = 'data/SCOP'  # Base directory for SCOP data

# Rest of the model code remains the same...

In [8]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv
from torch_geometric.nn import global_mean_pool
from torch_geometric.data import Dataset, Data
import numpy as np
import os
from Bio import PDB
import pandas as pd
import warnings
warnings.filterwarnings('ignore')

class SparseSCOPDataset(Dataset):
    def __init__(self, root, transform=None, pre_transform=None, pre_filter=None):
        # Set up the root path first
        self.root = root

        # Now we can set up the class_info_path
        self.class_info_path = os.path.join(root, 'raw/class_info.csv')

        # Dictionary to map SCOP classes to indices
        self.class_mapping = {
            'a': 0,  # All-alpha
            'b': 1,  # All-beta
            'c': 2,  # Alpha/beta
            'd': 3,  # Alpha+beta
            'e': 4,  # Multi-domain
            'f': 5,  # Membrane
            'g': 6   # Small proteins
        }

        # Dictionary to map amino acids to indices
        self.amino_acids = {
            'ALA': 0, 'ARG': 1, 'ASN': 2, 'ASP': 3, 'CYS': 4,
            'GLN': 5, 'GLU': 6, 'GLY': 7, 'HIS': 8, 'ILE': 9,
            'LEU': 10, 'LYS': 11, 'MET': 12, 'PHE': 13, 'PRO': 14,
            'SER': 15, 'THR': 16, 'TRP': 17, 'TYR': 18, 'VAL': 19,
            'UNK': 20  # Unknown amino acid
        }

        # Load class information
        if os.path.exists(self.class_info_path):
            self.class_info = pd.read_csv(self.class_info_path)
            print(f"Found class info file with {len(self.class_info)} entries")
        else:
            print(f"Warning: class_info.csv not found at {self.class_info_path}")
            self.class_info = None

        # Initialize the base class last
        super().__init__(root, transform, pre_transform, pre_filter)

    @property  # This decorator was missing!
    def raw_file_names(self):
        if not os.path.exists(os.path.join(self.root, 'raw')):
            return []
        return [f for f in os.listdir(os.path.join(self.root, 'raw'))
                if f.endswith('.pdb')]

    @property  # This decorator was missing!
    def processed_file_names(self):
        return ['data.pt']


    @property
    def raw_file_names(self):
        """List of raw file names in the dataset."""
        if not os.path.exists(os.path.join(self.root, 'raw')):
            return []
        return [f for f in os.listdir(os.path.join(self.root, 'raw'))
                if f.endswith('.pdb')]

    @property
    def processed_file_names(self):
        """List of processed file names in the dataset."""
        return ['data.pt']

    def download(self):
        """Download the dataset."""
        pass  # We already have the files

    def process(self):
        """Process the raw data into the internal format."""
        data_list = []
        parser = PDB.PDBParser(QUIET=True)

        for idx, row in self.class_info.iterrows():
            pdb_id = str(row['scop_id'])
            class_label = self.class_mapping[row['class']]

            try:
                # Load structure
                pdb_file = f"{pdb_id}.pdb"
                pdb_path = os.path.join(self.root, 'raw', pdb_file)
                structure = parser.get_structure('protein', pdb_path)
                model = structure[0]

                # Get residues and create features
                residues = list(model.get_residues())

                # Create node features
                node_features = []
                for residue in residues:
                    features = self._get_residue_features(residue)
                    node_features.append(features)

                # Create edges with 5Å cutoff
                edges = []
                for i in range(len(residues)):
                    for j in range(i+1, len(residues)):
                        if 'CA' in residues[i] and 'CA' in residues[j]:
                            ca_i = residues[i]['CA'].get_coord()
                            ca_j = residues[j]['CA'].get_coord()
                            dist = np.linalg.norm(ca_i - ca_j)
                            if dist < 5.0:  # 5Å cutoff
                                edges.append([i, j])
                                edges.append([j, i])  # Add both directions

                if len(edges) == 0:
                    continue

                # Create PyG Data object
                edge_index = torch.tensor(edges, dtype=torch.long).t()
                x = torch.tensor(node_features, dtype=torch.float)
                y = torch.tensor([class_label], dtype=torch.long)

                data = Data(
                    x=x,
                    edge_index=edge_index,
                    y=y,
                    num_nodes=len(residues)
                )

                if self.pre_filter is not None and not self.pre_filter(data):
                    continue

                if self.pre_transform is not None:
                    data = self.pre_transform(data)

                data_list.append(data)

            except Exception as e:
                print(f"Error processing {pdb_id}: {str(e)}")
                continue

        if len(data_list) == 0:
            raise RuntimeError("No data was successfully processed!")

        torch.save(data_list, os.path.join(self.processed_dir, 'data.pt'))

    def get(self, idx):
        """Get a specific graph from the dataset."""
        if not hasattr(self, '_data_list'):
            processed_path = os.path.join(self.processed_dir, 'data.pt')
            if not os.path.exists(processed_path):
                print("Warning: Processed data file not found, running processing...")
                self.process()
            # Add weights_only=False to allow loading PyG Data objects
            self._data_list = torch.load(processed_path, weights_only=False)
        return self._data_list[idx]

    def len(self):
        """Return the number of graphs in the dataset."""
        processed_path = os.path.join(self.processed_dir, 'data.pt')
        if not os.path.exists(processed_path):
            print("Warning: Processed data file not found, running processing...")
            self.process()
        if not hasattr(self, '_data_list'):
            # Add weights_only=False here as well
            self._data_list = torch.load(processed_path, weights_only=False)
        return len(self._data_list)

    def _get_residue_features(self, residue):
        """Create feature vector for a residue."""
        # One-hot encode amino acid type
        aa_features = np.zeros(21)  # 20 standard amino acids + UNK
        aa_name = residue.get_resname()
        if aa_name in self.amino_acids:
            aa_features[self.amino_acids[aa_name]] = 1
        else:
            aa_features[self.amino_acids['UNK']] = 1

        # Get CA atom coordinates
        try:
            ca_atom = residue['CA']
            coords = ca_atom.get_coord()
        except:
            coords = np.zeros(3)

        # Combine features
        features = np.concatenate([
            aa_features,  # Amino acid identity (21)
            coords,      # 3D coordinates (3)
        ])

        return features

    @property
    def num_classes(self):
        """Return the number of classes in the dataset."""
        return len(self.class_mapping)

    @property
    def num_features(self):
        """Return the number of node features."""
        return 24  # 21 for amino acids + 3 for coordinates

In [9]:
import os.path as osp
import time
from math import ceil

import torch
import torch.nn.functional as F

import torch_geometric.transforms as T
from torch_geometric.loader import DenseDataLoader

NUM_CLASSES = 7  # SCOP main classes
# First, let's check protein sizes
# First, analyze protein sizes
processed_path = os.path.join(data_dir, 'processed/data.pt')
data_list = torch.load(processed_path, weights_only=False)
sizes = [data.num_nodes for data in data_list]

print(f"Protein size statistics:")
print(f"Min size: {min(sizes)}")
print(f"Max size: {max(sizes)}")
print(f"Mean size: {sum(sizes)/len(sizes):.1f}")
print(f"Median size: {sorted(sizes)[len(sizes)//2]}")
print(f"Number of proteins > 150 residues: {sum(1 for s in sizes if s > 150)}")

# Set max_nodes to 500 since we have proteins up to 1381 residues
max_nodes = 1400

# Create dataset
dataset = SparseSCOPDataset(
    root=data_dir,
    pre_filter=lambda data: data.num_nodes <= max_nodes
)

print(f"\nDataset size: {len(dataset)}")
print(f"Number of features: {dataset.num_features}")
print(f"Number of classes: {dataset.num_classes}")


Protein size statistics:
Min size: 20
Max size: 1395
Mean size: 202.8
Median size: 165
Number of proteins > 150 residues: 1859
Found class info file with 3500 entries

Dataset size: 3420
Number of features: 24
Number of classes: 7


In [10]:
from torch_geometric.loader import DataLoader


# Prepare data loaders
dataset = dataset.shuffle()
n = (len(dataset) + 9) // 10
test_dataset = dataset[:n]
val_dataset = dataset[n:2 * n]
train_dataset = dataset[2 * n:]

# Replace your existing DataLoader with this
train_loader = DataLoader(train_dataset,  batch_size=20)
val_loader = DataLoader(val_dataset,  batch_size=20)
test_loader = DataLoader(test_dataset,  batch_size=20)

print("\nDataset splits:")
print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"Test samples: {len(test_dataset)}")


Dataset splits:
Training samples: 2736
Validation samples: 342
Test samples: 342


In [11]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GATConv, global_mean_pool

class GATNetwork(torch.nn.Module):
    def __init__(self, num_features, hidden_dim=64, num_classes=7, heads=4):
        super().__init__()

        # GAT layers with multiple attention heads
        self.conv1 = GATConv(num_features, hidden_dim, heads=heads, concat=True)
        self.conv2 = GATConv(hidden_dim * heads, hidden_dim, heads=heads, concat=True)
        self.conv3 = GATConv(hidden_dim * heads, hidden_dim, heads=1, concat=False)

        # Batch normalization
        self.bn1 = torch.nn.BatchNorm1d(hidden_dim * heads)
        self.bn2 = torch.nn.BatchNorm1d(hidden_dim * heads)
        self.bn3 = torch.nn.BatchNorm1d(hidden_dim)

        # Classification head
        self.lin1 = torch.nn.Linear(hidden_dim, hidden_dim)
        self.lin2 = torch.nn.Linear(hidden_dim, num_classes)

        # Dropout
        self.dropout = torch.nn.Dropout(0.3)

    def forward(self, x, edge_index, batch):
        # First GAT layer
        h1 = self.conv1(x, edge_index)
        h1 = self.bn1(h1)
        h1 = F.elu(h1)
        h1 = self.dropout(h1)

        # Second GAT layer
        h2 = self.conv2(h1, edge_index)
        h2 = self.bn2(h2)
        h2 = F.elu(h2)
        h2 = self.dropout(h2)

        # Third GAT layer
        h3 = self.conv3(h2, edge_index)
        h3 = self.bn3(h3)
        h3 = F.elu(h3)
        h3 = self.dropout(h3)

        # Global pooling
        out = global_mean_pool(h3, batch)

        # Classification head
        out = self.lin1(out)
        out = F.elu(out)
        out = self.dropout(out)
        out = self.lin2(out)

        return F.log_softmax(out, dim=-1)


def train(model, train_loader, optimizer, device):
    model.train()
    total_loss = 0

    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        output = model(data.x, data.edge_index, data.batch)
        loss = F.nll_loss(output, data.y.view(-1))
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * data.num_graphs

    return total_loss / len(train_loader.dataset)


@torch.no_grad()
def test(model, loader, device):
    model.eval()
    correct = 0
    total = 0

    for data in loader:
        data = data.to(device)
        pred = model(data.x, data.edge_index, data.batch).max(dim=1)[1]
        correct += int(pred.eq(data.y.view(-1)).sum())
        total += data.num_graphs

    return correct / total


def main(dataset, train_loader, val_loader, test_loader):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    model = GATNetwork(
        num_features=dataset.num_features,
        hidden_dim=64,
        num_classes=dataset.num_classes
    ).to(device)

    optimizer = torch.optim.Adam(
        model.parameters(),
        lr=0.001,
        weight_decay=1e-4
    )

    best_val_acc = 0
    test_acc = 0

    for epoch in range(1, 151):
        train_loss = train(model, train_loader, optimizer, device)
        val_acc = test(model, val_loader, device)

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            test_acc = test(model, test_loader, device)

        print(f'Epoch: {epoch:03d}, Train Loss: {train_loss:.4f}, '
              f'Val Acc: {val_acc:.4f}, Test Acc: {test_acc:.4f}')

    return model, best_val_acc, test_acc

# Usage remains the same as previous implementation
# main(dataset, train_loader, val_loader, test_loader)

In [12]:
# Prepare data loaders (as you had before)
dataset = dataset.shuffle()
n = (len(dataset) + 9) // 10
test_dataset = dataset[:n]
val_dataset = dataset[n:2 * n]
train_dataset = dataset[2 * n:]

# Use DataLoader for sparse graphs
train_loader = DataLoader(train_dataset, batch_size=20)
val_loader = DataLoader(val_dataset, batch_size=20)
test_loader = DataLoader(test_dataset, batch_size=20)

# Train the model
model, best_val_acc, test_acc = main(dataset, train_loader, val_loader, test_loader)

Epoch: 001, Train Loss: 1.9287, Val Acc: 0.2339, Test Acc: 0.2222
Epoch: 002, Train Loss: 1.9037, Val Acc: 0.2515, Test Acc: 0.2515
Epoch: 003, Train Loss: 1.8958, Val Acc: 0.2778, Test Acc: 0.2456
Epoch: 004, Train Loss: 1.8930, Val Acc: 0.2661, Test Acc: 0.2456
Epoch: 005, Train Loss: 1.8838, Val Acc: 0.2632, Test Acc: 0.2456
Epoch: 006, Train Loss: 1.8802, Val Acc: 0.2953, Test Acc: 0.2485
Epoch: 007, Train Loss: 1.7904, Val Acc: 0.3684, Test Acc: 0.3304
Epoch: 008, Train Loss: 1.6159, Val Acc: 0.3713, Test Acc: 0.3860
Epoch: 009, Train Loss: 1.5791, Val Acc: 0.4152, Test Acc: 0.3977
Epoch: 010, Train Loss: 1.5431, Val Acc: 0.4269, Test Acc: 0.4298
Epoch: 011, Train Loss: 1.5104, Val Acc: 0.3684, Test Acc: 0.4298
Epoch: 012, Train Loss: 1.4942, Val Acc: 0.3860, Test Acc: 0.4298
Epoch: 013, Train Loss: 1.4840, Val Acc: 0.4064, Test Acc: 0.4298
Epoch: 014, Train Loss: 1.4586, Val Acc: 0.3977, Test Acc: 0.4298
Epoch: 015, Train Loss: 1.4322, Val Acc: 0.4152, Test Acc: 0.4298
Epoch: 016

In [13]:
# Assuming your dataset is loaded
filtered_dataset = [data for data in dataset if data.num_nodes < 300]

# If you want to create a new dataset object
from torch_geometric.data import Dataset

class FilteredSCOPDataset(Dataset):
    def __init__(self, original_dataset):
        self.data_list = [data for data in original_dataset if data.num_nodes < 300]
        super().__init__(original_dataset.root)

    def len(self):
        return len(self.data_list)

    def get(self, idx):
        return self.data_list[idx]

# Create the filtered dataset
filtered_dataset = FilteredSCOPDataset(dataset)

# Verify the filtering
print(f"Original dataset size: {len(dataset)}")
print(f"Filtered dataset size: {len(filtered_dataset)}")

# Optional: Check distribution across classes
class_distribution = {}
for data in filtered_dataset:
    class_label = data.y.item()
    class_distribution[class_label] = class_distribution.get(class_label, 0) + 1

print("\nClass distribution in filtered dataset:")
for cls, count in class_distribution.items():
    print(f"Class {cls}: {count} proteins")

Original dataset size: 3420
Filtered dataset size: 2717

Class distribution in filtered dataset:
Class 0: 428 proteins
Class 6: 500 proteins
Class 2: 368 proteins
Class 3: 382 proteins
Class 1: 449 proteins
Class 4: 172 proteins
Class 5: 418 proteins


In [15]:
# Create the filtered dataset
from torch_geometric.loader import DataLoader

filtered_dataset = FilteredSCOPDataset(dataset)

# Shuffle and split the filtered dataset
filtered_dataset = filtered_dataset.shuffle()
n = (len(filtered_dataset) + 9) // 10
test_dataset = filtered_dataset[:n]
val_dataset = filtered_dataset[n:2 * n]
train_dataset = filtered_dataset[2 * n:]

# Create data loaders using the filtered datasets
train_loader = DataLoader(train_dataset, batch_size=20)
val_loader = DataLoader(val_dataset, batch_size=20)
test_loader = DataLoader(test_dataset, batch_size=20)

# Print dataset sizes
print("\nFiltered Dataset splits:")
print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"Test samples: {len(test_dataset)}")

# The rest of your training loop remains the same
# You can use these loaders directly in your existing training script
print("\nDataset splits:")
print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"Test samples: {len(test_dataset)}")

model, best_val_acc, test_acc = main(dataset, train_loader, val_loader, test_loader)


Filtered Dataset splits:
Training samples: 2173
Validation samples: 272
Test samples: 272

Dataset splits:
Training samples: 2173
Validation samples: 272
Test samples: 272
Epoch: 001, Train Loss: 1.8997, Val Acc: 0.2243, Test Acc: 0.1618
Epoch: 002, Train Loss: 1.8608, Val Acc: 0.2978, Test Acc: 0.2096
Epoch: 003, Train Loss: 1.8426, Val Acc: 0.3125, Test Acc: 0.2243
Epoch: 004, Train Loss: 1.8319, Val Acc: 0.3051, Test Acc: 0.2243
Epoch: 005, Train Loss: 1.7876, Val Acc: 0.3566, Test Acc: 0.2978
Epoch: 006, Train Loss: 1.6013, Val Acc: 0.4669, Test Acc: 0.3419
Epoch: 007, Train Loss: 1.5338, Val Acc: 0.4522, Test Acc: 0.3419
Epoch: 008, Train Loss: 1.5132, Val Acc: 0.4890, Test Acc: 0.4522
Epoch: 009, Train Loss: 1.4798, Val Acc: 0.4890, Test Acc: 0.4522
Epoch: 010, Train Loss: 1.4786, Val Acc: 0.4743, Test Acc: 0.4522
Epoch: 011, Train Loss: 1.4585, Val Acc: 0.5037, Test Acc: 0.4154
Epoch: 012, Train Loss: 1.4297, Val Acc: 0.5110, Test Acc: 0.4412
Epoch: 013, Train Loss: 1.4306, Val