<a href="https://colab.research.google.com/github/HadarRosenwald/severity-detection/blob/main/TabularModel.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip -q install torchxrayvision
!pip -q install image_tabular

[K     |████████████████████████████████| 29.0 MB 13 kB/s 
[K     |████████████████████████████████| 2.0 MB 42.7 MB/s 
[?25h

In [85]:
# !rm -rf "/content/covid-chestxray-dataset"

In [2]:
import matplotlib.pyplot as plt
import os
import shutil
import torch
import torchxrayvision as xrv
import numpy as np
import pandas as pd
from torch import nn
from torch import optim
import torch.nn.functional as F
from torchvision import datasets, transforms, models
import fastai
from fastai.tabular.data import TabularList
import image_tabular as imtab
from sklearn.model_selection import train_test_split
import math

from fastai.vision import *
from fastai.tabular import *
from image_tabular.core import *
from image_tabular.dataset import *
from image_tabular.model import *
from image_tabular.metric import *

In [86]:
!git clone https://github.com/ieee8023/covid-chestxray-dataset
d = xrv.datasets.COVID19_Dataset(imgpath="covid-chestxray-dataset/images/",csvpath="covid-chestxray-dataset/metadata.csv")

Cloning into 'covid-chestxray-dataset'...
remote: Enumerating objects: 3641, done.[K
remote: Total 3641 (delta 0), reused 0 (delta 0), pack-reused 3641[K
Receiving objects: 100% (3641/3641), 632.96 MiB | 39.29 MiB/s, done.
Resolving deltas: 100% (1450/1450), done.
Checking out files: 100% (1174/1174), done.


# **Creating label**

In [87]:
def generate_label(x):
    # no data
    if np.all(pd.isna([x.survival, x.intubated, x.went_icu, x.needed_supplemental_O2])):
        return np.NaN

    # didn't survive
    if x.survival=='N':
        return '4'

    #either survived or survival is unknown
    if x.intubated == 'Y':
        return '3'
    if x.survival=='Y':
        if x.went_icu == 'Y' or x.needed_supplemental_O2 == 'Y':
            return '1'
    if x.went_icu == 'Y' or x.needed_supplemental_O2 == 'Y':
        return '2'
    return '0'

In [88]:
metadata = d.csv
metadata['severity_class']=metadata.apply(generate_label, axis=1)

# **Tabular Data handling**

In [89]:
filtered_metadata = metadata.loc[(metadata.view!="APS") & (metadata.offset>=0) & (metadata.offset<=8) & (metadata.intubation_present != 'Y') & (metadata.in_icu != 'Y')]
filtered_metadata = filtered_metadata[['index','patientid','sex','age','RT_PCR_positive','temperature','pO2_saturation', 'leukocyte_count', 'neutrophil_count', 'lymphocyte_count', 'severity_class' ,'filename']]

In [90]:
data_path = Path("./covid-chestxray-dataset/")

In [91]:
train_df, test_df = train_test_split(filtered_metadata, test_size=0.2)

train_df = train_df.dropna(subset=['severity_class'])
test_df = test_df.dropna(subset=['severity_class'])

train_df.head()

Unnamed: 0,index,patientid,sex,age,RT_PCR_positive,temperature,pO2_saturation,leukocyte_count,neutrophil_count,lymphocyte_count,severity_class,filename
265,434,229,F,,Unclear,,55.0,,,,0,2cd63b76.jpg
445,765,399,F,62.0,,,,,,,3,000001-4.jpg
323,547,291,F,61.0,Y,37.6,,,,,0,296_2020_4584_Fig2_HTML-a.png
268,438,233,M,,Unclear,,25.0,,2.2,0.4,1,441c9cdd.jpg
397,673,358,F,25.0,Y,,,,,,0,b0f1684d1ee90dc09deef015e29dae_jumbo.jpeg


In [92]:
val_idx = get_valid_index(train_df)
# Features with categorical values
cat_names = ['sex', 'RT_PCR_positive']
# Features with continious values
cont_names = ['age', 'temperature', 'pO2_saturation', 'leukocyte_count', 'neutrophil_count', 'lymphocyte_count']
# Target
dep_var = ['severity_class']
procs = [FillMissing, Categorify, Normalize]

In [93]:
tab_data = (TabularList.from_df(train_df, path=data_path, cat_names=cat_names, cont_names=cont_names, procs=procs).split_by_idx(val_idx).label_from_df(cols=dep_var))

#add test - currently returns an error
tab_data.add_test(TabularList.from_df(test_df, cat_names=cat_names, cont_names=cont_names, processor=tab_data.train.x.processor))

LabelLists;

Train: LabelList (82 items)
x: TabularList
sex F; RT_PCR_positive Unclear; age_na True; temperature_na True; pO2_saturation_na False; leukocyte_count_na True; neutrophil_count_na True; lymphocyte_count_na True; age 0.1223; temperature 0.0809; pO2_saturation -0.7301; leukocyte_count -0.0496; neutrophil_count -0.1245; lymphocyte_count -0.1266; ,sex F; RT_PCR_positive #na#; age_na False; temperature_na True; pO2_saturation_na True; leukocyte_count_na True; neutrophil_count_na True; lymphocyte_count_na True; age 0.6793; temperature 0.0809; pO2_saturation 0.0315; leukocyte_count -0.0496; neutrophil_count -0.1245; lymphocyte_count -0.1266; ,sex M; RT_PCR_positive Unclear; age_na True; temperature_na True; pO2_saturation_na False; leukocyte_count_na True; neutrophil_count_na False; lymphocyte_count_na False; age 0.1223; temperature 0.0809; pO2_saturation -2.6342; leukocyte_count -0.0496; neutrophil_count -0.9672; lymphocyte_count -0.8521; ,sex F; RT_PCR_positive Y; age_na False; 

In [94]:
# embedding sizes of categorical data
emb_szs = tab_data.train.get_emb_szs()
# output size, will be concatenated with the CNN, same output size
tab_out_sz = 5
# The tabular model
tabular_model = TabularModel(emb_szs, len(cont_names), out_sz=tab_out_sz, layers=[8], ps=0.2)
tabular_model

TabularModel(
  (embeds): ModuleList(
    (0): Embedding(3, 3)
    (1): Embedding(3, 3)
    (2): Embedding(3, 3)
    (3): Embedding(3, 3)
    (4): Embedding(3, 3)
    (5): Embedding(3, 3)
    (6): Embedding(3, 3)
    (7): Embedding(3, 3)
  )
  (emb_drop): Dropout(p=0.0, inplace=False)
  (bn_cont): BatchNorm1d(6, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layers): Sequential(
    (0): Linear(in_features=30, out_features=8, bias=True)
    (1): ReLU(inplace=True)
    (2): BatchNorm1d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): Dropout(p=0.2, inplace=False)
    (4): Linear(in_features=8, out_features=5, bias=True)
  )
)

# **Image Data handling**

In [95]:
def create_sub_image_folder(dataframe, imgs_type):
  filtered_imgpath = d.imgpath + '/' + imgs_type + '/'
  filtered_filenames = dataframe.filename + ';' + dataframe.severity_class
  if not os.path.exists(filtered_imgpath):
      os.mkdir(filtered_imgpath)
  # for severity_class in list(dataframe.severity_class):
  #       if not os.path.exists(filtered_imgpath + f'/{severity_class}'):
  #           os.mkdir(filtered_imgpath + f'/{severity_class}')
  for file_name_label in filtered_filenames:
      file_name, label = file_name_label.split(';')
      src = d.imgpath + file_name
      # dst = filtered_imgpath + label + '/' + file_name
      dst = filtered_imgpath + file_name
      if not os.path.exists(dst):
          shutil.copyfile(src, dst)

In [96]:
create_sub_image_folder(train_df, 'train')
create_sub_image_folder(test_df, 'test')
filtered_img_base_path = d.imgpath + '/'

In [171]:
train_transforms = transforms.Compose([transforms.Grayscale(num_output_channels=1),
                                       transforms.RandomRotation(30),
                                       transforms.RandomResizedCrop(224),
                                       transforms.RandomHorizontalFlip(),
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.5), (0.5))])


test_transforms = transforms.Compose([transforms.Grayscale(num_output_channels=1),
                                     transforms.Resize(255),
                                     transforms.CenterCrop(224),
                                     transforms.ToTensor()])

tfms = get_transforms()

image_data = (ImageList.from_df(train_df, path=d.imgpath, convert_mode='LA', cols='filename', folder="train")
    .split_by_idx(val_idx).label_from_df(cols=dep_var))

image_data = image_data.transform(tfms=tfms, size=224, resize_method = ResizeMethod.SQUISH)
# train_data = datasets.ImageFolder(filtered_img_base_path + '/train', transform=train_transforms)
# test_data = datasets.ImageFolder(filtered_img_base_path + '/test', transform=test_transforms)

# trainloader = torch.utils.data.DataLoader(train_data, batch_size=64, shuffle=True)
# testloader = torch.utils.data.DataLoader(test_data, batch_size=64)

In [16]:
model = xrv.models.DenseNet(weights="densenet121-res224-all")
model.classifier

Downloading weights...
If this fails you can run `wget https://github.com/mlmed/torchxrayvision/releases/download/v1/nih-pc-chex-mimic_ch-google-openi-kaggle-densenet121-d121-tw-lr001-rot45-tr15-sc15-seed0-best.pt -O /root/.torchxrayvision/models_data/nih-pc-chex-mimic_ch-google-openi-kaggle-densenet121-d121-tw-lr001-rot45-tr15-sc15-seed0-best.pt`
[██████████████████████████████████████████████████]


Linear(in_features=1024, out_features=18, bias=True)

In [17]:
# Don't backprop model parameters!
for param in model.parameters():
    param.requires_grad = False
    
# New classifier
model.classifier = nn.Sequential(nn.Linear(1024, 512),
                          nn.ReLU(),
                          nn.Dropout(0.2),
                          nn.Linear(512,256),
                          nn.ReLU(),
                          nn.Dropout(0.2),
                          nn.Linear(256,5))

#criterion = nn.NLLLoss()

# Training only the classifier parameters, model parameters remains unchanged
optimizer = optim.RMSprop(model.classifier.parameters(), lr=0.004)

# **Integrated Model**

In [46]:
img_tabular_model = CNNTabularModel(model,
                                  tabular_model,
                                  layers = [model.classifier[-1].out_features + tab_out_sz, 32],
                                  ps=0.2,
                                  out_sz=5)

In [None]:
img_tabular_model

fastai.basic_data.DataBunch