In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence
from sklearn.model_selection import train_test_split
from torch_geometric.data import Data, Batch
from torch_geometric.loader import DataLoader
import torch.optim as optim
import pandas as pd
import numpy as np
import os
import random
from sklearn.metrics import precision_score, recall_score, roc_auc_score, precision_recall_curve, roc_curve, auc
import matplotlib.pyplot as plt
import datetime
from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingLR
import argparse

In [4]:
def load_data(input_dir, prefix, samples=None):
    # Define file paths
    node_file = os.path.join(input_dir, f'{prefix}.node.csv')
    edge_file = os.path.join(input_dir, f'{prefix}.edge.csv')
    path_file = os.path.join(input_dir, f'{prefix}.path.label.csv')
    phasing_file = os.path.join(input_dir, f'{prefix}.phasing.csv')
    
    # Define dtypes
    node_dtypes = {"chr": str, "graph_id": str, "node_id": int, "start_pos": np.uint64, "end_pos": np.uint64, 
                   "weight": float, "length": int, "maxcov": float, "stddev": float, "indel_sum_cov": float, 
                   "indel_ratio": float, "left_indel": int, "right_indel": int, "sample": str}
    edge_dtypes = {"chr": str, "graph_id": str, "source": int, "target": int, "start_pos": np.uint64, 
                   "end_pos": np.uint64, "weight": float, "length": int, "sample": str}
    path_dtypes = {"chr": str, "graph_id": str, "path_id": str, "node_sequence": str, "splice_source": str, 
                   "splice_target": str, "abundance": float, "label": int, "sample": str}
    phasing_dtypes = {"chr": str, "graph_id": str, "path_id": str, "node_sequence": str, "count": int, "sample": str}
    
    # Read CSV files with headers
    node_df = pd.read_csv(node_file, dtype=node_dtypes)
    edge_df = pd.read_csv(edge_file, dtype=edge_dtypes)
    path_df = pd.read_csv(path_file, dtype=path_dtypes)
    phasing_df = pd.read_csv(phasing_file, dtype=phasing_dtypes)
    
    # Filter by samples if specified
    if samples:
        node_df = node_df[node_df['sample'].isin(samples)]
        edge_df = edge_df[edge_df['sample'].isin(samples)]
        path_df = path_df[path_df['sample'].isin(samples)]
        phasing_df = phasing_df[phasing_df['sample'].isin(samples)]
    
    return node_df, edge_df, path_df, phasing_df

In [5]:
samples = ['polyester_test1_refseq_1']
node_df, edge_df, path_df, phasing_df = load_data('/data/qzs23/projects/pathEm/aletsch-results/nnInput', 'polyester_refseq.full', samples=samples)

In [1]:
# node_df.to_csv('node_df.csv')
# edge_df.to_csv('edge_df.csv')
# path_df.to_csv('path_df.csv')
# phasing_df.to_csv('phasing_df.csv')

NameError: name 'node_df' is not defined

In [6]:
def process_input_graph_to_input_data(node_df, edge_df, phasing_df, graph_id):

    nodes = node_df[node_df['graph_id'] == graph_id]
    edges = edge_df[edge_df['graph_id'] == graph_id]
    phasing = phasing_df[phasing_df['graph_id'] == graph_id]
    
    if len(nodes) < 5:
        return None
    if len(edges) == 0:
        return None

    node_feature_cols = ['extPathSupport', 'weight', 'length', 'maxcov', 'stddev', 'indel_sum_cov','indel_ratio','left_indel','right_indel']
    edge_feature_cols = ['extPathSupport', 'weight', 'length']

    # Define torch tensors with custom features for nodes and edges
    node_features = torch.tensor(nodes[node_feature_cols].values, dtype=torch.float)
    edge_features = torch.tensor(edges[edge_feature_cols].values, dtype=torch.float)

    # Tensor for the Data object to maintain links
    edge_index = torch.tensor(edges[['source', 'target']].values.T, dtype=torch.long)

    num_nodes = len(nodes)
    num_edges = len(edges)

    # Initialize coverage with zeros
    node_phasing_coverage = torch.zeros(num_nodes, dtype=torch.float)
    edge_phasing_coverage = torch.zeros(num_edges, dtype=torch.float)

    # Build a lookup from node_id -> row index in 'nodes'
    node_ids = nodes['node_id'].values
    node_index_lookup = {nid: i for i, nid in enumerate(node_ids)}

    # Build a lookup for (source, target) -> row index in 'edges'
    edge_lookup = {}
    for i, e_row in edges.iterrows():
        s, t = e_row['source'], e_row['target']
        edge_lookup[(s, t)] = i

    # Aggregate phasing coverage
    for _, row in phasing.iterrows():
        node_seq_str = row.get('node_sequence', None)
        if not isinstance(node_seq_str, str):
            continue  # skip invalid or empty sequences
        
        count_value = row.get('count', 0)
        node_list = list(map(int, node_seq_str.split(',')))  # e.g. "0,1,2" -> [0,1,2]

        # Update node coverage
        for nid in node_list:
            if nid in node_index_lookup:
                node_phasing_coverage[node_index_lookup[nid]] += count_value

        # Update edge coverage (for consecutive pairs)
        for s, t in zip(node_list, node_list[1:]):
            if (s, t) in edge_lookup:
                edge_idx = edge_lookup[(s, t)]
                edge_phasing_coverage[edge_idx] += count_value

    node_phasing_coverage = node_phasing_coverage.view(-1, 1)  # shape [num_nodes, 1]
    edge_phasing_coverage = edge_phasing_coverage.view(-1, 1)  # shape [num_edges, 1]

    node_features = torch.cat([node_features, node_phasing_coverage], dim=1)
    edge_features = torch.cat([edge_features, edge_phasing_coverage], dim=1)

    data = Data(
        x=node_features,
        edge_index=edge_index,
        edge_attr=edge_features
    )

    return data

In [7]:
data_obj = process_input_graph_to_input_data(node_df=node_df, edge_df=edge_df, phasing_df=phasing_df, graph_id='chrNC_000001.11.instance.0.0.2.0.0.polyester_test1_refseq_1')

In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv

class GNNForRLAgent(nn.Module):
    def __init__(self, num_node_features, num_edge_features, hidden_dim, num_heads, num_gat_layers):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads

        # Normalize input node and edge features
        self.input_norm = nn.BatchNorm1d(num_node_features)
        self.edge_norm = nn.BatchNorm1d(num_edge_features)

        # Define GAT layers. The first layer converts the input node features into a hidden space.
        self.gat_layers = nn.ModuleList()
        self.gat_layers.append(
            GATConv(num_node_features, hidden_dim, heads=num_heads, edge_dim=num_edge_features)
        )
        # For subsequent layers, the input dimension is (hidden_dim * num_heads)
        for _ in range(num_gat_layers - 1):
            self.gat_layers.append(
                GATConv(hidden_dim * num_heads, hidden_dim, heads=num_heads, edge_dim=num_edge_features)
            )

        self.edge_mlp = nn.Sequential(
            nn.Linear(2 * hidden_dim * num_heads + num_edge_features, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, data):
        x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr

        # Normalize the features
        x = self.input_norm(x)
        edge_attr = self.edge_norm(edge_attr)

        # Pass through the GAT layers
        for gat_layer in self.gat_layers:
            x = gat_layer(x, edge_index, edge_attr)
            x = F.elu(x)  # non-linear activation

        # Now, x has shape [num_nodes, hidden_dim * num_heads]
        # Get the embeddings for the source and target nodes for every edge.
        source_nodes = edge_index[0]  # [num_edges]
        target_nodes = edge_index[1]  # [num_edges]
        h_source = x[source_nodes]    # [num_edges, hidden_dim * num_heads]
        h_target = x[target_nodes]    # [num_edges, hidden_dim * num_heads]

        # Concatenate the source and target embeddings with the edge features.
        edge_input = torch.cat([h_source, h_target, edge_attr], dim=-1)
        # edge_input now has shape: [num_edges, 2*hidden_dim*num_heads + num_edge_features]

        # Pass through the final MLP to produce edge logits.
        edge_logits = self.edge_mlp(edge_input).squeeze(-1)  # [num_edges]

        # The output logits can be used directly in the RL agent for masking and computing softmax.
        return edge_logits

In [9]:
num_node_features = data_obj.x.size(1)
num_edge_features = data_obj.edge_attr.size(1)

hidden_dim = 64   
num_heads = 4     
num_gat_layers = 3  

model = GNNForRLAgent(num_node_features, num_edge_features, hidden_dim, num_heads, num_gat_layers)

model.eval()

def debug_forward(model, data):
    print("=== Debug Forward Pass ===")
    print("Input node features shape:", data.x.shape)
    print("Input edge features shape:", data.edge_attr.shape)
    print("Input edge index shape:", data.edge_index.shape)
    print("-" * 50)
    
    # Normalize features
    x = model.input_norm(data.x)
    edge_attr = model.edge_norm(data.edge_attr)
    print("After normalization:")
    print("Node features shape:", x.shape)
    print("Edge features shape:", edge_attr.shape)
    print("-" * 50)
    
    # Pass through each GAT layer with activation
    for i, gat_layer in enumerate(model.gat_layers):
        x = gat_layer(x, data.edge_index, edge_attr)
        x = F.elu(x)
        print(f"After GAT layer {i + 1}:")
        print("Node features shape:", x.shape)
        print("-" * 50)
    
    # Extract node embeddings for each edge
    source_nodes = data.edge_index[0]
    target_nodes = data.edge_index[1]
    h_source = x[source_nodes]
    h_target = x[target_nodes]
    print("Source node embeddings shape:", h_source.shape)
    print("Target node embeddings shape:", h_target.shape)
    print("-" * 50)
    
    # Concatenate source and target embeddings with edge features
    edge_input = torch.cat([h_source, h_target, edge_attr], dim=-1)
    print("Concatenated edge input shape:", edge_input.shape)
    print("-" * 50)
    
    # Compute edge logits via the final MLP
    edge_logits = model.edge_mlp(edge_input).squeeze(-1)
    print("Final edge logits shape:", edge_logits.shape)
    print("Edge logits:", edge_logits)
    print("=" * 50)
    
    return edge_logits

# Run the debug-forward pass on your already-loaded data object
edge_logits = debug_forward(model, data_obj)

=== Debug Forward Pass ===
Input node features shape: torch.Size([83, 10])
Input edge features shape: torch.Size([81, 4])
Input edge index shape: torch.Size([2, 81])
--------------------------------------------------
After normalization:
Node features shape: torch.Size([83, 10])
Edge features shape: torch.Size([81, 4])
--------------------------------------------------
After GAT layer 1:
Node features shape: torch.Size([83, 256])
--------------------------------------------------
After GAT layer 2:
Node features shape: torch.Size([83, 256])
--------------------------------------------------
After GAT layer 3:
Node features shape: torch.Size([83, 256])
--------------------------------------------------
Source node embeddings shape: torch.Size([81, 256])
Target node embeddings shape: torch.Size([81, 256])
--------------------------------------------------
Concatenated edge input shape: torch.Size([81, 516])
--------------------------------------------------
Final edge logits shape: torch

In [None]:
class SpliceGraphEnv:
    def __init__(
        self,
        data: Data,
        start_node: int,
        end_node: int,
        ground_truth_path: list,
        max_steps: int = 50,
        step_penalty: float = -0.01
    ):
       
        self.data = data
        self.start_node = start_node
        self.end_node = end_node
        self.ground_truth_path = ground_truth_path  
        self.max_steps = max_steps
        self.step_penalty = step_penalty

        self.edge_index = data.edge_index
        self.num_edges = self.edge_index.shape[1]

        self.current_node = None
        self.visited_edges = None
        self.steps_taken = 0
        self.done = False
        self.path_history = []  

        self.reset()

    def reset(self):
        self.current_node = self.start_node
        self.visited_edges = set()
        self.steps_taken = 0
        self.done = False
        self.path_history = []
        return self.get_state()

    def get_state(self):
        return {
            "current_node": self.current_node,
            "visited_edges": self.visited_edges,
            "steps_taken": self.steps_taken,
            "data": self.data
        }

    def get_valid_actions(self):
        valid_edges = []
        src_nodes = self.edge_index[0]  
        for e_idx in range(self.num_edges):
            if src_nodes[e_idx].item() == self.current_node and e_idx not in self.visited_edges:
                valid_edges.append(e_idx)
        return valid_edges

    def step(self, action_edge_idx):
        # If the edge is already visited or is invalid, we can penalize and end the episode
        valid_actions = self.get_valid_actions()
        if action_edge_idx not in valid_actions:
            # Invalid action chosen
            reward = -1.0
            self.done = True
            next_state = self.get_state()
            return next_state, reward, self.done, {}

        # Mark edge as visited
        self.visited_edges.add(action_edge_idx)
        self.path_history.append(action_edge_idx)

        # Move to next node
        dst_node = self.edge_index[1][action_edge_idx].item()
        self.current_node = dst_node

        # Step penalty
        reward = self.step_penalty
        self.steps_taken += 1

        # Check if we reached the end node
        if self.current_node == self.end_node:
            # Compare assembled path to the ground truth
            final_transcript_reward = self._evaluate_transcript()
            reward += final_transcript_reward
            self.done = True

        # Check if maximum steps exceeded
        if self.steps_taken >= self.max_steps:
            # If we haven't reached end_node by now, we can finalize reward as well
            if not self.done:
                # Possibly some penalty for not finishing
                reward -= 0.5
                self.done = True

        next_state = self.get_state()
        return next_state, reward, self.done, {}

    def _evaluate_transcript(self):
        # TODO: If the nodes or edges match exactly, big reward.
        # Otherwise, partial match or some scoring. We assume ground_truth_path
        # is a list of edges for simplicity. If it's a list of nodes, adapt accordingly.
        if self.path_history == self.ground_truth_path:
            return 10.0  # perfect match reward
        else:
            return -0.2  # small negative if not perfect

In [None]:
class RLAgent:
    def __init__(
        self,
        gnn_model: nn.Module,
        optimizer: optim.Optimizer,
        gamma: float = 0.99
    ):
        self.model = gnn_model
        self.optimizer = optimizer
        self.gamma = gamma

    def select_action(self, state, valid_actions):
        data = state["data"]
        edge_logits = self.model(data)  

        # Create a mask for valid actions
        # valid_mask is 1 for valid edges, 0 for invalid
        num_edges = edge_logits.shape[0]
        mask = torch.zeros(num_edges, dtype=torch.bool, device=edge_logits.device)
        mask[valid_actions] = True

        # Large negative for invalid edges to effectively remove them from the softmax
        masked_logits = torch.where(mask, edge_logits, torch.tensor(float('-inf'), device=edge_logits.device))

        # Convert logits to probabilities
        probs = F.softmax(masked_logits, dim=-1)
        
        # Sample from the distribution
        dist = torch.distributions.Categorical(probs)
        action_edge_idx = dist.sample()
        log_prob = dist.log_prob(action_edge_idx)

        return action_edge_idx.item(), log_prob

    def update_policy(self, trajectory):
        """
        Implements REINFORCE update.
        trajectory is a list of tuples:
            [(state, action, reward, log_prob), ..., (terminal_state, None, ...)]
        We first compute discounted returns, then compute the policy gradient loss.
        """
        rewards = [tr[2] for tr in trajectory]
        log_probs = [tr[3] for tr in trajectory if tr[3] is not None]

        returns = []
        G = 0
        for r in reversed(rewards):
            G = r + self.gamma * G
            returns.insert(0, G)
        returns = torch.tensor(returns, dtype=torch.float32)

        returns = (returns - returns.mean()) / (returns.std() + 1e-8)

        # Compute policy gradient loss
        # Sum of (-log_prob * return)
        policy_loss = []
        for log_prob, R in zip(log_probs, returns):
            policy_loss.append(-log_prob * R)
        policy_loss = torch.stack(policy_loss).sum()

        self.optimizer.zero_grad()
        policy_loss.backward()
        self.optimizer.step()

In [None]:
def train_agent(env: SpliceGraphEnv, agent: RLAgent, num_episodes: int = 500):
    for episode in range(num_episodes):
        state = env.reset()

        # Storage for trajectory
        # (state, action, reward, log_prob)
        trajectory = []

        while True:
            valid_actions = env.get_valid_actions()

            # Edge case: if no valid actions remain, break
            if len(valid_actions) == 0:
                _, reward, done, _ = env.step(-1)  # triggers invalid action
                # Add final step to trajectory
                trajectory.append((state, -1, reward, None))
                break

            # Agent selects an action
            action_edge_idx, log_prob = agent.select_action(state, valid_actions)
            
            # Take a step in the environment
            next_state, reward, done, info = env.step(action_edge_idx)

            # Store in trajectory
            trajectory.append((state, action_edge_idx, reward, log_prob))

            state = next_state

            if done:
                break

        # Update policy at the end of the episode
        agent.update_policy(trajectory)

        # Logging / print out
        total_reward = sum([x[2] for x in trajectory])
        print(f"Episode {episode+1}/{num_episodes}, Total Reward: {total_reward:.3f}")

    print("Trainng done")

In [None]:
# TODO: start_node =  
# TODO: end_node = 
# TODO: Provide a ground_truth_path (list of edges or nodes) used for reward shaping:

# TODO: env = SpliceGraphEnv(
#     data=data_obj,
#     start_node=start_node,
#     end_node=end_node,
#     ground_truth_path=ground_truth_path,
#     max_steps=100  # or any suitable limit
# )

num_node_features = data_obj.x.size(-1)
num_edge_features = data_obj.edge_attr.size(-1)

gnn_model = GNNForRLAgent(
    num_node_features=num_node_features,
    num_edge_features=num_edge_features,
    hidden_dim=64,  
    num_heads=4,
    num_gat_layers=3
)

optimizer = optim.Adam(gnn_model.parameters(), lr=1e-3)

agent = RLAgent(
    gnn_model=gnn_model,
    optimizer=optimizer,
    gamma=0.99
)

train_agent(env, agent, num_episodes=50)

def assemble_transcript(env, agent):
    """
    Runs the environment from start to end using the trained agent,
    returns the list of edges visited and the corresponding node path.
    """
    state = env.reset()
    done = False

    while not done:
        valid_actions = env.get_valid_actions()
        if not valid_actions:
            env.step(-1)
            break

        action_idx, _ = agent.select_action(state, valid_actions)
        next_state, reward, done, _ = env.step(action_idx)
        state = next_state

    # The environment tracks visited edges in env.path_history.
    node_path = [env.start_node]
    for edge_idx in env.path_history:
        dst_node = env.edge_index[1][edge_idx].item()
        node_path.append(dst_node)

    return env.path_history, node_path

final_edge_path, final_node_path = assemble_transcript(env, agent)

print("Assembled transcript (edge indices):", final_edge_path)
print("Assembled transcript (node sequence):", final_node_path)