In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

# --- 1. Data Simulation and Input Preparation ---
print("Step 1: Simulating data...")

# Dimensions
n_xenium_cells = 1000
n_visium_spots = 100
n_xenium_genes = 250
n_visium_genes = 500

# Create mock expression matrices
xenium_expression = np.random.rand(n_xenium_cells, n_xenium_genes).astype(np.float32)
visium_expression = np.random.rand(n_visium_spots, n_visium_genes).astype(np.float32)

# Create the mapping matrix
mapping_matrix = np.zeros((n_xenium_cells, n_visium_spots), dtype=np.float32)
for i in range(n_xenium_cells):
    mapping_matrix[i, np.random.randint(0, n_visium_spots)] = 1

# Convert numpy arrays to PyTorch Tensors
x_tensor = torch.from_numpy(xenium_expression)
v_tensor = torch.from_numpy(visium_expression)
map_tensor = torch.from_numpy(mapping_matrix)

# Prepare the Visium context for each cell (this remains necessary)
v_mapped_to_cells = torch.matmul(map_tensor, v_tensor)

print(f"Xenium input shape: {x_tensor.shape}")
print(f"Visium context input shape: {v_mapped_to_cells.shape}\n")


# --- 2. PyTorch Model Definition (Dual-Input "Two-Tower" Architecture) ---

class DualInputModel(nn.Module):
    def __init__(self, n_xenium_genes, n_visium_genes):
        super(DualInputModel, self).__init__()

        # --- Tower 1: Xenium Feature Extractor ---
        self.xenium_tower = nn.Sequential(
            nn.Linear(n_xenium_genes, 128),
            nn.ReLU(),
            nn.Linear(128, 64) # Output is a 64-dim feature vector
        )

        # --- Tower 2: Visium Feature Extractor ---
        self.visium_tower = nn.Sequential(
            nn.Linear(n_visium_genes, 256),
            nn.ReLU(),
            nn.Linear(256, 128) # Output is a 128-dim feature vector
        )

        # --- Shared Decoder: Merges features and generates U ---
        # Input dimension is the sum of the tower outputs (64 + 128 = 192)
        merged_dim = 64 + 128
        self.decoder = nn.Sequential(
            nn.Linear(merged_dim, 512),
            nn.ReLU(),
            nn.Linear(512, n_visium_genes) # Final output is a row of U
        )

        # --- Trivial Projector for Xenium Reconstruction ---
        self.xenium_projector = nn.Linear(n_visium_genes, n_xenium_genes)


    def forward(self, x_input, v_context_input, mapping_matrix):
        # 1. Process inputs through their respective towers
        x_features = self.xenium_tower(x_input)
        v_features = self.visium_tower(v_context_input)

        # 2. Merge the features from both towers
        merged_features = torch.cat([x_features, v_features], dim=1)

        # 3. Generate the U matrix using the shared decoder
        U_predicted = self.decoder(merged_features)

        # 4. Reconstruct Xenium by projecting U
        x_recon = self.xenium_projector(U_predicted)

        # 5. Reconstruct Visium by aggregating U
        map_T = mapping_matrix.T
        v_aggregated = torch.matmul(map_T, U_predicted)
        cells_per_spot = map_T.sum(dim=1, keepdim=True).clamp(min=1e-6)
        v_recon = v_aggregated / cells_per_spot

        return x_recon, v_recon, U_predicted

# --- 3. Training Setup ---

print("Step 2: Initializing Dual-Input model, loss, and optimizer...")
model = DualInputModel(
    n_xenium_genes=n_xenium_genes,
    n_visium_genes=n_visium_genes
)

loss_fn = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
num_epochs = 100

print("Model architecture:")
print(model)
print("\nStarting training...\n")

# --- 4. The Training Loop ---

for epoch in range(num_epochs):
    model.train()

    # Forward pass with two separate inputs
    x_recon, v_recon, _ = model(x_tensor, v_mapped_to_cells, map_tensor)

    # Calculate Losses
    loss_x = loss_fn(x_recon, x_tensor)
    loss_v = loss_fn(v_recon, v_tensor)
    total_loss = loss_x + loss_v # Simple sum, can be weighted if needed

    # Backpropagation
    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()

    if (epoch + 1) % 10 == 0:
        print(
            f'Epoch [{epoch+1}/{num_epochs}], '
            f'Total Loss: {total_loss.item():.4f}, '
            f'X-Recon: {loss_x.item():.4f}, '
            f'V-Recon: {loss_v.item():.4f}'
        )

# --- 5. Generating the Final Universal Matrix ---
print("\n--- 5. Generating the Final Universal Matrix ---")
model.eval()
with torch.no_grad():
    # Run the model one last time to get the final generated U
    _, _, universal_matrix_tensor = model(x_tensor, v_mapped_to_cells, map_tensor)

universal_matrix_numpy = universal_matrix_tensor.cpu().numpy()

print("Universal Matrix generation complete.")
print(f"Shape of the final matrix: {universal_matrix_numpy.shape}")



Of course. You've perfectly described the most elegant and powerful architecture. We will discard the "patchwork" idea and build a model that takes Xenium and Visium data as two separate, direct inputs, learns from them independently, and then merges them to generate the final U matrix.

This is a classic "multi-modal" or "two-tower" network design.

The Dual-Input Architecture
Here's how this superior model will work:

Two Input Towers: The model will have two distinct input pathways or "towers":

Xenium Tower: A dedicated neural network that processes only the Xenium data (X). Its job is to learn a rich, internal representation of the cell-specific information.

Visium Tower: A second, parallel neural network that processes only the contextual Visium data (V). Its job is to learn a representation of the cell's neighborhood.

Merge Layer: The outputs from both towers are then brought together and concatenated into a single, powerful feature vector. This vector now contains processed, high-level features from both modalities.

Shared Decoder: This merged vector is fed into a final, deep neural network. This decoder's job is to take the combined high-level features and generate the final, harmonized row of the U matrix.

Trivial Reconstruction & Loss: The rest of the process remains the same. The generated U matrix is used to reconstruct X' and V', and the loss is calculated by comparing them to the originals. The error is then backpropagated through the entire system, training all three components (both towers and the decoder) simultaneously.

This architecture is superior because it allows the model to create specialized representations for each data type before forcing them to interact. It's a much cleaner and more effective way to learn the connections.

I will now update the code to implement this final, dual-input architecture.