In [None]:
import h5py
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import dask.array as da

import ipywidgets as widgets
from IPython.display import display, clear_output

import torch
import torch.nn as nn

device = torch.device('cuda:0') # use the first GPU on this system

In [None]:
# load in h5 file
h5_location = r'C:\Users\formanj\GitHub\FISH_Processing\Demos\DUSP1_Dex_0min_20220224\DUSP1_Dex_0min_20220224.h5'
h5_file = h5py.File(h5_location, 'r') # making this read is important otherwise you may get WINERROR 33

level = 1

def print_group(name, obj):
    # Calculate current level based on the name depth
    current_level = name.count('/')
    if current_level <= level:
        print("  " * 2 * current_level + f"- {name}")

h5_file.visititems(print_group)


In [None]:
# select the spot df 
spots = pd.read_hdf(h5_location, key='Analysis_demo_2024-11-21/df_spotresults')
spots

In [None]:
# load in images
d = h5_file['raw_images']
# p, t, c, z, y, x
images = da.from_array(d, (1,1,-1,-1,-1,-1))

# display each of the channels
def display_channel(channel):
    plt.imshow(np.max(images[0,0,channel,:,:,:], axis=0))
    # increase the contrast of the image automatically
    plt.clim(0, np.percentile(images[0,0,channel,:,:,:].flatten(), 99.99))
    plt.show()

for i in range(images.shape[2]):
    display_channel(i)

In [None]:
# build gui to classify spots
def classify_spots(image, spot_df, fish_channel):
    # Go to the first position (index 0), first time point (index 1), and the fish channel (index 2)
    img = image[0, 0, fish_channel, :, :, :]
    
    # Max project on the z axis (index 4)
    max_proj_img = np.max(img, axis=0)
    
    # Create a column of NaN in the spot df called good_spots
    spot_df['good_spots'] = np.nan
    
    # Initialize the position index
    pos_index = 0

    current_spots = spot_df[spot_df['fov'] == 0]
    
    if current_spots.empty:
        print(f"No spots found for position {0}")
        return
    
    def update_display(idx):
        clear_output(wait=True)
        
        # Filter spots for the current position
        current_spots = spot_df[spot_df['fov'] == pos_index]
        spot = current_spots.iloc[idx]
        if current_spots.empty:
            print(f"No spots found for position {pos_index}")
            return
        
        # Display each spot
        y, x = int(spot['y_px']), int(spot['x_px'])

        crop = max_proj_img[y-15:y+15, x-15:x+15]
        
        fig, ax = plt.subplots(1, 2, figsize=(10, 5))
        
        im = ax[0].imshow(max_proj_img)
        ax[0].set_xlim(x-100 if x-100 > 0 else 0, x+100 if x+100 < max_proj_img.shape[1] else max_proj_img.shape[1])
        ax[0].set_ylim(y-100 if y-100 > 0 else 0, y+100 if y+100 < max_proj_img.shape[0] else max_proj_img.shape[0])
        ax[0].scatter([x], [y], c='r')
        im.set_clim(0, np.percentile(max_proj_img.flatten(), 99.99)) # TODO: Use big fish strech function
        ax[0].set_title('Max Projected Image')
        
        im = ax[1].imshow(crop)
        im.set_clim(0, np.percentile(crop.flatten(), 99.99))
        ax[1].set_title('Spot Crop')
        
        plt.show()
        
        def mark_good(b):
            nonlocal idx
            spot_df.at[idx, 'good_spots'] = 1
            idx += 1
            update_display(idx)
        
        def mark_bad(b):
            nonlocal idx
            spot_df.at[idx, 'good_spots'] = 0
            idx += 1
            update_display(idx)
        
        def next_image(b):
            nonlocal pos_index
            nonlocal idx
            nonlocal img
            nonlocal max_proj_img
            pos_index += 1
            idx = 0
            img = image[pos_index, 0, fish_channel, :, :, :]
            # Max project on the z axis (index 4)
            max_proj_img = np.max(img, axis=0)
            update_display(idx)
        
        def finish(b):
            print("Finished classifying spots")
            return

        finish_button = widgets.Button(description="Finish")
        finish_button.on_click(finish)
        
        good_button = widgets.Button(description="Good Spot")
        good_button.on_click(mark_good)
        
        bad_button = widgets.Button(description="Bad Spot")
        bad_button.on_click(mark_bad)
        
        next_button = widgets.Button(description="Next Image")
        next_button.on_click(next_image)
        
        display(widgets.HBox([good_button, bad_button, next_button, finish_button]))


    break_loop = False
    idx = 0
    spot = current_spots.iloc[idx]
    update_display(idx)

# Example usage
classify_spots(images, spots, fish_channel=0)




In [None]:
spots

In [None]:
# build binary classifier to classify spots as good or bad
input_width = 36

model = nn.Sequential(
    nn.Conv2d(1, 1, kernel_size=3, padding=1),
    nn.Sigmoid(),
    nn.Conv2d(1, 1, kernel_size=3, padding=1),
    nn.Sigmoid(),
    nn.Flatten(),
    nn.Linear(input_width**2, 1),
    nn.Sigmoid()
)

loss = nn.BCELoss()


In [None]:
# train classifier
def extract_images(image, xyz, width):
    z, y, x = xyz
    return np.max(image[:, y-width//2:y+width//2, x-width//2:x+width//2], axis=0)

dataset = []
for i, row in spots.iterrows():
    if row['good_spots'] == 1 or row['good_spots'] == 0:
        # Extract the image
        img = extract_images(images[row['fov'], row['timepoint'], row['FISH_Channel']:, :, :], (row['z_px'], row['y_px'], row['x_px']), input_width)

        dataset.append(img)
        
        



In [None]:
n_epochs = 10

model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.003)

for epoch in range(n_epochs):
    for i, img in enumerate(dataset):
        img = torch.tensor(img).unsqueeze(0).unsqueeze(0).float()
        y = torch.tensor(spots.iloc[i]['good_spots']).float()
        
        y_pred = model(img)
        l = loss(y_pred, y)
        
        l.backward()
        
        if i % 100 == 0:
            print(f"Epoch {epoch}, Image {i}, Loss: {l}")
        
        optimizer.step()
        optimizer.zero_grad()




In [None]:
# save data location of training data
