In [1]:
!pip install torchxrayvision

Collecting torchxrayvision
  Downloading torchxrayvision-0.0.37-py3-none-any.whl (29.0 MB)
[K     |████████████████████████████████| 29.0 MB 1.5 MB/s 
Installing collected packages: torchxrayvision
Successfully installed torchxrayvision-0.0.37


In [55]:
import torchxrayvision as xrv

from skimage import color
from skimage import io
from skimage.transform import resize
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

from math import nan

from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score
from skimage.transform import resize
from sklearn.metrics import accuracy_score

from sklearn.naive_bayes import GaussianNB, MultinomialNB, BernoulliNB, ComplementNB
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
from sklearn.neural_network import MLPClassifier

In [3]:
!git clone https://github.com/ieee8023/covid-chestxray-dataset

Cloning into 'covid-chestxray-dataset'...
remote: Enumerating objects: 3641, done.[K
remote: Total 3641 (delta 0), reused 0 (delta 0), pack-reused 3641[K
Receiving objects: 100% (3641/3641), 632.96 MiB | 24.31 MiB/s, done.
Resolving deltas: 100% (1450/1450), done.
Checking out files: 100% (1174/1174), done.


In [4]:
d = xrv.datasets.COVID19_Dataset(imgpath="covid-chestxray-dataset/images/",csvpath="covid-chestxray-dataset/metadata.csv")

In [5]:
d.csv.head()

Unnamed: 0,index,patientid,offset,sex,age,finding,RT_PCR_positive,survival,intubated,intubation_present,...,location,folder,filename,doi,url,license,clinical_notes,other_notes,Unnamed: 29,offset_day_int
0,0,2,0.0,M,65.0,Pneumonia/Viral/COVID-19,Y,Y,N,N,...,"Cho Ray Hospital, Ho Chi Minh City, Vietnam",images,auntminnie-a-2020_01_28_23_51_6665_2020_01_28_...,10.1056/nejmc2001272,https://www.nejm.org/doi/full/10.1056/NEJMc200...,,"On January 22, 2020, a 65-year-old man with a ...",,,0.0
1,1,2,3.0,M,65.0,Pneumonia/Viral/COVID-19,Y,Y,N,N,...,"Cho Ray Hospital, Ho Chi Minh City, Vietnam",images,auntminnie-b-2020_01_28_23_51_6665_2020_01_28_...,10.1056/nejmc2001272,https://www.nejm.org/doi/full/10.1056/NEJMc200...,,"On January 22, 2020, a 65-year-old man with a ...",,,3.0
2,2,2,5.0,M,65.0,Pneumonia/Viral/COVID-19,Y,Y,N,N,...,"Cho Ray Hospital, Ho Chi Minh City, Vietnam",images,auntminnie-c-2020_01_28_23_51_6665_2020_01_28_...,10.1056/nejmc2001272,https://www.nejm.org/doi/full/10.1056/NEJMc200...,,"On January 22, 2020, a 65-year-old man with a ...",,,5.0
3,3,2,6.0,M,65.0,Pneumonia/Viral/COVID-19,Y,Y,N,N,...,"Cho Ray Hospital, Ho Chi Minh City, Vietnam",images,auntminnie-d-2020_01_28_23_51_6665_2020_01_28_...,10.1056/nejmc2001272,https://www.nejm.org/doi/full/10.1056/NEJMc200...,,"On January 22, 2020, a 65-year-old man with a ...",,,6.0
4,4,4,0.0,F,52.0,Pneumonia/Viral/COVID-19,Y,,N,N,...,"Changhua Christian Hospital, Changhua City, Ta...",images,nejmc2001573_f1a.jpeg,10.1056/NEJMc2001573,https://www.nejm.org/doi/full/10.1056/NEJMc200...,,diffuse infiltrates in the bilateral lower lungs,,,0.0


In [6]:
split_tags = [ val.split('/') for val in d.csv['finding'] ]

In [7]:
all_tags = set()

for split_tag in split_tags:
    for val in split_tag:
        all_tags.add(val)

all_tags = list(all_tags)
print(all_tags)

['Lipoid', 'MRSA', 'Viral', 'Klebsiella', 'No Finding', 'Mycoplasma', 'SARS', 'COVID-19', 'Herpes ', 'Nocardia', 'Legionella', 'Aspiration', 'Tuberculosis', 'Chlamydophila', 'Aspergillosis', 'Pneumocystis', 'Influenza', 'Staphylococcus', 'Fungal', 'MERS-CoV', 'Varicella', 'H1N1', 'Streptococcus', 'Pneumonia', 'Bacterial']


In [8]:
tag_index = {'Bacterial':1, 'Klebsiella':1, 'H1N1':2, 'SARS':0, 'Aspergillosis':1, 'Staphylococcus':1, 'Herpes ':0, 'Varicella':0, 'Viral':0, 'Pneumonia':0, 'Mycoplasma':1, 'Streptococcus':1, 'Nocardia':1, 'MERS-CoV':0, 'MRSA':2, 'Influenza':0, 'Lipoid':2, 'Pneumocystis':1, 'No Finding':2, 'Chlamydophila':1, 'Fungal':1, 'Aspiration':2, 'Legionella':1, 'Tuberculosis':2, 'COVID-19':0}

In [9]:
data = d.csv[['finding', 'filename']]

In [10]:
data.head()

Unnamed: 0,finding,filename
0,Pneumonia/Viral/COVID-19,auntminnie-a-2020_01_28_23_51_6665_2020_01_28_...
1,Pneumonia/Viral/COVID-19,auntminnie-b-2020_01_28_23_51_6665_2020_01_28_...
2,Pneumonia/Viral/COVID-19,auntminnie-c-2020_01_28_23_51_6665_2020_01_28_...
3,Pneumonia/Viral/COVID-19,auntminnie-d-2020_01_28_23_51_6665_2020_01_28_...
4,Pneumonia/Viral/COVID-19,nejmc2001573_f1a.jpeg


In [11]:
severity_scores = pd.read_csv('/content/covid-chestxray-dataset/annotations/covid-severity-scores.csv', header=5)

In [12]:
severity_scores

Unnamed: 0,filename,geographic_mean,opacity_mean
0,01E392EE-69F9-4E33-BFCE-E5C968654078.jpeg,6.0,4.0
1,03BF7561-A9BA-4C3C-B8A0-D3E585F73F3C.jpeg,2.7,2.0
2,1-s2.0-S0140673620303706-fx1_lrg.jpg,2.0,2.0
3,1-s2.0-S1684118220300608-main.pdf-001.jpg,3.7,3.0
4,1-s2.0-S1684118220300608-main.pdf-002.jpg,1.7,1.7
...,...,...,...
89,ryct.2020200028.fig1a.jpeg,2.0,1.7
90,ryct.2020200034.fig2.jpeg,1.3,1.3
91,ryct.2020200034.fig5-day0.jpeg,3.7,3.3
92,ryct.2020200034.fig5-day4.jpeg,5.7,4.0


In [13]:
bitmasks = []
geographic_mean = []
opacity_mean = []

for i, row in data.iterrows():
    bitmask = 0
    
    for val in [tag_index[extracted_tag] for extracted_tag in row['finding'].split('/') ]:
        bitmask = bitmask | (1 << val)

    if row['filename'] in severity_scores['filename']:
        geographic_mean.append(severity_scores['geographic_mean'])
        opacity_mean.append(severity_scores['opacity_mean'])
    else:
        geographic_mean.append(nan)
        opacity_mean.append(nan)

    bitmasks.append(bitmask)

data['bitmasks'] = bitmasks
data['geographic_mean'] = geographic_mean
data['opacity_mean'] = opacity_mean
data.head()

Unnamed: 0,finding,filename,bitmasks,geographic_mean,opacity_mean
0,Pneumonia/Viral/COVID-19,auntminnie-a-2020_01_28_23_51_6665_2020_01_28_...,1,,
1,Pneumonia/Viral/COVID-19,auntminnie-b-2020_01_28_23_51_6665_2020_01_28_...,1,,
2,Pneumonia/Viral/COVID-19,auntminnie-c-2020_01_28_23_51_6665_2020_01_28_...,1,,
3,Pneumonia/Viral/COVID-19,auntminnie-d-2020_01_28_23_51_6665_2020_01_28_...,1,,
4,Pneumonia/Viral/COVID-19,nejmc2001573_f1a.jpeg,1,,


In [14]:
print(data['bitmasks'].unique())

[1 3 4 5 7]


In [36]:
def read_image(filename):
    image_as_np = io.imread("./covid-chestxray-dataset/images/" + filename, as_gray=True)
    RESOLUTION = 64
    return resize(image_as_np, (RESOLUTION, RESOLUTION), anti_aliasing=True).flatten()

In [39]:
X = []
y = []

for i, val in data.iterrows():
    image = read_image(val['filename'])
    if image is None:
        continue
    X.append(image)
    y.append(val['bitmasks'])

In [41]:
X = np.array(X)
y = np.array(y)

In [44]:
# Test train split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, shuffle=True)
print(X_train.shape, X_test.shape, y_train.shape, y_test.shape)

(374, 4096) (161, 4096) (374,) (161,)


In [45]:
def get_accuracy_for_model(X_train, X_test, y_train, y_test, model):
    model.fit(X_train, y_train)
    y_pred = model.predict(X_test)
    return accuracy_score(y_test, y_pred)

In [46]:
print(get_accuracy_for_model(X_train, X_test, y_train, y_test, GaussianNB()))

0.36645962732919257


In [47]:
print(get_accuracy_for_model(X_train, X_test, y_train, y_test, ComplementNB()))

0.5962732919254659


In [48]:
print(get_accuracy_for_model(X_train, X_test, y_train, y_test, BernoulliNB()))

0.7639751552795031


In [49]:
print(get_accuracy_for_model(X_train, X_test, y_train, y_test, MultinomialNB()))

0.4161490683229814


In [50]:
print(get_accuracy_for_model(X_train, X_test, y_train, y_test, RandomForestClassifier()))

0.7950310559006211


In [51]:
print(get_accuracy_for_model(X_train, X_test, y_train, y_test, SVC()))

0.7950310559006211


In [52]:
print(get_accuracy_for_model(X_train, X_test, y_train, y_test, MLPClassifier(learning_rate='adaptive', max_iter=300)))

0.7639751552795031


In [57]:
roc_auc_score(y_train, GaussianNB().fit(X_train, y_train).predict_proba(X_train), multi_class='ovr')

0.8566929494357218

In [58]:
roc_auc_score(y_train, ComplementNB().fit(X_train, y_train).predict_proba(X_train), multi_class='ovr')

0.6502881246944636

In [59]:
roc_auc_score(y_train, BernoulliNB().fit(X_train, y_train).predict_proba(X_train), multi_class='ovr')

0.5428498123828766

In [60]:
roc_auc_score(y_train, MultinomialNB().fit(X_train, y_train).predict_proba(X_train), multi_class='ovr')

0.7800997298537133

In [61]:
roc_auc_score(y_train, RandomForestClassifier().fit(X_train, y_train).predict_proba(X_train), multi_class='ovr')

1.0

In [64]:
roc_auc_score(y_train, SVC(probability=True).fit(X_train, y_train).predict_proba(X_train), multi_class='ovr')

0.8010690084697991

In [65]:
roc_auc_score(y_train, MLPClassifier(learning_rate='adaptive', max_iter=300).fit(X_train, y_train).predict_proba(X_train), multi_class='ovr')

0.997174768789051