<a href="https://colab.research.google.com/github/HadarRosenwald/severity-detection/blob/main/TabularModel.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
!pip -q install torchxrayvision
!pip -q install image_tabular

[K     |████████████████████████████████| 29.0 MB 13 kB/s 
[K     |████████████████████████████████| 2.0 MB 51.9 MB/s 
[?25h

In [15]:
import matplotlib.pyplot as plt
import os
import shutil
import torch
import torchxrayvision as xrv
import numpy as np
import pandas as pd
from torch import nn
from torch import optim
import torch.nn.functional as F
from torchvision import datasets, transforms, models
import fastai
from fastai.tabular.data import TabularList
import image_tabular as imtab
from sklearn.model_selection import train_test_split
import math

from fastai.vision import *
from fastai.tabular import *
from image_tabular.core import *
from image_tabular.dataset import *
from image_tabular.model import *
from image_tabular.metric import *

In [4]:
!git clone https://github.com/ieee8023/covid-chestxray-dataset

Cloning into 'covid-chestxray-dataset'...
remote: Enumerating objects: 3641, done.[K
remote: Total 3641 (delta 0), reused 0 (delta 0), pack-reused 3641[K
Receiving objects: 100% (3641/3641), 632.96 MiB | 35.70 MiB/s, done.
Resolving deltas: 100% (1450/1450), done.
Checking out files: 100% (1174/1174), done.


In [8]:
d = xrv.datasets.COVID19_Dataset(imgpath="covid-chestxray-dataset/images/",csvpath="covid-chestxray-dataset/metadata.csv")
metadata = d.csv
filtered_metadata = metadata.loc[(metadata.view!="APS") & (metadata.offset>=0) & (metadata.offset<=8) & (metadata.intubation_present != 'Y') & (metadata.in_icu != 'Y')]
filtered_metadata = filtered_metadata[['index','patientid','sex','age','RT_PCR_positive','temperature','pO2_saturation', 'leukocyte_count', 'neutrophil_count', 'lymphocyte_count', 'survival', 'intubated', 'went_icu', 'needed_supplemental_O2','filename']]

In [47]:
data_path = Path("./covid-chestxray-dataset/")

In [73]:
train_df, test_df = train_test_split(filtered_metadata, test_size=0.2)
train_df = train_df.dropna(subset=['survival', 'intubated', 'went_icu', 'needed_supplemental_O2'], how='any')
test_df = test_df.dropna(subset=['survival', 'intubated', 'went_icu', 'needed_supplemental_O2'], how='any')
train_df.head()

Unnamed: 0,index,patientid,sex,age,RT_PCR_positive,temperature,pO2_saturation,leukocyte_count,neutrophil_count,lymphocyte_count,survival,intubated,went_icu,needed_supplemental_O2,filename
368,623,331a,F,39.0,Y,36.8,,,,,Y,N,N,N,41182_2020_203_Fig3_HTML.jpg
3,3,2,M,65.0,Y,,,,,,Y,N,N,Y,auntminnie-d-2020_01_28_23_51_6665_2020_01_28_...
113,179,95,F,70.0,Y,,,,,,Y,N,N,N,58cb9263f16e94305c730685358e4e_jumbo.jpeg
345,592,315,F,78.0,Y,37.7,95.0,,,,Y,N,N,N,1-s2.0-S2214250920300834-gr1_lrg-b.png
2,2,2,M,65.0,Y,,,,,,Y,N,N,Y,auntminnie-c-2020_01_28_23_51_6665_2020_01_28_...


In [74]:
val_idx = get_valid_index(train_df)
# Features with categorical values
cat_names = ['sex', 'RT_PCR_positive']
# Features with continious values
cont_names = ['age', 'temperature', 'pO2_saturation', 'leukocyte_count', 'neutrophil_count', 'lymphocyte_count']
# Target
dep_var = ['survival', 'intubated', 'went_icu', 'needed_supplemental_O2']
procs = [FillMissing, Categorify, Normalize]

In [85]:
tab_data = (TabularList.from_df(train_df, path=data_path, cat_names=cat_names, cont_names=cont_names, procs=procs).split_by_idx(val_idx).label_from_df(cols=dep_var))

#add test - currently returns an error
#tab_data.add_test(TabularList.from_df(test_df, cat_names=cat_names, cont_names=cont_names, processor=tab_data.train.x.processor))

In [88]:
# embedding sizes of categorical data
emb_szs = tab_data.train.get_emb_szs()
# output size, will be concatenated with the CNN, same output size
tab_out_sz = 18
# The tabular model
tabular_model = TabularModel(emb_szs, len(cont_names), out_sz=tab_out_sz, layers=[18], ps=0.2)
tabular_model

TabularModel(
  (embeds): ModuleList(
    (0): Embedding(3, 3)
    (1): Embedding(2, 2)
    (2): Embedding(3, 3)
    (3): Embedding(3, 3)
    (4): Embedding(2, 2)
    (5): Embedding(2, 2)
    (6): Embedding(2, 2)
  )
  (emb_drop): Dropout(p=0.0, inplace=False)
  (bn_cont): BatchNorm1d(6, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layers): Sequential(
    (0): Linear(in_features=23, out_features=18, bias=True)
    (1): ReLU(inplace=True)
    (2): BatchNorm1d(18, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): Dropout(p=0.2, inplace=False)
    (4): Linear(in_features=18, out_features=18, bias=True)
  )
)