In [None]:
#Runs successfully on MBP 2024 metal_102 conda environment.
import numpy as np
import torch
import torch.nn as nn
import pymc as pm
from pytensor.compile.ops import as_op
import arviz as az
import pytensor.tensor as pt
from tqdm import tqdm
import matplotlib.pyplot as plt
from IPython.display import clear_output
# Suppress PyTorch warnings
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'  # Use CPU instead of GPU

In [None]:
# Define constants
SIZE = 100
T=100
dt=0.1
NUM_SAMPLES = 100
EPOCHS = 100

In [None]:
def sim_channel(params, T, dt):
    # kc12, kc21, spikeMax, Fnoise, scale, offset, nSpikes, kco1, koc2, ko12, ko21
    kc12, kc21, Fnoise, scale, offset,  kco1, koc2, ko12, ko21 = params
    
    t = torch.arange(0, T, dt)
    zero = torch.tensor(0.0)
    
    # Convert parameters to tensors
    kc12 = torch.tensor(kc12)
    kc21 = torch.tensor(kc21)
    kco1 = torch.tensor(kco1)
    koc2 = torch.tensor(koc2)
    ko12 = torch.tensor(ko12)
    ko21 = torch.tensor(ko21)
    Fnoise = torch.tensor(Fnoise)
    """
    nSpikes = torch.tensor(nSpikes, dtype=torch.int32)
    spikeMax = torch.tensor(spikeMax)"""
    scale = torch.tensor(scale)

    row1 = torch.stack([zero, kc12, zero, zero])
    row2 = torch.stack([kc21, zero, kco1, zero])
    row3 = torch.stack([zero, koc2, zero, ko12])
    row4 = torch.stack([zero,zero,ko21,zero])
    
    
    r1 = torch.sum(row1)
    #row1 = torch.stack([1-r1, kc12, zero, zero])
    r2 = torch.sum(row2)
    #row2 = torch.stack([kc21, 1-r2, kco1, zero])
    r3 = torch.sum(row3)
    #row3 = torch.stack([zero, koc2, 1-r3, ko12])
    r4 = torch.sum(row4)
    #row4 = torch.stack([zero, zero, ko21, 1-r4])

    def softmax_row(row):
        return torch.nn.functional.softmax(row, dim=0)
    
    row1 = softmax_row(torch.stack([1-r1, kc12, zero, zero]))
    row2 = softmax_row(torch.stack([kc21, 1-r2, kco1, zero]))
    row3 = softmax_row(torch.stack([zero, koc2, 1-r3, ko12]))
    row4 = softmax_row(torch.stack([zero, zero, ko21, 1-r4]))
    transition_matrix = torch.stack([row1, row2, row3, row4])
    
    # Define the transition function
    def transition_fn(state):
        probs = transition_matrix[state]
        return torch.distributions.Categorical(probs=probs).sample()
    
    # Define the initial state distribution
    initial_probs = torch.tensor([0.3, 0.3, 0.2, 0.2])
    initial_distribution = torch.distributions.Categorical(probs=initial_probs)
    
    # Define the Markov chain
    states = [initial_distribution.sample().item()]
    
    for _ in range(T - 1):
        states.append(transition_fn(states[-1]).item())
    
    channels = torch.tensor(states)
    channels = torch.where(channels < 2, torch.zeros_like(channels), torch.ones_like(channels))
    noise = torch.normal(zero, Fnoise, (SIZE,))
    #F = torch.sin(2 * np.pi * t)*Fnoise
    #noise =  A + F 
    return torch.stack([channels, channels * scale + offset + noise], axis=-1)

# Example usage
kc12=0.1
kc21=0.2
Fnoise=0.01
scale=0.25
offset=-0.4
kco1=0.5
koc=0.25
ko12=0.1
ko21=0.2 
orig_params = [0.1, 0.2, 0.01, 0.25, -0.4,  0.5, 0.25, 0.1, 0.2]

channels = sim_channel(orig_params, T, dt)
plt.plot(channels)

In [None]:
# Define the discriminator model
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv1d(2, 64, kernel_size=5, stride=2, padding=2),
            nn.LeakyReLU(),
            nn.Conv1d(64, 128, kernel_size=5, stride=2, padding=2),
            nn.LeakyReLU(),
            nn.Flatten(),
            nn.Linear(128 * (SIZE // 4), 1)
        )

    def forward(self, x):
        return self.model(x.permute(0, 2, 1))

In [None]:
# Loss function for the discriminator
def discriminator_loss(real_output, fake_output):
    real_loss = nn.BCEWithLogitsLoss()(real_output, torch.ones_like(real_output))
    fake_loss = nn.BCEWithLogitsLoss()(fake_output, torch.zeros_like(fake_output))
    return real_loss + fake_loss

In [None]:
# Create and compile the discriminator
discriminator = Discriminator()
discriminator_optimizer = torch.optim.Adam(discriminator.parameters(), lr=1e-4)

In [None]:
real_data = torch.stack([sim_channel(orig_params, T, dt) for _ in range(NUM_SAMPLES)], dim=0)

In [None]:
def pymc_sim_channel(params, T, dt):
    kc12, kc21, Fnoise, scale, offset, kco1, koc2, ko12, ko21 = params
    
    t = pt.arange(0, T, dt)
    zero = torch.tensor(0.0)
    """
    # Convert parameters to NumPy arrays first
    kc12 = pm.draw(kc12)
    kc21 = pm.draw(kc21)
    kco1 = pm.draw(kco1)
    koc2 = pm.draw(koc2)
    ko12 = pm.draw(ko12)
    ko21 = pm.draw(ko21)
    #Current version not even using spikes
    nSpikes = pm.draw(nSpikes)
    spikeMax = pm.draw(spikeMax)
    Fnoise = pm.draw(Fnoise)
    scale = pm.draw(scale)
    offset = pm.draw(offset)"""
    
    # Then convert to PyTorch tensors
    kc12 = torch.tensor(kc12)
    kc21 = torch.tensor(kc21)
    kco1 = torch.tensor(kco1)
    koc2 = torch.tensor(koc2)
    ko12 = torch.tensor(ko12)
    ko21 = torch.tensor(ko21)
    """
    nSpikes = torch.tensor(nSpikes, dtype=torch.int32)
    spikeMax = torch.tensor(spikeMax)"""
    Fnoise=torch.tensor(Fnoise)
    scale = torch.tensor(scale)
    offset = torch.tensor(offset)
    
    
    row1 = torch.stack([zero, kc12, zero, zero])
    row2 = torch.stack([kc21, zero, kco1, zero])
    row3 = torch.stack([zero, koc2, zero, ko12])
    row4 = torch.stack([zero, zero, ko21, zero])
    
    r1 = torch.sum(row1)
    #row1 = torch.stack([1-r1, kc12, zero, zero])
    r2 = torch.sum(row2)
    #row2 = torch.stack([kc21, 1-r2, kco1, zero])
    r3 = torch.sum(row3)
    #row3 = torch.stack([zero, koc2, 1-r3, ko12])
    r4 = torch.sum(row4)
    #row4 = torch.stack([zero, zero, ko21, 1-r4])

    def softmax_row(row):
        return torch.nn.functional.softmax(row, dim=0)
    
    row1 = softmax_row(torch.stack([1-r1, kc12, zero, zero]))
    row2 = softmax_row(torch.stack([kc21, 1-r2, kco1, zero]))
    row3 = softmax_row(torch.stack([zero, koc2, 1-r3, ko12]))
    row4 = softmax_row(torch.stack([zero, zero, ko21, 1-r4]))
            
    transition_matrix = torch.stack([row1, row2, row3, row4])
    
    def transition_fn(state):
        probs = transition_matrix[state]
        return torch.distributions.Categorical(probs=probs).sample()
    
    initial_probs = torch.tensor([0.3, 0.3, 0.2, 0.2])
    initial_distribution = torch.distributions.Categorical(probs=initial_probs)
    
    states = [initial_distribution.sample().item()]
    
    for _ in range(T - 1):
        states.append(transition_fn(states[-1]).item())
    
    channels = torch.tensor(states)
    channels = torch.where(channels < 2, torch.zeros_like(channels), torch.ones_like(channels))
    noise = torch.normal(zero, Fnoise, (SIZE,))
    
    res =  torch.stack([channels, channels * scale + offset + noise], axis=-1)
    return res.numpy() 


In [None]:
@as_op(itypes=[pt.dscalar, pt.dscalar, pt.dscalar, pt.dscalar, pt.dscalar, pt.dscalar,
       pt.dscalar, pt.dscalar, pt.dscalar, pt.dscalar, pt.dscalar], otypes=[pt.dvector])
def new_pymc_sim_channel(kc12, kc21, Fnoise, scale, offset, kco1, koc2, ko12, ko21, T, dt):
    kc12, kc21, Fnoise, scale, offset, kco1, koc2, ko12, ko21 = params
    
    t = pt.arange(0, T, dt)
    zero = torch.tensor(0.0)
    
    # Then convert to PyTorch tensors
    kc12 = torch.tensor(kc12)
    kc21 = torch.tensor(kc21)
    kco1 = torch.tensor(kco1)
    koc2 = torch.tensor(koc2)
    ko12 = torch.tensor(ko12)
    ko21 = torch.tensor(ko21)

    Fnoise=torch.tensor(Fnoise)
    scale = torch.tensor(scale)
    offset = torch.tensor(offset)
    
    
    row1 = torch.stack([zero, kc12, zero, zero])
    row2 = torch.stack([kc21, zero, kco1, zero])
    row3 = torch.stack([zero, koc2, zero, ko12])
    row4 = torch.stack([zero, zero, ko21, zero])
    
    r1 = torch.sum(row1)
    #row1 = torch.stack([1-r1, kc12, zero, zero])
    r2 = torch.sum(row2)
    #row2 = torch.stack([kc21, 1-r2, kco1, zero])
    r3 = torch.sum(row3)
    #row3 = torch.stack([zero, koc2, 1-r3, ko12])
    r4 = torch.sum(row4)
    #row4 = torch.stack([zero, zero, ko21, 1-r4])

    def softmax_row(row):
        return torch.nn.functional.softmax(row, dim=0)
    
    row1 = softmax_row(torch.stack([1-r1, kc12, zero, zero]))
    row2 = softmax_row(torch.stack([kc21, 1-r2, kco1, zero]))
    row3 = softmax_row(torch.stack([zero, koc2, 1-r3, ko12]))
    row4 = softmax_row(torch.stack([zero, zero, ko21, 1-r4]))
            
    transition_matrix = torch.stack([row1, row2, row3, row4])
    
    def transition_fn(state):
        probs = transition_matrix[state]
        return torch.distributions.Categorical(probs=probs).sample()
    
    initial_probs = torch.tensor([0.3, 0.3, 0.2, 0.2])
    initial_distribution = torch.distributions.Categorical(probs=initial_probs)
    
    states = [initial_distribution.sample().item()]
    
    for _ in range(T - 1):
        states.append(transition_fn(states[-1]).item())
    
    channels = torch.tensor(states)
    channels = torch.where(channels < 2, torch.zeros_like(channels), torch.ones_like(channels))
    noise = torch.normal(zero, Fnoise, (SIZE,))
    
    res =  torch.stack([channels, channels * scale + offset + noise], axis=-1)
    return res.numpy() 




In [None]:
import numpy as np
import pymc as pm
from pytensor.compile.ops import as_op
import pytensor.tensor as pt
import arviz as az
import matplotlib.pyplot as plt
import torch
from scipy.fft import fft, fftfreq

# Define the custom simulation function
@as_op(itypes=[pt.dscalar, pt.dscalar, pt.dscalar, pt.dscalar, pt.dscalar, pt.dscalar, pt.dscalar, pt.dscalar, pt.dscalar, pt.dscalar, pt.dscalar], otypes=[pt.dmatrix])
def new_pymc_sim_channel(kc12, kc21, Fnoise, scale, offset, kco1, koc2, ko12, ko21, T, dt):
    t = np.arange(0, T, dt)
    zero = torch.tensor(0.0)
    
    # Convert to PyTorch tensors
    kc12 = torch.tensor(kc12)
    kc21 = torch.tensor(kc21)
    kco1 = torch.tensor(kco1)
    koc2 = torch.tensor(koc2)
    ko12 = torch.tensor(ko12)
    ko21 = torch.tensor(ko21)
    Fnoise = torch.tensor(Fnoise)
    scale = torch.tensor(scale)
    offset = torch.tensor(offset)
    
    row1 = torch.stack([zero, kc12, zero, zero])
    row2 = torch.stack([kc21, zero, kco1, zero])
    row3 = torch.stack([zero, koc2, zero, ko12])
    row4 = torch.stack([zero, zero, ko21, zero])
    
    r1 = torch.sum(row1)
    r2 = torch.sum(row2)
    r3 = torch.sum(row3)
    r4 = torch.sum(row4)

    def softmax_row(row):
        return torch.nn.functional.softmax(row, dim=0)
    
    row1 = softmax_row(torch.stack([1-r1, kc12, zero, zero]))
    row2 = softmax_row(torch.stack([kc21, 1-r2, kco1, zero]))
    row3 = softmax_row(torch.stack([zero, koc2, 1-r3, ko12]))
    row4 = softmax_row(torch.stack([zero, zero, ko21, 1-r4]))
            
    transition_matrix = torch.stack([row1, row2, row3, row4])
    
    def transition_fn(state):
        probs = transition_matrix[state]
        return torch.distributions.Categorical(probs=probs).sample()
    
    initial_probs = torch.tensor([0.3, 0.3, 0.2, 0.2])
    initial_distribution = torch.distributions.Categorical(probs=initial_probs)
    
    states = [initial_distribution.sample().item()]
    
    for _ in range(int(T / dt) - 1):
        states.append(transition_fn(states[-1]).item())
    
    channels = torch.tensor(states)
    channels = torch.where(channels < 2, torch.zeros_like(channels), torch.ones_like(channels))
    noise = torch.normal(zero, Fnoise, (len(channels),))
    
    res = torch.stack([channels, channels * scale + offset + noise], axis=-1)
    return res.numpy()

# Define the Bayesian model
with pm.Model() as model:
    # Priors for parameters
    kc21 = pm.Beta('kc21', alpha=2, beta=5)
    kc12 = pm.Beta('kc12', alpha=2, beta=5)
    kco1 = pm.Beta('kco1', alpha=2, beta=5)
    koc2 = pm.Beta('koc2', alpha=2, beta=5)
    ko12 = pm.Beta('ko12', alpha=2, beta=5)
    ko21 = pm.Beta('ko21', alpha=2, beta=5)
    Fnoise = pm.HalfNormal('Fnoise', sigma=1)
    scale = pm.HalfNormal('scale', sigma=1)
    offset = pm.Normal('offset', mu=0, sigma=1)
    sigma = pm.HalfNormal("sigma", sigma=1)
    
    T = pt.as_tensor_variable(np.float64(100))  # Total time
    dt = pt.as_tensor_variable(np.float64(1))  # Time step
    
    # Likelihood
    predicted = pm.Normal('predicted', mu=new_pymc_sim_channel(kc12, kc21, Fnoise, scale, offset, kco1, koc2, ko12, ko21, T, dt), sigma=sigma, observed=channels)
      
    trace = pm.sample(10000, tune=5000, return_inferencedata=True)

# Plot the results
az.plot_trace(trace)
plt.tight_layout()
plt.show()

# Print the summary of the posterior
print(az.summary(trace, var_names=["kc12", "kc21", "Fnoise", "scale", "offset", "kco1", "koc2", "ko12", "ko21", "sigma"]))
"""
orig_params = [0.1, 0.2, 0.01, 0.25, -0.4,  0.5, 0.25, 0.1, 0.2]
"""

In [None]:
print(az.summary(trace, var_names=["offset"])["mean"])
toffset = az.summary(trace, var_names=["offset"])["mean"].item()
print(toffset)
with model:
    pred_eval = predicted.eval({
        kc12: az.summary(trace, var_names=["kc12"])["mean"].item(),
        kc21: az.summary(trace, var_names=["kc21"])["mean"].item(),
        Fnoise: az.summary(trace, var_names=["Fnoise"])["mean"].item(),
        scale: az.summary(trace, var_names=["scale"])["mean"].item(),
        offset: az.summary(trace, var_names=["offset"])["mean"].item(),
        kco1: az.summary(trace, var_names=["kco1"])["mean"].item(),
        koc2: az.summary(trace, var_names=["koc2"])["mean"].item(),
        ko12: az.summary(trace, var_names=["ko12"])["mean"].item(),
        ko21: az.summary(trace, var_names=["ko21"])["mean"].item(),
        sigma: 0
    })

plt.plot(pred_eval)

In [None]:
# Training loop
for epoch in tqdm(range(EPOCHS)):
    if epoch > 0:
        summary = az.summary(trace)
        print(summary)
        print(disc_loss)
    # Sample from the posterior
    with model:
        trace = pm.sample(5000, tune=500, chains=8, cores=8)
    
    # Generate synthetic data using the samples
    synthetic_data = []
    for i in range(NUM_SAMPLES):
        idx = np.random.randint(len(trace))
        
        posterior = trace.posterior.stack(sample=['chain', 'draw']) 
        #Concerned order might not be correct so manually doing this
        #params = [trace.get_values(param)[idx] for param in model.named_vars.keys() if param != 'likelihood']
        # kc12, kc21, spikeMax, Fnoise, scale, offset, nSpikes, kco1, koc2, ko12, ko21
        
        """ikc12 = posterior['kc12'][idx].values
        ikc21 = posterior['kc21'][idx].values
        ikco1 = posterior['kco1'][idx].values
        ikoc2 = posterior['koc2'][idx].values
        iko12 = posterior['ko12'][idx].values
        iko21 = posterior['ko21'][idx].values
        iFnoise = posterior['Fnoise'][idx].values
        iscale = posterior['scale'][idx].values
        ioffset = posterior['offset'][idx].values"""
        ikc12 = posterior['kc12'].mean().values
        ikc21 = posterior['kc21'].mean().values
        ikco1 = posterior['kco1'].mean().values
        ikoc2 = posterior['koc2'].mean().values
        iko12 = posterior['ko12'].mean().values
        iko21 = posterior['ko21'].mean().values
        iFnoise = posterior['Fnoise'].mean().values
        iscale = posterior['scale'].mean().values
        ioffset = posterior['offset'].mean().values  
        """
        inSpikes = posterior['nSpikes'][idx].values
        ispikeMax = posterior['spikeMax'][idx].values"""
        
        #kc12, kc21, spikeMax, Fnoise, scale, offset, nSpikes, kco1, koc2, ko12, ko21 = params
        params = [ikc12, ikc21, iFnoise, iscale, ioffset, ikco1, ikoc2, iko12, iko21]
        strip = pymc_sim_channel(params, T, dt)
        #synthetic_data.append(strip.clone().detach().to(dtype=torch.float32))
        synthetic_data.append(torch.from_numpy(strip).float())
    synthetic_data = torch.stack(synthetic_data)
    
    # Train discriminator
    for _ in range(5):  # Train discriminator more than generator
        real_output = discriminator(real_data)
        fake_output = discriminator(synthetic_data)
        
        disc_loss = discriminator_loss(real_output, fake_output)
        
        discriminator_optimizer.zero_grad()
        disc_loss.backward()
        discriminator_optimizer.step()

    # Update the model's posterior using the discriminator's feedback
    with model:
        pm.set_data({'discriminator_output': fake_output.mean().item()})
    clear_output(wait=True)

In [None]:
# Final evaluation
with model:
    posterior_predictive = pm.sample_posterior_predictive(trace)
az.plot_trace(trace)

In [None]:
summary = az.summary(trace)
print(summary)

In [None]:
# Extracting parameter names
parameter_names = summary.index.tolist()

# Iterating through parameter names
params=[]
for i, param in enumerate(parameter_names):
    mean_value = summary.loc[param, 'mean']
    params.append(mean_value)
    print(f"Parameter: {param}, Mean: {mean_value}")
    if i >= 5:
        break

In [None]:
plt.plot(sim_channel([params[1],params[2],params[3],params[4],params[5],params[0]],T,dt))
plt.show()

In [None]:
predicted