In [None]:
import os
import sys
import cv2
import math
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
from tqdm import tqdm
import torch 
from pykeops.torch import LazyTensor
import matplotlib.pylab as pl
from matplotlib.colors import ListedColormap
import pickle

In [None]:
split = "split_20231212"
experiment = "FWF006"
cell_type = "Proximal"
N = 512

In [None]:
def sample_grid(n, dim=2,dtype=torch.float32,device="cpu"):
    grid_points = torch.linspace(0,1,n,dtype=dtype,device=device)
    grid_points = torch.stack(
        torch.meshgrid((grid_points,) * dim, indexing="ij"), dim=-1
    )
    grid_points = grid_points.reshape(-1, dim)
    return grid_points

y_grid = sample_grid(N)
R = 0.1

def neighbours_grid(i,y_grid,R):
    yi = y_grid[i,:].reshape((1,2))
    return (((yi - y_grid)**2).sum(1) <= R**2).float().squeeze()



In [None]:
data_dir = split+"_"+experiment+"_smooth/"+cell_type+"/data/"

In [None]:
def open_data(data_dir,N=512):
    T = len([name for name in os.listdir(data_dir) if os.path.isfile(os.path.join(data_dir, name))])
    data = torch.zeros((N,N,T))
    for t in range(1,T+1):
        with open(data_dir + "data_" + str(t) + ".pkl", 'rb') as file:
            data[:,:,t-1] = pickle.load(file)
    return data

data = open_data(data_dir,N=N)

In [None]:
# plt.plot(smooth_max)
# # plt.plot(data.reshape((N*N,T))[data_max[1][-1].item(),:])

In [None]:
from matplotlib.widgets import Button, Slider

def smooth_max(data,y_grid,R):
    N = data.shape[0]
    T = data.shape[2]
    data_max = torch.max(data.reshape((N*N,T)),dim=0)
    ngh_max = neighbours_grid(data_max[1][-1].item(),y_grid,R).bool()
    smooth_max = data.reshape((N*N,T))[ngh_max,:].sum(0) * (1/N)**2
    return data_max[1][-1].item(), smooth_max

def cmap(color):
    if color=="blue":
        cmap_base = pl.cm.Blues
    elif color=="magenta":
        cmap_base = pl.cm.Purples
    elif color=="green":
        cmap_base = pl.cm.Greens
    else:
        raise ValueError("Unknown color")
    # Get the colormap colors
    cmap = cmap_base(np.arange(cmap_base.N))
    # Set alpha
    cmap[:,-1] = np.linspace(0, 1, cmap_base.N)
    # Create new colormap
    cmap = ListedColormap(cmap)
    return cmap
    
def data_to_smax(data,y_grid,R=0.1,T0=1,cmapcolor="blue"):
    index, smax = smooth_max(data,y_grid,R)
    yx = y_grid[index,0].item()
    yy = y_grid[index,1].item()
    fig, axs = plt.subplots(1,2,figsize=(8,4))
    im = axs[0].imshow(data[:,:,T0-1].numpy().transpose(),origin='lower',cmap=cmap(cmapcolor),vmin=0.0,vmax=data[:,:,-1].max().item())
    axs[0].invert_yaxis()
    axs[0].set_aspect("equal")
    
    scat = axs[0].scatter([yx*N],[yy*N],s=15,c='red')
    circ = axs[0].plot(N*(yx + R*np.cos(2*math.pi*np.linspace(0,1,100))),N*(yy + R*np.sin(2*math.pi*np.linspace(0,1,100))),c='r')

    maxplot, = axs[1].plot(smax)
    timeline = axs[1].vlines(x=T0,ymin=0,ymax=1.1*smax.max().item(),color='r')
    # axs[1].set_xlim(0,T)
    # axs[1].set_ylim(0,1.05*smax.max())
    fig.subplots_adjust(left=0.25, bottom=0.25)
    
    
    axtime = fig.add_axes([0.25, 0.1, 0.65, 0.03])
    frame_slider = Slider(
        ax=axtime,
        label='FRAME',
        valmin=1,
        valmax=data.shape[2],
        valinit=T0,
        valstep=1
    )
    
    def update(val):
        im.set_data(data[:,:,frame_slider.val-1].numpy().transpose())
        # maxplot.set_ydata(smax[:frame_slider.val-1])
        timeline.set_segments([np.array([[frame_slider.val-1, 0], [frame_slider.val-1, 1.1*smax.max().item()]])])
        fig.canvas.draw_idle()
    
    frame_slider.on_changed(update)
    
    
    return fig, axs



In [None]:
# %matplotlib widget
# fig, axs = data_to_smax(data,y_grid,R=0.1)

In [None]:
# for k in range(9):
#     split = "split_20231212"
#     experiment = "FWF00"+str(k+1)
#     cell_type = "Proximal"
#     N = 512
#     data_dir = split+"_"+experiment+"_smooth/"+cell_type+"/data/"
#     data = open_data(data_dir,N=N)
#     fig, axs = data_to_smax(data,y_grid,R=0.1,cmapcolor="magenta",T0=data.shape[2])
#     fig.suptitle(experiment + " (" + cell_type + ")")

In [None]:
# for k in range(9):
#     split = "split_20231212"
#     experiment = "FWF00"+str(k+1)
#     cell_type = "Distal"
#     N = 512
#     data_dir = split+"_"+experiment+"_smooth/"+cell_type+"/data/"
#     data = open_data(data_dir,N=N)
#     fig, axs = data_to_smax(data,y_grid,R=0.1,cmapcolor="green",T0=data.shape[2])
#     fig.suptitle(experiment + " (" + cell_type + ")")

In [None]:
def plot_all(split,experiment,cell_type,N=512):
    if cell_type == "Proximal":
        cmapcolor="magenta"
    elif cell_type == "Distal":
        cmapcolor="green"
    fig = plt.figure(figsize=(28,14))
    subfigs = fig.subfigures(3, 3, wspace=0.07)
    for k in range(9):
        sfi = int(k/3)
        sfj = k%3
        experiment_number = experiment+str(k+1)
        data_dir = split+"_"+experiment_number+"_smooth/"+cell_type+"/data/"
        try:
            data = open_data(data_dir,N=N)
            index, smax = smooth_max(data,y_grid,R)
            yx = y_grid[index,0].item()
            yy = y_grid[index,1].item()
            axs = subfigs[sfi,sfj].subplots(1,2)
            im = axs[0].imshow(data[:,:,-1].numpy().transpose(),origin='lower',cmap=cmap(cmapcolor),vmin=0.0,vmax=data[:,:,-1].max().item())
            axs[0].invert_yaxis()
            axs[0].set_aspect("equal")
            scat = axs[0].scatter([yx*N],[yy*N],s=15,c='red')
            circ = axs[0].plot(N*(yx + R*np.cos(2*math.pi*np.linspace(0,1,100))),N*(yy + R*np.sin(2*math.pi*np.linspace(0,1,100))),c='r')
            maxplot, = axs[1].plot(smax)
            subfigs[sfi,sfj].suptitle(experiment_number + " (" + cell_type + ")")
        except:
            print("There is a problem with " + experiment_number)
    return fig

In [None]:
%matplotlib inline
fig_prox = plot_all("split_20231212","FWF00","Proximal")
fig_dist = plot_all("split_20231212","FWF00","Distal")

fig_prox.savefig("FWF_Proximal_MaxPeak.png")
fig_dist.savefig("FWF_Distal_MaxPeak.png")

In [None]:
%matplotlib inline
fig_prox = plot_all("split_20231212","NCF00","Proximal")
fig_dist = plot_all("split_20231212","NCF00","Distal")

fig_prox.savefig("NCF_Proximal_MaxPeak.png")
fig_dist.savefig("NCF_Distal_MaxPeak.png")

In [None]:
%matplotlib inline
fig_prox = plot_all("split_20231212","RAF00","Proximal")
fig_dist = plot_all("split_20231212","RAF00","Distal")

fig_prox.savefig("RAF_Proximal_MaxPeak.png")
fig_dist.savefig("RAF_Distal_MaxPeak.png")