# MoA Prediction: Novel Feature Engineering

This notebook demonstrates the novel feature engineering approaches implemented in Phase 2:

1. **Chemical Graph Features** with counterfactual substructure analysis
2. **Mechanism Tokens (MechTokens)** - ontology-aware embeddings
3. **Perturbational Biology Features** from LINCS L1000
4. **Protein Pocket Features** (optional)

These features represent the key innovations that differentiate our approach from existing MoA prediction methods.

In [None]:
import sys
sys.path.append('..')

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import networkx as nx

from moa.utils.config import Config
from moa.features.chemical import ChemicalFeatureExtractor
from moa.features.mechanism_tokens import MechTokenFeatureExtractor
from moa.features.perturbational import PerturbationalFeatureExtractor
from moa.features.feature_extractor import MultiModalFeatureExtractor
from moa.data.processors import DataProcessor

# Set up plotting
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")
%matplotlib inline

## 1. Configuration and Data Setup

In [None]:
# Load configuration
config = Config('../configs/config.yaml')

# Enable all modalities for demonstration
config.set("scope.modalities.chemistry", True)
config.set("scope.modalities.targets", True)
config.set("scope.modalities.pathways", True)
config.set("scope.modalities.perturbation", True)

print("Configuration:")
print(f"Enabled modalities: {config.get('scope.modalities')}")
print(f"Counterfactual analysis: {config.get('features.chemistry.substructure_analysis.enable_counterfactual')}")
print(f"MechToken embedding dim: {config.get('features.mechanism_tokens.embedding_dim')}")

In [None]:
# Create sample dataset with diverse MoAs
sample_data = pd.DataFrame({
    'molecule_chembl_id': [
        'CHEMBL25', 'CHEMBL521', 'CHEMBL113', 'CHEMBL1200766', 'CHEMBL154',
        'CHEMBL6', 'CHEMBL1201585', 'CHEMBL744', 'CHEMBL1200960', 'CHEMBL1201'
    ],
    'canonical_smiles': [
        'CCO',  # Ethanol
        'CC(=O)OC1=CC=CC=C1C(=O)O',  # Aspirin
        'CN1C=NC2=C1C(=O)N(C(=O)N2C)C',  # Caffeine
        'CC(C)CC1=CC=C(C=C1)C(C)C(=O)O',  # Ibuprofen
        'CN(C)CCOC1=CC=C(C=C1)C(C2=CC=CC=C2)C3=CC=CC=C3',  # Diphenhydramine
        'CC(C)(C)NCC(C1=CC(=C(C=C1)O)CO)O',  # Salbutamol
        'CN1CCN(CC1)C2=C(C=C3C(=C2)N=CN=C3NC4=CC=C(C=C4)OC)F',  # Gefitinib
        'CC1=C(C=C(C=C1)C(=O)C2=CC=CC=C2)C',  # Tolmetin
        'COC1=CC=C(C=C1)C2=CC(=NN2C3=CC=C(C=C3)S(=O)(=O)N)C(F)(F)F',  # Celecoxib
        'CN(C)CCCN1C2=CC=CC=C2SC3=C1C=C(C=C3)Cl'  # Chlorpromazine
    ],
    'mechanism_of_action': [
        'CNS depressant',
        'Cyclooxygenase inhibitor',
        'Adenosine receptor antagonist',
        'Cyclooxygenase inhibitor',
        'Histamine H1 receptor antagonist',
        'Beta-2 adrenergic receptor agonist',
        'EGFR tyrosine kinase inhibitor',
        'Cyclooxygenase inhibitor',
        'Cyclooxygenase-2 inhibitor',
        'Dopamine receptor antagonist'
    ],
    'target_chembl_id': [
        'CHEMBL1', 'CHEMBL230', 'CHEMBL1824', 'CHEMBL230', 'CHEMBL231',
        'CHEMBL210', 'CHEMBL203', 'CHEMBL230', 'CHEMBL230', 'CHEMBL217'
    ]
})

print(f"Created dataset with {len(sample_data)} compounds")
print("\nMoA distribution:")
print(sample_data['mechanism_of_action'].value_counts())

## 2. Chemical Graph Features with Counterfactual Analysis

Our novel approach identifies causal molecular substructures by analyzing how their removal affects MoA prediction performance.

In [None]:
# Process data for feature extraction
processor = DataProcessor(config)
processed_data = processor.smiles_processor.process_smiles_column(sample_data)
processed_data = processor.label_processor.process_moa_labels(processed_data)

print(f"Processed {len(processed_data)} compounds")
print(f"Created {len([c for c in processed_data.columns if c.startswith('moa_')])} MoA labels")

In [None]:
# Extract chemical features with counterfactual analysis
chemical_extractor = ChemicalFeatureExtractor(config)

smiles_list = processed_data['standardized_smiles'].tolist()

# Create binary label matrix for counterfactual analysis
from sklearn.preprocessing import MultiLabelBinarizer
mlb = MultiLabelBinarizer()
labels = mlb.fit_transform(processed_data['moa_list'])
moa_names = list(mlb.classes_)

print(f"Extracting chemical features for {len(smiles_list)} compounds...")
chemical_features = chemical_extractor.extract_features(smiles_list, labels, moa_names)

print("\nChemical features extracted:")
for key, value in chemical_features.items():
    if isinstance(value, list):
        print(f"  {key}: {len(value)} items")
    elif isinstance(value, dict):
        print(f"  {key}: {len(value)} entries")
    else:
        print(f"  {key}: {type(value)}")

In [None]:
# Visualize molecular graphs
molecular_graphs = chemical_features['molecular_graphs']
valid_graphs = [g for g in molecular_graphs if g is not None]

print(f"Generated {len(valid_graphs)} valid molecular graphs")

if valid_graphs:
    # Show graph statistics
    node_counts = [g.x.shape[0] for g in valid_graphs]
    edge_counts = [g.edge_index.shape[1] for g in valid_graphs]
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
    
    ax1.hist(node_counts, bins=10, alpha=0.7)
    ax1.set_xlabel('Number of Nodes')
    ax1.set_ylabel('Count')
    ax1.set_title('Distribution of Graph Sizes (Nodes)')
    
    ax2.hist(edge_counts, bins=10, alpha=0.7)
    ax2.set_xlabel('Number of Edges')
    ax2.set_ylabel('Count')
    ax2.set_title('Distribution of Graph Sizes (Edges)')
    
    plt.tight_layout()
    plt.show()
    
    print(f"Average nodes per graph: {np.mean(node_counts):.1f}")
    print(f"Average edges per graph: {np.mean(edge_counts):.1f}")

In [None]:
# Analyze counterfactual substructure scores
if 'counterfactual_scores' in chemical_features:
    cf_scores = chemical_features['counterfactual_scores']
    causal_fragments = chemical_features['causal_fragments']
    
    print(f"Counterfactual analysis completed for {len(cf_scores)} MoAs")
    
    # Show top causal fragments for each MoA
    for moa, fragments in list(causal_fragments.items())[:3]:
        print(f"\nTop causal fragments for '{moa}':")
        for i, (fragment, score) in enumerate(fragments[:5]):
            print(f"  {i+1}. {fragment}: {score:.4f}")
    
    # Visualize fragment importance distribution
    all_scores = []
    for moa_scores in cf_scores.values():
        all_scores.extend(moa_scores.values())
    
    plt.figure(figsize=(10, 6))
    plt.hist(all_scores, bins=30, alpha=0.7)
    plt.xlabel('Counterfactual Importance Score')
    plt.ylabel('Count')
    plt.title('Distribution of Fragment Importance Scores')
    plt.axvline(x=0, color='red', linestyle='--', alpha=0.7, label='Neutral')
    plt.legend()
    plt.show()
    
    print(f"\nFragment importance statistics:")
    print(f"  Mean: {np.mean(all_scores):.4f}")
    print(f"  Std: {np.std(all_scores):.4f}")
    print(f"  Positive scores (causal): {sum(1 for s in all_scores if s > 0)} / {len(all_scores)}")
else:
    print("Counterfactual analysis not performed (requires labels)")

## 3. Mechanism Tokens (MechTokens)

Novel ontology-aware embeddings that encode drug-target-pathway-MoA relationships using graph neural networks.

In [None]:
# Create sample data sources for mechanism tokens
data_sources = {
    'mechanisms': pd.DataFrame({
        'molecule_chembl_id': processed_data['molecule_chembl_id'].tolist(),
        'target_chembl_id': processed_data['target_chembl_id'].tolist(),
        'mechanism_of_action': processed_data['mechanism_of_action'].tolist()
    }),
    'targets': pd.DataFrame({
        'target_chembl_id': ['CHEMBL1', 'CHEMBL230', 'CHEMBL1824', 'CHEMBL210', 'CHEMBL203', 'CHEMBL231', 'CHEMBL217'],
        'pref_name': ['Alcohol dehydrogenase', 'Cyclooxygenase-1', 'Adenosine A2a receptor', 
                     'Beta-2 adrenergic receptor', 'EGFR', 'Histamine H1 receptor', 'Dopamine D2 receptor'],
        'target_type': ['PROTEIN'] * 7
    }),
    'pathways': pd.DataFrame({
        'stId': ['R-HSA-1234', 'R-HSA-5678', 'R-HSA-9012'],
        'displayName': ['Alcohol metabolism', 'Arachidonic acid metabolism', 'Adenosine signaling']
    }),
    'protein_pathways': pd.DataFrame({
        'uniprot_id': ['P00325', 'P23219', 'P29274'],
        'pathway_id': ['R-HSA-1234', 'R-HSA-5678', 'R-HSA-9012']
    })
}

print("Created sample data sources for mechanism tokens")

In [None]:
# Build mechanism tokens
mechtoken_extractor = MechTokenFeatureExtractor(config)

print("Building mechanism tokens...")
mechtoken_extractor.build_mechanism_tokens(data_sources)

# Analyze the ontology graph
ontology_graph = mechtoken_extractor.ontology_graph
print(f"\nOntology graph statistics:")
print(f"  Nodes: {ontology_graph.number_of_nodes()}")
print(f"  Edges: {ontology_graph.number_of_edges()}")

# Analyze node types
node_types = mechtoken_extractor.ontology_builder.node_types
type_counts = {}
for node_type in node_types.values():
    type_counts[node_type] = type_counts.get(node_type, 0) + 1

print(f"\nNode type distribution:")
for node_type, count in type_counts.items():
    print(f"  {node_type}: {count}")

In [None]:
# Visualize the ontology graph
plt.figure(figsize=(12, 8))

# Create a simplified view of the graph
pos = nx.spring_layout(ontology_graph, k=2, iterations=50)

# Color nodes by type
node_colors = []
color_map = {'drug': 'lightblue', 'target': 'lightgreen', 'pathway': 'orange', 'moa': 'pink'}

for node in ontology_graph.nodes():
    node_type = node_types.get(node, 'unknown')
    node_colors.append(color_map.get(node_type, 'gray'))

nx.draw(ontology_graph, pos, 
        node_color=node_colors, 
        node_size=300, 
        with_labels=False, 
        edge_color='gray', 
        alpha=0.7)

# Add legend
legend_elements = [plt.Line2D([0], [0], marker='o', color='w', 
                             markerfacecolor=color, markersize=10, label=node_type.title())
                  for node_type, color in color_map.items()]
plt.legend(handles=legend_elements, loc='upper right')

plt.title('Biological Ontology Graph\n(Drug-Target-Pathway-MoA Relationships)')
plt.axis('off')
plt.tight_layout()
plt.show()

In [None]:
# Extract compound-specific mechanism tokens
compound_ids = processed_data['molecule_chembl_id'].tolist()
compound_tokens = mechtoken_extractor.extract_compound_tokens(compound_ids)

print(f"Extracted mechanism tokens for {len(compound_tokens)} compounds")

if compound_tokens:
    token_dims = [token.shape[0] for token in compound_tokens.values()]
    print(f"Token dimension: {token_dims[0]}")
    
    # Visualize token similarity
    token_matrix = np.array(list(compound_tokens.values()))
    
    # Compute pairwise cosine similarity
    from sklearn.metrics.pairwise import cosine_similarity
    similarity_matrix = cosine_similarity(token_matrix)
    
    plt.figure(figsize=(10, 8))
    sns.heatmap(similarity_matrix, 
                xticklabels=[f"C{i+1}" for i in range(len(compound_ids))],
                yticklabels=[f"C{i+1}" for i in range(len(compound_ids))],
                annot=True, fmt='.2f', cmap='viridis')
    plt.title('Compound Mechanism Token Similarity Matrix')
    plt.tight_layout()
    plt.show()
    
    print(f"\nToken similarity statistics:")
    # Exclude diagonal (self-similarity)
    off_diagonal = similarity_matrix[np.triu_indices_from(similarity_matrix, k=1)]
    print(f"  Mean similarity: {np.mean(off_diagonal):.3f}")
    print(f"  Std similarity: {np.std(off_diagonal):.3f}")
    print(f"  Min similarity: {np.min(off_diagonal):.3f}")
    print(f"  Max similarity: {np.max(off_diagonal):.3f}")

## 4. Perturbational Biology Features

Gene expression signatures from LINCS L1000 mapped to pathway activity scores using GSVA/ssGSEA.

In [None]:
# Create sample LINCS signatures
lincs_signatures = pd.DataFrame({
    'sig_id': [f'SIG_{i:04d}' for i in range(30)],
    'pert_iname': ['compound_25', 'compound_521', 'compound_113'] * 10,
    'cell_id': ['MCF7', 'PC3', 'A549'] * 10,
    'pert_time': ['24h'] * 30,
    'pert_dose': ['10 µM'] * 30
})

data_sources['lincs_signatures'] = lincs_signatures

print(f"Created sample LINCS signatures: {len(lincs_signatures)} signatures")
print(f"Unique compounds: {lincs_signatures['pert_iname'].nunique()}")
print(f"Cell lines: {lincs_signatures['cell_id'].unique()}")

In [None]:
# Extract perturbational features
perturbational_extractor = PerturbationalFeatureExtractor(config)

compound_names = ['compound_25', 'compound_521', 'compound_113']

print("Extracting perturbational features...")
perturbational_features = perturbational_extractor.extract_perturbational_features(
    lincs_signatures, compound_names
)

print("\nPerturbational features extracted:")
for key, value in perturbational_features.items():
    if isinstance(value, dict):
        print(f"  {key}: {len(value)} compounds")
        if value:
            sample_value = list(value.values())[0]
            if isinstance(sample_value, np.ndarray):
                print(f"    Dimension: {sample_value.shape}")
    elif isinstance(value, list):
        print(f"  {key}: {len(value)} items")
    else:
        print(f"  {key}: {type(value)}")

In [None]:
# Visualize meta-signatures
if 'meta_signatures' in perturbational_features:
    meta_signatures = perturbational_features['meta_signatures']
    gene_names = perturbational_features.get('gene_names', [])
    
    # Create signature matrix
    signature_matrix = np.array([meta_signatures[compound] for compound in compound_names])
    
    print(f"Meta-signature matrix shape: {signature_matrix.shape}")
    
    # Visualize signature heatmap (first 50 genes)
    plt.figure(figsize=(12, 6))
    sns.heatmap(signature_matrix[:, :50], 
                xticklabels=[f"G{i+1}" for i in range(50)],
                yticklabels=compound_names,
                cmap='RdBu_r', center=0)
    plt.title('Gene Expression Meta-Signatures (First 50 Genes)')
    plt.xlabel('Genes')
    plt.ylabel('Compounds')
    plt.tight_layout()
    plt.show()
    
    # Show signature statistics
    print(f"\nSignature statistics:")
    print(f"  Mean expression: {np.mean(signature_matrix):.3f}")
    print(f"  Std expression: {np.std(signature_matrix):.3f}")
    print(f"  Min expression: {np.min(signature_matrix):.3f}")
    print(f"  Max expression: {np.max(signature_matrix):.3f}")

In [None]:
# Visualize pathway activity scores
if 'pathway_scores' in perturbational_features:
    pathway_scores = perturbational_features['pathway_scores']
    pathway_names = perturbational_features.get('pathway_names', [])
    
    # Create pathway score matrix
    pathway_matrix = np.array([pathway_scores[compound] for compound in compound_names])
    
    print(f"Pathway score matrix shape: {pathway_matrix.shape}")
    
    # Visualize pathway scores
    plt.figure(figsize=(12, 6))
    sns.heatmap(pathway_matrix, 
                xticklabels=[f"P{i+1}" for i in range(len(pathway_names))],
                yticklabels=compound_names,
                cmap='RdBu_r', center=0)
    plt.title('Pathway Activity Scores')
    plt.xlabel('Pathways')
    plt.ylabel('Compounds')
    plt.tight_layout()
    plt.show()
    
    # Show top activated/suppressed pathways for each compound
    for i, compound in enumerate(compound_names):
        scores = pathway_matrix[i]
        top_activated = np.argsort(scores)[-3:][::-1]
        top_suppressed = np.argsort(scores)[:3]
        
        print(f"\n{compound}:")
        print(f"  Top activated pathways: {[f'P{j+1}' for j in top_activated]}")
        print(f"  Top suppressed pathways: {[f'P{j+1}' for j in top_suppressed]}")

## 5. Multi-Modal Feature Integration

Demonstration of the complete multi-modal feature extraction pipeline.

In [None]:
# Extract all features using the unified interface
multimodal_extractor = MultiModalFeatureExtractor(config)

print("Extracting all multi-modal features...")
all_features = multimodal_extractor.extract_all_features(processed_data, data_sources)

print("\nMulti-modal feature extraction completed!")
print(f"Extracted features from {len(all_features)} modalities:")

for modality, features in all_features.items():
    print(f"\n{modality.upper()}:")
    for feature_type, feature_data in features.items():
        if isinstance(feature_data, dict):
            print(f"  {feature_type}: {len(feature_data)} items")
        elif isinstance(feature_data, list):
            print(f"  {feature_type}: {len(feature_data)} items")
        else:
            print(f"  {feature_type}: {type(feature_data)}")

In [None]:
# Feature dimensionality summary
print("Feature Dimensionality Summary:")
print("=" * 40)

total_features = 0

for modality, features in all_features.items():
    modality_features = 0
    
    if modality == "chemistry":
        if "molecular_descriptors" in features:
            desc_example = features["molecular_descriptors"][0]
            desc_count = sum(1 if not isinstance(v, np.ndarray) else len(v) for v in desc_example.values())
            modality_features += desc_count
            print(f"  Chemical descriptors: ~{desc_count}")
    
    elif modality == "mechanism_tokens":
        if "compound_tokens" in features:
            tokens = features["compound_tokens"]
            if tokens:
                token_dim = list(tokens.values())[0].shape[0]
                modality_features += token_dim
                print(f"  Mechanism tokens: {token_dim}")
    
    elif modality == "perturbation":
        if "meta_signatures" in features:
            meta_sigs = features["meta_signatures"]
            if meta_sigs:
                sig_dim = list(meta_sigs.values())[0].shape[0]
                modality_features += sig_dim
                print(f"  Gene signatures: {sig_dim}")
        
        if "pathway_scores" in features:
            pathway_scores = features["pathway_scores"]
            if pathway_scores:
                pathway_dim = list(pathway_scores.values())[0].shape[0]
                modality_features += pathway_dim
                print(f"  Pathway scores: {pathway_dim}")
    
    total_features += modality_features
    print(f"  {modality.title()} total: {modality_features}")

print(f"\nTotal feature dimensions: ~{total_features}")
print("\nNote: This excludes graph features which have variable dimensions")

## 6. Summary and Next Steps

### Novel Features Implemented:

1. **Chemical Graph Features with Counterfactual Analysis**
   - Molecular graphs with rich node/edge features
   - Counterfactual substructure importance scoring
   - Causal fragment identification for MoA prediction

2. **Mechanism Tokens (MechTokens)**
   - Ontology-aware embeddings of drug-target-pathway-MoA relationships
   - Node2vec embeddings with hierarchical encoding
   - Compound-specific token aggregation

3. **Perturbational Biology Features**
   - LINCS L1000 gene expression meta-signatures
   - Pathway activity scores using GSVA/ssGSEA
   - Multi-cell line aggregation

4. **Protein Pocket Features (Optional)**
   - 3D binding site encoding using PointNet/3D CNN
   - AlphaFold structure integration
   - Druggability scoring

### Key Innovations:

- **Counterfactual Analysis**: First application to molecular substructure analysis for MoA prediction
- **MechTokens**: Novel ontology-aware embeddings that capture biological relationships
- **Multi-modal Integration**: Unified framework for combining diverse biological data types
- **Hierarchical Encoding**: Captures different levels of biological organization

### Next Steps (Phase 3):

1. **Model Architecture Development**
   - Graph Transformer for chemical features
   - Pathway Transformer for biological features
   - Hypergraph fusion layer for multi-modal integration

2. **Training Pipeline**
   - Multi-objective loss functions
   - Curriculum learning
   - Modality dropout for robustness

The feature engineering foundation is now complete and ready for model development!