# COVID-19 Chest X-Ray Database - CNN

In [51]:
from torchvision import transforms

In [52]:
%reload_ext autoreload
%autoreload 2

from src.cnn import CNN_Model, load_dataset

In [57]:
NUMBER_OF_CLASSES = 4
NUMBER_OF_EPOCHS = 20
BATCH_SIZE = 32
PATIENCE = 5
TARGET_VAL_LOSS = 0.1
IMAGE_SIZE = 224

PROJECT_NAME = "covid19-ChestXRay"
CLASS_NAMES = ["COVID", "Lung_Opacity", "Normal", "Viral Pneumonia"]

# Define the transformations
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),
    transforms.Resize(IMAGE_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

## Raw Images

In [54]:
train_loader_raw, val_loader_raw, test_loader_raw = load_dataset("./data/raw", transform=transform, batch_size=64)

In [55]:
cnn_raw_model = CNN_Model(NUMBER_OF_CLASSES, class_names=CLASS_NAMES, project_name=PROJECT_NAME)

In [58]:
cnn_raw_model.train(train_loader=train_loader_raw, val_loader=val_loader_raw, epochs=NUMBER_OF_EPOCHS, patience=PATIENCE, target_val_loss=TARGET_VAL_LOSS)

VBox(children=(Label(value='0.003 MB of 0.003 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011113723355487713, max=1.0…

Early stopping: validation loss target of <= 0.150 has been reached: 0.128, stopped at 1 epochs.


VBox(children=(Label(value='0.005 MB of 0.006 MB uploaded\r'), FloatProgress(value=0.9012504007694774, max=1.0…

0,1
accuracy,▁█
epoch,▁█
precision,▁█
recall,▁█
train_loss,█▁
val_loss,█▁

0,1
accuracy,0.95699
epoch,1.0
precision,0.95699
recall,0.95699
train_loss,0.11081
val_loss,0.12843


In [59]:
cnn_raw_model.save_model('./models/cnn_raw_model.pth')

In [60]:
cnn_raw_model.test(test_loader_raw)

VBox(children=(Label(value='0.038 MB of 0.038 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

## Bilateral Filtered Images

In [61]:
train_loader_filt, val_loader_filt, test_loader_filt = load_dataset("./data/bf", transform=transform, batch_size=BATCH_SIZE)

In [62]:
cnn_filtered_model = CNN_Model(NUMBER_OF_CLASSES, class_names=CLASS_NAMES, project_name=PROJECT_NAME, data_preprocss='bilateral-filtering')

In [63]:
cnn_filtered_model.train(train_loader=train_loader_filt, val_loader=val_loader_filt, epochs=NUMBER_OF_EPOCHS, patience=PATIENCE, target_val_loss=TARGET_VAL_LOSS)

Early stopping: validation loss has not improved in 5 epochs, stopped at 17 epochs.


VBox(children=(Label(value='0.031 MB of 0.031 MB uploaded\r'), FloatProgress(value=0.9813194465469739, max=1.0…

0,1
accuracy,▁█████████████████
epoch,▁▁▂▂▃▃▃▄▄▅▅▆▆▆▇▇██
precision,▁█████████████████
recall,▁█████████████████
train_loss,█▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁
val_loss,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
accuracy,0.93195
epoch,17.0
precision,0.93195
recall,0.93195
train_loss,0.1218
val_loss,0.1916


In [64]:
cnn_filtered_model.save_model('./models/cnn-filtered.pth')

In [65]:
cnn_filtered_model.test(test_loader_filt)

VBox(children=(Label(value='0.040 MB of 0.041 MB uploaded\r'), FloatProgress(value=0.9855571967831939, max=1.0…