In [1]:
import os
import pandas as pd
import numpy as np
import shutil


import torch
from sklearn.model_selection import train_test_split
from torch_geometric.data import Data, InMemoryDataset

from utils import load_and_filter_dataset, one_hot_encode_sequence, create_graph_from_sequence, summarize_dataset
from dataset import PiRNADataset

import networkx as nx
import matplotlib.pyplot as plt

In [2]:
data = load_and_filter_dataset("data\DASHR2_GEO_hg38_sequenceTable_export.csv", save_filtered=True)

  data = load_and_filter_dataset("data\DASHR2_GEO_hg38_sequenceTable_export.csv", save_filtered=True)
  df = pd.read_csv(file_path, sep=None, header=None)


Dataset Summary (Before Filtering):
  Total entries: 65156
  RNA types: 13
rnaClass
piRNA           50397
snRNA            4509
miRNAprimary     1881
rRNA             1840
scRNA            1420
mir-3p            959
mir-5p            957
mir-5p3pno        897
tRNA              631
tRF3              631
tRF5              631
snoRNA            402
rnaClass            1
Name: count, dtype: int64
-----------------------------------------
Filtered Dataset:
  Total piRNAs with length 26–32: 50397
  Unique piRNA IDs: 23116
  Length range in filtered data: 26.0 - 32.0
Filtered data saved to data/filtered_piRNAs.csv


In [3]:
train_df, temp_df = train_test_split(data, test_size=0.2, random_state=42)
val_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=42)

print(f"Train samples: {len(train_df)}, Val samples: {len(val_df)}, Test samples: {len(test_df)}")

output_root = "dataset"
os.makedirs(output_root, exist_ok=True)

for split_name, split_df in zip(['train', 'val', 'test'], [train_df, val_df, test_df]):
    split_path = os.path.join(output_root, split_name)
    os.makedirs(split_path, exist_ok=True)
    
    dataset = PiRNADataset(root=split_path, df=split_df)
    print(f"\n {split_name.capitalize()} set processed and saved to: {split_path}")
    
    summarize_dataset(dataset, split_name)

Train samples: 40317, Val samples: 5040, Test samples: 5040

 Train set processed and saved to: dataset\train

 Summary for train set:
Total samples: 40317
Sequence length: mean = 32.00, std = 0.00, min = 32, max = 32
Label distribution: {27.0: 5626, 31.0: 4619, 30.0: 8687, 28.0: 5138, 29.0: 8186, 32.0: 1271, 26.0: 6790}
Average nodes per graph: 32.00
Average edges per graph: 62.00
Node feature shapes: {(32, 5)}
Label shapes: {(1,)}

 Val set processed and saved to: dataset\val

 Summary for val set:
Total samples: 5040
Sequence length: mean = 32.00, std = 0.00, min = 32, max = 32
Label distribution: {30.0: 1081, 28.0: 635, 26.0: 893, 29.0: 962, 32.0: 181, 31.0: 606, 27.0: 682}
Average nodes per graph: 32.00
Average edges per graph: 62.00
Node feature shapes: {(32, 5)}
Label shapes: {(1,)}

 Test set processed and saved to: dataset\test

 Summary for test set:
Total samples: 5040
Sequence length: mean = 32.00, std = 0.00, min = 32, max = 32
Label distribution: {30.0: 1091, 28.0: 583, 2

In [4]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool

class GCNEncoder(torch.nn.Module):
    def __init__(self, in_channels=5, hidden_channels=32, out_channels=64):
        super().__init__()
        self.gcn1 = GCNConv(in_channels, hidden_channels)
        self.gcn2 = GCNConv(hidden_channels, hidden_channels)
        self.gcn3 = GCNConv(hidden_channels, out_channels)
    
    def forward(self, x, edge_index, batch=None):
        x = self.gcn1(x, edge_index)
        x = F.relu(x)
        
        x = self.gcn2(x, edge_index)
        x = F.relu(x)
        
        x = self.gcn3(x, edge_index)  
        
        if batch is not None:
            x = global_mean_pool(x, batch)
        else:
            x = x.mean(dim=0, keepdim=True)
        
        return x

In [5]:
import torch

def test_gcn_encoder(num_nodes, num_node_features, hidden_dim, out_channels, num_edges):
    x = torch.randn(num_nodes, num_node_features)
    edge_index = torch.randint(0, num_nodes, (2, num_edges), dtype=torch.long)

    model = GCNEncoder(in_channels=num_node_features, hidden_channels=hidden_dim, out_channels=out_channels)
    out = model(x, edge_index) 
    
    return out.shape

test_out_shape = test_gcn_encoder(num_nodes=32, num_node_features=5, hidden_dim=32, out_channels=64, num_edges=62)
print(test_out_shape)  # Expect torch.Size([1, 64])


torch.Size([1, 64])
