In [1]:
import sys
import os

additional_paths = [
    "/home/homesOnMaster/pehrlich/dataset_preparation",
    "/home/homesOnMaster/pehrlich/dataset_preparation/lib",
]

# Add the parent folder (adjust the path as needed)
for _path in additional_paths:
    parent_dir = os.path.expanduser(_path)
    if parent_dir not in sys.path:
        sys.path.append(parent_dir)

In [2]:
from copy import deepcopy

import numpy as np

from lib.metadata import StudyMetadata, Series
from lib.utils import flatten, remove_dupes

In [3]:
study_metadata: dict[str, StudyMetadata] = np.load(
    "/home/homesOnMaster/pehrlich/dataset_preparation/osg/study_map_osg.npy",
    allow_pickle=True
).tolist()

  warn_and_log(msg)
  warn_and_log(msg)


In [None]:
import matplotlib.pyplot as plt

def plot_histogram(data, labels, title, xlabel, ylabel="Count", color='skyblue', figsize=(10, 6)):
    """
    Plots a histogram with given data and labels.
    
    Args:
        data: List of counts for each group
        labels: List of labels for each group
        title: Title for the plot
        xlabel: Label for x-axis
        ylabel: Label for y-axis (default: "Count")
        color: Color for the bars (default: 'skyblue')
        figsize: Figure size tuple (default: (10, 6))
    """
    plt.figure(figsize=figsize)
    bars = plt.bar(labels, data, color=color)
    
    # Add count labels on top of bars
    for bar, count in zip(bars, data):
        plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + max(data)*0.01, 
                str(count), ha='center', va='bottom', fontweight='bold')
    
    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.show()

In [None]:
import re

ignore_patterns = [
    r"(?i).*befund.*",
    r"(?i).*survey.*",
    r"(?i).*localizer.*",
    r"(?i).*iso.*",
    r"(?i).*MPR.*",
]

series_matching_patterns = {
    "cor_t1": [
        r"(?i).*t1_tse_cor(?!.*km).*",
        r"(?i)(?=.*t1w)(?=.*cor)(?=.*tse)(?!.*km).*",
    ],
    "cor_t1_km": [
        r"(?i)(?=.*t1)(?=.*cor)(?=.*km).*",
        r"(?i)(?=.*t1)(?=.*cor)(?=.*km).*",
    ],
    "cor_pd": [
        r"(?i)(?=.*pd)(?=.*cor)(?!.*km).*",
        r"(?i).*PDW_SPIR_COR.*",
    ],
    "sag_t1": [
        r"(?i).*t1_tse_sag_DRB.*",
        r"(?i).*T1W_TSE_SAG.*",
        r"(?i).*T1W_aTSE_SAG.*",
        r"(?i).*t1_tse_sag.*",
    ],
    "sag_t1_km": [
        r"(?i)(?=.*t1)(?=.*sag)(?=.*km).*",
        r"(?i)(?=.*t1)(?=.*sag)(?=.*km).*",
    ],
    "sag_pd": [
        r"(?i)(?=.*pd)(?=.*sag)(?!.*km).*",
        r"(?i).*PDW_SPIR_SAG.*",
        r"(?i).*PDW_SPIR sag.*",
    ],
    "tra_t1_km": [
        r"(?i)(?=.*t1)(?=.*tra)(?=.*km).*",
        r"(?i)(?=.*t1)(?=.*tra)(?=.*km).*",
    ],
    "tra_t2": [
        r"(?i).*t2_tse_tra.*",
        r"(?i)(?=.*t2)(?=.*tra)(?=.*tse)(?!.*km).*",
    ],
    "tra_pd": [
        r"(?i)(?=.*pd)(?=.*tra)(?!.*km).*",
    ]
}

relevant_mods = ["tra_t2", "sag_pd", "cor_pd", "cor_t1"]

def contains_pattern(text, patterns):
    for pattern in patterns:
        if re.search(pattern, text):
            return True
    return False

def calculate_matches_in_modality(series_counts):
    tmp_counts = deepcopy(series_counts)
    mod_matches = []

    for mod, rules in series_matching_patterns.items():
        matches = flatten([list(filter(lambda x: contains_pattern(x[0], [rule]), tmp_counts)) for rule in rules])
        matches = remove_dupes(matches)
        mod_matches.extend(matches)
        print(f"{mod}: Found {len(matches)} rules matching {sum([x[1] for x in matches])}")

    return mod_matches

def show_matches_in_modality(series_counts: list[tuple[int,int]], mod: str):
    tmp_counts = deepcopy(series_counts) 
    matches = flatten([list(filter(lambda x: contains_pattern(x[0], [rule]), tmp_counts)) for rule in series_matching_patterns[mod]])
    matches = remove_dupes(matches)
    sorted_matches = sorted(matches, key=lambda x: x[1], reverse=True)
    print(f"{mod}: Found {len(matches)} rules matching {sum([x[1] for x in matches])}")
    for desc, count in sorted_matches:
        print(f"  {desc}: {count}")



In [None]:
series_description_counts = {}
failed_series_descriptions = 0
total_series = 0

for study in study_metadata.values():
    try:
        for series_uid, series in study["series"].items():
            total_series += 1
            description = series.get("series_description")
            if description:
                series_description_counts[description] = series_description_counts.get(description, 0) + 1
            else:
                series_description_counts['Unknown'] = series_description_counts.get('Unknown', 0) + 1
    except (TypeError, KeyError):
        failed_series_descriptions += 1

# Sort series descriptions by count (descending)
sorted_series_descriptions = sorted(series_description_counts.items(), key=lambda x: x[1], reverse=True)

print(f"Total number of series: {total_series}")
print(f"Number of different series descriptions: {len(series_description_counts)}")
print(f"Failed to parse series descriptions for {failed_series_descriptions} studies")
print("\nTop 20 series descriptions:")
for desc, count in sorted_series_descriptions[:20]:
    print(f"  {desc}: {count}")

# Plot top 20 series descriptions
top_20_series = sorted_series_descriptions[:15]
descriptions = [desc[0] for desc in top_20_series]
counts = [desc[1] for desc in top_20_series]

plot_histogram(
    data=counts,
    labels=descriptions,
    title="Top 15 Series Descriptions by Frequency",
    xlabel="Series Description",
    color='mediumpurple',
    figsize=(12, 8)
)

In [None]:

filtered_series_counts: list[tuple[str, int]] = []
for desc, count in series_description_counts.items():
    if contains_pattern(desc, patterns= ignore_patterns): continue
    filtered_series_counts.append((desc, count))

sorted_series_counts = sorted(filtered_series_counts, key=lambda x: x[1], reverse=True)

In [None]:
# find unmatched series descriptors
unmatched_series_counts = deepcopy(sorted_series_counts)

for _, rules in series_matching_patterns.items():
    unmatched_series_counts = list(filter(lambda x: not contains_pattern(x[0], rules), unmatched_series_counts))


matches = calculate_matches_in_modality(sorted_series_counts)

assert sum([x[1] for x in matches]) < total_series

print(f"Matched {sum([x[1] for x in matches])} from series {sum([x[1] for x in sorted_series_counts])}")
print(f"Still found {len(unmatched_series_counts)} unmatched descriptions and {sum([x[1] for x in unmatched_series_counts])} unmatched series")
for desc, count in unmatched_series_counts[:20]:
    print(f"  {desc}: {count}")

In [None]:
modality_matches = {}

for mod, rules in series_matching_patterns.items():
    matches = flatten([list(filter(lambda x: contains_pattern(x[0], [rule]), sorted_series_counts)) for rule in rules])
    matches = remove_dupes(matches)
    modality_matches[mod] = sum([x[1] for x in matches])

modalities = list(modality_matches.keys())
counts = list(modality_matches.values())

plot_histogram(
    data=counts,
    labels=modalities,
    title="Series Matches by Modality",
    xlabel="Modality",
    ylabel="Number of Matched Series",
    color='lightcoral',
    figsize=(12, 8)
)

print(f"Total matched series: {sum(counts)}")

## Improve sequence mapping - Find incomplete studies


In [None]:
# find studies which do not include all major sequences 
def get_modality_series(modality: str, study_metadata: StudyMetadata) -> list[Series]: 
    matched_series = []
    for series in study_metadata["series"].values():
        if contains_pattern(series["series_description"], series_matching_patterns[modality]):
            matched_series.append(series)

    return matched_series

studies_with_missing_sequences: dict[str, list[str]] = {}
studies_with_found_sequences: dict[str, dict[str, str]] = {}

for study in list(study_metadata.values()):
    for mod in relevant_mods:

        series = get_modality_series(mod, study)
        if not series:
            if study["study_instance_uid"] not in studies_with_missing_sequences:
                studies_with_missing_sequences[study["study_instance_uid"]] = [mod]
            else:
                studies_with_missing_sequences[study["study_instance_uid"]].append(mod)

        else:
            if study["study_instance_uid"] not in studies_with_found_sequences:
                studies_with_found_sequences[study["study_instance_uid"]] = {}
            
            studies_with_found_sequences[study["study_instance_uid"]][mod] = [s["series_description"] for s in series]
            

print(f"Found {len(studies_with_missing_sequences)} studies with missing sequences")
print(studies_with_missing_sequences)




In [None]:
import pandas as pd

# create document find missing sequences
study_correct = []

for study_uid, sequences in studies_with_missing_sequences.items():

    study_correct.append({
        "StudyInstanceUID": study_uid,
        **{mod: None for mod in sequences},
        **studies_with_found_sequences.get(study_uid, {}),
        "AvailableDescriptions": [s["series_description"] for s in study_metadata[study_uid]["series"].values() if not contains_pattern(s["series_description"], ignore_patterns)]
    })



column_order = [
    "StudyInstanceUID",
    "tra_t2", 
    "sag_pd", 
    "cor_pd", 
    "cor_t1",
    "AvailableDescriptions"
]

df = pd.DataFrame(study_correct, columns=column_order)
df.to_csv("studies_with_missing_sequences.csv", index=False)



## Improve sequence mapping - Find redundant studies

In [None]:
## Improve sequence mapping - Find redundant studies

studies_with_redundant_sequences: dict[str, list[str]] = {}
studies_with_redundant_sequences_series: dict[str, list[Series]] = {}
studies_with_clear_sequences: dict[str, dict[str, str]] = {}

for study in list(study_metadata.values()):
    if study["study_instance_uid"] in studies_with_missing_sequences.keys():
        continue

    for mod in relevant_mods:

        series = get_modality_series(mod, study)
        if len(series) > 1: 
            if study["study_instance_uid"] not in studies_with_redundant_sequences:
                studies_with_redundant_sequences[study["study_instance_uid"]] = [mod]
                studies_with_redundant_sequences_series[study["study_instance_uid"]] = [series]
            else:
                studies_with_redundant_sequences[study["study_instance_uid"]].append(mod)
                studies_with_redundant_sequences_series[study["study_instance_uid"]].append(series)

        else:
            if study["study_instance_uid"] not in studies_with_clear_sequences:
                studies_with_clear_sequences[study["study_instance_uid"]] = {}
            
            studies_with_clear_sequences[study["study_instance_uid"]][mod] = [s["series_description"] for s in series]
            

print(f"Found {len(studies_with_redundant_sequences)} studies with missing sequences")
print(studies_with_redundant_sequences)

In [None]:
# DEBUG CELL
# print(list(studies_with_redundant_sequences_series.keys()))
# print([list(map(lambda x: x["series_instance_uid"], series)) for series in studies_with_redundant_sequences_series.values()])

import pandas as pd

# create document find missing sequences
study_correct = []

for (study_uid, modalities), series in zip(studies_with_redundant_sequences.items(), studies_with_redundant_sequences_series.values()):
    _study_correct = {"StudyInstanceUID": study_uid}

    for _mod, _series in zip(modalities, series):
        _study_correct[_mod] = [s["series_instance_uid"] for s in _series]

    study_correct.append(_study_correct)


column_order = [
    "StudyInstanceUID",
    "tra_t2", 
    "sag_pd", 
    "cor_pd", 
    "cor_t1",
]

df = pd.DataFrame(study_correct, columns=column_order)
df.to_csv("studies_with_redundant_sequences.csv", index=False)


## Extract interesting studies 

In [4]:
interesting_studies = []

km_studies = list(filter(lambda x: x["with_km"], study_metadata.values()))
low_fs_studies = list(filter(lambda x: x["magnetic_field_strength"] == 0.55, study_metadata.values()))
mid_fs_studies = list(filter(lambda x: x["magnetic_field_strength"] == 1.5, study_metadata.values()))
high_fs_studies = list(filter(lambda x: x["magnetic_field_strength"] == 3, study_metadata.values()))

studies = [*km_studies[:3],*low_fs_studies[:3], *mid_fs_studies[:3], *high_fs_studies[:3]]
study_uids = set(map(lambda x: x["study_instance_uid"], studies))
print(len(study_uids), list(study_uids))



12 ['1.2.840.113619.6.95.31.0.3.4.1.3096.13.361415', '1.2.840.113619.6.95.31.0.3.4.1.3096.13.361585', '1.2.840.113619.6.95.31.0.3.4.1.3096.13.245827', '1.2.840.113619.6.95.31.0.3.4.1.3096.13.251187', '1.2.840.113619.6.95.31.0.3.4.1.3096.13.262522', '1.2.840.113619.6.95.31.0.3.4.1.3096.13.251182', '1.2.840.113619.6.95.31.0.3.4.1.3096.13.255552', '1.2.840.113619.6.95.31.0.3.4.1.3096.13.362128', '1.2.840.113619.6.95.31.0.3.4.1.3096.13.260426', '1.2.840.113619.6.95.31.0.3.4.1.3096.13.263561', '1.2.840.113619.6.95.31.0.3.4.1.3096.13.243674', '1.2.840.113619.6.95.31.0.3.4.1.3096.13.255578']
