In [None]:
import os
from timeit import default_timer as timer
from datetime import timedelta
import numpy as np
import nd2
import plotly.graph_objects as go
from io import BytesIO
import base64
from PIL import Image
import piscis
from sklearn.neighbors import NearestNeighbors

In [None]:
icos_img_dir = "../tobias_ICOS"
os.path.exists(icos_img_dir)# just to confirm our path is valid

In [None]:
##### Read in .nd2 files with nd2 library
loc1_untr = nd2.imread(os.path.join(icos_img_dir, "Myla_ICOS_slide1_well2_loc001.nd2")) # Shape is (Z, C, Y, X)
loc1_untr = np.transpose(loc1_untr, (1,0,2,3)) # Transpose to (C, Z, Y, X)

loc1_promKD = nd2.imread(os.path.join(icos_img_dir, "Myla_ICOS_slide1_well1_loc001.nd2"))
loc1_promKD = np.transpose(loc1_promKD, (1,0,2,3)) 

In [None]:
loc1_untr.shape

In [None]:
##### Select desired channel
loc_dot_channel_untr = loc1_untr[1]
loc_dot_channel_promKD = loc1_promKD[1]

In [None]:
print(loc_dot_channel_untr.shape)
print(loc_dot_channel_promKD.shape)

In [None]:
##### Piscis
model = piscis.Piscis(model_name='20230905')

In [None]:
### Untreated
start = timer()
spots_pred_untr = model.predict(loc_dot_channel_untr, threshold=1)
print(f"{timedelta(seconds=timer()-start)} elapsed")

In [None]:
### Promoter KD
start = timer()
spots_pred_promKD = model.predict(loc_dot_channel_promKD, threshold=1)
print(f"{timedelta(seconds=timer()-start)} elapsed")

In [None]:
##### Simple comparitive analysis
numSpots_loc1_untr = [len(s) for s in spots_pred_untr]
numSpots_loc1_promKD = [len(s) for s in spots_pred_promKD]
print(f"Mean no. of spots in Untreated (well #2): {round(np.mean(numSpots_loc1_untr),1)}")
print(f"Mean no. of spots in Promoter KD (well #2): {round(np.mean(numSpots_loc1_promKD),1)}")

In [None]:
spots = spots_pred_untr
print(f"spots shape: {spots.shape}")
print(f"spots[0]: {spots[0]}")

In [None]:
##### Interactive plotting for each Z plane independently
img = loc1_promKD[1] # specify channel to use 
spots = spots_pred_promKD

def normalize_to_uint8(slice_2d):
    p_min, p_max = np.percentile(slice_2d, (1, 99))
    norm = np.clip((slice_2d - p_min) / (p_max - p_min), 0, 1)
    return (norm * 255).astype(np.uint8)

Z = img.shape[0]  # total number of z slices 
frames = []
for z in range(Z):
    img_slice = normalize_to_uint8(img[z]) # get the current z slice and normalize it to [0, 255]

    coords_z = spots[z]  # extract the list with spots from current z slice
    y = coords_z[:, 0]   # y coord
    x = coords_z[:, 1]   # x coord

    frame = go.Frame(
        data=[
            go.Heatmap(  # show image
                z=img_slice,
                colorscale='gray',
                showscale=False
            ),
            go.Scatter(  # show spots
                x=x,
                y=y,
                mode='markers',
                marker=dict(color='red', size=5),
                name='Spots'
            )
        ],
        name=str(z)
    )
    frames.append(frame)

# Add the first frame content
fig = go.Figure(
    data=[
        go.Heatmap(z=normalize_to_uint8(img[0]), colorscale='gray', showscale=False),
        go.Scatter(
            x=spots[0][:, 1],
            y=spots[0][:, 0],
            mode="markers",
            marker=dict(color='red', size=5),
            name='Spots'
        )
    ],
    frames=frames
)

# Add slider and play buttons
fig.update_layout(
    sliders=[{
        "steps": [
            {"method": "animate", "args": [[str(z)], {"mode": "immediate"}], "label": f"Z={z+1}"}
            for z in range(Z)
        ],
        "currentvalue": {"prefix": "Slice: "}
    }],
    height=700,
    width=700,
    title="Z-stack Spot Viewer"
)

fig.update_yaxes(autorange="reversed")  # Important for image-style orientation
fig.show()


In [None]:
##### Function to remove redundant spots in successive z slices
def dedup_spots(spots):
    
    

In [None]:
all_spots = np.concatenate(spots_pred_untr)
nn = NearestNeighbors(radius=0.5)  # 1 pixel radius (tune this)
nn.fit(all_spots)
neighbors = nn.radius_neighbors(all_spots, return_distance=False)

In [None]:
visited = set()
keep_indices = []

for i, neigh in enumerate(neighbors):
    if i in visited:
        continue
    keep_indices.append(i)
    visited.update(neigh)

deduped_spots = all_spots[keep_indices]  # shape (M, 3)

In [None]:
##### Max project z stack
loc1_untr_maxProj = np.max(loc1_untr[1], axis=0)
loc1_promKD_maxProj = np.max(loc1_promKD[1], axis=0)

In [None]:
all_spots

In [None]:
x = all_spots[:,1]
y = all_spots[:,0]

In [27]:

lines_x = []
lines_y = []
for i, nbrs in enumerate(neighbors):
    for j in nbrs:
        if i >= j:  # avoid duplicates
            continue
        lines_x.extend([x[i], x[j], None])
        lines_y.extend([y[i], y[j], None])

# Normalize and convert Z-slice to uint8
def normalize_to_uint8(slice_2d):
    p_min, p_max = np.percentile(slice_2d, (1, 99))
    norm = np.clip((slice_2d - p_min) / (p_max - p_min), 0, 1)
    return (norm * 255).astype(np.uint8)

img_slice = normalize_to_uint8(loc1_untr_maxProj)

pil_img = Image.fromarray(img_slice)
buffer = BytesIO()
pil_img.save(buffer, format="PNG")
encoded = base64.b64encode(buffer.getvalue()).decode()

# Create figure with image background
fig = go.Figure()

# Add lines between neighbors
fig.add_trace(go.Scatter(
    x=lines_x,
    y=lines_y,
    mode='lines',
    line=dict(color='blue', width=1),
    name='Neighbor Links'
))

# Add spot markers
fig.add_trace(go.Scatter(
    x=x,
    y=y,
    mode='markers',
    marker=dict(color='rgba(255,0,0,1)', size=5),
    name='Spots'
))

# Overlay the image
fig.update_layout(
    images=[dict(
        source=f'data:image/png;base64,{encoded}',
        xref="x", yref="y",
        x=0, y=0,
        sizex=img_slice.shape[1],  # X-axis size (width)
        sizey=img_slice.shape[0],  # Y-axis size (height)
        sizing="stretch",
        opacity=1.0,
        layer="below"
    )],
    height=700,
    width=700,
    title=f"Max-Z projection with Neighbors"
)

fig.update_yaxes(autorange='reversed')
fig.write_html('../piscis/loc01_ICOS_untr_spotNeighborhoods.html')
