In [None]:

# Initial libraries & data exploration 
import os
import glob
import cv2
import matplotlib.pyplot as plt
import torch
from torchvision import transforms, datasets
from torch.utils.data import DataLoader

print(f"PyTorch Version: {torch.__version__}")

# Config & Setup 
DATA_DIR = 'chest_xray'
NUM_EPOCHS = 10
BATCH_SIZE = 32
LEARNING_RATE = 0.001

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

#  Visual inspection of data 
try:
    train_normal_files = glob.glob(os.path.join(DATA_DIR, "train/NORMAL/*.jpeg"))
    if train_normal_files:
        print(f"Found {len(train_normal_files)} normal images in training set.")
        sample_image = cv2.imread(train_normal_files[0])
        plt.imshow(cv2.cvtColor(sample_image, cv2.COLOR_BGR2RGB))
        plt.title("Normal Image Sample")
        plt.show()
    else:
        print("Could not find sample images. Check DATA_DIR path.")
except Exception as e:
    print(f"An error occurred during sample image loading: {e}")


# Data Preprocessing & loading 
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]),
    'test': transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]),
}

train_dir = os.path.join(DATA_DIR, 'train')
val_dir = os.path.join(DATA_DIR, 'val')
test_dir = os.path.join(DATA_DIR, 'test')

# Create datasets
train_dataset = datasets.ImageFolder(train_dir, transform=data_transforms['train'])
val_dataset = datasets.ImageFolder(val_dir, transform=data_transforms['val'])
test_dataset = datasets.ImageFolder(test_dir, transform=data_transforms['test'])

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

print(f"Found {len(train_dataset)} images in the training folder.")
print(f"Found {len(val_dataset)} images in the validation folder.")
print(f"Found {len(test_dataset)} images in the test folder.")
