# COVID-19 Chest X-Ray Database - Experiment

In [None]:
from torchvision import transforms

## CNN Model Implementation

In [None]:
%reload_ext autoreload
%autoreload 2

from src.cnn import CNN_Model, load_dataset

In [None]:
NUMBER_OF_CLASSES = 4
IMAGE_SIZE = 299
NUMBER_OF_EPOCHS = 20
class_names = ["COVID", "Lung_Opacity", "Normal", "Viral Pneumonia"]

# Define the transformations
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),
    transforms.Resize((248, 248)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

### Raw Images

In [None]:
train_loader_raw, val_loader_raw, test_loader_raw = load_dataset("./data/raw", transform=transform)

In [None]:
cnn_raw_model = CNN_Model(NUMBER_OF_CLASSES, class_names=class_names, project_name="covid-cnn-raw")

In [None]:
cnn_raw_model.train(train_loader=train_loader_raw, val_loader=val_loader_raw, epochs=NUMBER_OF_EPOCHS)

In [None]:
cnn_raw_model.test(test_loader_raw)

### Bilateral Filtered Images

In [None]:
train_loader_filt, val_loader_filt, test_loader_filt = CNN_Model.load_dataset("./data/bf", transform=transform)

In [None]:
cnn_filtered_model = CNN_Model(NUMBER_OF_CLASSES, class_names=class_names, project_name="covid-cnn-filtered")

In [None]:
cnn_filtered_model.train(train_loader=train_loader_filt, val_loader=val_loader_filt, epochs=NUMBER_OF_EPOCHS)

In [None]:
cnn_filtered_model.test(test_loader_filt)

In [None]:

cnn_raw_model.save_model('./models/cnn_raw_model.pth')
cnn_filtered_model.save_model('./models/cnn-filtered.pth')

## Multilayer Perceptron