In [17]:
import sys
sys.path.append('/app')
# unit_test/test_single_receptor.py
import torch
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
import os

from src.environment import *
from src.physics import *
from objectives.loss import InformationLoss


In [6]:
CONF = {
        "n_units": 1,
        "n_families": 1,
        "k_sub": 5,
        "batch_size": 512,
        "epochs": 600,
        "lr": 0.05
    }

In [None]:
# Experiment A: Gaussian


In [18]:

def test_optimization_visuals(CONFIG):
    print("--- Starting Single Receptor Optimization Test ---")
    
    # 1. SETUP
    # -----------------------------------------------------
    # We use 1 Unit, 1 Family. The receptor is a homopentamer (Unit 0 five times).
    
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    # Initialize Modules
    conc_strategy = NormalConcentration(n_families=CONFIG['n_families'], init_mean=5.0)
    env = LigandEnvironment(CONFIG['n_units'], CONFIG['n_families'], conc_model=conc_strategy).to(device)

    physics = Receptor(CONFIG["n_units"], CONFIG["k_sub"]).to(device)
    loss_fn = InformationLoss() # Default bandwidth
    
    # Create the receptor identity: [[0, 0, 0, 0, 0]]
    receptor_indices = torch.zeros(1, CONFIG["k_sub"], dtype=torch.long, device=device)
    
    # Optimizer
    optimizer = optim.Adam(list(env.parameters()) + list(physics.parameters()), lr=CONFIG["lr"])
    
    # Helper to calculate deterministic response curve for plotting
    def get_response_curve():
        c_range = torch.linspace(0, 10, 200, device=device)
        
        # Use mean interaction energies (no noise) to see the "average" curve
        # interaction_mu shape: (Units, Families, 2) -> We want Family 0
        mean_energies = env.interaction_mu[:, 0:1, :] # Keep dims
        # Expand for the batch of concentration points
        # Need shape: (200, Units, 2)
        energies_expanded = mean_energies.permute(1, 0, 2).expand(200, -1, -1)
        
        with torch.no_grad():
            # Pass c_range as batch of concentrations
            activity = physics(energies_expanded, c_range, receptor_indices)
        return c_range.cpu().numpy(), activity.cpu().squeeze().numpy()

    # 2. CAPTURE INITIAL STATE
    # -----------------------------------------------------
    print("Capturing initial state...")
    x_axis, y_initial = get_response_curve()

    # 3. OPTIMIZATION LOOP
    # -----------------------------------------------------
    print(f"Training for {CONFIG['epochs']} epochs...")
    loss_history = []
    
    for epoch in range(CONFIG['epochs']):
        optimizer.zero_grad()
        
        # A. Sample Batch
        # energies: (B, 1, 2), concs: (B,)
        energies, concs, _ = env.sample_batch(CONFIG['batch_size'])
        
        # B. Physics
        # activity: (B, 1)
        activity = physics(energies, concs, receptor_indices)
        
        # C. Loss (Maximize Entropy)
        # We assume Covariance is 0 because there is only 1 receptor
        loss, stats = loss_fn(activity)
        
        loss.backward()
        optimizer.step()
        
        loss_history.append(stats['entropy'].item())
        
        if epoch % 100 == 0:
            print(f"Epoch {epoch}: Entropy = {stats['entropy'].item():.4f}")

    # 4. CAPTURE FINAL STATE
    # -----------------------------------------------------
    _, y_final = get_response_curve()

    # 5. PLOTTING
    # -----------------------------------------------------
    print("Generating plot...")
    plt.figure(figsize=(12, 6))
    
    # Subplot 1: Response Curves
    plt.subplot(1, 2, 1)
    
    # Plot Input Distribution (Gaussian centered at 5)
    # This is hardcoded to match the initialization in LigandEnvironment
    mu_c, sigma_c = 5.0, 1.0 
    pdf = (1 / (sigma_c * np.sqrt(2 * np.pi))) * np.exp(-0.5 * ((x_axis - mu_c)/sigma_c)**2)
    # Scale PDF to match plotting range [0,1] roughly
    plt.fill_between(x_axis, pdf / pdf.max(), color='gray', alpha=0.3, label='Concentration P(c)')
    
    plt.plot(x_axis, y_initial, 'r--', lw=2, label='Initial Response')
    plt.plot(x_axis, y_final, 'g-', lw=3, label='Optimized Response')
    
    plt.title("Impedance Matching")
    plt.xlabel("Concentration")
    plt.ylabel("P(open) / Density")
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # Subplot 2: Entropy History
    plt.subplot(1, 2, 2)
    plt.plot(loss_history)
    plt.title("Entropy Maximization")
    plt.xlabel("Epoch")
    plt.ylabel("Estimated Entropy")
    plt.grid(True, alpha=0.3)
    
    # Save
    os.makedirs("unit_test/plots", exist_ok=True)
    save_path = "unit_test/plots/single_receptor_test.png"
    plt.savefig(save_path)
    print(f"Test Complete. Plot saved to: {save_path}")

In [19]:
test_optimization_visuals(CONF)

--- Starting Single Receptor Optimization Test ---
Capturing initial state...
Training for 600 epochs...
Epoch 0: Entropy = -0.7794
Epoch 100: Entropy = 0.2061
Epoch 200: Entropy = 0.1232
Epoch 300: Entropy = -0.0455
Epoch 400: Entropy = -0.1521
Epoch 500: Entropy = -0.1596


ValueError: Expected parameter loc (Tensor of shape (512, 1, 2)) of distribution Normal(loc: torch.Size([512, 1, 2]), scale: torch.Size([512, 1, 2])) to satisfy the constraint Real(), but found invalid values:
tensor([[[nan, nan]],

        [[nan, nan]],

        [[nan, nan]],

        ...,

        [[nan, nan]],

        [[nan, nan]],

        [[nan, nan]]], device='cuda:0', grad_fn=<IndexBackward0>)