In [1]:
import argparse
import json
import os
import pathlib
import sys
import time

import matplotlib.image as mpimg
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import psutil
import scipy
import tifffile
import torch
from arg_parsing_utils import check_for_missing_args, parse_args
from cellpose import models
from file_reading import *
from file_reading import read_zstack_image
from general_segmentation_utils import *
from notebook_init_utils import bandicoot_check, init_notebook
from organoid_segmentation import *
from segmentation_decoupling import *
from skimage.filters import sobel
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.model_selection import train_test_split

In [2]:
def read_labels(infile: str) -> dict:
    """
    Description
    ----------
    Read labels from a parquet file.
    Parameters
    ----------
    infile : str
        Path to the input parquet file.
    Returns
    -------
    dict
        Dictionary containing the labels.
    """
    data = pd.read_parquet(infile).to_dict(orient="list")
    return data

In [3]:
start_time = time.time()
# get starting memory (cpu)
start_mem = psutil.Process(os.getpid()).memory_info().rss / 1024**2

In [4]:
root_dir, in_notebook = init_notebook()

image_base_dir = bandicoot_check(
    pathlib.Path(os.path.expanduser("~/mnt/bandicoot")).resolve(), root_dir
)
patient_list_file_path = pathlib.Path(f"{root_dir}/data/patient_IDs.txt").resolve(
    strict=True
)

In [5]:
labels_save_file = pathlib.Path(
    "../image_labels/organoid_image_labels.parquet"
).resolve()
sammed_features_save_path = pathlib.Path(
    f"../../3.cellprofiling/results/sammed_features.parquet"
).resolve()
labels = read_labels(labels_save_file)
labels_df = pd.DataFrame(labels)
labels_df
sammed_features_df = pd.read_parquet(sammed_features_save_path)
sammed_features_df
df = pd.merge(
    sammed_features_df,
    labels_df,
    on=["patient", "well_fov"],
    how="right",
)
# drop rows with na
df = df.dropna(subset=["label"])
df

Unnamed: 0,patient,well_fov,405_SAMMed3D_feature_0,405_SAMMed3D_feature_1,405_SAMMed3D_feature_10,405_SAMMed3D_feature_100,405_SAMMed3D_feature_101,405_SAMMed3D_feature_102,405_SAMMed3D_feature_103,405_SAMMed3D_feature_104,...,TRANS_SAMMed3D_feature_92,TRANS_SAMMed3D_feature_93,TRANS_SAMMed3D_feature_94,TRANS_SAMMed3D_feature_95,TRANS_SAMMed3D_feature_96,TRANS_SAMMed3D_feature_97,TRANS_SAMMed3D_feature_98,TRANS_SAMMed3D_feature_99,label,annotator
0,NF0014_T1,C10-1,-0.319416,-0.301876,0.236247,-0.006220,-0.155010,0.167629,0.024517,-0.111055,...,-0.122748,-0.011013,0.003297,-0.104058,0.051536,0.142032,0.460218,0.099868,globular,Mike
1,NF0014_T1,C10-2,-0.262450,-0.291977,0.253318,0.016509,-0.097607,0.184147,0.029037,-0.125325,...,-0.107115,-0.011186,0.009765,-0.084305,0.069489,0.154950,0.442492,0.085907,small,Mike
2,NF0014_T1,C11-1,-0.245149,-0.331004,0.227504,-0.010081,-0.153452,0.186263,0.018825,-0.117956,...,-0.122888,-0.011035,0.006764,-0.073882,0.056722,0.142125,0.444284,0.087057,small,Mike
3,NF0014_T1,C11-2,-0.252592,-0.328891,0.218564,-0.032857,-0.171053,0.201910,0.015939,-0.130160,...,-0.123045,-0.011132,0.002153,-0.078759,0.066817,0.144077,0.465682,0.101818,small,Mike
4,NF0014_T1,C2-1,-0.213654,-0.348314,0.227645,-0.011737,-0.186671,0.208217,-0.001262,-0.114908,...,-0.140545,-0.011159,0.005114,-0.117489,0.047207,0.131581,0.476471,0.087278,small,Mike
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
143,NF0014_T2,C5-6,-0.294294,-0.331082,0.218519,-0.010784,-0.177558,0.190228,0.018600,-0.123144,...,-0.097683,-0.011064,0.001831,-0.072271,0.067838,0.154291,0.436165,0.112196,dissociated,Mike
144,NF0014_T2,C5-7,-0.402969,-0.368304,0.220878,0.007899,-0.157321,0.125653,0.031363,-0.081504,...,-0.111580,-0.011178,0.003155,-0.083915,0.044134,0.141032,0.442127,0.085489,dissociated,Mike
145,NF0014_T2,C6-1,-0.320862,-0.333989,0.223607,-0.003560,-0.157684,0.186301,0.017634,-0.113109,...,-0.107170,-0.011085,0.000660,-0.083736,0.052167,0.137304,0.459293,0.096390,globular,Mike
146,NF0014_T2,C6-2,-0.307637,-0.332819,0.221561,-0.013901,-0.164055,0.187039,0.018587,-0.103550,...,-0.108075,-0.011085,0.003191,-0.087506,0.048573,0.143679,0.441101,0.105372,small,Mike


In [8]:
# set up data splits
# train: 70%, val: 15%, test: 15%
# stratify by label, patient

train_df, test_df = train_test_split(
    df,
    test_size=0.15,
    random_state=42,
    stratify=df[["label"]],
)
train_df, val_df = train_test_split(
    train_df,
    test_size=0.1765,  # 0.1765 * 0.85 = 0.15
    random_state=42,
    stratify=train_df[["label"]],
)
print(f"Train size: {len(train_df)}")
print(f"Validation size: {len(val_df)}")
print(f"Test size: {len(test_df)}")

Train size: 102
Validation size: 22
Test size: 22


In [9]:
# train a random forest classifier for the organoid labels


rf_model = RandomForestClassifier(n_estimators=1000, random_state=42)
rf_model.fit(
    train_df.drop(columns=["patient", "well_fov", "label", "annotator"]),
    train_df["label"],
)
val_preds = rf_model.predict(
    val_df.drop(columns=["patient", "well_fov", "label", "annotator"])
)
print("Validation Classification Report:")
print(classification_report(val_df["label"], val_preds))
print("Validation Confusion Matrix:")
print(confusion_matrix(val_df["label"], val_preds))
test_preds = rf_model.predict(
    test_df.drop(columns=["patient", "well_fov", "label", "annotator"])
)
print("Test Classification Report:")
print(classification_report(test_df["label"], test_preds))
print("Test Confusion Matrix:")
print(confusion_matrix(test_df["label"], test_preds))

Validation Classification Report:
              precision    recall  f1-score   support

 dissociated       0.67      0.25      0.36         8
   elongated       0.00      0.00      0.00         1
    globular       0.60      0.86      0.71         7
       small       0.56      0.83      0.67         6

    accuracy                           0.59        22
   macro avg       0.46      0.49      0.43        22
weighted avg       0.58      0.59      0.54        22

Validation Confusion Matrix:
[[2 0 3 3]
 [0 0 0 1]
 [1 0 6 0]
 [0 0 1 5]]
Test Classification Report:
              precision    recall  f1-score   support

 dissociated       0.00      0.00      0.00         8
   elongated       0.00      0.00      0.00         2
    globular       0.70      1.00      0.82         7
       small       0.45      1.00      0.62         5

    accuracy                           0.55        22
   macro avg       0.29      0.50      0.36        22
weighted avg       0.33      0.55      0.40      

  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
