In [None]:
from typing import List, Tuple

import medical_image_segmentation.analyze_data.utils as utils
import matplotlib.pyplot as plt
import numpy as np
from collections import Counter, defaultdict

import json
import os

import pydicom
import random


In [None]:
import matplotlib as mpl
mpl.rc('image', cmap='cool')

def get_discrete_colors(n: int, solid=False) -> List[Tuple[float, float, float, float]]:
    """
    Gets the discrete colors for all plots.
    
    Parameters
    ----------
    n : int The number of items to get colors for.
    solid : bool [default: False] If True, all colors are the same. Otherwise, returns a discrete sampling over a gradient.

    Returns
    -------
    List[Tuple[float, float, float, float]]
        List of tuples of RGBA colors.
    """
    if solid:
        return [mpl.colors.to_rgba("darkorchid") for x in range(n)]
    else:
        cmap = mpl.colormaps["cool"]
        return cmap(np.linspace(0, 1, n))

In [None]:
def plot_image_shapes(image_shapes: List[List[int]]):
    """
    Plot image sizes on a scatter plot, with more frequent sizes represented by larger points.

    Parameters
    ----------
    image_shapes : List[List[int]] List of width, height pairs.
    """
    shape_counter = Counter(tuple(shape) for shape in image_shapes)
    widths, heights, counts = zip(*[(shape[0], shape[1], count) for shape, count in shape_counter.items()])

    widths = np.array(widths)
    heights = np.array(heights)
    counts = np.array(counts)
    
    max_size = 10_000
    min_size = 1
    point_sizes = np.clip(counts / np.max(counts) * max_size, min_size, max_size)

    plt.figure(figsize=(10, 6))
    plt.scatter(widths, heights, s=point_sizes, c=np.sqrt(widths * heights), alpha=0.8)
    plt.colorbar(label="Sqrt of Area (pixels^2)")
    plt.title("Distribution of DICOM Image Dimensions")
    plt.xlabel("Width (pixels)")
    plt.ylabel("Height (pixels)")
    plt.tight_layout()

    plt.show()

In [None]:
subset_image_path_list_file_path = "/scratch/gpfs/eh0560/repos/medical-image-segmentation/data/dicom_image_analysis_info/image_path_list"

In [None]:
with open(subset_image_path_list_file_path, "r") as f:
    files = f.readlines()

files = [file.strip() for file in files]

In [None]:
dimensions_output_path = "/scratch/gpfs/eh0560/repos/medical-image-segmentation/data/dicom_image_analysis_info/dicom_image_dimensions.json"

In [None]:
if os.path.isfile(dimensions_output_path):
    with open(dimensions_output_path, "r") as f:
        dimensions = json.load(f)
        subset_dimensions = {}
        for file_path in files:
            subset_dimensions[file_path] = dimensions[file_path]
        
    if len(subset_dimensions) != len(files):
        raise ValueError(f"subset_dimensions has length different than files. Length of files is {len(files)}, and length of subset_dimensions is {len(subset_dimensions)}")
else:
    subset_dimensions = utils.get_dicom_image_dimensions(files, num_processes=8)

In [None]:
plot_image_shapes(list(subset_dimensions.values()))

In [None]:
dataset_counts = defaultdict(int)
for file_path in files:
    dataset_name = file_path.split("med_datasets/")[1].split("/")[0]
    dataset_counts[dataset_name] += 1

In [None]:
def plot_dataset_counts(dataset_dict: dict[str, int]):
    """
    Plot the number of images in each dataset.

    Parameters
    ----------
    dataset_dict : dict[str, int]
        A dictionary where the keys are the name of the dataset and the value is the number of images in that dataset.
    """
    names = sorted(list(dataset_dict.keys()))
    counts = []
    for dataset_name in names:
        counts.append(dataset_dict[dataset_name])
    n = len(counts)
    
    plt.figure(figsize=(10, 6))
    plt.bar(names, counts, color=get_discrete_colors(n, solid=True))
    plt.title("Distribution of DICOM Images in Each Dataset")
    plt.xlabel("Dataset")
    plt.xticks(rotation=45, ha="right")
    plt.ylabel("Frequency")
    plt.tight_layout()

    plt.show()

In [None]:
plot_dataset_counts(dataset_counts)

In [None]:
file_path = random.choice(files)
file_path = files[0]
ds = pydicom.dcmread(file_path)
img = ds.pixel_array
print(img.dtype)

In [None]:
file_path

In [None]:
ds

In [None]:
print(img.min(), img.max(), img.max()-img.min())

In [None]:
plt.imshow(img.astype(np.float16), cmap="gray")

In [None]:
plt.imshow(img.astype(np.float16), cmap="gray")