In [18]:
# Import stuff
from torchvision import transforms
import os
from PIL import Image
import numpy as np
import platform
import torch
import torch.nn as nn

# Data Preparation

#### Resize the images to 128 x 128

In [None]:
# Set up paths
normal_dir = "data/NORMAL"
pneumonia_dir = "data/PNEUMONIA"

# Find which images have the smallest size
min_size = float('inf')
min_dimensions = (float('inf'), float('inf'))
smallest_img_path = ""

for image in os.listdir(normal_dir) + os.listdir(pneumonia_dir):
    if image.endswith((".jpeg", ".jpg", ".png")):
        img_path = os.path.join(normal_dir if image in os.listdir(normal_dir) else pneumonia_dir, image)
        with Image.open(img_path, "r") as img:
            width, height = img.size
            if width * height < min_size:
                min_size = width * height
                min_dimensions = (width, height)
                smallest_img_path = img_path

print(f"Smallest image dimensions: {min_dimensions}")
print(f"Smallest image path: {smallest_img_path}")

# Paths for resized images
resized_normal_dir = "128x128_data/NORMAL"
resized_pneumonia_dir = "128x128_data/PNEUMONIA"

# Resize all images to 128x128 pixels and save them
for image in os.listdir(normal_dir) + os.listdir(pneumonia_dir):
    if image.endswith((".jpeg", ".jpg", ".png")):
        img_path = os.path.join(normal_dir if image in os.listdir(normal_dir) else pneumonia_dir, image)
        with Image.open(img_path, "r") as img:
            img = img.convert("L")  # Convert to grayscale
            width, height = img.size
            if width > height:
                cropped_width = 128
                cropped_height = int(height * 128 / width)
            else:
                cropped_height = 128
                cropped_width = int(width * 128 / height)
            img = img.resize((cropped_width, cropped_height)) # Resize preserving ratio
            left = 0
            upper = (cropped_height - 128) // 2
            right = 128
            lower = upper + 128
            img = img.crop((left, upper, right, lower)) # Center crop

            if "NORMAL" in img_path:
                save_path = os.path.join(resized_normal_dir, image)
            else:
                save_path = os.path.join(resized_pneumonia_dir, image)
            img.save(save_path)

# ~30 seconds

Smallest image dimensions: (384, 127)
Smallest image path: data/PNEUMONIA/person407_virus_811.jpeg


#### Ensure that all of the images are the same shape

In [None]:
# Convert each image to a tensor and ensure they are all 1 x 128 x 128
counter = 0
for image in os.listdir(resized_normal_dir) + os.listdir(resized_pneumonia_dir):
    if image.endswith((".jpeg", ".jpg", ".png")):
        img_path = os.path.join(resized_normal_dir if image in os.listdir(resized_normal_dir) else resized_pneumonia_dir, image)
        with Image.open(img_path, "r") as img:
            img_tensor = transforms.ToTensor()(img)
            if img_tensor.shape != (1, 128, 128):
                print(f"{image} has shape {img_tensor.shape}")
                counter += 1
print(f"{counter} images with incorrect shape")

Image person69_bacteria_338.jpeg has shape torch.Size([3, 128, 128])
Image person407_virus_811.jpeg has shape torch.Size([3, 128, 128])
Image person1253_bacteria_3211.jpeg has shape torch.Size([3, 128, 128])
Image person977_virus_1652.jpeg has shape torch.Size([3, 128, 128])
Image person64_bacteria_316.jpeg has shape torch.Size([3, 128, 128])
Image person547_bacteria_2296.jpeg has shape torch.Size([3, 128, 128])
Image person495_bacteria_2094.jpeg has shape torch.Size([3, 128, 128])
Image person469_virus_965.jpeg has shape torch.Size([3, 128, 128])
Image person598_virus_1154.jpeg has shape torch.Size([3, 128, 128])
Image person1684_bacteria_4461.jpeg has shape torch.Size([3, 128, 128])
Image person461_virus_949.jpeg has shape torch.Size([3, 128, 128])
Image person600_virus_1156.jpeg has shape torch.Size([3, 128, 128])
Image person516_bacteria_2191.jpeg has shape torch.Size([3, 128, 128])
Image person757_virus_1385.jpeg has shape torch.Size([3, 128, 128])
Image person388_virus_777.jpeg h

In [19]:
# Select device
if torch.cuda.is_available():
    device = torch.device("cuda")
elif getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
print(f"Using device: {device.type}\n")

Using device: mps



# Fully Connected Neural Network

In [21]:
# Define the network
class MLP(nn.Module):
    def __init__(self, in_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Flatten(),
            nn.Linear(in_dim, 2048),
            nn.ReLU(),
            nn.Linear(2048, 512),
            nn.ReLU(),
            nn.Linear(512, 2)
        )
    def forward(self, x):
        return self.net(x)


In [None]:
# Define the training parameters
model = MLP(128*128).to(device)
loss = nn.BCELoss()