# COVID 19 Detection from chest X-Rays using a CNN Classifier

The COVID-19 pandemic has been, over the last 3 years, one of the biggest challenges for healthcare systems worldwide.

In this notebook, we will use X-ray data of lungs from both normal and COVID-positive patients and train a deep learning model to differentiate between them.

## Introduction - Dataset and models

The dataset used in this project is the Winner of the COVID-19 Dataset Award by Kaggle Community. The dataset was collected by researchers from Qatar and Bangladesh. It can be found at: https://www.kaggle.com/datasets/tawsifurrahman/covid19-radiography-database

References:
-M.E.H. Chowdhury, T. Rahman, A. Khandakar, R. Mazhar, M.A. Kadir, Z.B. Mahbub, K.R. Islam, M.S. Khan, A. Iqbal, N. Al-Emadi, M.B.I. Reaz, M. T. Islam, “Can AI help in screening Viral and COVID-19 pneumonia?” IEEE Access, Vol. 8, 2020, pp. 132665 - 132676.[Paper link](https://ieeexplore.ieee.org/document/9144185)
-Rahman, T., Khandakar, A., Qiblawey, Y., Tahir, A., Kiranyaz, S., Kashem, S.B.A., Islam, M.T., Maadeed, S.A., Zughaier, S.M., Khan, M.S. and Chowdhury, M.E., 2020. Exploring the Effect of Image Enhancement Techniques on COVID-19 Detection using Chest X-ray Images. [Paper Link](https://doi.org/10.1016/j.compbiomed.2021.104319)

This dataset contains a total of 21,215 images of 4 types:

1. COVID-19 positive (3,616 images)
1. Viral Pneumonia (1,395 images)
1. Normal X-ray (10,192 images)
1. Lung Opacity (6,012 images)

We will only consider the first three types, and therefore we'll have to classify among these 3 different classes; we'll use a softmax layer for classification. These images have a size (1024, 1024) and 3 color channels. 

The authors of the dataset also trained a ResNet-34 model and achieved a classification accuracy of 98.5%. In this notebook we'll use the **Xception** model. Xception is a deep convolutional neural network (CNN) architecture that involves [Depthwise Separable Convolutions](https://paperswithcode.com/method/depthwise-separable-convolution). This network was introduced in the paper by Francois Chollet, ["Xception: Deep Learning With Depthwise Separable Convolutions"](https://openaccess.thecvf.com/content_cvpr_2017/html/Chollet_Xception_Deep_Learning_CVPR_2017_paper.html); Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2017, pp. 1251-1258 . Xception is also known as “extreme” version of an Inception module. This model obtained an ImageNet top-1 accuracy of 79% and a top-5 accuracy of 95%.

## Libraries

We will use and fine-tune a pre-trained version of Xception from the fantastic `timm` library ([link](https://github.com/huggingface/pytorch-image-models)). PyTorch Image Models (`timm`), a deep-learning library created by Ross Wightman, is a collection of SOTA computer vision models, layers, utilities, optimizers, schedulers, data-loaders, augmentations and also training/validating scripts with ability to reproduce ImageNet training results. We will also use `fastai` ([link](https://docs.fast.ai/)), a deep learning library that builds on [*PyTorch*](https://pytorch.org/) and that helps simplify training neural networks by providing high-level components for standard deep learning tasks, `torchtnt`, a library for PyTorch training tools and utilities [(link)](https://pytorch.org/tnt/stable/), plus the usual libraries in the PyData stack: `numpy`, `pandas`, `sklearn` and `matplotlib`.

### About fast.ai
`fastai` is a deep learning library which provides practitioners with high-level components that can quickly and easily provide state-of-the-art results in standard deep learning domains, and provides researchers with low-level components that can be mixed and matched to build new approaches. It aims to do both things without substantial compromises in ease of use, flexibility, or performance. This is possible thanks to a carefully layered architecture, which expresses common underlying patterns of many deep learning and data processing techniques in terms of decoupled abstractions. These abstractions can be expressed concisely and clearly by leveraging the dynamism of the underlying Python language and the flexibility of the PyTorch library.

In [None]:
%pip install timm fastai torchtnt

### Import libraries

Now, we will import the necessary Python packages into our Jupyter Notebook. Here’s a brief overview of how we’ll use these packages:

1. Python Standard Library dependencies: These are built-in modules that come with Python. We’ll use them for various tasks like handling file paths ([`pathlib.Path`](https://docs.python.org/3/library/pathlib.html#pathlib.Path)), manipulating JSON files ([`json`](https://docs.python.org/3/library/json.html)), random number generation ([`random`](https://docs.python.org/3/library/random.html)), mathematical operations ([`math`](https://docs.python.org/3/library/math.html)), copying Python objects ([`copy`](https://docs.python.org/3/library/copy.html)), and working with dates and times ([`datetime`](https://docs.python.org/3/library/datetime.html)).
1. Utility functions: These are helper functions from the packages we installed earlier. They provide shortcuts for routine tasks and keep our code clean and readable.
1. matplotlib: We use the matplotlib package to explore the dataset samples and class distribution.
1. NumPy: We’ll use it to store PIL Images as arrays of pixel values.
1. pandas: We use Pandas `DataFrame` and `Series` objects to format data as tables.
1. PIL (Pillow): We’ll use it for opening and working with image files.
1. timm library: We’ll use the timm library to download and prepare a pre-trained Xception model for fine-tuning.

In [2]:
# Import Python Standard Library dependencies
from copy import copy
import datetime
import json
import math
from pathlib import Path
import random
from itertools import chain

In [3]:
# Import matplotlib for creating plots
import matplotlib.pyplot as plt

# Import numpy 
import numpy as np

# Import pandas module for data manipulation
import pandas as pd

# Set options for Pandas DataFrame display
pd.set_option('max_colwidth', None)  # Do not truncate the contents of cells in the DataFrame
pd.set_option('display.max_rows', None)  # Display all rows in the DataFrame
pd.set_option('display.max_columns', None)  # Display all columns in the DataFrame

# Import PIL for image manipulation
from PIL import Image

# Import timm library
import timm

# Import PyTorch dependencies
import torch
import torchvision
#torchvision.disable_beta_transforms_warning()
import torchvision.transforms as transforms
from torchvision.transforms import functional as TF
#from torchtnt.utils import get_module_summary

# Import fastai computer vision functionality 
from fastai.vision.all import *

### The timm library

One of the most popular features of `timm` is its large, and ever-growing collection of model architectures. Many of these models contain pretrained weights — either trained natively in PyTorch, or ported from other libraries such as Jax and TensorFlow — which can be easily downloaded and used.

We can list, and query, the collection available models as demonstrated below:

In [4]:
len(timm.list_models('*'))

1007

We can also use the `pretrained` argument to filter this selection to the models with pretrained weights:

In [5]:
len(timm.list_models(pretrained=True))

1260

## Setup project

In this section, we set up some basics for our project.

### Set random seed

First, we set a seed for generating random numbers using the set_seed function included with the fastai library.
A fixed seed value is helpful when training deep-learning models for reproducibility, debugging, and comparison.
Having reproducible results allows others to confirm your findings. Using a fixed seed can make it easier to find bugs as it ensures the same inputs produce the same outputs. Likewise, using fixed seed values lets you compare performance between models and training parameters.
That said, it’s often a good idea to test different seed values to see how your model’s performance varies between them. Also, don’t use a fixed seed value when you deploy the final model.

In [6]:
# Set the seed for generating random numbers in PyTorch, NumPy, and Python's random module.
seed = 1234
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
torch.use_deterministic_algorithms(False)

### Set the PyTorch Device and Data Type
Next, we determine the device to run our computations on and the data type of our tensors using fastai’s [`default_device()`](https://docs.fast.ai/torch_core.html#default_device) function.

In [7]:
device = default_device()
dtype = torch.float32
device, dtype

(device(type='cuda', index=0), torch.float32)

### Set directory paths

We then need to set up a directory for our project to store our results and other related files. The following code creates the folder in the current directory (./). Update the path if that is not suitable for you.

We also need a place to store our datasets and a location to download the zip file containing the dataset. Readers following the tutorial on their local machine should select locations with read and write access to store archive files and datasets. For a cloud service like Google Colab, you can set it to the current directory.

In [8]:
# A name for the project
project_name = f"covid19-classifier"

# The path for the project folder
project_dir = Path(f"./{project_name}/")

# Create the project directory if it does not already exist
project_dir.mkdir(parents=True, exist_ok=True)

# Define path to store datasets
dataset_dir = Path("./Datasets/")
# Create the dataset directory if it does not exist
dataset_dir.mkdir(parents=True, exist_ok=True)

# Define path to store archive files
archive_dir = project_dir/'Archive/'
# Create the archive directory if it does not exist
archive_dir.mkdir(parents=True, exist_ok=True)

pd.Series({
    "Project Directory:": project_dir, 
    "Dataset Directory:": dataset_dir, 
    "Archive Directory:": archive_dir
}).to_frame().style.hide(axis='columns')

0,1
Project Directory:,covid19-classifier
Dataset Directory:,Datasets
Archive Directory:,covid19-classifier/Archive


Double-check the project and dataset directories exist in the specified paths and that you can add files to them before continuing.

At this point, our environment is set up and ready to go. We’ve set our random seed, determined our computation device, and set up directories for our project and dataset. In the next section, we will download and explore the dataset.

## Download the dataset

The following steps demonstrate how to download the dataset from Kaggle, inspect the dataset, and visualize some sample images. 
To download the Kaggle dataset to the local jupyter environment we will use the [`opendatasets`](https://pypi.org/project/opendatasets/) library, so before starting, we need to have the `opendatasets` library installed in our system. If it's not present in your system, use Python’s package manager pip and run:

In [None]:
%pip install opendatasets

The process is as follows (**for your convenience I've already downloaded the Dataset, it's in the _Dataset/COVID19-Radiography-Dataset_ folder, so there's no need to execute the following cells, they are just to show how to use the opendatasets library.**)

1. Import the opendatasets library

In [None]:
import opendatasets as od

2. Now we use the `download()` function of the `opendatasets` library, which as the name suggests, is used to download the dataset. It takes the link to the dataset as an argument.

In [None]:
url = "https://www.kaggle.com/datasets/tawsifurrahman/covid19-radiography-database"
data_dir= "Datasets"
od.download(url, data_dir=data_dir)

3. On executing the above line, it will prompt for your Kaggle username. Your Kaggle username can be fetched from the Account tab of the My Profile section.

4. On entering your username, it will prompt for Kaggle Key. Again, go to the account tab of the My Profile section and click on Create New API Token. This will download a kaggle.json file.

5. On opening this file, you will find the username and key in it. Copy the key and paste it into the prompted Jupyter Notebook cell. The content of the downloaded file would look like this:

`{"username":<KAGGLE USERNAME>,"key":"<KAGGLE KEY>"}`

6. Do NOT store the API Token (kaggle.json file) in your GitHub repository.

7. A progress bar will show if the dataset is downloaded completely or not.

8. After successful completion of the download, a folder will be created in the Datasets directory of your Jupyter Notebook. This folder contains our dataset.

9. Since we will not be using any of the metadata info nor the Lung Opacity images, delete them. This will also make loading the dataset much easier (no need to filter out sub-directories when building the training and evaluation sets).

### Get dataset information

In [9]:
# Set the dataset path
# To exclude the 'masks' subdirectories, we will create a new path that is the result of subtracting the excluded subdirectories 
# from the root Dataset directory. 
# The `rglob()` method is used to find all files and directories that match the specified glob pattern, 
# which in this case includes the exclude patterns. 
# The `-` operator is then used to subtract the excluded paths from the root directory, 
# resulting in a new path that does not include the excluded subdirectories.

# Define the root directory and the files and subdirectories to exclude
dataset_root_dir = Path(dataset_dir)

# Define function that checks whether paths to files and subdirectories should be included in the dataset
# This is defined by passing as argument the intended exclusion's pattern and type 
# (file extension or subdirectory name)
def include_path(p, pattern, excl_type):
    if excl_type == "dir":
        out = [path for path in p.rglob("*") if pattern not in path.stem]
    elif excl_type == "ext":
        out = [path for path in p.rglob("*") if pattern not in path.suffix]
    return(out)

exclusions = [("masks", "dir"),("*.xlsx","ext"),("*.txt", "ext")]

# Run function on the dataset root path with the intended exclusions
dataset_path = list(chain(*[include_path(dataset_root_dir, *excl) for excl in exclusions]))

In [None]:
# Get the file paths for each image in the dataset
img_paths = get_image_files(dataset_path)

# Get the number of samples for each image class
class_counts = Counter(path.parent.name for path in img_paths)

# Get the class names
class_names = list(class_counts.keys())

# Print the number of samples for each image class
class_counts_df = pd.DataFrame.from_dict(class_counts, orient='index', columns=['Count'])
class_counts_df

### Visualizing Sample Images

Lastly, we will visualize a sample image from each class in our dataset. Visualizing the samples helps us get a feel for the kind of images we’re working with and whether they’re suitable for the task at hand.

In [None]:
# Get paths for the first sample in the dataset for each class
sample_paths = [next((dataset_path/class_name).iterdir()) for class_name in class_names]
sample_paths.sort()

# Calculate the number of rows and columns
grid_size = math.floor(math.sqrt(len(sample_paths)))
n_rows = grid_size+(1 if grid_size**2 < len(sample_paths) else 0)
n_cols = grid_size

# Create a list to store the first image found for each class
images = [Image.open(path) for path in sample_paths]
labels = [path.parent.name for path in sample_paths]

# Create a figure for the grid
fig, axs = plt.subplots(n_rows, n_cols, figsize=(10,10))

for i, ax in enumerate(axs.flatten()):
    # If we have an image for this subplot
    if i < len(images) and images[i]:
        # Add the image to the subplot
        ax.imshow(np.array(images[i]))
        # Set the title to the corresponding class name
        ax.set_title(labels[i])
        # Remove the axis
        ax.axis('off')
    else:
        # If no image, hide the subplot
        ax.axis('off')

# Display the grid
plt.tight_layout()
plt.show()

The original Dataset is huge (over 20,000 images), so it would take a long time to train the model with the full dataset; for illustration purposes, we will do data augmentation, and therefore, we will take a random sample of 20% of images from each category. It also contains "lung mask" images (in the *masks* subdirectories of each of the directories for the 4 categories of images. We won't be using these, so we'll have to build a fastai `DataBlock` that skips them. First, let's define some parameters:

1. we'll resize (if needed) the size of images to (299,299) (the maximum size accepted by the Xception model), 
1. we'll split the dataset into batches of size of 32,
1. we'll use a traing/validation split of 80/20.

In [None]:
IMG_HEIGHT = 299
IMG_WIDTH = 299
NUM_CHANNELS = 3