In [11]:
import os
import numpy as np
import pandas as pd
import plotly.express as px
import matplotlib.pyplot as plt
import plotly.graph_objects as go
from plotly.subplots import make_subplots

### Setup

In [13]:
SIMILAR_LOG_TOP_SCAN_RETRIEVAL_SIZE = 5
ROOT_DIR = "/home/miclin/nlst-similarity"

In [14]:
class DataDict():
    
    data_dict = pd.DataFrame
    metadata = pd.DataFrame    
    
    def __init__(self, data_dict_loc="Data dictionary.csv", 
                 metadata_loc="nlst_297_prsn_20170404.csv"):
        
        self.data_dict = pd.read_csv(data_dict_loc)
        self.metadata = pd.read_csv(metadata_loc)
        all_categories = np.unique(self.data_dict["Category"])
        
        
        by_categories = {}

        for category in all_categories:
            col_by_categories = self.data_dict[self.data_dict["Category"] == category]["Variable"].values
            col_by_categories = [col for col in col_by_categories if col in self.metadata.columns]
            by_categories[category] = self.metadata.loc[:, col_by_categories]
            
        self.by_categories = by_categories

    
    def __str__(self):
        return f"All categories: {np.unique(self.data_dict['Category'])}"
        
    def return_category(self, category):
        '''Returns the columns in the metadata belonging to the given category.'''
        
        assert category in self.by_categories, f"category {category} doesn't exist"
        
        return self.by_categories[category]
    
    def category_info(self, category):
        '''Returns information about a specific category.'''
        
        assert category in self.by_categories, f"category {category} doesn't exist"
        
        return self.data_dict[self.data_dict["Category"] == category]
    
    def column_desc(self, column_name):
        '''Returns description about a column name from the metadata table.'''
        
        assert column_name in self.metadata.columns, "column not found"
        
        desc = self.data_dict[self.data_dict["Variable"] == column_name] 
        print(f"Label: {desc['Label'].values}")
        print(f"Description: {desc['Description'].values}")
        return desc



In [15]:
class Visualize():

    log = pd.DataFrame
    metadata = pd.DataFrame 

    def __init__(self, metadata, log_loc):

        self.metadata = metadata.set_index("pid")

        self.log = pd.read_csv(log_loc)
    
    def get_pids(self):
        '''Get a 2D array with just pids. 
        An element of the array looks like the following:
        [PID of query scan, PID of SMLR 0, PID of SMLR 1, ...]
        '''
        pid_col_names = ["PID"] + [f"SMLR_{i}_PID" for i in range(SIMILAR_LOG_TOP_SCAN_RETRIEVAL_SIZE)]
        
        return self.log[pid_col_names].values


    def quantitative_(self, col: str, pids: np.ndarray,
                  width: int =1500, height: int=500):
        '''
        Parameters:
            1. col: column name 
            2. metadata: metadata table  
            3. pids: list of pids; each item represents a query scan and its top five most similar scans
            4. width and height: size of the plot 
            
        Produce a dot plot of up to 100 query scans at once, so will only take the first 100 from the list of pids.
        Good for representing quantitative data, such as age, smokeage, height, etc.
        Appropriate to present one attribute at a time.
        
        '''
        
        #smoke = by_categories["Smoking"]#.to_dict()
        #pids = np.array(toy_sample[:100])
        
        pids = pids[:100]
        query_pids = pids[:,0].astype(str)
        names = ["Query scan"] + [f"SMLR_{i}" for i in range(5)]

        fig = go.Figure()
        for i in range(SIMILAR_LOG_TOP_SCAN_RETRIEVAL_SIZE):
            

            y = self.metadata.loc[pids[:,i],:][col]

            fig.add_trace(go.Scatter(
                x=query_pids,
                y=y,
                marker=dict(color="#636EFA", size=5),
                mode="markers",
                name=names[i],
                showlegend=False
            ))


        fig.update_layout(title=col,
                        xaxis_title="Query scan pid",
                        yaxis_title=col, 
                        xaxis=dict(dtick=1), 
                        width=width, 
                        height=height)

        fig.show()

    def binary_(self, cols: str, pids: np.ndarray,
                    width: int =1500, height: int=500):
        '''
        Parameters:
            1. cols: colume name
            2. metadata: metadata table  
            3. pids: list of pids; each item represents a query scan and its top five most similar scans
            4. width and height: size of the plot 
            
        Produce bar plots of up to 100 query scans at once, so will only take the first 100 from the list of pids.
        Good for representing quantitative data, such as age, smokeage, height, etc.
        
        '''
        
        
        pids = pids[:100]
        query_pids = pids[:,0].astype(str)
        
        fig = make_subplots(rows=len(cols), cols=1, shared_xaxes=True)

        for i, col in enumerate(cols):
            
            # get the binary markers without NaN: could be 0 and 1 or 1 and 2
            binary_markers = np.unique(self.metadata[col])
            binary_markers = binary_markers[~np.isnan(binary_markers)]

            assert len(binary_markers) == 2, f"{col} is not binary data"
            
            y = []
            for j in range(len(pids)):
                data = self.metadata.loc[pids[j],:][col].values

                # proportion out of the 5 scans that has marker binary_marker[0]
                prop = np.sum(data == binary_markers[0]) / SIMILAR_LOG_TOP_SCAN_RETRIEVAL_SIZE
                y.append(prop)

            fig.add_trace(go.Bar(
                x=query_pids,
                y=y,
                showlegend=False,
                name = f"{col}:{binary_markers[0]}",
            ), row= i + 1,
                col=1)
            fig.update_xaxes(dtick=1,title="Query scan pid")

        fig.update_layout(width=width, 
                        height=height)
        fig.show()



In [16]:

data_dict_loc = os.path.join(ROOT_DIR, "metadata", "Data dictionary.csv")
metadata_loc = os.path.join(ROOT_DIR, "metadata/nlst_297_prsn_20170404.csv")
dd = DataDict(data_dict_loc=data_dict_loc, metadata_loc=metadata_loc)

model_name = "r3d_18_v0"
metadata = pd.read_csv(
        metadata_loc,
        dtype={201: "str", 224: "str", 225: "str"}
    )

log_loc = os.path.join(ROOT_DIR,"logs","same_patient", model_name, "0.csv")
visualize = Visualize(metadata=metadata, log_loc=log_loc)
pids = visualize.get_pids()
print(f"Number of records: {len(pids)}")


  if (await self.run_code(code, result,  async_=asy)):


Number of records: 300


In [18]:
print(dd)

All categories: ['Alcohol' 'Cancer history' 'Cancers of Any Site' 'Death' 'Demographic'
 'Disease history' 'Follow-Up/Procedures' 'Lung cancer' 'Progression'
 'Screening' 'Smoking' 'Study' 'Work history']


In [19]:
dd.category_info("Smoking")

Unnamed: 0,Variable,Label,Description,Format Text,Category
17,age_quit,Age at smoking cessation,Age at which a participant stopped smoking cig...,"Numeric .N=""No age given""",Smoking
18,cigar,Participant smokes/smoked cigars,,".M=""Missing"" 0=""No"" 1=""Yes""",Smoking
19,cigsmok,Smoking status at T0,Cigarette smoking status (current vs former) a...,"0=""Former"" 1=""Current""",Smoking
20,pipe,Participant smokes/smoked a pipe,,".M=""Missing"" 0=""No"" 1=""Yes""",Smoking
21,pkyr,Pack years,"Pack years, calculated as: (Total Years Smoked...",Numeric,Smoking
22,smokeage,Age at smoking onset,Age at which a participant started smoking cig...,"Numeric .M=""Missing""",Smoking
23,smokeday,Avg. num. of cigarettes per day,,Numeric,Smoking
24,smokelive,Participant lives/lived with smoker,,".M=""Missing"" 0=""No"" 1=""Yes""",Smoking
25,smokework,Participant works/worked with exposure to smokers,,".M=""Missing"" 0=""No"" 1=""Yes""",Smoking
26,smokeyr,Total years of smoking,Total number of years the participant smoked c...,Numeric,Smoking


In [8]:
visualize.quantitative_(col="smokeage", 
                        pids=pids)



In [9]:
visualize.binary_(cols=["gender", "diaghype"], 
                        pids=pids)