In [47]:
import pycromanager
from pycromanager import Core
import numpy as np
import matplotlib.pyplot as plt


core = Core()
#from mmpycorex import download_and_install_mm
#download_and_install_mm("C:\\Program Files\\Micro-Manager-pycro")

In [48]:
print("test\\")

test\


In [49]:
def get_data():
    tagged_image = core.get_tagged_image()
    # get the pixels in numpy array and reshape it according to its height and width
    image_array = np.reshape(
        tagged_image.pix,
        newshape=[-1, tagged_image.tags["Height"], tagged_image.tags["Width"]],
    )
    # for display, we can scale the image into the range of 0~255
    image_array = (image_array / image_array.max() * 255).astype("uint8")
    # return the first channel if multiple exists
    return image_array[0, :, :]

In [50]:
import torch

from torch_utils.transform import NormalizeIntensityTrace 
from torchvision import transforms
from torch_utils.transform import BackgroundRemovalNormalize, SkipFrames

def inference(model: torch.nn.Module, input: torch.Tensor, apply_transforms: bool = False) -> torch.Tensor:
    if apply_transforms:
        f = transforms.Compose(
            [NormalizeIntensityTrace(),
            SkipFrames(skip=3),]
        )
        input = f(input)
    
    with torch.no_grad():
        ## Video data has shape (frames, width, height)
        ## Expected input: (batches, frames, channels, width, height)
        input = input.unsqueeze(0).unsqueeze(2)

        output = model.forward(input, inference=True)

        return output.squeeze(0).squeeze(0)

In [56]:

class OnTheFlySegmentation:
    def __init__(self, model, frames: int):
        self.model = model
        self.n_frames = frames
        self.frame_buffer: list[np.ndarray] = []

    def set_n_frames(self, n: int):
        self.n_frames = n
    
    def clear_buffer(self):
        del self.frame_buffer
        self.frame_buffer = []

    def add_image(self, image: np.ndarray):
        self.frame_buffer.append(torch.from_numpy(image))

    def inference(self) -> np.ndarray:
        if (len(self.frame_buffer) < self.n_frames):
            raise ValueError("Number of frames in buffer smaller than requested amount.")
        
        input_frames = torch.zeros((self.n_frames, *self.frame_buffer[0].shape))
        for i in range(self.n_frames):
            image = self.frame_buffer[i]
            input_frames[i, :, :] = image
        self.clear_buffer()

        return inference(self.model, input_frames, apply_transforms=True).numpy()
        


In [52]:
import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib.colors import Normalize
from PIL import Image, ImageTk
def apply_color_map(data, cmap_name: str='viridis', vmin=None, vmax=None):
    cmap = plt.get_cmap(cmap_name)

    norm = Normalize(vmin=vmin, vmax=vmax)

    rgba = cmap(norm(data))
    rgb = rgba[:,:, :3]

    return (255*rgb).astype(np.uint8)


In [62]:
from models import *

model_class = PLSegmentationModel
model_path = 'saved_models\\default_model_v2.model'

model = model_class.load(model_path)

Model loaded from: 'saved_models\default_model_v2.model'


In [63]:
## Synthetic dataset
from models.psf import GuassionPSF
from simulation.grain_PL_simulation import TrainingDataSimulationOptions
from torch_utils.dataset import GeneratedPLOutlineDataset
from torch.utils.data import Dataset

from torchvision import transforms

from torch_utils.transform import BackgroundRemovalNormalize, SkipFrames


## 1 Pixel is 200 nm
def get_training_data(length: int = 20) -> Dataset:
    psf = GuassionPSF(2.5)

    factor = 2
    options = TrainingDataSimulationOptions(
        grid_size=256 // factor,
        min_grains=3000 // (2 * factor * factor),
        max_grains=3200 // (2 * factor * factor),
        min_noise=0.05 ,
        max_noise=0.12,
        sample_rate=10,
        seconds=11, ## Decides how many frames each test samples has: total frames = sample_rate * seconds
        min_blinker_transition=0.04,
        max_blinker_transition=0.1,
        min_base_counts=6000,
        max_base_counts=12000,
        min_hole_chance=0.01,
        max_hole_chance=0.1,
        min_boundary_dimish=0,    
        max_boundary_dimish=1.0,
        min_blinker_strength=0.005,
        max_blinker_strength=0.08,
        min_blinkers_average=50,
        max_blinkers_average=80,
        psf=psf,
        label_scaling=2,
    )

    generated_dataset = GeneratedPLOutlineDataset(length=20, 
                                              sim_options=options,)

    return generated_dataset

dataset = get_training_data(20)
# for _ in range(1000):
#     for i in range(20):
video, label = dataset.__getitem__(0)

segmentation = OnTheFlySegmentation(model, 100)
for i in range(100):
    segmentation.add_image(
        video[i, :, :].numpy()
    )


out = segmentation.inference()
mapped = apply_color_map(out)

im = Image.fromarray(mapped, mode='RGB')
im.save('testing.png')

mapped.dtype

dtype('uint8')

In [None]:
import tkinter as tk
import numpy as np
from PIL import Image, ImageTk
from pycromanager import Acquisition
from pycromanager import multi_d_acquisition_events
import time

# Create the main window
root = tk.Tk()
root.title("Test On the Fly segmentation")
root.geometry("800x600")

# Create frame for button bar
button_frame = tk.Frame(root, bg="#f0f0f0")
button_frame.pack(side=tk.TOP, fill=tk.X, padx=10, pady=10)

current_image: Image.Image = None

# Function to generate and display a random image
def display_random_image():
    # Generate a random image using NumPy
    width, height = 400, 300
    # Create random RGB data
    img_data = np.random.randint(0, 255, (height, width, 3), dtype=np.uint8)
    
    # Convert NumPy array to PIL Image
    img = Image.fromarray(img_data)
    
    # Convert to Tkinter-compatible image
    photo = ImageTk.PhotoImage(img)
    
    # Update the label with the new image
#    image_label.config(image=photo)
    image_label.config(image=photo)
    image_label.image = photo  # Keep a reference to prevent garbage collection


# Button stubs
def open_image():
    # Stub function for opening an image
    print("Open image button clicked")

def save_image():
    global current_image

    if not current_image:
        return
    
    curr_time = int(time.time())
    file_name = f'saved_images/{curr_time}_segmentation.png'
    
    current_image.save(file_name)

def generate_random_image():
    pass

def generate_inference():
    global current_image

    n_frames = int(num_frames_entry.get())
    segmentation = OnTheFlySegmentation(model, n_frames)

    def add_image_func(im_data, metadata, *args):
        segmentation.add_image(im_data)
        return (im_data, metadata)

    print('Starting acquisition')
    with Acquisition(directory=None, name=None, image_process_fn=add_image_func, show_display=False) as acq:
        events = multi_d_acquisition_events(num_time_points=n_frames)
        acq.acquire(events)
    print('Finished acquisition')

    print('Starting inference')
    result = segmentation.inference()
    print('Finished inference')

    mapped_imaged = apply_color_map(result)
    print(f'Max: {np.max(result)}')

    print(f'Output image size: {mapped_imaged.shape}')

    pill_img = Image.fromarray(mapped_imaged, mode='RGB')
    pill_img = pill_img.resize((800,800), Image.Resampling.NEAREST)
    current_image = pill_img

    photo = ImageTk.PhotoImage(pill_img)
    image_label.config(image=photo)
    image_label.image = photo
    
    print('\n\n')
# Add buttons to the button bar
open_button = tk.Button(button_frame, text="Open Image", command=open_image)
open_button.pack(side=tk.LEFT, padx=5)

save_button = tk.Button(button_frame, text="Save Image", command=save_image)
save_button.pack(side=tk.LEFT, padx=5)

generate_button = tk.Button(button_frame, text="Generate Random", command=generate_random_image)
generate_button.pack(side=tk.LEFT, padx=5)


entry_label = tk.Label(button_frame, text="Number of frames: ")
entry_label.pack(side=tk.LEFT, padx=5)

num_frames_entry = tk.Entry(button_frame)
num_frames_entry.pack(side=tk.LEFT, padx=2)
num_frames_entry.insert(0, "100")

acquire_button = tk.Button(button_frame, text="Acquire", command=generate_inference)
acquire_button.pack(side=tk.LEFT, padx=2)


# Create frame for image display
image_frame = tk.Frame(root, bg="#ffffff")
image_frame.pack(side=tk.TOP, fill=tk.BOTH, expand=True, padx=10, pady=10)

# Label to display image
image_label = tk.Label(image_frame, bg="#ffffff", text="Loading image...")
image_label.pack(fill=tk.BOTH, expand=True)

# Schedule the initial image display after the main loop starts
root.after(100, display_random_image)

# Start the main loop
root.mainloop()

Starting acquisition
Finished acquisition
Starting inference
Finished inference
Max: 0.9863914251327515
Output image size: (424, 424, 3)



Starting acquisition
Finished acquisition
Starting inference
Finished inference
Max: 0.967772901058197
Output image size: (424, 424, 3)



