# Detecting COVID-19 with Chest X Ray using PyTorch

Image classification of Chest X Rays in one of three classes: Normal, Viral Pneumonia, COVID-19

Notebook created for the guided project [Detecting COVID-19 with Chest X Ray using PyTorch](https://www.coursera.org/projects/covid-19-detection-x-ray) on Coursera

Dataset from [COVID-19 Radiography Dataset](https://www.kaggle.com/tawsifurrahman/covid19-radiography-database) on Kaggle

In [24]:
%matplotlib inline

import os
import shutil
import random
import torch
import torchvision
import numpy as np
from torch.utils.data import Dataset, DataLoader

from PIL import Image
from matplotlib import pyplot as plt

torch.manual_seed(0)

print('Using PyTorch version', torch.__version__)

Using PyTorch version 1.13.1


# Preparing Training and Test Sets.

## Cleaning the dataset here

In [22]:
!python clean_dataset.py
print(os.listdir('Clean_ChestXRay_Dataset'))


✅ Copied 3616 images to Clean_ChestXRay_Dataset/covid
✅ Copied 1345 images to Clean_ChestXRay_Dataset/pneumonia
✅ Copied 10192 images to Clean_ChestXRay_Dataset/normal

🎉 Clean dataset created successfully!
['pneumonia', 'normal', 'covid']


# SPLITTING DATA Training and Test

In [23]:
# Run the split_dataset.py script
!python split_dataset.py

✅ pneumonia: 1076 train, 269 test
✅ normal: 8153 train, 2039 test
✅ covid: 2892 train, 724 test
🎉 Dataset successfully split into training and test sets!


# Creating Custom Dataset

In [None]:
class ChestXrayDataset(Dataset):
    def __init__(self, image_dirs, transform=None):
        def get_images(class_name):
            images = [x for x in os.listdir(image_dirs[class_name]) if x.lower().endswith('.png')]
            print(f'Found {len(images)} images for class: {class_name}')
            return images

        self.images = {}
        self.class_names = ['normal', 'pneumonia', 'covid']  # consistent naming

        # Gather all images
        self.all_images = []
        for c in self.class_names:
            self.images[c] = get_images(c)
            for img in self.images[c]:
                self.all_images.append((c, img))

        self.image_dirs = image_dirs
        self.transform = transform

    def __len__(self):
        return len(self.all_images)

    def __getitem__(self, index):
        class_name, image_name = self.all_images[index]
        image_path = os.path.join(self.image_dirs[class_name], image_name)
        image = Image.open(image_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        label = self.class_names.index(class_name)
        return image, label