In [None]:
import os
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib import cm
import tifffile
import numpy as np
from sklearn.cluster import DBSCAN
from scipy.spatial import ConvexHull
from pathlib import Path

def frame_counts(csv_file, frames):
    spots = pd.read_csv(csv_file, skiprows=[1, 2, 3], low_memory=False)
    if not isinstance(frames, list):
        frames = [frames]
    counts = []
    for f in frames:
        counts.append(len(spots[spots["FRAME"] == f]))
    return counts

def frame_clusters_with_dbscan(csv_file, frame, db_eps, db_min_sample):
    spots = pd.read_csv(csv_file, skiprows=[1, 2, 3], low_memory=False)
    spots_frame = spots[spots["FRAME"] == frame]
    pos_x = spots_frame["POSITION_X"].values
    pos_y = spots_frame["POSITION_Y"].values
    pos = np.vstack((pos_x, pos_y)).T
    if len(pos) == 0:
        return 0, pos, []
    clustering = DBSCAN(eps=db_eps, min_samples=db_min_sample).fit(pos)
    num_clusters = len(set(clustering.labels_)) - (1 if -1 in clustering.labels_ else 0)
    return num_clusters, pos, clustering.labels_

def plot_clusters(ax, img, frame, pos, labels):
    with tifffile.TiffFile(img) as tif:
        frame_img = tif.pages[frame*2+1].asarray()
    frame_img = frame_img.astype(np.uint8)
    frame_img[frame_img > 100] = 255
    ax.imshow(frame_img, cmap="gray")
    for label in set(labels):
        if label == -1:
            continue
        cluster_pos = pos[labels == label]
        if len(cluster_pos) < 3:
            # draw a circle that fits the one point or both points, with a minimum radius
            if len(cluster_pos) == 1:
                center_x = cluster_pos[0, 0]
                center_y = cluster_pos[0, 1]
                radius = 10
            else:
                center_x = (cluster_pos[0, 0] + cluster_pos[1, 0]) / 2
                center_y = (cluster_pos[0, 1] + cluster_pos[1, 1]) / 2
                radius = np.linalg.norm(cluster_pos[0] - cluster_pos[1]) / 2
            # now plot a circle with the parameters center_x center_y and radius
            circle = plt.Circle((center_x, center_y), radius, color="r", fill=False)
            ax.add_patch(circle)
        else:
            hull = ConvexHull(cluster_pos)
            ax.plot(cluster_pos[hull.vertices, 0], cluster_pos[hull.vertices, 1], "r-")

def get_files_list(root):
    root = Path(root)
    return [root/f for f in os.listdir(root) if not f.startswith('.')]

def match_csv_and_img_file(csv_files, img_files):
    matched_files = []
    for csv_file in csv_files:
        csv_name = Path(csv_file).stem
        for img_file in img_files:
            img_name = Path(img_file).stem
            if img_name in csv_name:
                matched_files.append((img_file, csv_file))
    return matched_files


In [None]:
PROJECT_ROOT = Path("/Users/ashkanhzdr/workspace/ViralTally")
DATASET_ROOT = PROJECT_ROOT / "dataset" / "Sars2Plaque"

# A549 plots

In [None]:
dmso_csv_files = get_files_list(DATASET_ROOT / "trackmate" / "ACE2_A549_DMSO_wo_bg")
dmso_img_files = get_files_list(DATASET_ROOT / "processed" / "ACE2_A549_DMSO_wo_bg")

inhibitor_csv_files = get_files_list(DATASET_ROOT / "trackmate" / "ACE2_A549_inhibitor_wo_bg")
inhibitor_img_files = get_files_list(DATASET_ROOT / "processed" / "ACE2_A549_inhibitor_wo_bg")

In [None]:
FRAME = 5
DB_EPS = 50
DB_MIN_SAMPLES = 1

dmso_files = match_csv_and_img_file(dmso_csv_files, dmso_img_files)
inhibitor_files = match_csv_and_img_file(inhibitor_csv_files, inhibitor_img_files)
files = [dmso_files, inhibitor_files]
names = ["DMSO", "Inhibitor"]

for g, group in enumerate(names):
    # fig, axes = plt.subplots(nrows=len(images[g]), ncols=len(frames), figsize=(10, 50))
    for i, (img, csv) in enumerate(files[g]):
        figure = plt.figure(figsize=(15, 15))
        ax = figure.add_subplot(111)
        num_cluster, pos, labels = frame_clusters_with_dbscan(csv, FRAME, db_eps=DB_EPS, db_min_sample=DB_MIN_SAMPLES)
        # also plot pos dot scatters with yellow color
        ax.scatter(pos[:, 0], pos[:, 1], c="blue", edgecolors='green', s=20)
        plot_clusters(ax, img, FRAME, pos, labels)
        ax.set_title(f"{img.stem} : dbscan:{num_cluster} ; counts:{len(pos)}")
        ax.axis("off")
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.show()

# Vero files

In [None]:
PROJECT_ROOT = Path("/Users/ashkanhzdr/workspace/ViralTally")
DATASET_ROOT = PROJECT_ROOT / "dataset" / "Sars2Plaque"

dmso_csv_files = get_files_list(DATASET_ROOT / "trackmate" / "Vero_DMSO_wo_bg")
dmso_img_files = get_files_list(DATASET_ROOT / "processed" / "Vero_DMSO_wo_bg")

inhibitor_csv_files = get_files_list(DATASET_ROOT / "trackmate" / "Vero_inhibitor_wo_bg")
inhibitor_img_files = get_files_list(DATASET_ROOT / "processed" / "Vero_inhibitor_wo_bg")

FRAME = 4

DB_EPS = 20
DB_MIN_SAMPLES = 1

dmso_files = match_csv_and_img_file(dmso_csv_files, dmso_img_files)
inhibitor_files = match_csv_and_img_file(inhibitor_csv_files, inhibitor_img_files)
files = [dmso_files, inhibitor_files]
names = ["DMSO", "Inhibitor"]

for g, group in enumerate(names):
    # fig, axes = plt.subplots(nrows=len(images[g]), ncols=len(frames), figsize=(10, 50))
    for i, (img, csv) in enumerate(files[g]):
        figure = plt.figure(figsize=(15, 15))
        ax = figure.add_subplot(111)
        num_cluster, pos, labels = frame_clusters_with_dbscan(csv, FRAME, db_eps=DB_EPS, db_min_sample=DB_MIN_SAMPLES)
        # also plot pos dot scatters with yellow color
        ax.scatter(pos[:, 0], pos[:, 1], c="blue", edgecolors='green', s=20)
        plot_clusters(ax, img, FRAME, pos, labels)
        ax.set_title(f"{img.stem} : dbscan:{num_cluster} ; counts:{len(pos)}")
        ax.axis("off")
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.show()