In [2]:
from pathlib import Path
from collections import Counter
import random
import h5py
from PIL import Image
import pandas as pd
import os
import shutil
import numpy as np
from dataset import make_dataset_file
import pydicom
from pydicom import dcmread
from collections import defaultdict
import sys
import re
import matplotlib.pyplot as plt

random.seed(42)
pd.set_option('display.max_colwidth', None)


In [7]:
# -------- KATHER CRC 100K -------- #


def generate_crc100k_data(input_dir, output_dir, seed=42):
    random.seed(seed)
    input_dir = Path(input_dir).resolve()
    crc100k_paths = list(Path(input_dir).glob('**/*.png'))
    counts = Counter([str(p.stem).split("-")[0] for p in crc100k_paths])

    # get a random of 15 samples per class as test_samples
    # use all others as random x-shot examples
    test_samples = {}
    x_shot_samples = {}
    for k, _ in counts.items():
        all_samples = list(Path(input_dir).glob(f'{k}-*.png'))
        test_samples_list = random.sample(all_samples, 15) # ADJUST FOR A DIFFERERENT NUMBER OF SAMPLES
        x_shot_samples_list = list(set(all_samples) - set(test_samples_list))

        # Store just the file path as a string
        test_samples[k] = [str(sample) for sample in test_samples_list]
        x_shot_samples[k] = [str(sample) for sample in x_shot_samples_list]

    # Convert to DataFrame
    test_samples_df = pd.DataFrame([(k, v) for k, vs in test_samples.items() for v in vs], columns=['label', 'path'])
    x_shot_samples_df = pd.DataFrame([(k, v) for k, vs in x_shot_samples.items() for v in vs], columns=['label', 'path'])
    
    test_samples_df["fname"] = test_samples_df["path"].apply(lambda x: Path(x).stem.split("/")[-1])
    x_shot_samples_df["fname"] = x_shot_samples_df["path"].apply(lambda x: Path(x).stem.split("/")[-1])
    test_samples_df["path"] = test_samples_df.path.apply(lambda x: [x])
    x_shot_samples_df["path"] = x_shot_samples_df.path.apply(lambda x: [x])

    test_samples_df = test_samples_df.reindex(columns=["fname", "label", "path"])
    x_shot_samples_df = x_shot_samples_df.reindex(columns=["fname", "label", "path"])

    test_samples_df.to_csv(f"{output_dir}/test_samples.csv", index=False)
    x_shot_samples_df.to_csv(f"{output_dir}/x_shot_samples.csv", index=False)

    return test_samples_df, x_shot_samples_df, counts

#test_df, x_shot_ex_df, counts = generate_crc100k_data(input_dir="./data/CRC-VAL-HE-7K-png", output_dir="./Datafiles/CRC100K")

In [17]:
def generate_more_crc100k_sample_data(existing_df_path, full_samples_df_path, seed=42):
    existing = pd.read_csv(existing_df_path)
    full = pd.read_csv(full_samples_df_path)
    existing_images = existing["path"].tolist()

    filtered = full[~full["path"].isin(existing_images)]

    test_samples = filtered.groupby("label").sample(15, random_state=seed)
    test_samples.to_csv(f"./Datafiles/CRC100K/test_samples_complete.csv", index=False)


# generate_more_crc100k_sample_data("./Datafiles/CRC100K/prompt_samples.csv", "./Datafiles/CRC100K/all_samples.csv")

In [26]:
test = pd.read_csv("./Datafiles/CRC100K/test_samples_complete.csv")
prompt = pd.read_csv("./Datafiles/CRC100K/prompt_samples.csv")

set(test.fname) & set(prompt.fname)

set()

In [10]:
# -------- PATCH CAMELYON -------- #


def generate_pcam_data(input_dir, output_dir, seed=42):

    random.seed(seed)

    input_dir = Path(input_dir).resolve()
    paths = list(input_dir.rglob("*.h5"))
    images = [p for p in paths if "test_x" in str(p)][0]
    labels = [p for p in paths if "test_y" in str(p)][0]

    with h5py.File(images, "r") as h5imgs:
        images = h5imgs["x"][:]

    with h5py.File(labels, "r") as h5labels:
        labels = h5labels["y"][:]

    flattened_labels = labels.flatten()

    # check that we are writing into an empty directory
    img_dir = input_dir.parent / "full_imgs"
    
    if os.path.exists(img_dir):
        shutil.rmtree(img_dir)

    if not os.path.exists(img_dir):
        os.makedirs(img_dir)

    label_map = {1: "TUM", 0: "NORM"}

    test_samples = []
    for sidx, (img, label) in enumerate(zip(images, flattened_labels, strict=True)):
        img = Image.fromarray(img)
        fname = f"{label_map[label]}-PCAM-{sidx}"
        save_name = f"{img_dir}/{fname}.png"
        img.save(save_name)
        test_samples.append([fname, label_map[label], [save_name]])

    full_samples = pd.DataFrame(test_samples, columns=["fname", "label", "path"])
    full_samples.to_csv(input_dir.parent / "pcam_full_samples.csv", index=False)
    test_samples = full_samples.groupby("label").sample(15, random_state=seed)

    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    test_samples.to_csv(f"{output_dir}/pcam_test_samples.csv", index=False)
    return test_samples

# test_samples = generate_pcam_data("./data/PCam/h5s", output_dir="./Datafiles/PCam")

In [5]:
def generate_more_pcam_sample_data(existing_df_path, full_samples_df_path, seed=42):
    existing = pd.read_csv(existing_df_path)
    full = pd.read_csv(full_samples_df_path)
    existing_images = existing["path"].tolist()

    filtered = full[~full["path"].isin(existing_images)]

    test_samples = filtered.groupby("label").sample(15, random_state=seed)
    test_samples.to_csv(f"./Datafiles/PCam/pcam_test_samples2.csv", index=False)

# generate_more_pcam_sample_data("./Datafiles/PCam/pcam_test_samples.csv",
#                           "./data/PCam/pcam_full_samples.csv")

In [5]:
# df = pd.read_csv("...csv")
# df.rename(columns={"label": "orig_label"}, inplace=True)
# df["label"] = df.orig_label.map({1: "TUM", 0: "NORM"})
# df.to_csv("....csv", index=False)

In [4]:
# -------- MHIST CAMELYON -------- #

def prepend_labels_to_imgs(input_dir, annotations_path):
    annotations = pd.read_csv(annotations_path)
    for img_path in Path(input_dir).glob("*.png"):
        label_row = annotations[annotations["Image Name"] == img_path.name]
        prefix = label_row["Majority Vote Label"].iloc[0]
        new_name = prefix + "_" + img_path.name
        img_path.rename(img_path.with_name(new_name))
        annotations.loc[annotations["Image Name"] == img_path.name, "Image Name"] = new_name
    annotations.to_csv(annotations_path, index=False)
    print("Done.")


def generate_mhist_data(input_dir, output_dir, seed=42, simplify=True):
    random.seed(seed)
    input_dir = Path(input_dir).resolve()
    paths = list(input_dir.glob("*.png"))

    if simplify:
        annotations = pd.read_csv("./data/MHIST/annotations.csv")
        hp_imgs = annotations.query("`Majority Vote Label` == 'HP' and `Number of Annotators who Selected SSA (Out of 7)` == 0")["Image Name"].tolist()
        ssa_imgs = annotations.query("`Majority Vote Label` == 'SSA' and `Number of Annotators who Selected SSA (Out of 7)` == 7")["Image Name"].tolist()

    else:
        hp_imgs = [p for p in paths if str(p.name).startswith("HP")]
        ssa_imgs = [p for p in paths if str(p.name).startswith("SSA")]
    
    hp_samples = random.sample(hp_imgs, 15)
    ssa_samples = random.sample(ssa_imgs, 15)
    samples = hp_samples + ssa_samples

    samples = [[Path(p).stem, Path(p).stem.split("_")[0], [os.path.join(input_dir, p)]] for p in samples]
    samples_df = pd.DataFrame(samples, columns=["fname", "label", "path"])

    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    samples_df.to_csv(f"{output_dir}/mhist_samples.csv", index=False)

    annotations["fname"] = annotations["Image Name"].apply(lambda x: Path(x).stem)
    annotations["label"] = annotations["Image Name"].apply(lambda x: str(x).split("_")[0])
    annotations["path"] = annotations["Image Name"].apply(lambda x: [os.path.join(input_dir, x)])
    annotations[["fname", "label", "path"]].to_csv("./mhist_full_samples.csv", index=False)

# prepend_labels_to_imgs("./data/MHIST/images", "./data/MHIST/annotations.csv")
# generate_mhist_data("./data/MHIST/images", "./Datafiles/MHIST")

In [19]:
def generate_more_mhist_sample_data(existing_df_path, full_samples_df_path, seed=42):
    existing = pd.read_csv(existing_df_path)
    full = pd.read_csv(full_samples_df_path)
    existing_images = existing["path"].tolist()
    filtered = full[~full["path"].isin(existing_images)]
    test_samples = filtered.groupby("label").sample(15, random_state=seed)
    test_samples.to_csv(f"./Datafiles/MHIST/test_samples.csv", index=False)

# generate_more_mhist_sample_data("./Datafiles/MHIST/mhist_samples.csv",
#                           "./data/MHIST/mhist_full_samples.csv")

In [1]:
# change the MSSI image names
# originally those are in 2 folders called MSIMUT and MSS
# to the images in MSIMUT prepend MSI
# to the images in MSS prepend MSS
# this will make label extraction for knn sampling easier

# Set your paths here
msimut_path = ""
mss_path = ""

def prepend_labels_to_imgs(input_dir, prefix):
    for img_path in Path(input_dir).glob("*.png"):
        new_name = prefix + "_" + img_path.name
        img_path.rename(img_path.with_name(new_name))
    print("Done.")

prepend_labels_to_imgs(msimut_path, "MSI")
prepend_labels_to_imgs(mss_path, "MSS")

In [22]:
# -------- CMMD -------- #
# !{sys.executable} -m pip install pydicom


def process_metadata(clini_path):
    clini_df = pd.read_excel(clini_path, engine="openpyxl")
    sampled_df = clini_df.groupby("classification").sample(15, random_state=42)
    return sampled_df


def process_dicoms(image_path, dicom_df, output_folder, verbose=False):

    def load_and_save_dicom(dicom_file, dicom_lr):
        ds = pydicom.dcmread(dicom_file)
        # skip wrong side
        if ds.ImageLaterality == dicom_lr:
            if hasattr(ds, "pixel_array"):
                image = Image.fromarray(ds.pixel_array)
                if verbose:
                    plt.imshow(ds.pixel_array, cmap=plt.cm.bone)
                    plt.axis("off")
                    plt.show()

                relative_path = dicom_file.relative_to(image_path.parent)
                output_path = output_folder / relative_path.with_suffix(".png")
                output_path.parent.mkdir(parents=True, exist_ok=True)

                image.save(output_path)
                return output_path

    cmmd_test_samples = defaultdict(lambda: {"path": [], "label": ""})
    
    for _, sample in dicom_df.iterrows():
        dicom_path = sample.path
        dicom_lr = sample.LeftRight
        dicom_files = list(Path(dicom_path).rglob("*.dcm"))
        
        cmmd_test_samples[sample.ID1]["label"] = sample.classification
        for dicom_file in dicom_files:
            output_path = load_and_save_dicom(dicom_file, dicom_lr)
            if output_path is not None:
                cmmd_test_samples[sample.ID1]["path"].append(output_path)
        cmmd_test_samples[sample.ID1]["path"].sort()
    
    return cmmd_test_samples


def generate_cmmd_data(clini_path, image_path, output_dir):
    sampled_df = process_metadata(clini_path)
    # dicom_paths = sampled_df.ID1.to_list()
    dicom_df = sampled_df[["ID1", "LeftRight", "classification"]]
    dicom_df["path"] = dicom_df["ID1"].apply(lambda x: f"{image_path}/{x}")
    cmmd_test_samples = process_dicoms(image_path, dicom_df, output_dir, verbose=False)
    sampled_df = pd.DataFrame.from_dict(cmmd_test_samples, orient="index").reset_index(names="fname")
    sampled_df = sampled_df.reindex(columns=["fname", "label", "path"])
    if not os.path.exists("./Datafiles/CMMD/"):
        os.makedirs("./Datafiles/CMMD/")
    sampled_df.to_csv(f"./Datafiles/CMMD/cmmd_sampled_data.csv", index=False)
    return sampled_df


clini_path = "./data/CMMD/CMMD_clinicaldata_revision.xlsx"
image_path = Path("./data/CMMD/CMMD")
output_dir = Path("./data/CMMD/selected_imgs")
# s = generate_cmmd_data(clini_path, image_path, output_dir)

In [21]:
# -------- MSSI -------- #

def generate_mssi_data(input_mss, input_msi, output_dir, seed=42, n_pats=15, n_tiles=10):
    random.seed(42)
    input_mss = Path(input_mss).resolve()
    input_msi = Path(input_msi).resolve()
    mss_paths = list(Path(input_mss).glob('**/*.png'))
    msi_paths = list(Path(input_msi).glob('**/*.png'))

    # these are the 15 patients we will use for testing for MSI vs MSS

    patient_id_regex = "TCGA-..-...."
    mss_patients = set([re.search(patient_id_regex, str(path)).group() for path in mss_paths])
    msi_patients = set([re.search(patient_id_regex, str(path)).group() for path in msi_paths])

    mss = {}
    for pat in mss_patients:
        mss[pat] = [str(p) for p in mss_paths if pat in str(p)]
    mss_records = [(patient, path) for patient, paths in mss.items() for path in paths]

    msi = {}
    for pat in msi_patients:
        msi[pat] = [str(p) for p in msi_paths if pat in str(p)]
    msi_records = [(patient, path) for patient, paths in msi.items() for path in paths]
    
    mss_df = pd.DataFrame(mss_records, columns=["patient", "path"])
    mss_df["label"] = "MSS"
    mss_df = mss_df.groupby("patient").agg({"path":list, "label": "first", "patient": "first"})

    msi_df = pd.DataFrame(msi_records, columns=["patient", "path"])
    msi_df["label"] = "MSI"
    msi_df = msi_df.groupby("patient").agg({"path":list, "label": "first", "patient": "first"})

    df = pd.concat([mss_df, msi_df])
    df["fname"] = df["path"].apply(lambda x: [Path(fp).stem.split("/")[-1] for fp in x])

    # sample 15 patients
    sampled_df = df.groupby("label").sample(n=n_pats, random_state=seed, replace=False)
    # sample 10 random tiles

    def sample_tiles(lst):
        max_len = len(lst)
        num_samples = min(n_tiles, max_len)
        return random.sample(range(max_len), num_samples)
    
    sample_indices_per_row = sampled_df["path"].apply(sample_tiles)
    for i, tiles_per_row in enumerate(sample_indices_per_row):
        row_idx = sampled_df.index[i]

        # Assuming 'path' and 'fname' are lists within each cell of the DataFrame
        new_path = [sampled_df.at[row_idx, 'path'][idx] for idx in tiles_per_row]
        new_fname = [sampled_df.at[row_idx, 'fname'][idx] for idx in tiles_per_row]

        # Update the entire row with the new lists
        sampled_df.at[row_idx, 'path'] = new_path
        sampled_df.at[row_idx, 'fname'] = new_fname

    sampled_df.to_csv(f"{output_dir}/mssi_samples.csv", index=False)

    return sampled_df, sample_indices_per_row


### Set your paths here
msi_path = ""
mss_path = ""

df, idx = generate_mssi_data(mss_path, msi_path, "./Datafiles/MSSI")
df.shape

In [None]:
#### merge MSI Mut and MSS train and test data

# import os
# import shutil

# src = "/Users/dykeferber/Downloads/MSS 2"
# dest = "/Users/dykeferber/Desktop/GPT4VMed/data/MSSI/MSS"

# for f in os.listdir(src):
#     shutil.move(os.path.join(src, f), dest)

# src = "/Users/dykeferber/Downloads/MSIMUT 2"
# dest = "/Users/dykeferber/Desktop/GPT4VMed/data/MSSI/MSIMUT"

# for f in os.listdir(src):
#     shutil.move(os.path.join(src, f), dest)

# msst = os.listdir("/Users/dykeferber/Desktop/GPT4VMed/data/MSSI/MSS")
# msit = os.listdir("/Users/dykeferber/Desktop/GPT4VMed/data/MSSI/MSIMUT")

# print(len(msst), len(msit))

# msstest = os.listdir("/Users/dykeferber/Downloads/MSS 2")
# msitest = os.listdir("/Users/dykeferber/Downloads/MSIMUT 2")

# print(len(msstest), len(msitest))

# assert len(os.listdir("/Users/dykeferber/Desktop/GPT4VMed/data/MSSI/MSS")) == len(msst+msstest)
# assert len(os.listdir("/Users/dykeferber/Desktop/GPT4VMed/data/MSSI/MSIMUT")) == len(msit+msitest)