In [1]:
import pandas as pd
import numpy as np
import csv
import re
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import logging
import math
from lmfit import minimize, fit_report, Parameters
from aim2_population_model_spatial_aff_parallel import get_mod_spike
from model_constants import (MC_GROUPS, LifConstants)
from popul_model import pop_model
from aim2_population_model_spatial_aff_parallel import Afferent, SimulationConfig, Simulation

In [2]:
#Global Variables
lmpars_init_dict = {}
lmpars = Parameters()
lmpars.add('tau1', value=8, vary=False)
lmpars.add('tau2', value=200, vary=False)
lmpars.add('tau3', value=1744.6, vary=False)
lmpars.add('tau4', value=np.inf, vary=False)
lmpars.add('k1', value=.74, vary=False, min=0) #a constant
lmpars.add('k2', value=2.75, vary=False, min=0) #b constant
lmpars.add('k3', value=.07, vary=False, min=0) #c constant
lmpars.add('k4', value=.0312, vary=False, min=0)
lmpars_init_dict['t3f12v3final'] = lmpars


In [3]:
#random helper functino calculating the distance between 2 points
def distance(x1,y1,x2,y2):
    x = (x2-x1) **2
    y  = (y2 -y1) **2
    return np.sqrt(x+y)
    

**Population Model Class**

In [None]:
class VF_Population_Model:
    
    def __init__(self, vf_tip_size, aff_type):
        self.vf_tip_size = vf_tip_size
        self.aff_type = aff_type
        self.results = None
        self.stress_data = None
        self.x_coords = None
        self.y_coords = None
    """
        functino takes in a vf_tip_size (given that that there is data assicated with it) an
        afferent type, and runs the single unit model for all of those coordinates with the data
        
    """
    def spatial_stress_vf_model(self, scaling_factor = 0.3):

        #reading data in 
        coords = pd.read_csv(f"data/vfspatial/{self.vf_tip_size}_spatial_coords.csv", header = None)

        #assinging the instance variables for all the x coords and y coords of spatial points
        self.x_coords = [float(row[0]) for row in coords.iloc[1:].values]
        self.y_coords = [float(row[1]) for row in coords.iloc[1:].values]


        stress_data = pd.read_csv(f"data/vfspatial/{self.vf_tip_size}_spatial_stress.csv", )
        time = stress_data['Time (ms)'].to_numpy()

        afferent_type = []
        x_pos = []
        y_pos = []
        spikes = []
        mean_firing_frequency = []
        peak_firing_frequency = []
        first_spike_time = []
        last_spike_time = []

        #iterating through each of the coordinates
        for i, row in coords.iloc[1:].iterrows():
            #getting stress data
            
            if f"Coord {i} Stress (kPa)" in stress_data.columns:
                stress = stress_data[f"Coord {i} Stress (kPa)"]
            else:
                logging.warning("STRESS VALUE COULD NOT BE INDEXED")

            lmpars = lmpars_init_dict['t3f12v3final']
            if self.aff_type == "RA":
                lmpars['tau1'].value = 2.5
                lmpars['tau2'].value = 200
                lmpars['tau3'].value = 1
                lmpars['k1'].value = 35
                lmpars['k2'].value = 0
                lmpars['k3'].value = 0.0
                lmpars['k4'].value = 0

            groups = MC_GROUPS
            mod_spike_time, mod_fr_inst = get_mod_spike(lmpars, groups, time, stress)

            if len(mod_spike_time) == 0 or len(mod_fr_inst) == 0:
                logging.warning(f"SPIKES COULD NOT BE GENERATED FOR {self.vf_tip_size}")
                continue

            if len(mod_spike_time) != len(mod_fr_inst):
                if len(mod_fr_inst) > 1:
                    mod_fr_inst_interp = np.interp(mod_spike_time, time, mod_fr_inst)
                else:
                    mod_fr_inst_interp = np.zeros_like(mod_spike_time)
            else:
                mod_fr_inst_interp = mod_fr_inst

            features, _ = pop_model(mod_spike_time,mod_fr_inst_interp)

            #appending stuff to lists
            afferent_type.append(self.aff_type)
            x_pos.append(row[0])
            y_pos.append(row[1])
            spikes.append(len(mod_spike_time) if len(mod_spike_time) !=0 else None)
            mean_firing_frequency.append(features["Average Firing Rate"])
            peak_firing_frequency.append(np.max(mod_fr_inst_interp))
            first_spike_time.append(mod_spike_time[0] if len(mod_spike_time) != None else None)
            last_spike_time.append(mod_spike_time[-1])
            
        model_results = {
            'afferent_type': self.aff_type,
            'x_position': x_pos,
            'y_position': y_pos,
            'num_of_spikes' : spikes,
            'mean_firing_frequency' : mean_firing_frequency,
            'peak_firing_frequency' : peak_firing_frequency, 
            'first_spike_time': first_spike_time,
            'last_spike_time' : last_spike_time
        }

        self.results = model_results
        return model_results

    def radial_stress_vf_model(self,scaling_factor = 0.3):
        """ Read in the Radial which has sample stress traces for every 2mm from a center point
        to calculate firing"""

        #regex pattern for exstracting the distance from the middle point
        distance_regex = r'\d\.\d{2}'

        #reading data from spatial stress data file for 50 (or n) data points in a grid
        coords = pd.read_csv(f"data/vfspatial/{self.vf_tip_size}_spatial_coords.csv", header = None)

        self.x_coords = [float(row[0]) for row in coords.iloc[1:].values]
        self.y_coords = [float(row[1]) for row in coords.iloc[1:].values]


        
        stress_df = pd.read_csv(f"data/vfspatial/{self.vf_tip_size}_spatial_stress.csv", )
        time = stress_df['Time (ms)'].to_numpy()

        #Reading in the radial stress file
        radial_stress = pd.read_csv(f"data/vfspatial/{self.vf_tip_size}_radial_stress.csv")
        radial_time = stress_df['Time (ms)'].to_numpy()
        
        stress_data = {}
        iff_data = {}


        #Outer loop top iterate through all n spatial points
        for i, row in coords.iloc[1:].iterrows():
            radial_spatial_flag = True
            stress_data[i] = {}
            iff_data[i] = {}

            if f"Coord {i} Stress (kPa)" in stress_df.columns:
                spatial_stress = stress_df[f"Coord {i} Stress (kPa)"]
                spatial_stress_max = np.max(spatial_stress)

                # Inner loop to iterate through radial distances
                for col in radial_stress.columns[1:]:
                    distance_from_center = float(re.findall(distance_regex, col)[0])

                    # Initialize lists for each coordinate-distance pair
                    afferent_type = []
                    x_pos = []
                    y_pos = []
                    spikes = []
                    mean_firing_frequency = []
                    peak_firing_frequency = []
                    first_spike_time = []
                    last_spike_time = []

                    if radial_spatial_flag:
                        radial_stress_vals = radial_stress[col]
                        radial_stress_max = np.max(radial_stress_vals)
                        distance_scaling_factor = spatial_stress_max / radial_stress_max
                        radial_spatial_flag = False
                    
                    scaled_stress = radial_stress[col] * distance_scaling_factor * scaling_factor

                    stress_data[i][distance_from_center] = {
                        "Time": radial_time,
                        distance_from_center: scaled_stress.to_numpy()
                    }

                    lmpars = lmpars_init_dict['t3f12v3final']
                    if self.aff_type == "RA":
                        lmpars['tau1'].value = 2.5
                        lmpars['tau2'].value = 200
                        lmpars['tau3'].value = 1
                        lmpars['k1'].value = 35
                        lmpars['k2'].value = 0
                        lmpars['k3'].value = 0.0
                        lmpars['k4'].value = 0

                    groups = MC_GROUPS
                    mod_spike_time, mod_fr_inst = get_mod_spike(lmpars, groups, stress_data[i][distance_from_center]["Time"], stress_data[i][distance_from_center][distance_from_center])

                    if len(mod_spike_time) == 0 or len(mod_fr_inst) == 0:
                        # logging.warning(f"SPIKES COULD NOT BE GENERATED FOR {self.vf_tip_size}")
                        continue

                    if len(mod_spike_time) != len(mod_fr_inst):
                        if len(mod_fr_inst) > 1:
                            mod_fr_inst_interp = np.interp(mod_spike_time, time, mod_fr_inst)
                        else:
                            mod_fr_inst_interp = np.zeros_like(mod_spike_time)
                    else:
                        mod_fr_inst_interp = mod_fr_inst

                    features, _ = pop_model(mod_spike_time, mod_fr_inst_interp)

                    # Append single values to the lists
                    afferent_type.append(self.aff_type)
                    x_pos.append(row[0])
                    y_pos.append(row[1])
                    spikes.append(len(mod_spike_time) if len(mod_spike_time) != 0 else None)
                    mean_firing_frequency.append(features["Average Firing Rate"])
                    peak_firing_frequency.append(np.max(mod_fr_inst_interp))
                    first_spike_time.append(mod_spike_time[0] if len(mod_spike_time) != 0 else None)
                    last_spike_time.append(mod_spike_time[-1])

                    # Store each coordinate-distance dictionary within iff_data
                    iff_data[i][distance_from_center] = {
                        'afferent_type': self.aff_type,
                        'x_position': x_pos[0],
                        'y_position': y_pos[0],
                        'num_of_spikes': spikes[0],
                        'mean_firing_frequency': mean_firing_frequency[0],
                        'peak_firing_frequency': peak_firing_frequency[0],
                        'first_spike_time': first_spike_time[0],
                        'last_spike_time': last_spike_time[0]
                    }
            else:
                logging.warning("STRESS VALUE COULD NOT BE INDEXED")
        self.stress_data = stress_data
        self.results = iff_data


        

    def aggregate_results(self):
        df = pd.DataFrame(self.results)
        file_path = f"generated_csv_files/{self.vf_tip_size}_vf_popul_model.csv"
        df.to_csv(file_path, index = False)
        return file_path 
    
    def plot_spatial_coords(self):
        """
        Plots the iffs on a grid for the original n points, the magniude of the peak firing
        frequency directly affects the size of the circle plotted, and the opacity
        """
        #colors for differnet afferents
        colors = {'SA': '#31a354', 'RA': '#3182bd'}
        plt.figure(figsize=(10, 5))

        # Plot the stimulus locations as circles
        x_positions = self.results.get("x_position")
        y_positions = self.results.get("y_position")
        mean_iffs = self.results.get("mean_firing_frequency")
        peak_iffs = self.results.get("peak_firing_frequency")
        
        x_positions = [float(value) for value in x_positions]
        y_positions = [float(value) for value in y_positions]
        #scaling peak_iffs so it looks better when plotting
    
        alphas = [float(value)/max(peak_iffs) for value in peak_iffs]

        #Scatter plot
                # Plot the stimulus locations as circles
        for x_pos, y_pos, radius, alpha in zip(x_positions, y_positions, peak_iffs, alphas):
            plt.gca().add_patch(
                patches.Circle((x_pos, y_pos), radius*2, edgecolor='black', facecolor = colors.get(self.aff_type) , linewidth=1, alpha = 0.5)
            )
        plt.xlabel('Length (mm)')
        plt.ylabel('Width (mm)')
        plt.title("VF Afferent Stress Distribution")
        plt.gca().set_aspect('equal', adjustable='box')
        plt.xlim(min(x_positions) - 1, max(x_positions) + 1)
        plt.ylim(min(y_positions) - 1, max(y_positions) + 1)
        plt.savefig(f"vf_graphs/aggregated_results_on_grid/{self.vf_tip_size}_{self.aff_type}_constant_opacity.png")

    def simulate_afferent_response(self, aff : Afferent):
        radii = list(self.results.get(1).keys())


        #calculate the distance between the afferent and all of the spatial point centers to find which one is the closest
        distance_from_spatial_centers = [distance(x1, y1, aff.x_pos, aff.y_pos) for x1, y1 in zip(self.x_coords, self.y_coords)]

        #ffinding the index of the minimum distance
        min_index = distance_from_spatial_centers.index(min(distance_from_spatial_centers))

        # getting the x,y coordiantes of the spatial coordiante closest to the the afferent
        min_x = self.x_coords[min_index]
        min_y = self.y_coords[min_index]



        #the distance from afferent point to the distance of the spatial center
        distance_to_nearest_center = distance_from_spatial_centers[min_index]


        #If the afferent is equivalent to the stress trace at the spatial point
        for radius in radii:    
            if distance_to_nearest_center <= radius:
                        return (
                        self.stress_data.get(min_index).get(radius), 
                        self.results.get(min_index).get(radius)
                        )
        
        #In the case that its too far away from all spatial coordinates
        return None, None
    
    def plot_radial_spatial_afferents(self,aff: Afferent, config : SimulationConfig, stress, iff):
        radii = list(self.results.get(1).keys())
        colors = {'SA': '#31a354', 'RA': '#3182bd'}

        plt.figure( figsize=(10,5))
        """
        Plotting the "Stimuli" or the know stress data this will look similar to spatial
        data points in a grid and each of the spatial points will be surrounded be concentric
        circles
        """

        for x_pos,y_pos in zip(sl





        pass


    def get_iffs(self):
        return self.results
    
    def get_stress_traces(self):
        return self.stress_data


**COnfiguring the Von-Frey Population Model**

In [48]:
#creates model class
vf_model = VF_Population_Model(4.17, "SA")

#runs the model which calculates the results
# vf_model.spatial_stress_vf_model()
vf_model.radial_stress_vf_model()
iffs_results = vf_model.get_iffs()
stress_results = vf_model.get_stress_traces()


**Randomly Generating Afferents and setting up the configuration for the Simulation**

In [49]:
tongue_size = (50, 25)  # in mm
density_ratio = (1, 0)  # Ratio of SA and RA afferents
n_afferents = 200
rf_sizes = {
    'SA': [1, 10, 19.6],
    'RA': [1, 6.5, 12.5]
}

config = SimulationConfig(tongue_size, density_ratio, n_afferents, rf_sizes,
                          stimulus_diameter=None, 
                          x_stimulus=None, y_stimulus=None,
                          stress=None)

simulation = Simulation(config)

afferents = simulation.get_afferents()



**Getting Stress & Firing Data for the Afferents**

In [52]:
stress_data = []
iff_data = []

for afferent in afferents:
   stresses, iffs = vf_model.simulate_afferent_response(afferent)
   stress_data.append(stresses)
   iff_data.append(iffs)




Radii: [0.0, 0.33, 0.51, 0.71, 0.9, 1.09, 1.31, 1.52, 1.69, 1.88, 2.11, 2.3, 2.48, 2.71, 2.92, 3.31, 3.51, 3.69, 3.89, 4.1, 4.3, 4.49, 4.7, 4.9]
Radii: [0.0, 0.33, 0.51, 0.71, 0.9, 1.09, 1.31, 1.52, 1.69, 1.88, 2.11, 2.3, 2.48, 2.71, 2.92, 3.31, 3.51, 3.69, 3.89, 4.1, 4.3, 4.49, 4.7, 4.9]
Radii: [0.0, 0.33, 0.51, 0.71, 0.9, 1.09, 1.31, 1.52, 1.69, 1.88, 2.11, 2.3, 2.48, 2.71, 2.92, 3.31, 3.51, 3.69, 3.89, 4.1, 4.3, 4.49, 4.7, 4.9]
Radii: [0.0, 0.33, 0.51, 0.71, 0.9, 1.09, 1.31, 1.52, 1.69, 1.88, 2.11, 2.3, 2.48, 2.71, 2.92, 3.31, 3.51, 3.69, 3.89, 4.1, 4.3, 4.49, 4.7, 4.9]
Radii: [0.0, 0.33, 0.51, 0.71, 0.9, 1.09, 1.31, 1.52, 1.69, 1.88, 2.11, 2.3, 2.48, 2.71, 2.92, 3.31, 3.51, 3.69, 3.89, 4.1, 4.3, 4.49, 4.7, 4.9]
Radii: [0.0, 0.33, 0.51, 0.71, 0.9, 1.09, 1.31, 1.52, 1.69, 1.88, 2.11, 2.3, 2.48, 2.71, 2.92, 3.31, 3.51, 3.69, 3.89, 4.1, 4.3, 4.49, 4.7, 4.9]
Radii: [0.0, 0.33, 0.51, 0.71, 0.9, 1.09, 1.31, 1.52, 1.69, 1.88, 2.11, 2.3, 2.48, 2.71, 2.92, 3.31, 3.51, 3.69, 3.89, 4.1, 4.3, 4

In [51]:
for iff in iff_data:
    print(iff)


(None, None)
(None, None)
({'Time': array([   0,    1,    2, ..., 4983, 4984, 4985]), 0.9: array([ 0.        ,  0.00018411,  0.00036821, ..., -0.01701775,
       -0.01597566, -0.01493358])}, None)
({'Time': array([   0,    1,    2, ..., 4983, 4984, 4985]), 4.7: array([ 0.00000000e+00,  2.73366536e-05,  5.46733072e-05, ...,
       -2.52684779e-03, -2.37211589e-03, -2.21738399e-03])}, None)
(None, None)
(None, None)
(None, None)
(None, None)
({'Time': array([   0,    1,    2, ..., 4983, 4984, 4985]), 1.88: array([ 0.00000000e+00,  6.23566262e-05,  1.24713252e-04, ...,
       -5.76389884e-03, -5.41094550e-03, -5.05799245e-03])}, None)
(None, None)
({'Time': array([   0,    1,    2, ..., 4983, 4984, 4985]), 3.89: array([ 0.00000000e+00,  2.05597679e-05,  4.11195357e-05, ...,
       -1.90043029e-03, -1.78405716e-03, -1.66768403e-03])}, None)
(None, None)
({'Time': array([   0,    1,    2, ..., 4983, 4984, 4985]), 1.31: array([ 0.        ,  0.00014247,  0.00028494, ..., -0.01316893,
       -

In [23]:
stress_results.get(4).get(0.0), 
iffs_results.get(4).get(0.0)

{'afferent_type': 'SA',
 'x_position': '7.48174',
 'y_position': '1.18685',
 'num_of_spikes': 17,
 'mean_firing_frequency': 25.942516694926557,
 'peak_firing_frequency': 0.038461538461538464,
 'first_spike_time': 155.0,
 'last_spike_time': 2486.5}

In [7]:
iffs_results.get(3)

{0.0: {'afferent_type': 'SA',
  'x_position': '5.22687',
  'y_position': '2.46636',
  'num_of_spikes': 31,
  'mean_firing_frequency': 44.738098914941304,
  'peak_firing_frequency': 0.06896551724137931,
  'first_spike_time': 129.5,
  'last_spike_time': 2487.0},
 0.33: {'afferent_type': 'SA',
  'x_position': '5.22687',
  'y_position': '2.46636',
  'num_of_spikes': 18,
  'mean_firing_frequency': 27.269562497924895,
  'peak_firing_frequency': 0.043478260869565216,
  'first_spike_time': 152.0,
  'last_spike_time': 2466.0},
 0.51: {'afferent_type': 'SA',
  'x_position': '5.22687',
  'y_position': '2.46636',
  'num_of_spikes': 13,
  'mean_firing_frequency': 21.585531606126818,
  'peak_firing_frequency': 0.03225806451612903,
  'first_spike_time': 165.5,
  'last_spike_time': 2454.0},
 0.71: {'afferent_type': 'SA',
  'x_position': '5.22687',
  'y_position': '2.46636',
  'num_of_spikes': 11,
  'mean_firing_frequency': 17.676772047343814,
  'peak_firing_frequency': 0.02666666666666667,
  'first_sp