# D2Cell-pred demo (E. coli)

In [11]:
import pandas as pd
from model import D2Cell_Model
from dataset import D2CellDataset
import torch


### Inputs
- `organism` (`str`): organism/strain name (e.g., `"E. coli"`)
- `product` (`str`): target product/metabolite ID (e.g., `"ala__D_c"`)
- `knock_gene` (`List[str]`): list of genes to knock out (e.g., `["B0002", "B0123"]`; use `[]` if none)
- `over_gene` (`List[str]`): list of genes to overexpress (e.g., `["B0351"]`; use `[]` if none)

### Outputs
- final decision:
 - `"It's a target gene."` **or**
 - `"It's not a target gene."`


Download and unzip the [model parameters](https://drive.google.com/file/d/1XPvHyERKNqgMAqjVy3yobzZaYQEuL84_/view?usp=sharing) under D2Cell

In [12]:
def predict_demo(organism, product, knock_gene, over_gene):
    # Select device: use GPU if available, otherwise CPU
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device:", device)
    
    # Load organism-specific configs and mapping files
    if organism == 'E. coli':
        config = {'hidden_size': 128,
              'num_gnn_layers': 1,
              'device': device,
              'num_met': 2000
              }
        # Mapping tables: metabolite/product index and perturbation (gene) indices
        product_df = pd.read_csv('../../Data/D2Cell-pred Data/Ecoli/ecoli_product_idx.csv')
        knock_df = pd.read_csv('../../Data/D2Cell-pred Data/Ecoli/pert2idx-knock.csv')
        over_df = pd.read_csv('../../Data/D2Cell-pred Data/Ecoli/pert2idx-overexpress.csv')
    else:
        config = {'hidden_size': 128,
              'num_gnn_layers': 1,
              'device': device,
              'num_met': 3000
              }
    # Convert product/metabolite ID to its index in GEM
    try:
        product_index = product_df[product_df['met_id']==product]['index in gem'].values[0]
    except IndexError:
        print('product not in GEM')
         
    # Build perturbation index list for gene modifications (knock-out / overexpression)
    gene_modification_list = []
     
    # Parse knock-out genes
    if knock_gene != ['']:
        try:
            for gene in knock_gene:
                # Lowercase gene name to match mapping table
                knock_index = knock_df[knock_df['pert_gene']==gene.lower()]['index'].values[0]
                gene_modification_list.append(knock_index)
        except IndexError:
            print('gene not in GEM (knock out)')
             
    # Parse overexpression genes
    if over_gene != ['']:
        try:
            for gene in over_gene:
                # Lowercase gene name to match mapping table
                over_index = over_df[over_df['pert_gene']==gene.lower()]['index'].values[0]
                gene_modification_list.append(over_index)
        except IndexError:
            print('gene not in GEM (overexpression)')
    df = pd.DataFrame({'inf_label_01': 5, 'product': product, 'index in gem': int(product_index), 'pert index': str(gene_modification_list)}, index=[2])
    df = pd.concat([df, df])
    df = pd.concat([df, df])
    df.to_csv('predict_demo.csv', index=False)
    data = D2CellDataset('predict_demo.csv', '../../Data/D2Cell-pred Data/Ecoli/iML1515_S.txt', 
                         '../../Data/D2Cell-pred Data/Ecoli/IgnoreMets_iML1515.csv')
    dataloader, edge_index, edge_weight = data.get_dataloader()
    test_loader = dataloader['test_loader']
     
    # Initialize model and load trained weights
    model = D2Cell_Model(config, edge_index, edge_weight)
    model.to(device)
    # Load model checkpoint (map to current device to avoid CPU/GPU mismatch)
    model.load_state_dict(torch.load('../../save_model/ecoli_model/ecoli_D2Cell_pred_model.pth'))
    model.eval()
    
    # Run inference and collect predictions
    predict_list = []
    for step, batch in enumerate(test_loader):
        batch = batch.to(device)
        output = model(batch)
        
        # Get predicted class for each sample
        _, predicted = torch.max(output.data, 1)
        predict_list.extend(predicted.detach().cpu())
        
    # Majority vote over duplicated rows
    result = 1 if predict_list.count(1) > predict_list.count(0) else 0 if predict_list.count(0) > predict_list.count(1) else None
    
    # Print final decision
    if result == 1:
        print('-'*30)
        print("It's not a target gene.")
    else:
        print('-'*30)
        print("It's a target gene.")

# D2Cell-pred Interactive Demo

### This interactive UI allows you to run predictions using the D2Cell-pred model directly within the notebook.

#### **Input Parameters:**

#### 1.  **Strain:** The target organism (currently supports *E. coli*).
#### 2.  **Target Product:** The metabolite ID of the desired product (e.g., `ala__D_c` for D-Alanine).
#### 3.  **Knock-out Genes:** A list of gene IDs to be removed (separated by semicolons `;`).
#### 4.  **Overexpression Genes:** A list of gene IDs to be overexpressed (separated by semicolons `;`).

#### ** Note on IDs:**
> The input fields require specific **ID formats** corresponding to the iML1515 metabolic model:
> **Gene IDs:** Use the gene ID (e.g., `b0002`) found in **`iML1515_Genes.tsv`**.
> **Product IDs:** Use the specific metabolite ID (e.g., `ala__D_c`) found in **`ecoli_product_idx.csv`**.


In [13]:
import ipywidgets as widgets
from IPython.display import display, clear_output

# ------------------------------------------------------------
# D2Cell-pred interactive demo UI (Jupyter Notebook)
# Requirements:
#   1) Run this cell in a Jupyter environment
#   2) Make sure your core function `predict_demo(...)` is defined ABOVE
#
# Expected signature:
#   predict_demo(organism, product, knock_gene_list, over_gene_list)
# Example:
#   predict_demo("E. coli", "ala__D_c", ["B0002"], [])
# ------------------------------------------------------------


# Organism/strain dropdown
strain_ui = widgets.Dropdown(
    options=['E. coli'],
    value='E. coli',
    description='Strain:',
    style={'description_width': 'initial'}
)

# Target product/metabolite id input
product_ui = widgets.Text(
    value='ala__D_c',
    description='Target product:',
    placeholder='e.g., ala__D_c'
)

# Knock-out gene(s) input (semicolon-separated)
knock_ui = widgets.Text(
    value='B0002',
    description='Knock-out genes:',
    placeholder='Separate by semicolons ;'
)

# Overexpression gene(s) input (semicolon-separated)
over_ui = widgets.Text(
    value='',
    description='Overexpression genes:',
    placeholder='Separate by semicolons ;'
)

# Run button
run_btn = widgets.Button(
    description='Run prediction',
    button_style='primary',  # 'success', 'info', 'warning', 'danger' or ''
    icon='flask'
)


# Output area to display print statements / errors
output_area = widgets.Output()

# 2) Define button click behavior
def on_btn_click(b):
    """
    Callback function triggered when user clicks the button.
    - Reads inputs from widgets
    - Splits gene strings by ';' into lists
    - Calls predict_demo(...) and shows results in output_area
    """
    with output_area:
        clear_output()  # Clear previous output

        # Read user inputs
        s_val = strain_ui.value
        p_val = product_ui.value.strip()

        # Split input strings by ';' and remove empty entries
        k_val = [x.strip() for x in knock_ui.value.split(';') if x.strip()]
        o_val = [x.strip() for x in over_ui.value.split(';') if x.strip()]

        # Print the parsed configuration
        print(
            "üî¨ Setting up parameters...\n"
            f"Strain: {s_val}\n"
            f"Product: {p_val}\n"
            f"Knock-out: {k_val}\n"
            f"Overexpression: {o_val}"
        )
        print("-" * 30)

        # Call your core prediction function
        try:
            predict_demo(s_val, p_val, k_val, o_val)
        except NameError:
            print("‚ùå Error: `predict_demo` is not defined. Please run the cell that defines it first.")
        except Exception as e:
            print(f"‚ùå Runtime error: {e}")


# Register the callback to the button
run_btn.on_click(on_btn_click)


# 3) Layout and display the UI
ui_layout = widgets.VBox([
    widgets.HTML("<h2>D2Cell-pred Demo</h2>"),
    strain_ui,
    product_ui,
    knock_ui,
    over_ui,
    run_btn,
    output_area
])

display(ui_layout)

VBox(children=(HTML(value='<h2>D2Cell-pred Demo</h2>'), Dropdown(description='Strain:', options=('E. coli',), ‚Ä¶

In [ ]:
#@markdown Enter the products to be predicted, knockout and overexpressed genes.
strain = 'E. coli' #@param {type:"string"}
product = 'ala__D_c' #@param {type:"string"}
knock_modi = 'B0002' #@param {type:"string"}
over_modi = '' #@param {type:"string"}
knock_modi = knock_modi.split(";")
over_modi = over_modi.split(';')
knock_modi = [s.strip() for s in knock_modi]
over_modi = [s.strip() for s in over_modi]
predict_demo(strain, product, knock_modi, over_modi)