In [None]:
#!pip install pyfaidx

In [None]:
#!pip install torch-geometric

In [None]:
# install torchsummary using pip
#!pip install pytorch-lightning


In [None]:
import torch
from torch_geometric.data import Data

import numpy as np
import pandas as pd
from sklearn.preprocessing import LabelEncoder

import pytorch_lightning as pl  # For displaying the summary of a model

from pyfaidx import Fasta
from functools import lru_cache


In [None]:
mut_filepath = "/kaggle/input/mutation/Mutation_and_healthy_data.csv"
df = pd.read_csv(mut_filepath, low_memory=False)
df


In [None]:
df['chr'].value_counts()


In [None]:
df[df['ref'].isnull()]


In [None]:
out_chroms = ['NW_015148969.2', 'NW_009646203.1', 'NW_009646206.1']
df = df[~df['chr'].isin(out_chroms)]
df = df.drop('variation_id', axis=1).dropna()
df


In [None]:
df.isnull().sum()


In [None]:
df[df['alt'].apply(lambda var: len(var.split(",")) == 5)]


In [None]:
temp_df = df.loc[[0, 10000, 16148, 1255788, 1145676]]
temp_df


In [None]:
df = df[df['alt'].apply(lambda var: len(var.split(",")) == 1)].reset_index(drop=True)
df['chr'] = df['chr'].replace({'MT': 'M'})
df['chr'].unique()


In [None]:
#df1 = df[df['clinical_significance'] == 'non-harmful'].iloc[:207150]
#df2 = df[df['clinical_significance'] == 'harmful'].iloc[:207150]
#df = pd.concat([df1, df2], axis=0).reset_index(drop=True)
#df


In [None]:
df['clinical_significance'].value_counts()


In [None]:
genome_path = "/kaggle/input/mutation/hg38.fa"
#new_genome_path = "/kaggle/working/hg38.fa"
#!cp $genome_path "/kaggle/working/"
genome = Fasta(genome_path, rebuild=False)

W = 10
pos = 74484579 - 1
genome['chr14'][pos-W : pos+W].seq


In [None]:
# Function to extract reference and mutated sequences
def get_reference_and_mutated_sequence(chrom, pos, ref, alt, window=10):
    """
    Extracts the reference and mutated sequences from the genome around the mutation.
    chr: Chromosome (as string, e.g., '1')
    pos: Position (1-based position of the mutation)
    ref: Reference allele
    alt: Alternate alleles (comma-separated if multiple)
    window: Number of base pairs before and after the mutation position to extract
    """
    # Adjust for 0-based indexing
    pos = int(pos) - 1
    
    # Extract reference sequence around the mutation (± window)
    ref_seq = genome[chrom][pos-window:pos+window+1].seq
    
    # Mutated sequences (could be multiple if multiple alternate alleles)
    alt_alleles = alt.split(',')
    mutated_seqs = []
    for alt_allele in alt_alleles:
        mutated_seq = ref_seq[:window] + alt_allele + ref_seq[window+1:]
        mutated_seqs.append(mutated_seq)

    return ref_seq, mutated_seqs

def one_hot_encode(sequence):
    mapping = {'A': [1, 0, 0, 0], 'C': [0, 1, 0, 0], 'G': [0, 0, 1, 0], 'T': [0, 0, 0, 1]}
    one_hot_seq = [mapping.get(base, [0, 0, 0, 0]) for base in sequence]
    return torch.tensor(one_hot_seq)

chrom, pos, ref, alt = 'chr11', 6391633, 'C', '-,CC,CCAACCCCCC'

# Extract reference and mutated sequences from hg38 genome
ref_seq, mutated_seqs = get_reference_and_mutated_sequence(chrom, pos, ref, alt)
print(mutated_seqs)

# Convert reference and mutated sequences to one-hot encoded format
ref_one_hot = one_hot_encode(ref_seq)
mut_one_hots = [one_hot_encode(mut_seq) for mut_seq in mutated_seqs]

print("One-Hot Encoded Reference Sequence:", ref_one_hot)
for i, mut_one_hot in enumerate(mut_one_hots):
    print(f"One-Hot Encoded Mutated Sequence {i+1}:", mut_one_hot)


In [None]:
@lru_cache(maxsize=2048)
def get_sequence(chrom, start, end):
    return str(genome[f"chr{chrom}"][start:end])

def process_mutations_for_lstm(df, window=10):
    # Extract sequences
    df['start'] = df['pos'] - window - 1
    df['end'] = df['pos'] + window
    df['ref_seq'] = df.apply(lambda row: get_sequence(row['chr'], row['start'], row['end']), axis=1)
    
    # Generate mutated sequences
    def generate_mut_seqs(row):
        ref_seq = row['ref_seq']
        alt_alleles = row['alt'].split(',')
        return [ref_seq[:window] + alt + ref_seq[window+len(row['ref']):] for alt in alt_alleles]
    
    df['mut_seqs'] = df.apply(generate_mut_seqs, axis=1)
    
    # Separate reference and mutated sequences
    ref_seqs = df['ref_seq'].tolist()
    mut_seqs = [seq for seqs in df['mut_seqs'] for seq in seqs]
    
    # One-hot encode all sequences
    ref_encoded = one_hot_encode_vectorized(ref_seqs)
    mut_encoded = one_hot_encode_vectorized(mut_seqs)
    
    return ref_encoded, mut_encoded

def one_hot_encode_vectorized(sequences):
    char_to_int = {'A': 0, 'C': 1, 'G': 2, 'T': 3}
    max_len = max(len(seq) for seq in sequences)
    seq_array = np.array([list(seq.ljust(max_len, 'N')) for seq in sequences])
    one_hot = np.zeros((len(sequences), max_len, 4), dtype=np.float32)
    for char, index in char_to_int.items():
        one_hot[seq_array == char, index] = 1
    return torch.from_numpy(one_hot)

# Main execution
if __name__ == "__main__":
    # Process mutations
    ref_encoded, mut_encoded = process_mutations_for_lstm(df)
    print(f"Reference shape: {ref_encoded.shape}, Mutated shape: {mut_encoded.shape}")


In [None]:
# Encode 'chr' and 'variation_id' column
#variation_id_encoder = []
#for i, vid in enumerate(df['variation_id']):
#    if str(vid) == '0':
#        variation_id_encoder.append(0)
#    else:
#        variation_id_encoder.append(i + 1)
#print(variation_id_encoder)

#df.loc[:, 'variation_id_encoded'] = variation_id_encoder
#df['variation_id'] = df['variation_id'].astype(int)
df['chr_encoded'] = df['chr'].replace({'X' : 23, 'Y' : 24, 'M' : 25, 'NW_009646206.1' : 26,
            'NW_009646201.1' : 27, 'NW_009646203.1' : 28, 'NW_015148969.2': 29}).astype(int)
df


In [None]:
# Initialize LabelEncoders
gene_encoder = LabelEncoder()
ref_encoder = LabelEncoder()
alt_encoder = LabelEncoder()
#variation_encoder = LabelEncoder()
variant_type_encoder = LabelEncoder()
clinical_significance_encoder = LabelEncoder()

# Fit and transform the categorical data
df['gene_encoded'] = gene_encoder.fit_transform(df['gene'])
df['ref_encoded'] = ref_encoder.fit_transform(df['ref'])
df['alt_encoded'] = alt_encoder.fit_transform(df['alt'])
#df['variation'] = variation_encoder.fit_transform(df['variation'])
df['variant_type_encoded'] = variant_type_encoder.fit_transform(df['variant_type'])
df['clinical_significance_encoded'] = clinical_significance_encoder.fit_transform(df['clinical_significance'])
df


In [None]:
"""

import networkx as nx

data = df.iloc[:50]

# Create an empty graph
G = nx.Graph()

# Step 1: Add nodes to the graph
# Each node is identified by variation_id and has attributes like gene, chr, pos, etc.
for i, row in data.iterrows():
    G.add_node(row['variation_id'], gene=row['gene'], chr=row['chr'], pos=row['pos'],
               ref=row['ref'], alt=row['alt'], variant_type=row['variant_type'],
               clinical_significance=row['clinical_significance'])

# Step 2: Add edges based on the combined criteria
# Define a proximity threshold (e.g., 100 kb)
threshold = 100000  # 100 kb threshold

# Iterate through the dataset and construct edges
for i, row in data.iterrows():
    alt_alleles = set(row['alt'].split(','))
    genes = set(row['gene'].split(';'))  # Split multiple genes if present
    for j, other_row in data.iterrows():
        other_alt_alleles = set(other_row['alt'].split(','))
        other_genes = set(other_row['gene'].split(';'))

        # Check clinical significance, shared alternate alleles, shared genes, and proximity
        if (i != j and
            row['clinical_significance'] == other_row['clinical_significance'] and  # Same clinical significance
            len(alt_alleles & other_alt_alleles) > 0 and  # Shared alternate alleles
            len(genes & other_genes) > 0 and  # Shared genes
            row['chr'] == other_row['chr'] and  # Same chromosome
            abs(row['pos'] - other_row['pos']) < threshold):  # Proximity threshold

            # Add edge if all conditions are met
            G.add_edge(row['variation_id'], other_row['variation_id'])

# Step 3: Check the graph (optional)
print(f"Number of nodes: {G.number_of_nodes()}")
print(f"Number of edges: {G.number_of_edges()}")

"""

In [None]:
# Assuming df is your pandas DataFrame
def create_graph(df):
    # Create node features more efficiently
    feature_columns = ['gene_encoded', 'chr_encoded', 'pos', 'ref_encoded', 'alt_encoded', 'variant_type_encoded']
    node_features = torch.tensor(df[feature_columns].values, dtype=torch.float)

    # Encode clinical significance
    node_labels = torch.tensor(df['clinical_significance_encoded'].values, dtype=torch.long)

    # Create edges more efficiently
    gene_groups = df.groupby('gene').groups
    edges = []
    for indices in gene_groups.values():
        indices = sorted(indices)
        edges.extend(zip(indices[:-1], indices[1:]))

    edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()

    # Create the PyTorch Geometric Data object
    graph_data = Data(x=node_features, edge_index=edge_index, y=node_labels)

    return graph_data

# Usage
graph_data = create_graph(df)

# Check the prepared data
print("Node Features shape:", graph_data.x.shape)
print("Edge Index shape:", graph_data.edge_index.shape)
print("Node Labels shape:", graph_data.y.shape)


In [None]:
graph_data.x

In [None]:
import networkx as nx
import matplotlib.pyplot as plt

def visualize_graph(graph_data):
    # Create a NetworkX graph from the PyTorch Geometric Data object
    G = nx.Graph()

    # Add nodes with features as attributes
    for i in range(graph_data.x.size(0)):
        G.add_node(i, features=graph_data.x[i].numpy(), label=graph_data.y[i].item())

    # Add edges
    edge_index = graph_data.edge_index.numpy()
    print(edge_index)
    G.add_edges_from(zip(edge_index[0], edge_index[1]))

    # Draw the graph
    pos = nx.spring_layout(G)  # positions for all nodes
    node_labels = {i: G.nodes[i]['label'] for i in G.nodes}  # Extract labels
    node_colors = [G.nodes[i]['label'] for i in G.nodes]  # Optional: use labels for colors

    plt.figure(figsize=(10, 8))
    nx.draw(G, pos, with_labels=True, labels=node_labels, node_color=node_colors, cmap=plt.cm.viridis, node_size=500, font_size=10)
    plt.title('Graph Visualization')
    plt.show()

# Usage
#visualize_graph(graph_data)


In [None]:
import torch
from torch_geometric.data import Data, Batch
from torch.utils.data import Dataset


In [None]:
class Efficient_GCN_LSTM_Dataset(Dataset):
    def __init__(self, graph_data, ref_encoded, mut_encoded):
        self.x = graph_data.x
        self.edge_index = graph_data.edge_index
        self.y = graph_data.y
        self.ref_seq = ref_encoded
        self.mut_seq = mut_encoded
        self.num_nodes = graph_data.x.shape[0]  # Total number of nodes

    def __len__(self):
        return self.y.shape[0]  # Number of graphs

    def __getitem__(self, idx):
        start_idx = idx * self.num_nodes
        end_idx = (idx + 1) * self.num_nodes
        return Data(
            x=self.x[start_idx:end_idx],
            edge_index=self.edge_index - start_idx,  # Adjust edge_index for this graph
            y=self.y[idx],
            ref_seq=self.ref_seq[idx],
            mut_seq=self.mut_seq[idx],
            num_nodes=self.num_nodes
        )

def custom_collate(batch):
    data_list = [item for item in batch]
    batched_data = Batch.from_data_list(data_list)
    
    # Adjust edge_index for the batched graph
    cumsum_nodes = torch.cat([data.num_nodes.new_zeros(1), data.num_nodes.cumsum(dim=0)[:-1]])
    batched_data.edge_index += cumsum_nodes[batched_data.batch][batched_data.edge_index[0]]
    
    return batched_data

dataset = Efficient_GCN_LSTM_Dataset(graph_data, ref_encoded, mut_encoded)
loader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=custom_collate)


In [None]:
class GCN_LSTM_Hybrid(nn.Module):
    def __init__(self, num_node_features, gcn_hidden_channels, lstm_hidden_size, lstm_num_layers, seq_length):
        super(GCN_LSTM_Hybrid, self).__init__()
        
        # GCN layers
        self.conv1 = GCNConv(num_node_features, gcn_hidden_channels)
        self.conv2 = GCNConv(gcn_hidden_channels, gcn_hidden_channels)
        
        # LSTM layers for reference and mutated sequences
        self.lstm_ref = nn.LSTM(1, lstm_hidden_size, num_layers=lstm_num_layers, batch_first=True)
        self.lstm_mut = nn.LSTM(1, lstm_hidden_size, num_layers=lstm_num_layers, batch_first=True)
        
        # Fully connected layers
        self.fc1 = nn.Linear(gcn_hidden_channels + 2 * lstm_hidden_size, 64)
        self.fc2 = nn.Linear(64, 1)  # Binary classification
        
    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        ref_seq, mut_seq = data.ref_seq, data.mut_seq

        # GCN forward pass
        x = F.relu(self.conv1(x, edge_index))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, edge_index)

        # Global mean pooling
        x = global_mean_pool(x, batch)

        # LSTM forward pass for ref sequence
        _, (h_ref, _) = self.lstm_ref(ref_seq.unsqueeze(2).float())
        ref_out = h_ref[-1]

        # LSTM forward pass for mut sequence
        _, (h_mut, _) = self.lstm_mut(mut_seq.unsqueeze(2).float())
        mut_out = h_mut[-1]

        # Combine GCN and LSTM outputs
        combined = torch.cat([x, ref_out, mut_out], dim=-1)

        # Fully connected layers
        out = F.relu(self.fc1(combined))
        out = self.fc2(out)

        return torch.sigmoid(out)


In [None]:
# Set up device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Model initialization
num_node_features = 6  # As per your Node Features shape
gcn_hidden_channels = 32
lstm_hidden_size = 64
lstm_num_layers = 2
seq_length = dataset.ref_seq.shape[1]  # Assuming ref_seq and mut_seq have the same length

model = GCN_LSTM_Hybrid(num_node_features, gcn_hidden_channels, lstm_hidden_size, lstm_num_layers, seq_length).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = nn.BCELoss()


In [None]:
# Training loop
num_epochs = 10

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for batch in loader:
        # Move the entire batch to the device
        batch = batch.to(device)

        # Reset gradients
        optimizer.zero_grad()

        # Forward pass
        out = model(batch)

        # Calculate loss
        loss = criterion(out.squeeze(), batch.y.float())

        # Backpropagation
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * batch.num_graphs

    # Calculate and print average loss for the epoch
    avg_loss = total_loss / len(dataset)
    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}')


In [None]:
# Evaluation
model.eval()
with torch.no_grad():
    total_correct = 0
    total_samples = 0
    for batch in loader:
        # Move the entire batch to the device
        batch = batch.to(device)

        # Forward pass
        out = model(batch)

        # Apply threshold for binary classification
        pred = (out.squeeze() > 0.5).float()

        # Calculate number of correct predictions
        total_correct += int((pred == batch.y).sum())
        total_samples += batch.num_graphs

    # Calculate accuracy
    accuracy = total_correct / total_samples
    print(f'Accuracy: {accuracy:.4f}')
