# main

In [1]:
import os
import shutil
import random
import torch
import torchvision
import numpy as np

from PIL import Image
from matplotlib import pyplot as plt

torch.manual_seed(0)

print('Using PyTorch version', torch.__version__)

Using PyTorch version 2.5.1


In [2]:
class_names = ['normal', 'viral', 'covid']
root_dir = 'COVID-19_Radiography_Dataset'
source_dirs = ['Normal', 'viral', 'COVID']

if os.path.isdir(os.path.join(root_dir, source_dirs[1])):
    os.mkdir(os.path.join(root_dir, 'test'))

    # Rename the folders to class_names
    for i, d in enumerate(source_dirs):
        os.rename(os.path.join(root_dir, d), os.path.join(root_dir, class_names[i]))

    # Create subdirectories under 'test'
    for c in class_names:
        os.mkdir(os.path.join(root_dir, 'test', c))

    # Sample from the 'images' subfolders and move to 'test'
    for c in class_names:
        # Specify the 'images' subfolder path
        images_folder_path = os.path.join(root_dir, c, 'images')
        
        # List all PNG images in the 'images' subfolder
        images = [x for x in os.listdir(images_folder_path) if x.lower().endswith('png')]
        print(f"Found {len(images)} images in {images_folder_path}")

        # Check if there are at least 30 images
        if len(images) < 30:
            print(f"Not enough images in {c}, selecting all available images.")
            selected_images = images  # Select all available images
        else:
            selected_images = random.sample(images, 30)  # Sample 30 images

        # Move selected images to the 'test' folder
        for image in selected_images:
            source_path = os.path.join(images_folder_path, image)
            target_path = os.path.join(root_dir, 'test', c, image)
            shutil.move(source_path, target_path)
            print(f"Moved {source_path} to {target_path}")
else:
    print(f"Directory does not exist: {os.path.join(root_dir, source_dirs[1])}")

FileExistsError: [Errno 17] File exists: 'COVID-19_Radiography_Dataset/test'

In [6]:
class chestXRayDataset(torch.utils.data.Dataset):
    def __init__(self,image_dirs,transform):
        def get_images(class_name):
            images = [x for x in os.listdir(image_dirs[class_name]) if x.lower().endswith('png')]
            print(f'Found {len(images)} {class_name} examples')
            return images
        
        self.images = {}
        self.class_names = ['normal','viral','covid']

        for c in self.class_names:
            self.images[c] = get_images(c)

        self.image_dirs = image_dirs
        self.transform = transform

    def __len__(self):
        return sum([len(self.images[c]) for c in self.class_names])
    
    def __getitem__(self,index):
        class_name =  random.choice(self.class_names)
        index = index%len(self.images[class_name])
        image_name = self.images[class_name][index]
        image_path = os.path.join(self.image_dirs[class_name],image_name)
        image = Image.open(image_path).convert('RGB')
        return self.transform(image),self.class_names.index(class_name)



# Image Transformation

In [11]:
train_tranform = torchvision.transforms.Compose(
    [
        torchvision.transforms.Resize(size = (224,224)),
        torchvision.transforms.RandomHorizontalFlip(),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean = [0.485,0.456,0.406],std = [0.229,0.224,0.225])
    ]
)

In [12]:
test_tranform = torchvision.transforms.Compose(
    [
        torchvision.transforms.Resize(size = (224,224)),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean = [0.485,0.456,0.406],std = [0.229,0.224,0.225])
    ]
)