In [None]:
import os
import numpy as np
import torch
import pandas as pd
import matplotlib.pyplot as plt
import deeplay as dl
import deeptrack as dt
import h5py
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.optim import Adam
from PIL import Image
from torchvision import transforms
from matplotlib.widgets import Slider
from IPython.display import clear_output

In [2]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        # First block of three 3x3 conv layers with 32 filters each
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1)
        
        # Max pooling layer
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # Second block of eight 3x3 conv layers with 32 filters each
        self.conv4 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1)
        self.conv5 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1)
        self.conv6 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1)
        self.conv7 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1)
        self.conv8 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1)
        self.conv9 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1)
        self.conv10 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1)
        self.conv11 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1)
        
        # Final 1x1 conv layer with 3 filters
        self.final_conv = nn.Conv2d(in_channels=32, out_channels=3, kernel_size=1)
        
    def forward(self, x):
        # Apply the first three convolutional layers with ReLU
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        
        # Apply max pooling
        x = self.pool(x)
        
        # Apply the next eight convolutional layers with ReLU
        
        x = F.relu(self.conv4(x))
        x = F.relu(self.conv5(x))
        x = F.relu(self.conv6(x))
        x = F.relu(self.conv7(x))
        x = F.relu(self.conv8(x))
        x = F.relu(self.conv9(x))
        x = F.relu(self.conv10(x))
        x = F.relu(self.conv11(x))
        
        # Apply the final convolutional layer without activation
        x = self.final_conv(x)
        
        return x
    
    def __call__(self,x):
        return self.forward(x)

In [3]:
def generate_standard(image_size: int=64, noise_value: float=1e-4, position: tuple[int]=(0,0), minv: float=0, maxv: float=1) -> tuple:

    """Generates an image of a particle an its position"""
    
    particle=dt.MieSphere(position=(position[0]+image_size//2,position[1]+image_size//2), radius=7e-8, refractive_index=1.4, z=0, position_objective=np.array([0, 0, 0]))
    args=dt.Arguments(hccoeff=lambda: np.random.uniform(-100,100))
    pupil=dt.HorizontalComa(coefficient=args.hccoeff)
    optics=dt.Brightfield(NA=1.0,working_distance=.2e-3,aberration=pupil,wavelength=660e-9,resolution=.15e-6,magnification=1,output_region=(0,0,image_size,image_size),return_field=False,illumination_angle=np.pi) 
    
    def phase_adder(ph):
        def inner(image):
            image=image-1
            image=image*np.exp(1j*ph)
            image=image+1
            return np.abs(image)
        return inner
    
    phadd=dt.Lambda(phase_adder,ph=np.pi/2)
    s0=optics(particle)
    sample=s0>>phadd
    sample=(sample>>dt.Gaussian(sigma=noise_value))
    sample=(sample>>dt.NormalizeMinMax(minv, maxv))
    im = sample.update()()
    positions = im.get_property('position', get_one=False)
    return im, positions

In [4]:
class Particle_dataset(torch.utils.data.Dataset):

    """Base Particle_dataset from which all specialized data sets inherit. \n
    Contains necessary methods for torch, such as __len__ and __getitem__, 
    and some useful methods, such as __add__ and __setitem__"""

    def __init__(self):
        self.images=[]

    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, i: int):
        return self.images[i]
    
    def __setitem__(self, i: int, value: any):
        self.images[i] = value
        
    def __add__(self, data):

        """Combines two datasets into one, or adds images to an existing dataset""" 

        if 'dataset' in str(type(data)) or 'Matlab' in str(type(data)):
            self.images += data.images
        else:
            self.images += [data]
        return self
    
    def __neg__(self):
        
        """Negates the pixel values in the image (inverts the image)"""

        for i in range(len(self.images)):
            self.images[i] = -self.images[i]
        return self
    
    def __call__(self):
        return self.images
    
    def delete(self, i: int):
        self.images = [image for image, j in enumerate(self.images) if j != i]

class Image_dataset(Particle_dataset):

    """Creates a Particle_dataset from all images (.png) contained in the folder path
    Training the network on .png files is not recommended due to possible loss of information
    during saving and loading from our method, but could work if care is taken."""

    def __init__(self, path):
        super().__init__()
        self.folder_path = path
        self.extract_pngs(path)

    def extract_pngs(self, folder_path):
        self.image_paths = []
        for filename in os.listdir(folder_path):
            filepath = os.path.join(folder_path, filename)
            if os.path.isfile(filepath):
                self.image_paths.append(filepath)

        transform = transforms.ToTensor()
        for image_path in self.image_paths:
            img = Image.open(image_path)
            img_tensor = transform(img)
            self.images.append(img_tensor)

class Array_dataset(Particle_dataset):

    """Creates a Particle_dataset from all arrays (.npy) contained in the folder path"""

    def __init__(self, path):
        super().__init__()
        self.folder_path = path
        self.extract_arrays(path)

    def extract_arrays(self, folder_path):
        self.image_paths = [os.path.join(folder_path, filename) for filename in os.listdir(folder_path)]    
        
        for image_path in self.image_paths:
            image = np.load(image_path).astype('float32')
            img_tensor = torch.from_numpy(image).unsqueeze(0)
            self.images.append(img_tensor)

class Generative_dataset(Particle_dataset):

    """Generates a Particle_dataset from a default deeptrack generator if not otherwise specified. \n
    Invert=True will invert the image value \n
    If both inverted and non-inverted images are used for training, make sure to normalize
    in some way to make the images comparable to each other"""

    def __init__(self, num_samples=1, generator=generate_standard, invert=False, **kwargs):
        super().__init__()
        func = lambda x: (-1)**invert*x
        for i in range(num_samples):
            image, label = generator(**kwargs)
            self.images.append(func(torch.tensor(image).float().permute(2, 0, 1)))

class Matlab_dataset(Particle_dataset):

    """Create a Particle_dataset from a matlab file with images in range "Range". \n
    If Range is not specified, all images in the matlab file will be loaded which takes
    a lot of space and time"""

    def __init__(self, path: str=None, Range: tuple[int]=None):
        super().__init__()
        file = h5py.File(path)['Im_stack']
        if Range!=None:
            for i in range(*Range):
                image = file[i]
                img_tensor = torch.from_numpy(image).unsqueeze(0)
                self.images.append(img_tensor.type(torch.float32))
        else:
            for image in file:
                img_tensor = torch.from_numpy(image).unsqueeze(0)
                self.images.append(img_tensor.type(torch.float32))

In [21]:
class Trainer:

    def __init__(self):
        
        self.epochs = 2000
        self.num_transforms = 8
        self.batch_size = 1
        self.learning_rate = 1e-4
        self.position = (0,0)
        self.num_samples = 1 # Does nothing atm
        self.lower_b = None # Does nothing atm

        self.device = torch.device('cpu')
        self.Network = CNN().to(self.device)
        self.lodestar = dl.LodeSTAR(model=self.Network, n_transforms=self.num_transforms)

        self.dataset = Particle_dataset()
        self.trainset = Particle_dataset()

    def set_dataset(self, dataset: Particle_dataset) -> None:

        """Set trainer's regular data set (recommended method)"""

        self.dataset = dataset

    def add_dataset(self, dataset: Particle_dataset) -> None:

        """Add a data set to trainer's regular data set (recommended method). \n
        Only works if data set has an __add__ method such as a Particle_dataset does"""

        self.dataset += dataset

    def generate_data(self, **kwargs) -> None:

        """Add generated images from deeptrack to the trainer's regular data set (not recommended)"""

        generated_data = Generative_dataset(
            image_size=kwargs.get('image_size', 64), 
            generator=kwargs.get('generator', generate_standard), 
            invert=kwargs.get('invert', False), 
            **kwargs
            )
        self.dataset += generated_data

    def load_data(self, folder_path: str=os.path.join(os.getcwd(), r'\Bilder'), mapp: str='static_removed_images', Dataset: Particle_dataset=Array_dataset, **kwargs) -> None:
        
        """Add loaded images to the trainer's regular data set (not recommended)"""
        
        loaded_data = Dataset(folder_path, mapp, **kwargs)
        self.dataset += loaded_data

    def clear_data(self) -> None:

        """Clear trainer's regular data set"""

        self.dataset = Particle_dataset()
    
    def clear_train_data(self) -> None:

        """Clear trainer's train set"""

        self.trainset = Particle_dataset()

    def add_train_data(self, i: int|list|torch.Tensor=None) -> None:

        """Add data to trainer's train set. \n
        If None the trainer will load the entire dataset as trainset."""

        if i==None:
            self.trainset += self.dataset
        elif type(i)==list:
            for j in i:
                self.trainset += self.dataset[j]
        elif type(i)==int:
            self.trainset += self.dataset[i]
        elif 'torch' in str(type(i)):
            self.trainset += i

    def show_data(self, i: int|list=None) -> None:

        """Plots all data from trainer's regular data if indices are not inputted"""

        if i==None:
            for j in range(len(self.dataset)):
                plt.title(j)
                plt.imshow(self.dataset[j].squeeze(), cmap='gray')
                plt.show()
        elif type(i)==list:
            for j in i:
                plt.title(j)
                plt.imshow(self.dataset[j].squeeze(), cmap='gray')
                plt.show()
        elif type(i)==int:
            plt.title(i)
            plt.imshow(self.dataset[i].squeeze(), cmap='gray')
            plt.show()

    def transform(self, image: torch.FloatTensor, mode: str='barebones') -> None:

        """Extension of the transform_data transform. \n 
        If the mode='realistic' then random multiplication and noise will be applied to every transformation. """

        if mode=='realistic':
            image *= np.random.uniform(0.1, 2)
            image += np.random.normal(scale=np.random.rand()*1e-3, size=image.shape).astype('float32')
            transforms, inverses = self.lodestar.transform_data(image)
        if mode=='barebones':
            transforms, inverses = self.lodestar.transform_data(image)
        return transforms, inverses

    def train(self, mode: str='barebones') -> None:

        """Starts training the network on the trainers train set"""

        if len(self.trainset)==0:
            print('No train data. Add train data with add_train_data.')
            return
        
        dataloader = DataLoader(self.trainset, shuffle=True)
        optimizer = Adam(params = self.Network.parameters(), lr=self.learning_rate)
        self.losses = []

        for i in range(1, self.epochs+1):
            if i%100 == 0:
                clear_output()
                print(f'Epoch {i}/{self.epochs} \t Loss: {loss:.3f}')
                if i!=self.epochs-1:
                    print('|'+'-'*(35*i//self.epochs) + ' '*(35-(35*i//self.epochs)) + '|')
                else:
                    print('|'+'-'*35+'|')

            for image in dataloader:
                image = image.to(self.device)
                transforms, inverses = self.transform(image, mode=mode)
                if i == 1:
                    self.transforms_example = transforms
                prediction = self.lodestar.forward(transforms)
                loss_dict = self.lodestar.compute_loss(prediction, inverses)
                loss1, loss2 = loss_dict['between_image_disagreement'], loss_dict['within_image_disagreement']
                loss = loss1 + loss2
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                self.losses.append(loss.item())
            
            if self.lower_b != None:
                if self.losses[i] < self.lower_b and self.losses[i-1] < self.lower_b and self.losses[i-2] < self.lower_b:
                    break

    def plot_losses(self, yscale: str='linear') -> None:

        """Plots losses from training (will not work if you used load_network)"""

        plt.scatter([i for i in list(range(1, len(self.losses)+1))], self.losses, s=1, color='black')
        plt.grid()
        plt.yscale(yscale)
        plt.show()

    def plot_scatter(self, image: torch.Tensor, points: list[list[float]], show: bool=False, **kwargs: any) -> None:
        
        """Shows image and points if show=True"""

        plt.imshow(image.detach().numpy().squeeze(), cmap=kwargs.get('cmap'))
        for point in points:
            plt.scatter(point[0], point[1], marker=kwargs.get('marker'), color=kwargs.get('color'), s=kwargs.get('s', 10))
        if show:
            plt.show()

    def animate(self, dataset: Particle_dataset, points: list[list[float]]=None, show: str=True, **kwargs: any) -> None:

        """Animates images contained in dataset and scatters points if provided [[points frame 0], [points frame 1] ...]"""

        fig, ax = plt.subplots()
        im = ax.imshow(dataset[0].squeeze(), cmap=kwargs.get('cmap'))
        if points!=None:
            graph = ax.scatter([point[0] for point in points[0]], [point[1] for point in points[0]], 
                               marker=kwargs.get('marker', 'o'), 
                               color=kwargs.get('color', 'blue'), 
                               s=kwargs.get('s'), 
                               alpha=kwargs.get('alpha')
                               )

        ax_slider = plt.axes([0.1, 0.01, 0.8, 0.03])
        slider = Slider(ax_slider, 'Frame', 0, len(dataset)-1, valinit=0, valstep=1)

        def update(val):
            frame = int(slider.val)
            im.set_array(dataset[frame].squeeze())
            if points!=None:
                graph.set_offsets(points[frame])
            fig.canvas.draw_idle()
    
        slider.on_changed(update)
        if show:
            plt.show()

    def predict_single(self, image: torch.FloatTensor, plot: bool=True) -> None:

        """Makes a single prediction of the posisition of one particle.\n
        This function will always predict one particle, and is thus not intended for detection"""

        raw_pred = self.lodestar.pooled(image.unsqueeze(0)).detach().numpy()
        processed_pred = raw_pred + np.array([image.shape[2], image.shape[1]]) / 2
        position = processed_pred
        if plot:
            self.plot_scatter(image, position, show=True)
        else:
            return position

    def get_detections(self, image: torch.Tensor, **kwargs) -> np.ndarray:

        """Returns the positions of all particles detected in the image"""

        raw_preds = self.lodestar.detect(image.unsqueeze(0), cutoff=kwargs.get('cutoff', 0.9), alpha=kwargs.get('alpha', 0.5))[0]
        processed_pred = raw_preds + np.array([image.shape[1], image.shape[2]])/2
        return processed_pred[:, ::-1]

    def detect(self, Input: np.ndarray|torch.Tensor|Particle_dataset, plot: bool=None, **kwargs) -> list[list[float]]|list[list[list[float]]]|None:

        """Plots/animates all detections in an image/dataset if plot=True,
        else it returns the detections"""
        
        if 'numpy' in str(type(Input)):
            Input = torch.from_numpy(Input)
        if 'torch' in str(type(Input)):
            image = Input
            detections = self.get_detections(image, kwargs.get('cutoff'),kwargs.get('alpha'))
            if plot==None or plot:
                self.plot_scatter(image, detections)
        if 'dataset' in str(type(Input)) or 'Matlab' in str(type(Input)):
            dataset = Input
            detections = [self.get_detections(dataset[i], **kwargs) for i in range(len(dataset))]
            if plot == None or plot:
                self.animate(dataset, detections)
        else:
            print('Input not recognised. Input either an unsqueezed torch tensor, unsqueezed numpy array, or a dataset with unsqueezed images.')
        if not plot:
            return detections

    def get_df(self, dataset: Particle_dataset=None, time_interval: tuple= None, **kwargs: any) -> pd.DataFrame:

        """Returns a dataframe which can be used in MAGIK
        If dataset is not specified, the function uses the trainer's regular data set"""

        data_dict = {'centroid-0': [], 'centroid-1': [], 'frame': [], 'label': [], 'set': [], 'solution': []}
        if dataset==None:
            dataset=self.dataset
        if time_interval==None:
            time_interval=[len(dataset)]
        for t in range(*time_interval):
            detections = self.get_detections(dataset[t], **kwargs)
            data_dict['centroid-0'] += [pos[0] for pos in detections]
            data_dict['centroid-1'] += [pos[1] for pos in detections]
            data_dict['frame'] += [t]*len(detections)
            data_dict['label'] += [0]*len(detections)
            data_dict['set'] += [0]*len(detections)
            data_dict['solution'] += [0]*len(detections)
        df = pd.DataFrame.from_dict(data_dict)
        self.df = df
        return df

    def get_score(self, image: torch.Tensor) -> torch.Tensor:

        """Returns the detection score map of an image as outlined in the paper
        "Single-shot self-supervised object detection in microscopy" \n
        The local maxima of the score map are the final detections"""

        yhat = self.lodestar(image.unsqueeze(0))[0]
        pred, weight = yhat[:, :-1], yhat[:,-1:]
        score = self.lodestar.get_detection_score(pred, weight)
        return score
    
    def show_score(self, image: torch.Tensor, mode: str='3D') -> None:

        """Shows the detection score map in 3D or in "2D" """

        score = self.get_score(image)

        if mode.lower()=='3d':
            x = np.arange(0, score.shape[1], 1)
            y = np.arange(0, score.shape[0], 1)
            X, Y = np.meshgrid(x, y)
            Z = score.ravel()
            fig = plt.figure()
            ax = fig.add_subplot(111, projection='3d')
            ax.plot_surface(X, Y, score, cmap='gray')
            ax.set_xlabel('X')
            ax.set_ylabel('Y')
            ax.set_zlabel('Z')
            plt.show()
        if mode.lower()=='2d':
            plt.imshow(score, cmap='gray')
            plt.show()

    def downscale(self, i: int|list=None) -> None:

        """Downscales all images in trainer's data set, using maxpool, if indices are not provided"""

        if i==None:
            for j in range(len(self.dataset)):
                self.dataset[j] = torch.nn.functional.max_pool2d(self.dataset[j], kernel_size=2)
        elif type(i)==int:
            self.dataset[i] = torch.nn.functional.max_pool2d(self.dataset[i], kernel_size=2)
        elif type(i)==list:
            for j in i:
                self.dataset[j] = torch.nn.functional.max_pool2d(self.dataset[j], kernel_size=2)
    
    def Normalize(self, value: float, Input: torch.Tensor|Particle_dataset=None) -> None|torch.Tensor|Particle_dataset:

        """Placeholder function (does not normalize but rescales instead)"""

        if Input==None:
            for i in range(len(self.dataset.images)):
                self.dataset.images[i] *= value
        if 'torch' in str(type(Input)):
            Input *= value
            return Input
        if 'dataset' in str(type(Input)) or 'Matlab' in str(type(Input)):
            for i in range(len(Input.images)):
                Input.images[i] *= value
                return Input
    
    def load_network(self, path: str) -> None:

        """Loads network from specified path"""

        self.Network = torch.load(path)
        self.lodestar = dl.LodeSTAR(model=self.Network, n_transforms=self.num_transforms)

Load pre-trained network

In [22]:
network_path = r''
matlab_files_path = r''
dataset = Matlab_dataset(matlab_files_path, Range=(0,10))

T = Trainer()
T.load_network(network_path)
T.set_dataset(dataset)
T.Normalize(4)

Train Network (Do not run if you loaded a network) 

In [6]:
# T.add_train_data(T.dataset[0][:,226:290, 426:493]) # 300 nm

# T.add_train_data(T.dataset[0][:,295:365,330:399]) # 500 nm Svart
# T.add_train_data(T.dataset[0][:,585:654,464:534]) # 500 nm Vit

# T.add_train_data(T.dataset[0][:,528:594,147:210]) # 5100 nm svart
# T.add_train_data(T.dataset[0][:,95:223, 220:347]) # 5100 nm stor
# T.add_train_data(T.dataset[0][:,42:105, 244:306]) # 5100 nm vit

T.train(mode='barebones')
T.plot_losses(yscale='log')

Visualize data

In [None]:
# v Required for animations v
%matplotlib widget

T.detect(dataset, plot=True, fc='None', marker='o')
# T.show_score(T.dataset[0])
# T.show_data(0)

Export detections to .csv file

In [11]:
save_path = r''
df = T.get_df(T.dataset, cutoff=0.95)
df.to_csv(save_path)