In [43]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv  # Changed from DenseSAGEConv
from torch_geometric.nn import global_mean_pool  # For sparse pooling
from torch_geometric.data import Dataset, Data
import numpy as np
import os
from Bio import PDB
import warnings
import torch
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv
from torch_geometric.nn import global_mean_pool
from torch_geometric.data import Dataset, Data
import numpy as np
import os
from Bio import PDB
import pandas as pd
from itertools import groupby
from collections import Counter
import warnings
warnings.filterwarnings('ignore')
warnings.filterwarnings('ignore')

# Define data directory
data_dir = 'data/SCOP'  # Base directory for SCOP data

# Rest of the model code remains the same...

In [44]:
import torch
from torch_geometric.data import Dataset, Data
import numpy as np
import os
from Bio import PDB
from Bio.PDB import NeighborSearch, Selection
import pandas as pd
import warnings
import networkx as nx
import matplotlib.pyplot as plt
from typing import Dict, List, Tuple, Optional
warnings.filterwarnings('ignore')

class SparseSCOPDataset(Dataset):

    def __init__(self, root: str, transform=None, pre_transform=None, pre_filter=None):
        self.root = root
        self.class_info_path = os.path.join(root, 'raw/class_info.csv')

        # Add normalization parameters
        self.coord_mean: Optional[float] = None
        self.coord_std: Optional[float] = None
        self.mass_mean: Optional[float] = None
        self.mass_std: Optional[float] = None
        self.dist_mean: Optional[float] = None
        self.dist_std: Optional[float] = None
        self.count_mean: Optional[float] = None
        self.count_std: Optional[float] = None

        # Feature indices for easy access
        self.feature_indices = {
            'aa_onehot': slice(0, 21),
            'coords': slice(21, 24),
            'mass': 24,
            'avg_dist': 25,
            'max_dist': 26,
            'neighbor_count': 27
        }

        # Dictionary to map SCOP classes to indices
        self.class_mapping = {
            'a': 0,  # All-alpha
            'b': 1,  # All-beta
            'c': 2,  # Alpha/beta
            'd': 3,  # Alpha+beta
            'e': 4,  # Multi-domain
            'f': 5,  # Membrane
            'g': 6   # Small proteins
        }

        # Dictionary to map amino acids to indices and their properties
        self.amino_acids = {
            'ALA': {'index': 0, 'mass': 89.1, 'name': 'Alanine'},
            'ARG': {'index': 1, 'mass': 174.2, 'name': 'Arginine'},
            'ASN': {'index': 2, 'mass': 132.1, 'name': 'Asparagine'},
            'ASP': {'index': 3, 'mass': 133.1, 'name': 'Aspartic Acid'},
            'CYS': {'index': 4, 'mass': 121.2, 'name': 'Cysteine'},
            'GLN': {'index': 5, 'mass': 146.2, 'name': 'Glutamine'},
            'GLU': {'index': 6, 'mass': 147.1, 'name': 'Glutamic Acid'},
            'GLY': {'index': 7, 'mass': 75.1, 'name': 'Glycine'},
            'HIS': {'index': 8, 'mass': 155.2, 'name': 'Histidine'},
            'ILE': {'index': 9, 'mass': 131.2, 'name': 'Isoleucine'},
            'LEU': {'index': 10, 'mass': 131.2, 'name': 'Leucine'},
            'LYS': {'index': 11, 'mass': 146.2, 'name': 'Lysine'},
            'MET': {'index': 12, 'mass': 149.2, 'name': 'Methionine'},
            'PHE': {'index': 13, 'mass': 165.2, 'name': 'Phenylalanine'},
            'PRO': {'index': 14, 'mass': 115.1, 'name': 'Proline'},
            'SER': {'index': 15, 'mass': 105.1, 'name': 'Serine'},
            'THR': {'index': 16, 'mass': 119.1, 'name': 'Threonine'},
            'TRP': {'index': 17, 'mass': 204.2, 'name': 'Tryptophan'},
            'TYR': {'index': 18, 'mass': 181.2, 'name': 'Tyrosine'},
            'VAL': {'index': 19, 'mass': 117.1, 'name': 'Valine'},
            'UNK': {'index': 20, 'mass': 0.0, 'name': 'Unknown'}
        }

        # Load class information before calling super().__init__()
        if os.path.exists(self.class_info_path):
            self.class_info = pd.read_csv(self.class_info_path)
            print(f"Found class info file with {len(self.class_info)} entries")
        else:
            print(f"Warning: class_info.csv not found at {self.class_info_path}")
            self.class_info = None

        # Initialize the base class after setting up our attributes
        super().__init__(root, transform, pre_transform, pre_filter)

    @property
    def raw_file_names(self):
        """List of raw file names."""
        if not os.path.exists(os.path.join(self.root, 'raw')):
            return []
        return [f for f in os.listdir(os.path.join(self.root, 'raw'))
                if f.endswith('.pdb')]

    @property
    def processed_file_names(self):
        """List of processed file names."""
        return ['data.pt']

    def download(self):
        """Download function is not needed as we assume data is present."""
        pass

    def process(self):
        """Process the raw data into the internal format."""
        if self.class_info is None:
            raise RuntimeError("No class information available. Cannot process data.")

        data_list = []
        parser = PDB.PDBParser(QUIET=True)

        # Collect statistics for normalization
        all_coords = []
        all_masses = []
        all_distances = []
        all_counts = []

        for idx, row in self.class_info.iterrows():
            pdb_id = str(row['scop_id'])
            class_label = self.class_mapping[row['class']]

            try:
                # Load structure
                pdb_file = f"{pdb_id}.pdb"
                pdb_path = os.path.join(self.root, 'raw', pdb_file)
                structure = parser.get_structure('protein', pdb_path)
                model = structure[0]

                # Get residues and create features
                residues = list(model.get_residues())

                # Create neighbor search for the entire structure
                atom_list = Selection.unfold_entities(model, 'A')
                ns = NeighborSearch(atom_list)

                # Process residues and collect statistics
                node_features = []
                for residue in residues:
                    center = residue['CA'].get_coord() if 'CA' in residue else None
                    if center is not None:
                        neighbors = ns.search(center, 5.0, level='R')
                        features = self._get_residue_features(residue, neighbors)
                        node_features.append(features)

                        # Collect statistics
                        all_coords.extend(features[21:24])
                        all_masses.append(features[24])
                        all_distances.extend(features[25:27])
                        all_counts.append(features[27])

                # Create edges with 5Å cutoff
                edges = []
                for i in range(len(residues)):
                    for j in range(i+1, len(residues)):
                        if 'CA' in residues[i] and 'CA' in residues[j]:
                            ca_i = residues[i]['CA'].get_coord()
                            ca_j = residues[j]['CA'].get_coord()
                            dist = np.linalg.norm(ca_i - ca_j)
                            if dist < 5.0:
                                edges.append([i, j])
                                edges.append([j, i])

                if len(edges) == 0:
                    continue

                # Create PyG Data object
                edge_index = torch.tensor(edges, dtype=torch.long).t()
                x = torch.tensor(node_features, dtype=torch.float)
                y = torch.tensor([class_label], dtype=torch.long)

                data = Data(x=x, edge_index=edge_index, y=y, num_nodes=len(residues))

                if self.pre_filter is not None and not self.pre_filter(data):
                    continue

                if self.pre_transform is not None:
                    data = self.pre_transform(data)

                data_list.append(data)

            except Exception as e:
                print(f"Error processing {pdb_id}: {str(e)}")
                continue

        if len(data_list) == 0:
            raise RuntimeError("No data was successfully processed!")

        # Calculate normalization parameters
        all_coords = np.array(all_coords)
        all_masses = np.array(all_masses)
        all_distances = np.array(all_distances)
        all_counts = np.array(all_counts)

        # Save normalization parameters
        self.coord_mean = float(np.mean(all_coords))
        self.coord_std = float(np.std(all_coords))
        self.mass_mean = float(np.mean(all_masses))
        self.mass_std = float(np.std(all_masses))
        self.dist_mean = float(np.mean(all_distances))
        self.dist_std = float(np.std(all_distances))
        self.count_mean = float(np.mean(all_counts))
        self.count_std = float(np.std(all_counts))

        # Create processed directory if it doesn't exist
        os.makedirs(self.processed_dir, exist_ok=True)

        # Save both the processed data and normalization parameters
        torch.save({
            'data_list': data_list,
            'normalization': {
                'coord_mean': self.coord_mean,
                'coord_std': self.coord_std,
                'mass_mean': self.mass_mean,
                'mass_std': self.mass_std,
                'dist_mean': self.dist_mean,
                'dist_std': self.dist_std,
                'count_mean': self.count_mean,
                'count_std': self.count_std
            }
        }, os.path.join(self.processed_dir, 'data.pt'))

    def len(self):
        """Return the number of graphs in the dataset."""
        if not hasattr(self, '_data_list'):
            processed_path = os.path.join(self.processed_dir, 'data.pt')
            if not os.path.exists(processed_path):
                self.process()

            saved_data = torch.load(processed_path, weights_only=False)
            self._data_list = saved_data['data_list']

            # Load normalization parameters
            norm_params = saved_data['normalization']
            self.coord_mean = norm_params['coord_mean']
            self.coord_std = norm_params['coord_std']
            self.mass_mean = norm_params['mass_mean']
            self.mass_std = norm_params['mass_std']
            self.dist_mean = norm_params['dist_mean']
            self.dist_std = norm_params['dist_std']
            self.count_mean = norm_params['count_mean']
            self.count_std = norm_params['count_std']

        return len(self._data_list)

    def get(self, idx):
        """Get a specific graph from the dataset."""
        if not hasattr(self, '_data_list'):
            self.len()  # This will load the data
        return self._data_list[idx]

    def _get_residue_features(self, residue, neighbors):
        """Create expanded feature vector for a residue."""
        # One-hot encode amino acid type
        aa_features = np.zeros(21)  # 20 standard amino acids + UNK
        aa_name = residue.get_resname()
        aa_info = self.amino_acids.get(aa_name, self.amino_acids['UNK'])
        aa_features[aa_info['index']] = 1

        # Get mass
        mass = aa_info['mass']

        # Get CA atom coordinates
        try:
            ca_atom = residue['CA']
            coords = ca_atom.get_coord()
        except:
            coords = np.zeros(3)

        # Calculate neighborhood features
        if neighbors:
            neighbor_distances = []
            for neighbor in neighbors:
                if neighbor != residue and 'CA' in neighbor:
                    dist = np.linalg.norm(coords - neighbor['CA'].get_coord())
                    neighbor_distances.append(dist)

            avg_neighbor_dist = np.mean(neighbor_distances) if neighbor_distances else 0
            max_neighbor_dist = np.max(neighbor_distances) if neighbor_distances else 0
            neighbor_count = len(neighbor_distances)
        else:
            avg_neighbor_dist = 0
            max_neighbor_dist = 0
            neighbor_count = 0

        # Combine all features
        features = np.concatenate([
            aa_features,          # Amino acid identity (21)
            coords,              # 3D coordinates (3)
            [mass],             # Mass (1)
            [avg_neighbor_dist], # Average neighbor distance (1)
            [max_neighbor_dist], # Maximum neighbor distance (1)
            [neighbor_count]     # Number of neighbors (1)
        ])

        return features

    def get_amino_acid_name(self, features: torch.Tensor) -> str:
        """Get amino acid name from one-hot encoded features."""
        aa_idx = torch.argmax(features[self.feature_indices['aa_onehot']]).item()
        for aa, info in self.amino_acids.items():
            if info['index'] == aa_idx:
                return aa
        return 'UNK'

    def get_feature_info(self, data: Data, node_idx: int) -> Dict[str, float]:
        """Get denormalized feature information for a specific node."""
        features = data.x[node_idx]
        aa_name = self.get_amino_acid_name(features)

        return {
            'amino_acid': aa_name,
            'full_name': self.amino_acids[aa_name]['name'],
            'coordinates': [
                self.denormalize_feature(features[21].item(), 'x'),
                self.denormalize_feature(features[22].item(), 'y'),
                self.denormalize_feature(features[23].item(), 'z')
            ],
            'mass': self.denormalize_feature(features[24].item(), 'mass'),
            'avg_neighbor_dist': self.denormalize_feature(features[25].item(), 'avg_dist'),
            'max_neighbor_dist': self.denormalize_feature(features[26].item(), 'max_dist'),
            'neighbor_count': round(self.denormalize_feature(features[27].item(), 'neighbor_count'))
        }

    def denormalize_feature(self, value: float, feature_name: str) -> float:
        """Denormalize a single feature value."""
        if feature_name in ['x', 'y', 'z']:
            return value * self.coord_std + self.coord_mean
        elif feature_name == 'mass':
            return value * self.mass_std + self.mass_mean
        elif feature_name in ['avg_dist', 'max_dist']:
            return value * self.dist_std + self.dist_mean
        elif feature_name == 'neighbor_count':
            return value * self.count_std + self.count_mean
        else:
            raise ValueError(f"Unknown feature name: {feature_name}")

    def get_graph_stats(self, data: Data) -> Dict[str, float]:
        """Get statistical information about the protein graph."""
        G = nx.Graph()
        edge_index = data.edge_index.numpy()
        for i in range(edge_index.shape[1]):
            G.add_edge(edge_index[0, i], edge_index[1, i])

        return {
            'num_nodes': G.number_of_nodes(),
            'num_edges': G.number_of_edges(),
            'average_degree': sum(dict(G.degree()).values()) / G.number_of_nodes(),
            'density': nx.density(G),
            'is_connected': nx.is_connected(G),
            'average_clustering': nx.average_clustering(G),
            'average_shortest_path_length': nx.average_shortest_path_length(G) if nx.is_connected(G) else float('inf')
        }

    def visualize_protein(self, data: Data,
                          color_by: str = 'amino_acid',
                          node_size: int = 100,
                          with_labels: bool = True,
                          figure_size: Tuple[int, int] = (12, 8)) -> None:
        """
        Visualize protein structure as a graph.

        Args:
            data: PyG Data object
            color_by: Feature to color nodes by ('amino_acid', 'mass', 'neighbor_count')
            node_size: Size of nodes in the visualization
            with_labels: Whether to show node labels
            figure_size: Size of the figure (width, height)
        """
        # Convert to networkx
        G = nx.Graph()
        edge_index = data.edge_index.numpy()

        # Add edges
        for i in range(edge_index.shape[1]):
            G.add_edge(edge_index[0, i], edge_index[1, i])

        # Prepare node colors and labels
        node_colors = []
        node_labels = {}

        for node in G.nodes():
            info = self.get_feature_info(data, node)

            if color_by == 'amino_acid':
                # Use a hash of amino acid name for color
                color = hash(info['amino_acid']) % 20 / 20.0
                node_colors.append(plt.cm.tab20(color))
            elif color_by == 'mass':
                node_colors.append(info['mass'])
            elif color_by == 'neighbor_count':
                node_colors.append(info['neighbor_count'])

            if with_labels:
                node_labels[node] = info['amino_acid']

        # Create visualization
        plt.figure(figsize=figure_size)
        pos = nx.spring_layout(G)

        nx.draw(G, pos,
                node_color=node_colors,
                node_size=node_size,
                with_labels=with_labels,
                labels=node_labels if with_labels else None,
                cmap=plt.cm.viridis if color_by != 'amino_acid' else None)

        if color_by != 'amino_acid':
            sm = plt.cm.ScalarMappable(cmap=plt.cm.viridis)
            sm.set_array([])
            plt.colorbar(sm, label=color_by.replace('_', ' ').title())

        plt.title(f'Protein Structure Graph (colored by {color_by})')
        plt.show()



In [45]:
#import os
#os.makedirs(os.path.join(data_dir, 'processed'), exist_ok=True)  # Make sure processed dir exists
#os.remove(os.path.join(data_dir, 'processed/data.pt'))  # Remove old processed file

dataset = SparseSCOPDataset(root=data_dir)
data = dataset[0]  # Should work now

Found class info file with 3500 entries


Processing...


Error processing 76386: [Errno 2] No such file or directory: 'data/SCOP/raw/76386.pdb'
Error processing 191866: [Errno 2] No such file or directory: 'data/SCOP/raw/191866.pdb'
Error processing 246439: [Errno 2] No such file or directory: 'data/SCOP/raw/246439.pdb'
Error processing 103823: [Errno 2] No such file or directory: 'data/SCOP/raw/103823.pdb'
Error processing 81092: [Errno 2] No such file or directory: 'data/SCOP/raw/81092.pdb'
Error processing 73992: [Errno 2] No such file or directory: 'data/SCOP/raw/73992.pdb'
Error processing 364407: [Errno 2] No such file or directory: 'data/SCOP/raw/364407.pdb'
Error processing 80047: [Errno 2] No such file or directory: 'data/SCOP/raw/80047.pdb'
Error processing 308345: [Errno 2] No such file or directory: 'data/SCOP/raw/308345.pdb'
Error processing 309195: [Errno 2] No such file or directory: 'data/SCOP/raw/309195.pdb'
Error processing 164734: [Errno 2] No such file or directory: 'data/SCOP/raw/164734.pdb'
Error processing 60948: [Errn

Done!


In [46]:
stats = dataset.get_graph_stats(data)
print(stats)

{'num_nodes': 177, 'num_edges': 214, 'average_degree': 2.4180790960451977, 'density': 0.013739085772984078, 'is_connected': True, 'average_clustering': 0.002071563088512241, 'average_shortest_path_length': 24.853428351309706}


In [47]:
#dataset = SparseSCOPDataset(root=data_dir)
#data = dataset[0]

# Get info about a specific residue
info = dataset.get_feature_info(data, node_idx=0)
print(info)

# Visualize the protein in different ways
#dataset.visualize_protein(data, color_by='amino_acid')
#dataset.visualize_protein(data, color_by='mass')
#dataset.visualize_protein(data, color_by='neighbor_count')

# Get graph statistics


{'amino_acid': 'GLU', 'full_name': 'Glutamic Acid', 'coordinates': [10233.289051144287, 8598.223906333711, 5686.039750469833], 'mass': 4408.534576972759, 'avg_neighbor_dist': 12.091510059127168, 'max_neighbor_dist': 13.52793481138863, 'neighbor_count': 10}


In [65]:
import os.path as osp
import time
from math import ceil

import torch
import torch.nn.functional as F
import torch_geometric.transforms as T
from torch_geometric.loader import DenseDataLoader

NUM_CLASSES = 7  # SCOP main classes

# Load the processed data
processed_path = os.path.join(data_dir, 'processed/data.pt')
saved_data = torch.load(processed_path, weights_only=False)
data_list = saved_data['data_list']  # Access the data_list from the dictionary

# Analyze protein sizes
sizes = [data.num_nodes for data in data_list]

print(f"Protein size statistics:")
print(f"Min size: {min(sizes)}")
print(f"Max size: {max(sizes)}")
print(f"Mean size: {sum(sizes)/len(sizes):.1f}")
print(f"Median size: {sorted(sizes)[len(sizes)//2]}")
print(f"Number of proteins > 150 residues: {sum(1 for s in sizes if s > 150)}")

# Set max_nodes to match maximum size found
max_nodes = max(sizes) + 20  # Add some buffer

# Create dataset
dataset = SparseSCOPDataset(root=data_dir)
print(f"\nDataset size: {len(dataset)}")
print(f"Number of features: {dataset[0].num_features}")
print(f"Number of classes: {NUM_CLASSES}")

Protein size statistics:
Min size: 20
Max size: 1523
Mean size: 204.2
Median size: 165
Number of proteins > 150 residues: 1863
Found class info file with 3500 entries

Dataset size: 3424
Number of features: 28
Number of classes: 7


In [66]:
import torch
import matplotlib.pyplot as plt
import numpy as np
import os

# Load the processed data
processed_path = os.path.join(data_dir, 'processed/data.pt')
saved_data = torch.load(processed_path, weights_only=False)
data_list = saved_data['data_list']  # Access the data_list from the dictionary

# Reverse class mapping for readable labels
class_mapping_reverse = {
    0: 'a (All-alpha)',
    1: 'b (All-beta)',
    2: 'c (Alpha/beta)',
    3: 'd (Alpha+beta)',
    4: 'e (Multi-domain)',
    5: 'f (Membrane)',
    6: 'g (Small proteins)'
}

# Separate nodes by class
nodes_by_class = {}
for data in data_list:
    class_label = data.y.item()
    if class_label not in nodes_by_class:
        nodes_by_class[class_label] = []
    nodes_by_class[class_label].append(data.num_nodes)

# Create the histogram
plt.figure(figsize=(12, 6))

# Box plot
plt.boxplot([nodes_by_class[key] for key in sorted(nodes_by_class.keys())],
            labels=[class_mapping_reverse[key] for key in sorted(nodes_by_class.keys())])

plt.title('Number of Nodes per SCOP Class', fontsize=16)
plt.xlabel('SCOP Class', fontsize=12)
plt.ylabel('Number of Nodes (Residues)', fontsize=12)
plt.xticks(rotation=45, ha='right')
plt.tight_layout()

# Save the plot
plt.savefig('nodes_per_class_boxplot.png')
plt.close()

# Print some statistics
print("Node count statistics per class:")
for class_label, nodes in nodes_by_class.items():
    print(f"\n{class_mapping_reverse[class_label]}:")
    print(f"  Count: {len(nodes)}")
    print(f"  Min nodes: {min(nodes)}")
    print(f"  Max nodes: {max(nodes)}")
    print(f"  Mean nodes: {np.mean(nodes):.2f}")
    print(f"  Median nodes: {np.median(nodes):.2f}")

Node count statistics per class:

a (All-alpha):
  Count: 500
  Min nodes: 39
  Max nodes: 656
  Mean nodes: 172.21
  Median nodes: 142.50

b (All-beta):
  Count: 500
  Min nodes: 53
  Max nodes: 548
  Mean nodes: 170.79
  Median nodes: 135.00

c (Alpha/beta):
  Count: 500
  Min nodes: 44
  Max nodes: 815
  Mean nodes: 255.96
  Median nodes: 247.00

d (Alpha+beta):
  Count: 424
  Min nodes: 56
  Max nodes: 640
  Mean nodes: 183.87
  Median nodes: 164.00

e (Multi-domain):
  Count: 500
  Min nodes: 66
  Max nodes: 1523
  Mean nodes: 396.43
  Median nodes: 358.00

f (Membrane):
  Count: 500
  Min nodes: 25
  Max nodes: 746
  Mean nodes: 184.99
  Median nodes: 146.00

g (Small proteins):
  Count: 500
  Min nodes: 20
  Max nodes: 157
  Mean nodes: 62.37
  Median nodes: 56.00


In [67]:
processed_path = os.path.join(data_dir, 'processed/data.pt')
saved_data = torch.load(processed_path, weights_only=False)
data_list = saved_data['data_list']  # Access data_list from the dictionary

# Count proteins per class and above 300 nodes
class_counts = {}
above_300_counts = {}

for data in data_list:
    class_label = data.y.item()
    if class_label not in class_counts:
        class_counts[class_label] = 0
        above_300_counts[class_label] = 0

    class_counts[class_label] += 1
    if data.num_nodes > 300:
        above_300_counts[class_label] += 1

print("\nTotal proteins per class:")
for cls, count in class_counts.items():
    print(f"Class {cls}: {count} total, {above_300_counts[cls]} above 300 nodes ({above_300_counts[cls]/count*100:.2f}%)")


Total proteins per class:
Class 0: 500 total, 71 above 300 nodes (14.20%)
Class 1: 500 total, 50 above 300 nodes (10.00%)
Class 2: 500 total, 131 above 300 nodes (26.20%)
Class 3: 424 total, 42 above 300 nodes (9.91%)
Class 4: 500 total, 328 above 300 nodes (65.60%)
Class 5: 500 total, 81 above 300 nodes (16.20%)
Class 6: 500 total, 0 above 300 nodes (0.00%)


In [83]:
import torch
from torch_geometric.loader import DataLoader
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv

class SparseGNN(torch.nn.Module):
    def __init__(self, num_features, hidden_dim=64, num_classes=7):
        super().__init__()
        self.conv1 = SAGEConv(num_features, hidden_dim)
        self.conv2 = SAGEConv(hidden_dim, hidden_dim)
        self.conv3 = SAGEConv(hidden_dim, hidden_dim)

        self.bn1 = torch.nn.BatchNorm1d(hidden_dim)
        self.bn2 = torch.nn.BatchNorm1d(hidden_dim)
        self.bn3 = torch.nn.BatchNorm1d(hidden_dim)

        self.lin1 = torch.nn.Linear(hidden_dim, hidden_dim)
        self.lin2 = torch.nn.Linear(hidden_dim, num_classes)

        self.dropout = torch.nn.Dropout(0.2)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        batch = data.batch if hasattr(data, 'batch') else torch.zeros(x.size(0), dtype=torch.long, device=x.device)

        # First GraphSAGE layer
        x = self.conv1(x, edge_index)
        x = self.bn1(x)
        x = F.relu(x)
        x = self.dropout(x)

        # Second GraphSAGE layer
        x = self.conv2(x, edge_index)
        x = self.bn2(x)
        x = F.relu(x)
        x = self.dropout(x)

        # Third GraphSAGE layer
        x = self.conv3(x, edge_index)
        x = self.bn3(x)
        x = F.relu(x)
        x = self.dropout(x)

        # Custom pooling implementation
        unique_batches = torch.unique(batch)
        pooled_features = []

        for b in unique_batches:
            mask = (batch == b)
            graph_features = x[mask]
            pooled_features.append(torch.mean(graph_features, dim=0))

        x = torch.stack(pooled_features)

        # MLP head
        x = self.lin1(x)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.lin2(x)

        return x

def train_model(train_loader, val_loader, test_loader, device):
    model = SparseGNN(num_features=28, num_classes=7).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
    criterion = torch.nn.CrossEntropyLoss()
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=5, verbose=True)

    best_val_acc = 0
    test_acc = 0
    patience = 10
    no_improve = 0
    best_model_state = None

    for epoch in range(1, 151):
        model.train()
        total_loss = 0
        correct = 0
        total = 0

        for data in train_loader:
            try:
                data = data.to(device)
                optimizer.zero_grad()
                out = model(data)
                loss = criterion(out, data.y)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  # Add gradient clipping
                optimizer.step()

                pred = out.argmax(dim=1)
                correct += pred.eq(data.y).sum().item()
                total += data.num_graphs
                total_loss += loss.item() * data.num_graphs

            except Exception as e:
                print(f"Error in batch: {str(e)}")
                continue

        if total == 0:  # Skip epoch if all batches failed
            continue

        train_loss = total_loss / total
        train_acc = correct / total
        val_acc = test_model(model, val_loader, device)

        # Update learning rate
        scheduler.step(val_acc)

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            test_acc = test_model(model, test_loader, device)
            best_model_state = model.state_dict()
            no_improve = 0
        else:
            no_improve += 1

        print(f'Epoch: {epoch:03d}, Train Loss: {train_loss:.4f}, '
              f'Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}, '
              f'Test Acc: {test_acc:.4f}, '
              f'LR: {optimizer.param_groups[0]["lr"]:.6f}')

        # Early stopping
        if no_improve >= patience:
            print("Early stopping triggered")
            break

    # Restore best model
    if best_model_state is not None:
        model.load_state_dict(best_model_state)

    return model, best_val_acc, test_acc

@torch.no_grad()
def test_model(model, loader, device):
    model.eval()
    total_correct = 0
    total_examples = 0

    for data in loader:
        try:
            data = data.to(device)
            pred = model(data).max(dim=1)[1]
            total_correct += pred.eq(data.y).sum().item()
            total_examples += data.num_graphs
        except Exception as e:
            print(f"Error in testing batch: {str(e)}")
            continue

    if total_examples == 0:
        return 0.0

    return total_correct / total_examples

def create_data_loaders(train_dataset, val_dataset, test_dataset, batch_size=8):
    # Custom collate function to handle varying batch sizes
    def collate_fn(data_list):
        return Batch.from_data_list(data_list)

    train_loader = DataLoader(train_dataset,
                              batch_size=batch_size,
                              shuffle=True,
                              drop_last=True,
                              collate_fn=collate_fn)
    val_loader = DataLoader(val_dataset,
                            batch_size=batch_size,
                            drop_last=True,
                            collate_fn=collate_fn)
    test_loader = DataLoader(test_dataset,
                             batch_size=batch_size,
                             drop_last=True,
                             collate_fn=collate_fn)
    return train_loader, val_loader, test_loader

In [84]:
import torch
from torch_geometric.loader import DataLoader
from torch_geometric.data import Batch
import os

# Load the dataset and create splits
processed_path = os.path.join(data_dir, 'processed/data.pt')
saved_data = torch.load(processed_path, weights_only=False)
data_list = saved_data['data_list']

# Create indices for random split
n = len(data_list)
indices = torch.randperm(n)
train_size = int(0.8 * n)
val_size = int(0.1 * n)

# Split indices
train_indices = indices[:train_size]
val_indices = indices[train_size:train_size + val_size]
test_indices = indices[train_size + val_size:]

# Create dataset splits
train_dataset = [data_list[i] for i in train_indices]
val_dataset = [data_list[i] for i in val_indices]
test_dataset = [data_list[i] for i in test_indices]

# Create data loaders with proper batch handling
batch_size = 8  # Reduced batch size for better stability
train_loader, val_loader, test_loader = create_data_loaders(
    train_dataset, val_dataset, test_dataset, batch_size=batch_size
)

# Setup device
device = torch.device('cuda' if torch.cuda.is_available() else
                      'mps' if torch.backends.mps.is_available() else
                      'cpu')
print(f"Using device: {device}")

# Print dataset statistics
print(f"Number of training graphs: {len(train_dataset)}")
print(f"Number of validation graphs: {len(val_dataset)}")
print(f"Number of test graphs: {len(test_dataset)}")

# Train the model
model, best_val_acc, test_acc = train_model(train_loader, val_loader, test_loader, device)

print(f'\nFinal results:')
print(f'Best validation accuracy: {best_val_acc:.4f}')
print(f'Test accuracy: {test_acc:.4f}')

# Save the best model
torch.save(model.state_dict(), 'best_model.pt')

Using device: mps
Number of training graphs: 2739
Number of validation graphs: 342
Number of test graphs: 343
Error in batch: The shape of the mask [1491] at index 0 does not match the shape of the indexed tensor [1490, 64] at index 0
Error in batch: The shape of the mask [1378] at index 0 does not match the shape of the indexed tensor [1377, 64] at index 0
Error in batch: The shape of the mask [1003] at index 0 does not match the shape of the indexed tensor [1002, 64] at index 0


KeyboardInterrupt: 