In [None]:
# Install medcat
! pip install medcat==1.2.7

**Restart the runtime if on colab, sometimes necessary after installing models**

In [None]:
import pandas as pd
import numpy as np
import seaborn as sns
import json 

from matplotlib import pyplot as plt
from medcat.cat import CAT

In [None]:
DATA_DIR = "./data/"
model_pack_path = DATA_DIR + "medmen_wstatus_2021_oct.zip"

### Download and load the data

In [None]:
# Download the models and required data
!wget https://raw.githubusercontent.com/CogStack/MedCATtutorials/main/notebooks/introductory/data/pt_notes.csv -P ./data/
!wget https://raw.githubusercontent.com/CogStack/MedCATtutorials/main/notebooks/introductory/data/MedCAT_Export.json -P ./data/
!wget https://raw.githubusercontent.com/CogStack/MedCATtutorials/main/notebooks/introductory/data/cui_location.json -P ./data/

!wget https://medcat.rosalind.kcl.ac.uk/media/medmen_wstatus_2021_oct.zip -P ./data/

In [None]:
# Load the cui_location dictionary
cui_location = json.load(open(DATA_DIR + "cui_location.json"))
# Load the pt_notes
data = pd.read_csv(DATA_DIR + "pt_notes.csv")
# Get the model
cat = CAT.load_model_pack(model_pack_path)

## Calculate the required maps

`CUI` - Disease identifier

`cui_subjects_cnts` - Used to calculate the number of occurances of a disease for a certain patient (in his text records). 

`cui_subjects` - Map from cui to patientId

`subject2gender` - Map from patientId to gender
`subject2agegroup` - Map from patientId to agegroup (defined [here](https://www.researchgate.net/publication/232746130_Automated_Medical_Literature_Retrieval)

`subject2age` - Map from patientId to age

`subject2cuis` - Map from patientId to list of CUIs

In [None]:
# We are only going to keep those subjects that have at least two appearances of a concept
cui_subjects_cnts = {}
cui_subjects = {}
subject2gender = {}
subject2agegroup = {}
subject2age = {}
subject2cuis = {}

for cui in cui_location:
    cui_subjects_cnts[cui] = {}
    cui_subjects[cui] = set()
    
     
    for location in cui_location[cui]:
        subject_id = data.iat[location, list(data.columns).index('subject_id')]
        gender  = data.iat[location, list(data.columns).index('gender')]
        age = data.iat[location, list(data.columns).index('age_year')]
        agegroup = "UNK"
        
        if subject_id not in subject2cuis:
            subject2cuis[subject_id] = set()

        subject2cuis[subject_id].add(cui)
        
        if age >= 13 and age <= 18:
            agegroup = 'Adolescent'
        elif age >= 19 and age <= 24:
            agegroup = "Young Adult"
        elif age >= 25 and age <= 44:
            agegroup = 'Adult'
        elif age >= 45 and age <= 64:
            agegroup = "Middle Aged"
        elif age >= 64:
            agegroup = "Aged"
        
        if subject_id in cui_subjects_cnts[cui]:
            cui_subjects_cnts[cui][subject_id] += 1
        else:
            cui_subjects_cnts[cui][subject_id] = 1

        if cui_subjects_cnts[cui][subject_id] == 2:
            cui_subjects[cui].add(subject_id)
            
            subject2gender[subject_id] = gender
            subject2agegroup[subject_id] = agegroup
            subject2age[subject_id] = age

# Calculate total number of patients in each group:

`pt_total` - Total number (all patients together)

`pt_male` - Number of male patients

`pt_female` - Number of female patients

`pt_adl` - Number of Adolescents

`pt_yadult` - Number of Young Adults

`pt_adult` - Number of Adults

`pt_maged` - Number of Middle Aged

`pt_aged` - Number of Aged

In [None]:
pt_total = len(subject2age)
pt_male = len([x for x in subject2gender if subject2gender[x] == "M"])
pt_female = len([x for x in subject2gender if subject2gender[x] == "F"])
pt_adl = len([x for x in subject2agegroup if subject2agegroup[x] == "Adolescent"])
pt_yadult = len([x for x in subject2agegroup if subject2agegroup[x] == "Young Adult"])
pt_adult = len([x for x in subject2agegroup if subject2agegroup[x] == "Adult"])
pt_maged = len([x for x in subject2agegroup if subject2agegroup[x] == "Middle Aged"])
pt_aged = len([x for x in subject2agegroup if subject2agegroup[x] == "Aged"])

### Create the header of our CSV that will contain all information for one disease

In [None]:
dt = [['disease', 'cui', 'tui', 'total', 'male', 'female', 'Adolescent', 'Young Adult', "Adult", "Middle Aged", "Aged"]]

### Fill the array with data

In [None]:
for cui in cui_subjects:
    d = cat.cdb.cui2preferred_name.get(cui, 'unk')
    t = (len(cui_subjects[cui]) / pt_total) * 100
    m = (len([x for x in cui_subjects[cui] if subject2gender[x] == 'M']) / pt_male) * 100
    f = (len([x for x in cui_subjects[cui] if subject2gender[x] == 'F']) / pt_female) * 100
    
    adl =  (len([x for x in cui_subjects[cui] if subject2agegroup[x] == 'Adolescent']) / pt_adl) * 100
    yadult =  (len([x for x in cui_subjects[cui] if subject2agegroup[x] == 'Young Adult']) / pt_yadult) * 100
    adult = (len([x for x in cui_subjects[cui] if subject2agegroup[x] == 'Adult']) / pt_adult) * 100
    maged = (len([x for x in cui_subjects[cui] if subject2agegroup[x] == 'Middle Aged']) / pt_maged) * 100
    aged = (len([x for x in cui_subjects[cui] if subject2agegroup[x] == 'Aged']) / pt_aged) * 100
    
    tui = cat.cdb.cui2type_ids.get(cui, 'unk')
    dt.append([d, cui, tui, t, m, f, adl, yadult, adult, maged, aged])
    #dt.append([d, cui, tui, t, m, f, adl, yadult, adult, maged, aged, m_adl, m_yadult, m_adult, m_maged, m_aged,
            #  f_adl, f_yadult, f_adult, f_maged, f_aged])

### Convert to dataframe

In [None]:
df = pd.DataFrame(dt[1:], columns=dt[0])
df = df.sort_values(by=['total'], ascending=True)
df.reset_index(drop=True, inplace=True)

# Plot top 30 diseases for the type T047: Physical Disorders

In [None]:
sns.reset_defaults()
sns.set(
    rc={'figure.figsize':(5,10)}, 
    style="whitegrid",
    palette='pastel'
)
# Subset the data and chose only T047, top 30
_data = df[df['tui'].apply(lambda x: 'T047' in x)].iloc[-30:]

# Create the plot
ax = _data.plot(y=['total'], x="disease", kind="barh")
# Legend position and label names
ax.legend(loc='lower right')
_ = ax.set(xlim=(0, 15), ylabel="Disease Name", xlabel="Percentage of patients with disease", )
# Show
plt.show()

End of tutorial 