In [1]:
import pycromanager
from pycromanager import Core
import numpy as np
import matplotlib.pyplot as plt


core = Core()
#from mmpycorex import download_and_install_mm
#download_and_install_mm("C:\\Program Files\\Micro-Manager-pycro")

In [2]:
def get_data():
    core.snap_image()
    tagged_image = core.get_tagged_image()
    # get the pixels in numpy array and reshape it according to its height and width
    image_array = np.reshape(
        tagged_image.pix,
        newshape=[-1, tagged_image.tags["Height"], tagged_image.tags["Width"]],
    )
    # for display, we can scale the image into the range of 0~255
    image_array = (image_array / image_array.max() * 255).astype("uint8")
    # return the first channel if multiple exists
    return image_array[0, :, :]

In [3]:
import torch

from torch_utils.transform import NormalizeIntensityTrace 
from torchvision import transforms
from torch_utils.transform import *

def inference(model: torch.nn.Module, input: torch.Tensor, apply_transforms: bool = False) -> torch.Tensor:
    if apply_transforms:
        f = transforms.Compose(
            [
                #NormalizeIntensityTrace(),
                SkipFrames(skip=3),
                ZScoreNorm()
                ]
        )
        input = f(input)
    
    with torch.no_grad():
        ## Video data has shape (frames, width, height)
        ## Expected input: (batches, frames, channels, width, height)
        input = input.unsqueeze(0).unsqueeze(2)

        output = model.forward(input, inference=True)

        return output.squeeze(0).squeeze(0)

In [4]:

class OnTheFlySegmentation:
    def __init__(self, model, frames: int):
        self.model = model
        self.n_frames = frames
        self.frame_buffer: list[np.ndarray] = []

    def set_n_frames(self, n: int):
        self.n_frames = n
    
    def clear_buffer(self):
        del self.frame_buffer
        self.frame_buffer = []

    def add_image(self, image: np.ndarray):
        self.frame_buffer.append(torch.from_numpy(image))

    def inference(self) -> np.ndarray:
        if (len(self.frame_buffer) < self.n_frames):
            raise ValueError("Number of frames in buffer smaller than requested amount.")
        
        input_frames = torch.zeros((self.n_frames, *self.frame_buffer[0].shape))
        for i in range(self.n_frames):
            image = self.frame_buffer[i]
            input_frames[i, :, :] = image
        self.clear_buffer()

        return inference(self.model, input_frames, apply_transforms=True).numpy()
        


In [5]:
import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib.colors import Normalize
from PIL import Image, ImageTk
def apply_color_map(data, cmap_name: str='viridis', vmin=None, vmax=None):
    cmap = plt.get_cmap(cmap_name)

    norm = Normalize(vmin=vmin, vmax=vmax)

    rgba = cmap(norm(data))
    rgb = rgba[:,:, :3]

    return (255*rgb).astype(np.uint8)


In [6]:
from models import *

model_class = PLSegmentationModel
model_path = 'saved_models\\default_model_v3.model'

model = model_class.load(model_path)

Model loaded from: 'saved_models\default_model_v3.model'


In [7]:
## Synthetic dataset
from models.psf import GuassionPSF
from simulation.grain_PL_simulation import TrainingDataSimulationOptions
from torch_utils.dataset import GeneratedPLOutlineDataset
from torch.utils.data import Dataset

from torchvision import transforms

from torch_utils.transform import BackgroundRemovalNormalize, SkipFrames


## 1 Pixel is 200 nm
def get_training_data(length: int = 20) -> Dataset:
    psf = GuassionPSF(2.5)

    factor = 2
    options = TrainingDataSimulationOptions(
        grid_size=256 // factor,
        min_grains=3000 // (2 * factor * factor),
        max_grains=3200 // (2 * factor * factor),
        min_noise=0.05 ,
        max_noise=0.12,
        sample_rate=10,
        seconds=11, ## Decides how many frames each test samples has: total frames = sample_rate * seconds
        min_blinker_transition=0.04,
        max_blinker_transition=0.1,
        min_base_counts=6000,
        max_base_counts=12000,
        min_hole_chance=0.01,
        max_hole_chance=0.1,
        min_boundary_dimish=0,    
        max_boundary_dimish=1.0,
        min_blinker_strength=0.005,
        max_blinker_strength=0.08,
        min_blinkers_average=50,
        max_blinkers_average=80,
        psf=psf,
        label_scaling=2,
    )

    generated_dataset = GeneratedPLOutlineDataset(length=20, 
                                              sim_options=options,)

    return generated_dataset

dataset = get_training_data(20)
# for _ in range(1000):
#     for i in range(20):
video, label = dataset.__getitem__(0)

segmentation = OnTheFlySegmentation(model, 100)
for i in range(100):
    segmentation.add_image(
        video[i, :, :].numpy()
    )


out = segmentation.inference()
mapped = apply_color_map(out)

im = Image.fromarray(mapped, mode='RGB')
im.save('testing.png')

mapped.dtype

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


dtype('uint8')

In [8]:
from collections import deque
import threading

class BufferedSegmentation:
    def __init__(self, model, n: int, color_map = 'viridis'):
        self.n = n
        self.model = model

        self.buffer = deque(maxlen=n)

        self.color_map = color_map
        self.lock = threading.Lock()

    def set_n_frames(self, n: int):
        with self.lock:
            new_buffer = deque(maxlen=n)
            for _, item in enumerate(self.buffer):
                new_buffer.append(item)
            self.buffer = new_buffer

        self.n = n
    
    def clear_buffer(self):
        self.buffer = deque(maxlen=self.n) 

    def add_image(self, image: np.ndarray):
        with self.lock:
            self.buffer.append(torch.from_numpy(image))

    def inference(self) -> np.ndarray:
        input_frames = torch.zeros((len(self.buffer), *self.buffer[0].shape))
        with self.lock:
            for i, image in enumerate(self.buffer):
                input_frames[i, :, :] = image

        model_output = inference(self.model, input_frames, apply_transforms=True).numpy()
        
        mapped_imaged = apply_color_map(model_output, self.color_map)

        return mapped_imaged

In [None]:
import tkinter as tk
import numpy as np

from PIL import Image, ImageTk
from pycromanager import Acquisition
from pycromanager import multi_d_acquisition_events
import time
import os

buffer_segment = BufferedSegmentation(model, 150)

class FramedComponent:
    def __init__(self, root, **pack_args):
        self.root = root
        self.frame = tk.Frame(root)
        self.frame.pack(**pack_args)

        self.pack_args = pack_args

    def hide(self):
        self.frame.pack_forget()

    def show(self):
        self.frame.pack(**self.pack_args)

class ToggleButton:
    def __init__(self, root, on_text:str, off_text: str, command=None, init_state: bool=False, **pack_args):
        self.state = init_state
        if (self.state):
            text = on_text
        else:
            text = off_text

        self.frame = tk.Button(root, text=text, command=self.toggle)
        self.frame.pack(**pack_args)

        self.command = command

        self.on_text = on_text
        self.off_text = off_text

    def toggle(self):
        self.set_state(not self.state)

    def set_state(self, state):
        self.state = state
        if self.command:
            self.command(self.state)

        if (self.state):
            text = self.on_text
        else:
            text = self.off_text
        self.frame.config(text=text)


    def hide(self):
        self.frame.pack_forget()

    def show(self):
        self.frame.pack(**self.pack_args)


class LabeledTypeEntry(FramedComponent):
    def __init__(self, root, label: str, type: type, default: str, command_on_valid=None, **pack_args):
        super().__init__(root, **pack_args)
        
        self.default = default
        self.type = type
        self.command_on_valid = command_on_valid 

        self.label = tk.Label(self.frame, text=label)
        self.label.pack(side=tk.LEFT, padx=5)
        
        self.entry_var = tk.StringVar()
        self.entry_var.trace_add("write", self.on_changed)

        self.entry = tk.Entry(self.frame, textvariable=self.entry_var)
        self.entry.pack(side=tk.LEFT, padx=5)
        self.entry.insert(0, default)

   

    def get_value(self):
        value = self.entry_var.get()
        try:
            return self.type(value)
        except:
            return None

    def on_changed(self, *args):
        if self.command_on_valid:
            value = self.entry_var.get()
            try:
                typed_value = self.type(value)

                self.command_on_valid(typed_value)
            except:
                pass
    



class LabeledImageViewer(FramedComponent):
    def __init__(self, root, label_text: str, im_size, **pack_args):
        super().__init__(root, **pack_args)
        
        self.text_label = tk.Label(self.frame, text=label_text)
        self.text_label.pack(side=tk.TOP, pady=5)

        self.im_label = tk.Label(self.frame, text="No image")
        self.im_label.pack(side=tk.TOP)

        self.im_size = im_size

    ## RGB input
    def set_image(self, data, mode=None):
        pill_img = Image.fromarray(data, mode=mode)
        pill_img = pill_img.resize(self.im_size, Image.Resampling.NEAREST)

        photo = ImageTk.PhotoImage(pill_img)
        self.im_label.config(image=photo)
        self.im_label.image = photo

        return pill_img


class AcquisitionWindow:
    def __init__(self, root):
        global model

        self.root = root

        new_window = tk.Toplevel(root)
        new_window.title("Acquisition")
        new_window.geometry("1900x800")

        new_window.attributes('-toolwindow', False)
        new_window.resizable(True, True)

        if new_window.tk.call('tk', 'windowingsystem') == 'win32':
            new_window.wm_attributes('-toolwindow', 0)
        
        new_window.transient(root)  
        new_window.grab_set()   

        self.window = new_window

        ## Values
        self.acquiring = False
        self.buffered_segment = BufferedSegmentation(model, 100)
        self.n = 500

        ## Frames
        self.upper_bar = tk.Frame(new_window, bg="#f0f0f0")
        self.upper_bar.pack(side=tk.TOP, fill=tk.X, padx=10, pady=10)

        self.lower_content = tk.Frame(new_window, bg="#ffffff")
        self.lower_content.pack(side=tk.TOP, fill=tk.BOTH, expand=True, padx=10, pady=10)
        ## Control buttons and inputs
        self.frames = LabeledTypeEntry(self.upper_bar, "Count", int, "500", 
                                     side=tk.LEFT, padx=10)
        self.interval = LabeledTypeEntry(self.upper_bar, "Interval(ms)", float, "0", 
                                       side=tk.LEFT, padx=10)
        
        self.n_segmentation_frames = LabeledTypeEntry(self.upper_bar, "Segmentation bin size", int, "100", 
                                                    side=tk.LEFT, padx=10)
        self.segmentation_period = LabeledTypeEntry(self.upper_bar, "Segmentation period", int, "100",
                                                  side=tk.LEFT, padx=10)

        self.directory = LabeledTypeEntry(self.upper_bar, "Directory", str, "", 
                                        side=tk.LEFT, padx=10)
        self.file_name = LabeledTypeEntry(self.upper_bar, "Filename", str, "mov",
                                        side=tk.LEFT, padx=10)


        self.start_acquire_button = tk.Button(self.upper_bar, text="Start Acquire", command=self.run_acquire_threaded)
        self.start_acquire_button.pack(side=tk.LEFT, padx=10)
        ##
        
        self.frame_counter = tk.Label(self.lower_content, text=" ")
        self.frame_counter.pack(side=tk.TOP, pady=10)
        
        self.view_frame = tk.Frame(self.lower_content)
        self.view_frame.pack(side=tk.TOP, pady=10)

        self.pl_viewer = LabeledImageViewer(self.view_frame, "PL view", (700, 700),
                                            side=tk.LEFT, padx=10)
        self.segment_viewer = LabeledImageViewer(self.view_frame, "Segmentation", (700, 700),
                                                 side=tk.LEFT, padx=10)



        new_window.wait_window() 
    
    def num_frames_changed(self, n: int):
        self.n = n

    def received_image(self, im: np.ndarray, metadata, *args):
        self.buffered_segment.add_image(im)
        self.frames_received += 1

        self.frame_counter.config(text=f'Frame: {self.frames_received}/{self.total_frames}')
        pl_cmap = apply_color_map(im, cmap_name='grey')
        self.pl_viewer.set_image(pl_cmap)

        ## We have not had enough frames to fill the buffer to n
        if self.frames_received < self.n:
            return (im, metadata)
        ## We already have done a segmentation within the segent period
        if self.frames_received - self.last_run < self.segment_period:
            return (im, metadata)
        

        print(f'Generating segmentation')

        image = self.create_segmentation()
        self.last_run = self.frames_received

        filename = f'frame{self.frames_received}_segmentation.png'
        path = os.path.join(self.dir, filename)
        image.save(path)

        print(f'Finished generating segmentation')

        return (im, metadata)

    def run_acquire_threaded(self):
        threading.Thread(target=self.run_acquire, args=[]).start()

    def run_acquire(self):
        if (self.acquiring):
            return
        
        self.acquiring = True
        dir = self.directory.get_value()
        name = self.file_name.get_value()

        if not dir or not name:
            print('No dir or name given')
            return

        n_segment_frames = self.n_segmentation_frames.get_value()
        segment_period = self.segmentation_period.get_value()

        self.n = n_segment_frames
        self.segment_period = segment_period

        count = self.frames.get_value()
        interval = self.interval.get_value()

        self.buffered_segment.set_n_frames(n_segment_frames)
        self.frames_received = 0
        self.last_run = -100000
        self.total_frames = count

        self.dir = dir

        os.makedirs(dir, exist_ok=True)
        with Acquisition(directory=dir, name=name, image_process_fn=self.received_image, show_display=False) as acq:
            events = multi_d_acquisition_events(num_time_points=count, time_interval_s=interval / 1000.0)
            acq.acquire(events)

        self.acquiring = False

    def create_segmentation(self):
        output_im = self.buffered_segment.inference()        
        final_image = self.segment_viewer.set_image(output_im, mode='RGB')

        return final_image

class App:
    def __init__(self, root):
        self.root = root
        self.root.title("On The Fly Segmentation")

        self.value = 0

        self.last_output = None
        self.last_camera_im = None
        self.segment_period: float = 2.0
        
        ## FRAMES
        self.upper_bar = tk.Frame(root, bg="#f0f0f0")
        self.upper_bar.pack(side=tk.TOP, fill=tk.X, padx=10, pady=10)

        self.lower_content = tk.Frame(root, bg="#ffffff")
        self.lower_content.pack(side=tk.TOP, fill=tk.BOTH, expand=True, padx=10, pady=10)

        ##
        self.live_image = LabeledImageViewer(self.lower_content, "Live preview", (512, 512), side=tk.LEFT, padx=10, pady=10)
        self.segmentation_image = LabeledImageViewer(self.lower_content, "Segmentation", (512, 512), side=tk.LEFT, padx=10)

        ## Toggle live button
        self.show_live = ToggleButton(self.upper_bar, "Stop Live", "Start Live", init_state=True, side=tk.LEFT, padx=10)
        
        ## Toggle hide
        self.hide_button = ToggleButton(self.upper_bar, on_text="Unhide", off_text="Hide", init_state=False, command=self.toggle_hidden, side=tk.LEFT, padx=10)

        ## Create segmenation image
        self.create_segment_button = tk.Button(self.upper_bar, text="Create Segmentation", command=self.create_segmentation_pressed)
        self.create_segment_button.pack(side=tk.LEFT, padx=10)

        ## Segmentation period
        self.period_entry = LabeledTypeEntry(self.upper_bar, "Segmentation period", float, str(self.segment_period), command_on_valid=self.period_changed
                                           , side=tk.LEFT, padx=10)

        ## Open acquire button
        self.open_acquire_button = tk.Button(self.upper_bar, text="Open Acquire", command=self.open_acquire)
        self.open_acquire_button.pack(side=tk.LEFT, padx=10)

        ## Frame counter
        self.frame_count_label = tk.Label(self.upper_bar, text="")
        self.frame_count_label.pack(side=tk.LEFT, padx=10) 

        ## Total counter
        self.total_count_label = tk.Label(self.upper_bar, text="")
        self.total_count_label.pack(side=tk.LEFT, padx=10) 

        ## Clear buffer button
        self.clear_buffer_button = tk.Button(self.upper_bar, text="Clear Buffer", command=self.clear_buffer)
        self.clear_buffer_button.pack(side=tk.LEFT, padx=10) 

        ## Save segmentation button
        self.save_segmentation_button = tk.Button(self.upper_bar, text="Save Segmentation", command=self.save_image)
        self.save_segmentation_button.pack(side=tk.LEFT, padx=10)

        self.do_periodic_segment = True
        self.total_frames = 0

        root.after(100, self.update_live)
        root.after(2000, self.periodic_segment)
    
    def open_acquire(self):
        self.show_live.set_state(False)
        self.do_periodic_segment = False

        acq = AcquisitionWindow(self.root)
        print(f'Acquire closed')

        self.show_live.set_state(True)
        self.do_periodic_segment = True

    def period_changed(self, value: float):
        if value < 0.5:
            return

        self.segment_period = value

    def toggle_hidden(self, state):
        ## If state == True then it should be hidden
        if state:
            self.live_image.hide()
        else:
            self.live_image.show()

    def clear_buffer(self):
        self.total_frames = 0

        buffer_segment.clear_buffer()

    def update_live(self):
        if self.do_periodic_segment:
            try:
                im = get_data()
                self.last_camera_im = im
            
                buffer_segment.add_image(im)

                self.total_frames += 1

                self.frame_count_label.config(text=f"Stored frames: {len(buffer_segment.buffer)}")
                self.total_count_label.config(text=f"Total frames: {self.total_frames}")
                
                if (self.show_live.state):
                    self.live_image.set_image(im)
            except:
                print(f'Failed to get image')

        self.root.after(int(1000/5), self.update_live)

    def save_image(self):
        curr_time = int(time.time())
        file_name = f'saved_images/{curr_time}_segmentation.png'
        
        last_segmentation = Image.fromarray(self.last_output, mode='RGB')
        last_segmentation.save(file_name)

        last_snap = Image.fromarray(self.last_camera_im)
        last_snap.save(f'saved_images/{curr_time}_snap.png')

    def periodic_segment(self):
        if self.do_periodic_segment:
            threading.Thread(target=self.create_segmentation, args=[]).start()
        self.root.after(int(self.segment_period * 1000), self.periodic_segment)

    def create_segmentation_pressed(self):
        threading.Thread(target=self.create_segmentation, args=[]).start()
    def create_segmentation(self):
        global buffer_segment
        
        output_im = buffer_segment.inference()        
        self.segmentation_image.set_image(output_im, mode='RGB')
        self.last_output = output_im

        
root = tk.Tk()
root.geometry("800x600")
app = App(root)
root.mainloop()

  image_array = np.reshape(


Generating segmentation
Finished generating segmentation
Generating segmentation
Finished generating segmentation
Generating segmentation
Finished generating segmentation
Generating segmentation
Finished generating segmentation
Generating segmentation
Finished generating segmentation
Generating segmentation
Finished generating segmentation
Generating segmentation
Finished generating segmentation
Generating segmentation
Finished generating segmentation
Generating segmentation
Finished generating segmentation
Generating segmentation
Finished generating segmentation
Generating segmentation
Finished generating segmentation
Generating segmentation
Finished generating segmentation
Acquire closed
