<center>
    <h3>
        <a href="https://www.wadhwaniai.org/work/cough-against-covid/">Cough Against COVID-19</a>
    </h3>
    This is a demo notebook to test COVID-19 detector model on a sample contextual data
</center>

### Instructions for use:
1. Manually fill in values in the load contextual data step

> **Important Note**: Please understand that these predictions are from a model that has not gone through clinical trials and thus, please treat these predictions only as a demo and consult your medical professionals for clinical advice. 

In [None]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
from os.path import join, splitext, exists
from subprocess import call
from IPython.display import Markdown as md
import torch

from cac.config import Config
from cac.utils.logger import set_logger, color
from cac.models import factory as model_factory
from utils import _preprocess_raw_context_data, check_data_correct_format
import ipywidgets as widgets

##### Load trained model

In [3]:
version = "iclrw/context/v9.7/context-neural.yml"
config_path = join("experiments", version)

ckpt = 31
ckpt_path = join(splitext(version)[0], f'checkpoints/{ckpt}_ckpt.pth.tar')

# please configure threshold based on model performance on validation set
# for the given model, we observed best threshold to be as follows in order
# to maximize precision at a recall of 90%
threshold = 0.3862

In [4]:
# # check if model exists within the repository
assert exists(join("../assets/models", ckpt_path))

In [5]:
# # check if config file exists within the repository
assert exists(join('../configs/', config_path))

In [6]:
!python ../training/copy_model_ckpts.py -p $ckpt_path

[33mCopying from /workspace/cough-against-covid/assets/models/iclrw/context/v9.7/context-neural/checkpoints/31_ckpt.pth.tar to /output//iclrw/context/v9.7/context-neural/checkpoints/31_ckpt.pth.tar[0m
sending incremental file list

sent 54 bytes  received 12 bytes  132.00 bytes/sec
total size is 433,841  speedup is 6,573.35


##### Load and set config parameters appropriately

In [7]:
config = Config(config_path)

In [8]:
# set logging
set_logger(join(config.log_dir, 'demo_inference.log'))

In [9]:
# add info about the model
config.model['load']['epoch'] = ckpt
config.model['load']['load_best'] = False
config.model['load']['version'] = config_path.replace(".yml", "")

In [10]:
# set inference directories for logging
dirpaths = ['config_save_path', 'output_dir', 'log_dir', 'checkpoint_dir']
for key in dirpaths:
    train_version = splitext(version)[0]
    infer_version = train_version + "_demo_inference"

    dirpath = getattr(config, key).replace(train_version, infer_version)
    os.makedirs(dirpath, exist_ok=True)
    setattr(config, key + '_demo_inference', dirpath)

In [11]:
model = model_factory.create(config.model['name'], **{'config': config})

[33mBuilding the network[0m
[33mSetting up the optimizer ...[0m
[33m=> Loading model weights from /output/experiments/iclrw/context/v9.7/context-neural/checkpoints/31_ckpt.pth.tar[0m
[33mFreezing specified layers[0m
[33mUsing loss functions:[0m
{'train': {'name': 'cross-entropy', 'params': {'reduction': 'none'}}, 'val': {'name': 'cross-entropy', 'params': {'reduction': 'none'}}}


In [12]:
model.network = model.network.eval()

#### Load and process contextual data

##### Insert Values Here 

In [13]:
input_ = {
# Choose Age between [1, 100]
'enroll_patient_age' : 65,

# Choose Temperature between [95, 103]
'enroll_patient_temperature' : 98.,

# Choose Days with Symptom (Cough, Voice & Shortness of Breath), Values usually range between [0., 30.]
'enroll_days_with_cough' : 10 ,
'enroll_days_with_shortness_of_breath' : 0,
'enroll_days_with_fever' : 0 ,

# Choose Travel history, Four Options either {0 : 'No', 1 : 'Other country', 2 : 'Other district', 3 : 'Other state'}
'enroll_travel_history' : 0.,   

# Binary either {0 : No, 1 : Yes}
'enroll_contact_with_confirmed_covid_case' : 0., # Have you come in contact with Covid confirmed Case
'enroll_health_worker' : 0., # Are you a Health Worker
'enroll_fever' : 0., # Do you have Fever
'enroll_cough' : 1., # Do you have Cough 
'enroll_shortness_of_breath' : 0. # Do you have Shortness of Breath
}
# Assert Statements to Check if Data Inserted is Correct
check_data_correct_format(input_)

#### Preprocessing (Normalizing continuous values)

In [14]:
input_processed = _preprocess_raw_context_data(input_)

In [15]:
x = [input_processed['enroll_patient_age'],
    input_processed['enroll_patient_temperature'],
    input_processed['enroll_days_with_cough'],
    input_processed['enroll_days_with_shortness_of_breath'],
    input_processed['enroll_days_with_fever'],
    input_processed['enroll_travel_history'],
    input_processed['enroll_contact_with_confirmed_covid_case'],
    input_processed['enroll_health_worker'],
    input_processed['enroll_fever'],
    input_processed['enroll_cough'],
    input_processed['enroll_shortness_of_breath']]

In [16]:
batch = torch.cat([torch.tensor(x).float()]).unsqueeze(0)

In [17]:
batch = batch.to(model.device)

In [18]:
batch.shape

torch.Size([1, 11])

##### Forward pass through the model to obtain prediction

In [19]:
predictions = model.network(batch)
predicted_proba = torch.nn.functional.softmax(predictions, dim=1)[:, 1]
final_predicted_label = predicted_proba >= threshold

In [20]:
output_string = "<center>Based on the model prediction, <b>COVID-19 is NOT detected given the contextual information</b>. </center>"
if final_predicted_label:
    output_string = "<center>Based on the model prediction, <b>COVID-19 is detected given the contextual information</b>. </center>"

In [21]:
md(output_string)

<center>Based on the model prediction, <b>COVID-19 is NOT detected given the contextual information</b>. </center>

> **Important Note**: Please understand that these predictions are from a model that has not gone through clinical trials and thus, please treat these predictions only as a demo and consult your medical professionals for clinical advice. 