# Example: Generating the CRBN Open-to-Closed Conformational Path using rMD

This notebook demonstrates the end-to-end usage of the Reinforcement Molecular Dynamics (rMD) model developed, specifically focusing on generating the conformational transition of the CRBN protein as described in the source paper (Figure 4).

In [None]:
# Standard Imports
import torch
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

# Project Modules
from prepare_data import get_datasets # To get the data used for training
from generation_pipeline import generate_path, generate_structure

# Set up device
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## 1. Load Trained Model (Simulation)

In a real scenario, we would load the checkpoint file of the fully trained dual-loss autoencoder. For this example, we mock the retrieval of the trained model instance.

In [None]:
# --- Mock Model Setup (Replace with actual loading in production) ---
class FullyTrainedRMDModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        # Simulate the decoder part
        self.decoder = torch.nn.Linear(3, 9696)
        
    def forward(self, ls_input):
        # Takes a 3D CV/LS point and decodes it back to 9696-dimensional coordinates
        return self.decoder(ls_input)

rmd_model = FullyTrainedRMDModel().to(DEVICE)
print("Mock RMD Model Loaded (Representing the validated dual-loss autoencoder).")

## 2. Define the Conformational Transition Path

As described in the paper and Figure 4, the path is defined by manually picking anchor points in the Collective Variable (CV) space that trace a low free-energy route from the open (inactive) state to the closed (active) state.

In [None]:
# Simulated Anchor Points in 3D CV Space (representing key points on the Free Energy Map)
anchor_points = np.array([
    [10.0, 10.0, 10.0],  # 0: Open State (High flexibility region)
    [4.0, 12.0, 8.0],    # 1: Transition Entry Point
    [2.0, 6.0, 4.0],     # 2: Narrow Path Middle Point
    [1.5, 1.5, 1.5]      # 3: Closed State (Low flexibility region)
])

NUM_FRAMES = 20 # Generate the 20 structures mentioned in the Supplementary Materials

# T4: Use B-Spline interpolation to create a smooth path in CV space
cv_path = generate_path(anchor_points, NUM_FRAMES)

print(f"Generated path spanning {len(cv_path)} frames in 3D CV space.")

### Visualization of the Path (Simulating Figure 4)

Here we visualize the path in the 3D CV space that the rMD model will use to generate structures.

In [None]:
fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection='3d')

ax.plot(cv_path[:, 0], cv_path[:, 1], cv_path[:, 2], label='rMD Transition Path (B-Spline)', color='blue')
ax.scatter(cv_path[0, 0], cv_path[0, 1], cv_path[0, 2], color='black', s=100, label='Open State (PDB: 6H0F)')
ax.scatter(cv_path[-1, 0], cv_path[-1, 1], cv_path[-1, 2], color='red', s=100, label='Closed State (PDB: 6H0G)')

ax.set_xlabel('CV1')
ax.set_ylabel('CV2')
ax.set_zlabel('CV3')
ax.set_title('Conformational Path in CV Space (Simulated Free Energy Map)')
ax.legend()
plt.show()

## 3. Generate Atomistic Structures

For each of the 20 interpolated CV points along the path, we use the trained rMD model (specifically the decoder) to instantaneously generate the full 9696-dimensional physical coordinates.

In [None]:
generated_structures = []
print("Generating atomistic structures...")

for i, cv_point in enumerate(cv_path):
    # U3: Use the decoder to generate structure from the CV point
    structure_vector = generate_structure(rmd_model, cv_point)
    generated_structures.append(structure_vector)
    print(f"  Generated structure {i+1}/{NUM_FRAMES}. Feature vector size: {structure_vector.shape}")

final_ensemble = np.stack(generated_structures)

print("\n--- Generation Complete ---")
print(f"Final generated trajectory shape: {final_ensemble.shape} (Frames x Features)")

## 4. Post-Processing and Output

The resulting structures are saved. In a real-world scenario, these coordinates would be reshaped back into a PDB/trajectory format file and passed to a refinement tool (like Rosetta Relax) to correct local geometric distortions, as recommended by the paper.

In [None]:
# Example of output format for a Pymol/MD analysis tool
output_filename = 'crbn_transition_path.npy'
np.save(output_filename, final_ensemble)
print(f"Generated structure coordinates saved to '{output_filename}'")
