In [1]:
import os

import pandas as pd 
import numpy as np
from sklearn.cluster import Birch
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from PIL import Image
import networkx as nx
from PIL import Image
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import pandas as pd
import networkx as nx

from IPython.display import clear_output

from BIRCH import graph

import warnings
warnings.filterwarnings('ignore')


In [2]:
##### Section for constants #####
MIMIC_EYE_PATH = "C:\\Users\\mike8\\mimic-eye"

DEFAULT_REFLACX_BOX_COORD_COLS = ["xmin", "ymin", "xmax", "ymax"]
DEFAULT_REFLACX_LABEL_COLS = [
    "Pulmonary edema",
    "Enlarged cardiac silhouette",
    "Consolidation",
    "Atelectasis",
    "Pleural abnormality",
    # "Support devices",
]

anatomical_cmap = {
        "cardiac silhouette": "yellow",
        "left clavicle": "orange",
        "right clavicle": "orange",
        "left costophrenic angle": "royalblue",
        "right costophrenic angle": "royalblue",
        "left hilar structures": "bisque",
        "right hilar structures": "bisque",

        "left lung": "tomato",
        "right lung": "tomato",

        "trachea": "plum",
        "upper mediastinum": "darkseagreen",

        # "left lower lung zone": "lime",
        # "right lower lung zone": "lime",
        # "left mid lung zone": "slategray",
        # "right mid lung zone": "slategray",
        # "left upper lung zone": "darkviolet",
        # "right upper lung zone": "darkviolet",
    }

EYEGAZE_COORD_COLS = ["x1", "y1", "x2", "y2"]

In [3]:
# Choose clustering parameters
threshold = 0.2 # clustering threshold
branching_factor = 150 # branching factor
n_clusters = None  # number of clusters (None means BIRCH will choose automatically)

k = 50 # sparsify

In [4]:
os.makedirs("./figs",exist_ok=True)

In [5]:
in_both_df = pd.read_csv("in_both.csv")

In [6]:
def gaze_df_preprocessing(df, img_width, img_height):
    df = df[df['x_position'].between(0, img_width)]
    df = df[df['y_position'].between(0, img_height)]
    df = df[~(df['x_position'].isna() | df['y_position'].isna())]
    return df 

def gaze_points_visualise_in_frame(norm_gaze_points, subcluster_labels):
    fig, ax = plt.subplots(figsize=(15, 15))
    plt.gca().invert_yaxis()
    # Visualize the clustering results
    plt.scatter(
        norm_gaze_points[:, 0],
        norm_gaze_points[:, 1],
        c=subcluster_labels,
        cmap="rainbow",
    )
    ax.add_patch(Rectangle((-1, -1), 2, 2, fill=False))
    plt.title("BIRCH Clustering of Gaze Points")
    plt.xlabel("X coordinate")
    plt.ylabel("Y coordinate")
    plt.show()
    return fig


def gaze_points_visualise_on_cxr(
    cxr, norm_gaze_points, subcluster_labels, centre_x, centre_y
):
    fig, ax = plt.subplots(figsize=(15, 15))
    plt.imshow(cxr)
    plt.scatter(
        (norm_gaze_points[:, 0] * centre_x) + centre_x,
        (norm_gaze_points[:, 1] * centre_y) + centre_y,
        c=subcluster_labels,
        cmap="rainbow",
    )
    return fig


def gaze_with_reflacx_lesion_bb(
    bb_df, norm_gaze_points, subcluster_labels, centre_x, centre_y
):

    bb_df = bb_df[DEFAULT_REFLACX_BOX_COORD_COLS + DEFAULT_REFLACX_LABEL_COLS]

    cxr = Image.open("./BIRCH/cxr.jpg").convert("RGB")
    fig, ax = plt.subplots(figsize=(15, 15))
    plt.imshow(cxr)
    plt.scatter(
        (norm_gaze_points[:, 0] * centre_x) + centre_x,
        (norm_gaze_points[:, 1] * centre_y) + centre_y,
        c=subcluster_labels,
        cmap="rainbow",
    )

    for _, bb in bb_df.iterrows():
        bbox = list(bb[DEFAULT_REFLACX_BOX_COORD_COLS])
        lesions = []

        for k, v in dict(bb[DEFAULT_REFLACX_LABEL_COLS]).items():
            if v:
                lesions.append(k)

        lesion_name = ", ".join(lesions)
        ax.add_patch(
            Rectangle(
                (bbox[0], bbox[1]),
                bbox[2] - bbox[0],
                bbox[3] - bbox[1],
                fill=False,
                color="darkorange",
                linewidth=2,
            )
        )
        ax.text(
            bbox[0],
            bbox[1],
            lesion_name,
            color="black",
            backgroundcolor="darkorange",
            fontdict={"size": 11},
        )

    return fig


def gaze_with_eyegaze_anatomical_bb(
    anatomical_df, norm_gaze_points, subcluster_labels, centre_x, centre_y
):
    anatomical_df = pd.read_csv("./patient_18087960/EyeGaze/bounding_boxes.csv")

    cxr = Image.open("./BIRCH/cxr.jpg").convert("RGB")
    fig, ax = plt.subplots(figsize=(15, 15))
    plt.imshow(cxr)
    plt.scatter(
        (norm_gaze_points[:, 0] * centre_x) + centre_x,
        (norm_gaze_points[:, 1] * centre_y) + centre_y,
        c=subcluster_labels,
        cmap="rainbow",
    )

    for _, bb in anatomical_df.iterrows():
        bbox = list(bb[EYEGAZE_COORD_COLS])
        bbox_name = bb["bbox_name"]

        if not (bbox_name in anatomical_cmap):
            continue

        c = anatomical_cmap[bbox_name]
        ax.add_patch(
            Rectangle(
                (bbox[0], bbox[1]),
                bbox[2] - bbox[0],
                bbox[3] - bbox[1],
                fill=False,
                color=c,
                linewidth=2,
            )
        )
        ax.text(
            bbox[0],
            bbox[1],
            bbox_name,
            color="black",
            backgroundcolor=c,
            fontdict={"size": 11},
        )
        
    return fig


def centroid_in_frame(birch):
    fig, ax = plt.subplots(figsize=(15, 15))
    norm_centroid_x = birch.subcluster_centers_[:, 0]
    norm_centroid_y = birch.subcluster_centers_[:, 1]
    plt.scatter(norm_centroid_x, norm_centroid_y, c=birch.subcluster_labels_, cmap='rainbow', s=500)
    plt.gca().invert_yaxis()
    ax.add_patch(Rectangle((-1, -1), 2, 2, fill=False))
    plt.title('BIRCH Clustering of Gaze Points')
    plt.xlabel('X coordinate')
    plt.ylabel('Y coordinate')
    plt.show()
    return fig


def centroid_on_cxr(birch, centre_x, centre_y):
    fig, ax = plt.subplots(figsize=(15, 15))
    cxr = Image.open('./BIRCH/cxr.jpg').convert("RGB")
    plt.imshow(cxr)

    norm_centroid_x = birch.subcluster_centers_[:, 0]
    norm_centroid_y = birch.subcluster_centers_[:, 1]

    centroids_x = (norm_centroid_x*centre_x)+centre_x
    centroids_y = (norm_centroid_y*centre_y)+centre_y
    ax.scatter(centroids_x, centroids_y, c=birch.subcluster_labels_, cmap='rainbow', s=1000)
    for x,y, t in zip(centroids_x, centroids_y,  birch.subcluster_labels_):
        ax.annotate(t, (x, y), )

    return fig


def centroid_on_cxr_with_anatomical_bb(cxr, birch, anatomical_df, centre_x, centre_y):
    fig, ax = plt.subplots(figsize=(15, 15))
    plt.imshow(cxr)

    norm_centroid_x = birch.subcluster_centers_[:, 0]
    norm_centroid_y = birch.subcluster_centers_[:, 1]

    centroids_x = (norm_centroid_x*centre_x)+centre_x
    centroids_y = (norm_centroid_y*centre_y)+centre_y
    ax.scatter(centroids_x, centroids_y, c=birch.subcluster_labels_, cmap='rainbow', s=1000)
    for x,y, t in zip(centroids_x, centroids_y,  birch.subcluster_labels_):
        ax.annotate(t, (x, y), )
        

    for _, bb in anatomical_df.iterrows():
        bbox = list(bb[EYEGAZE_COORD_COLS])
        bbox_name  = bb['bbox_name']

        if not (bbox_name in anatomical_cmap):
            continue

        c = anatomical_cmap[bbox_name]
        ax.add_patch(
            Rectangle(
                (bbox[0], bbox[1]),
                bbox[2] - bbox[0],
                bbox[3] - bbox[1],
                fill=False,
                color=c,
                linewidth=2,
            )
        )
        ax.text(
            bbox[0],
            bbox[1],
            bbox_name,
            color="black",
            backgroundcolor=c,
            fontdict={"size": 11},
    )
    return fig

def centroid_on_cxr_with_lision_bb(cxr, birch, bb_df, centre_x, centre_y):
    fig, ax = plt.subplots(figsize=(15, 15))
    plt.imshow(cxr)
    
    norm_centroid_x = birch.subcluster_centers_[:, 0]
    norm_centroid_y = birch.subcluster_centers_[:, 1]


    centroids_x = (norm_centroid_x*centre_x)+centre_x
    centroids_y = (norm_centroid_y*centre_y)+centre_y
    ax.scatter(centroids_x, centroids_y, c=birch.subcluster_labels_, cmap='rainbow', s=1000)
    for x,y, t in zip(centroids_x, centroids_y,  birch.subcluster_labels_):
        ax.annotate(t, (x, y), )
        
    for _, bb in bb_df.iterrows():
        bbox = list(bb[DEFAULT_REFLACX_BOX_COORD_COLS])
        lesions = []

        for k,v  in dict(bb[DEFAULT_REFLACX_LABEL_COLS]).items():
            if v:
                lesions.append(k)

        lesion_name = ", ".join(lesions)
        ax.add_patch(
            Rectangle(
                (bbox[0], bbox[1]),
                bbox[2] - bbox[0],
                bbox[3] - bbox[1],
                fill=False,
                color='darkorange',
                linewidth=2,
            )
        )
        ax.text(
            bbox[0],
            bbox[1],
            lesion_name,
            color="black",
            backgroundcolor='darkorange',
            fontdict={"size": 11},
        )
    return fig


def get_graph(gaze_df, birch, subcluster_labels, num_clusters):

    # assign cluster to each gaze point.
    gaze_df['cluster'] = subcluster_labels.astype(int)

    # create the directed adjacency matrix (it will be later be transformed to undirected to use the following sparsification method.)
    directed_adjacency =  np.zeros((num_clusters, num_clusters)).astype(int)

    norm_centroid_x = birch.subcluster_centers_[:, 0]
    norm_centroid_y = birch.subcluster_centers_[:, 1]

    for i in range(len(gaze_df)-1):
        u_cluster = gaze_df.iloc[i]['cluster'] 
        v_cluster = gaze_df.iloc[i+1]['cluster'] 
        directed_adjacency[int(u_cluster), int(v_cluster)] += 1 

    undirected_adjacency = directed_adjacency + directed_adjacency.T

    # divide the diagonal by two since we just added them.
    for i in range(len(undirected_adjacency)):
        undirected_adjacency[i, i] /=2

    G = nx.Graph()
    # Add nodes to the graph and set their positions
    for (i, x, y) in zip(birch.subcluster_labels_, norm_centroid_x, norm_centroid_y):
        N = len(gaze_df[gaze_df['cluster'] == i])
        C = undirected_adjacency[i,i] + 1
        G.add_node(i, pos=(x,y), N=N, C=C)
    
    # Add edges to the graph and set their weights
    for i in range(len(undirected_adjacency)):
        for j in range(len(undirected_adjacency[i])):
            if undirected_adjacency[i][j] > 0:
                G.add_edge(i, j, weight=undirected_adjacency[i][j])


    return G

def graph_fig(G):
    fig, ax = plt.subplots(figsize=(15, 15))
    plt.gca().invert_yaxis()
    pos = nx.get_node_attributes(G, 'pos')
    edge_weights = nx.get_edge_attributes(G, 'weight')
    edge_widths = [np.log10(w + 1) for (_, _, w) in G.edges(data='weight')]
    nx.draw_networkx(G, pos=pos, with_labels=True, width=edge_widths)
    nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_weights)
    plt.show()
    return fig

def graph_on_cxr(G, cxr, centre_x, centre_y):
    fig, ax = plt.subplots(figsize=(15, 15))
    plt.imshow(cxr)
    # plt.gca().invert_yaxis()
    pos = nx.get_node_attributes(G, 'pos')
    for k, p in pos.items():
        pos[k] = ((p[0]*centre_x)+centre_x, (p[1]*centre_y)+centre_y)
    edge_weights = nx.get_edge_attributes(G, 'weight')
    edge_widths = [np.log10(w + 1) for (_, _, w) in G.edges(data='weight')]
    nx.draw_networkx(G, pos=pos, with_labels=True, width=edge_widths, edge_color="red")
    # nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_weights)
    plt.show()

    return fig

def sparsify(G, k):

    n = G.number_of_nodes()
    m = G.number_of_edges()

    # so first, we get the matrix
    L = nx.laplacian_matrix(G).todense()

    L_plus = np.linalg.pinv(L)
    effective_resistance = np.zeros(m)
    custom_weights = np.zeros(m)
    for i, (u, v) in enumerate(G.edges()):
        data = G.get_edge_data(u, v)
        e_u = np.zeros(n)
        e_u[u] = 1
        e_v = np.zeros(n)
        e_v[v] = 1

        R_uv = (e_u-e_v).T @ L_plus @ (e_u-e_v)
        effective_resistance[i] = R_uv

        N_u = G.nodes[u]['N'] # nodes
        N_v = G.nodes[v]['N']
        C_u = G.nodes[u]['C'] # self.loop
        C_v = G.nodes[v]['C']

        # custom_weights[i] = (np.exp(-(N_u**2)*C_u)**(-1)) * (np.exp(-np.power(N_v,2)*C_v)**(-1)) # according to the paper,
        custom_weights[i] = np.e**((-(N_u**2)*C_u).astype(float)**(-1)) * np.e**((-(N_v**2)*C_v).astype(float)**(-1)) #

        if custom_weights[i] == 0:
            print(N_u)
            print(N_v)
            print(C_u)
            print(C_v)
            raise StopIteration()

    combined_weights = custom_weights * effective_resistance
    p = combined_weights / np.sum(combined_weights)

    H = nx.Graph()
    H.add_nodes_from(G.nodes(data=True))

    none_zero_count = np.count_nonzero(p)

    if  none_zero_count < k:
        sampled_edges = np.random.choice(m, size=none_zero_count, replace=False, p=p)
    else:
        sampled_edges = np.random.choice(m, size=k, replace=False, p=p)

    for i in sampled_edges:
        u, v = list(G.edges())[i]
        H.add_edge(u, v, weight=G.edges()[u,v]['weight'])

    return H

def eyegaze_graph_on_cxr_with_anatomical(anatomical_df, cxr, H, centre_x, centre_y):

    fig, ax = plt.subplots(figsize=(15, 15))
    plt.imshow(cxr)

    for _, bb in anatomical_df.iterrows():
        bbox = list(bb[EYEGAZE_COORD_COLS])
        bbox_name  = bb['bbox_name']

        if not (bbox_name in anatomical_cmap):
            continue

        c = anatomical_cmap[bbox_name]
        ax.add_patch(
            Rectangle(
                (bbox[0], bbox[1]),
                bbox[2] - bbox[0],
                bbox[3] - bbox[1],
                fill=False,
                color=c,
                linewidth=2,
            )
        )
        ax.text(
            bbox[0],
            bbox[1],
            bbox_name,
            color="black",
            backgroundcolor=c,
            fontdict={"size": 11},
        )


    # plt.gca().invert_yaxis()
    pos = nx.get_node_attributes(H, 'pos')
    for k, p in pos.items():
        pos[k] = ((p[0]*centre_x)+centre_x, (p[1]*centre_y)+centre_y)
    edge_weights = nx.get_edge_attributes(H, 'weight')
    edge_widths = [np.log10(w + 1) for (_, _, w) in H.edges(data='weight')]
    nx.draw_networkx(H, pos=pos, with_labels=True, width=edge_widths, edge_color="red")
    # nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_weights)

    plt.show()

    return fig

def reflacx_graph_on_cxr_with_lesion_bb(bb_df, cxr, H, centre_x, centre_y):

    bb_df = bb_df[DEFAULT_REFLACX_BOX_COORD_COLS + DEFAULT_REFLACX_LABEL_COLS]

    fig, ax = plt.subplots(figsize=(15, 15))
    plt.imshow(cxr)

    for _, bb in bb_df.iterrows():
        bbox = list(bb[DEFAULT_REFLACX_BOX_COORD_COLS])
        lesions = []

        for k,v  in dict(bb[DEFAULT_REFLACX_LABEL_COLS]).items():
            if v:
                lesions.append(k)

        lesion_name = ", ".join(lesions)
        ax.add_patch(
            Rectangle(
                (bbox[0], bbox[1]),
                bbox[2] - bbox[0],
                bbox[3] - bbox[1],
                fill=False,
                color='darkorange',
                linewidth=2,
            )
        )
        ax.text(
            bbox[0],
            bbox[1],
            lesion_name,
            color="black",
            backgroundcolor='darkorange',
            fontdict={"size": 11},
        )


    # plt.gca().invert_yaxis()
    pos = nx.get_node_attributes(H, 'pos')
    for k, p in pos.items():
        pos[k] = ((p[0]*centre_x)+centre_x, (p[1]*centre_y)+centre_y)
    edge_weights = nx.get_edge_attributes(H, 'weight')
    edge_widths = [np.log10(w + 1) for (_, _, w) in H.edges(data='weight')]
    nx.draw_networkx(H, pos=pos, with_labels=True, width=edge_widths, edge_color="red")
    # nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_weights)

    plt.show()

    return fig

In [7]:
for i, data in in_both_df.iterrows():
    patient_id = data['subject_id'] 
    study_id = data['study_id']
    reflacx_id = data['id']
    dicom_id = data['dicom_id']

    patient_folder = os.path.join(MIMIC_EYE_PATH, f"patient_{patient_id}")
    save_folder_name = os.path.join( "figs",f"{patient_id}_{study_id}_{dicom_id}")

    os.makedirs(save_folder_name,exist_ok=True)

    cxr_meta_df = pd.read_csv(os.path.join(MIMIC_EYE_PATH, f"patient_{patient_id}", "CXR-JPG", "cxr_meta.csv"))
    cxr = Image.open(os.path.join(MIMIC_EYE_PATH, f"patient_{patient_id}", "CXR-JPG", f"s{study_id}", f"{dicom_id}.jpg" )).convert("RGB")


    img_width, img_height = cxr_meta_df["Columns"][0], cxr_meta_df["Rows"][0]
    centre_x = img_width / 2
    centre_y = img_height / 2

    eyegaze_gaze_df = pd.read_csv(os.path.join(patient_folder, "EyeGaze", "eye_gaze.csv"))
    eyegaze_gaze_df = eyegaze_gaze_df[["X_ORIGINAL", "Y_ORIGINAL"]]
    eyegaze_gaze_df.columns = ["x_position", "y_position"]
    eyegaze_gaze_df = gaze_df_preprocessing(eyegaze_gaze_df, img_width, img_height)

    reflacx_gaze_df = pd.read_csv(os.path.join(patient_folder, "REFLACX", "gaze_data", reflacx_id, "gaze.csv"))
    reflacx_gaze_df = gaze_df_preprocessing(reflacx_gaze_df, img_width, img_height)

    eyegaze_gaze_points = np.array(
        [
            (x, y)
            for x, y in zip(eyegaze_gaze_df["x_position"], eyegaze_gaze_df["y_position"])
        ]
    )

    reflacx_gaze_points = np.array([(x,y) for x, y in zip(reflacx_gaze_df['x_position'], reflacx_gaze_df['y_position'])])

    eyegaze_norm_gaze_points = (
        eyegaze_gaze_points - np.array([centre_x, centre_y])
    ) / np.array([centre_x, centre_y])

    reflacx_norm_gaze_points = (reflacx_gaze_points - np.array([centre_x, centre_y])) / np.array([centre_x, centre_y])

    eyegaze_birch = Birch(
        threshold=threshold, branching_factor=branching_factor, n_clusters=n_clusters
    )
    eyegaze_birch.fit(eyegaze_norm_gaze_points)
    eyegaze_subcluster_labels = eyegaze_birch.predict(eyegaze_norm_gaze_points)


    reflacx_birch = Birch(threshold=threshold, branching_factor=branching_factor, n_clusters=n_clusters)
    reflacx_birch.fit(reflacx_norm_gaze_points)
    reflacx_subcluster_labels = reflacx_birch.predict(reflacx_norm_gaze_points)

    eyegaze_gp_frame_fig = gaze_points_visualise_in_frame(
        eyegaze_norm_gaze_points, eyegaze_subcluster_labels
    )

    eyegaze_gp_frame_fig.savefig(os.path.join(save_folder_name, "eyegaze_gp_frame_fig.png")) 
    clear_output()
    reflacx_gp_frame_fig = gaze_points_visualise_in_frame(
        reflacx_norm_gaze_points, reflacx_subcluster_labels
    )
    reflacx_gp_frame_fig.savefig(os.path.join(save_folder_name, "reflacx_gp_frame_fig.png")) 
    clear_output()


    eyegaze_gp_cxr_fig = gaze_points_visualise_on_cxr(
        cxr, eyegaze_norm_gaze_points, eyegaze_subcluster_labels, centre_x, centre_y
    )
    eyegaze_gp_cxr_fig.savefig(os.path.join(save_folder_name, "eyegaze_gp_cxr_fig.png")) 
    clear_output()


    reflacx_gp_cxr_fig = gaze_points_visualise_on_cxr(
        cxr, reflacx_norm_gaze_points, reflacx_subcluster_labels, centre_x, centre_y
    )
    reflacx_gp_cxr_fig.savefig(os.path.join(save_folder_name, "reflacx_gp_cxr_fig.png")) 
    clear_output()


    reflacx_lesion_bb_df = pd.read_csv(
        os.path.join(
            patient_folder,
            "REFLACX",
            "main_data",
            reflacx_id,
            "anomaly_location_ellipses.csv",
        )
    )

    reflacx_lesion_bb_fig = gaze_with_reflacx_lesion_bb(
        reflacx_lesion_bb_df,
        reflacx_norm_gaze_points,
        reflacx_subcluster_labels,
        centre_x,
        centre_y,
    )

    reflacx_lesion_bb_fig.savefig(os.path.join(save_folder_name, "reflacx_lesion_bb_fig.png")) 
    clear_output()

    eyegaze_anatomical_df = pd.read_csv(
        os.path.join(patient_folder, "EyeGaze", "bounding_boxes.csv")
    )

    eyegaze_anatomical_fig = gaze_with_eyegaze_anatomical_bb(
        eyegaze_anatomical_df,
        eyegaze_norm_gaze_points,
        eyegaze_subcluster_labels,
        centre_x,
        centre_y,
    )

    eyegaze_anatomical_fig.savefig(os.path.join(save_folder_name, "eyegaze_anatomical_fig.png")) 
    clear_output()

    eyegaze_num_clusters = len(eyegaze_birch.subcluster_labels_)
    reflacx_num_clusters = len(reflacx_birch.subcluster_labels_)

    eyegaze_centroid_frame_fig = centroid_in_frame(eyegaze_birch)
    eyegaze_centroid_frame_fig.savefig(os.path.join(save_folder_name, "eyegaze_centroid_frame_fig.png")) 
    clear_output()

    reflacx_centroid_frame_fig = centroid_in_frame(reflacx_birch)
    reflacx_centroid_frame_fig.savefig(os.path.join(save_folder_name, "reflacx_centroid_frame_fig.png")) 
    clear_output()

    eyegaze_centroid_cxr_fig = centroid_on_cxr(eyegaze_birch, centre_x, centre_y)
    eyegaze_centroid_cxr_fig.savefig(os.path.join(save_folder_name, "eyegaze_centroid_cxr_fig.png")) 
    clear_output()

    reflacx_centroid_cxr_fig = centroid_on_cxr(reflacx_birch, centre_x, centre_y)
    reflacx_centroid_cxr_fig.savefig(os.path.join(save_folder_name, "reflacx_centroid_cxr_fig.png")) 
    clear_output()

    eyegaze_centroid_cxr_anatomical_fig = centroid_on_cxr_with_anatomical_bb(cxr, eyegaze_birch, eyegaze_anatomical_df, centre_x, centre_y)
    eyegaze_centroid_cxr_anatomical_fig.savefig(os.path.join(save_folder_name, "eyegaze_centroid_cxr_anatomical_fig.png")) 
    clear_output()

    relfacx_centroid_cxr_lesion_fig=  centroid_on_cxr_with_lision_bb(cxr, reflacx_birch, reflacx_lesion_bb_df, centre_x, centre_y)
    relfacx_centroid_cxr_lesion_fig.savefig(os.path.join(save_folder_name, "relfacx_centroid_cxr_lesion_fig.png")) 
    clear_output()

    eyegaze_G = get_graph(eyegaze_gaze_df, eyegaze_birch, eyegaze_subcluster_labels, eyegaze_num_clusters)
    reflacx_G = get_graph(reflacx_gaze_df, reflacx_birch, reflacx_subcluster_labels, reflacx_num_clusters)

    eyegaze_G_fig = graph_fig(eyegaze_G)
    eyegaze_G_fig.savefig(os.path.join(save_folder_name, "eyegaze_G_fig.png")) 
    clear_output()

    reflacx_G_fig = graph_fig(reflacx_G)
    reflacx_G_fig.savefig(os.path.join(save_folder_name, "reflacx_G_fig.png")) 
    clear_output()

    eyegaze_G_cxr_fig = graph_on_cxr(eyegaze_G, cxr, centre_x, centre_y)
    eyegaze_G_cxr_fig.savefig(os.path.join(save_folder_name, "eyegaze_G_cxr_fig.png")) 
    clear_output()

    reflacx_G_cxr_fig = graph_on_cxr(reflacx_G, cxr, centre_x, centre_y)
    reflacx_G_cxr_fig.savefig(os.path.join(save_folder_name, "reflacx_G_cxr_fig.png")) 
    clear_output()

    eyegaze_H = sparsify(eyegaze_G, k)
    reflacx_H = sparsify(reflacx_G, k)

    eyegaze_H_fig = graph_fig(eyegaze_H)
    eyegaze_H_fig.savefig(os.path.join(save_folder_name, "eyegaze_H_fig.png")) 
    clear_output()

    reflacx_H_fig = graph_fig(reflacx_H)
    reflacx_H_fig.savefig(os.path.join(save_folder_name, "reflacx_H_fig.png")) 
    clear_output()

    eyegaze_H_cxr_fig = graph_on_cxr(eyegaze_H, cxr, centre_x, centre_y)
    eyegaze_H_cxr_fig.savefig(os.path.join(save_folder_name, "eyegaze_H_cxr_fig.png")) 
    clear_output()

    reflacx_H_cxr_fig = graph_on_cxr(reflacx_H, cxr, centre_x, centre_y)
    reflacx_H_cxr_fig.savefig(os.path.join(save_folder_name, "reflacx_H_cxr_fig.png")) 
    clear_output()

    eyegaze_H_cxr_anatomical_fig = eyegaze_graph_on_cxr_with_anatomical(eyegaze_anatomical_df, cxr, eyegaze_H, centre_x, centre_y)
    eyegaze_H_cxr_anatomical_fig.savefig(os.path.join(save_folder_name, "eyegaze_H_cxr_anatomical_fig.png")) 
    clear_output()

    reflacx_H_cxr_lesion_fig =  reflacx_graph_on_cxr_with_lesion_bb(reflacx_lesion_bb_df, cxr, reflacx_H, centre_x, centre_y)
    reflacx_H_cxr_lesion_fig.savefig(os.path.join(save_folder_name, "reflacx_H_cxr_lesion_fig.png")) 
    clear_output()