In [12]:
from PIL import Image
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, models
from sklearn.metrics import accuracy_score

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [9]:
class PneumoniaDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = []
        self.labels = []

        for label in ['NORMAL', 'PNEUMONIA']:
            class_dir = os.path.join(root_dir, label)
            for img_name in os.listdir(class_dir):
                self.image_paths.append(os.path.join(class_dir, img_name))
                self.labels.append(0 if label == 'NORMAL' else 1)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label
        

In [10]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [13]:
train_dataset = PneumoniaDataset(root_dir = 'data/train', transform = transform)
test_dataset = PneumoniaDataset(root_dir = 'data/test', transform = transform)
val_dataset = PneumoniaDataset(root_dir = 'data/val', transform = transform)

In [14]:
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

In [15]:
model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
model.fc = nn.Linear(model.fc.in_features, 2) #Two neurons for normal and pneumonia
model = model.to(device)

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to C:\Users\Pragya/.cache\torch\hub\checkpoints\resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:32<00:00, 1.44MB/s]


In [17]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr = 0.001)

num_epochs = 10

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    for images, labels in train_loader:
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss= criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    print(f"Epoch {epoch +1} / {num_epochs}, Loss: {running_loss/ len(train_loader)}")
    model.eval()
    val_labels = []
    val_preds = []

    with torch.no_grad():
        for images, labels in val_loader:
            images = images.to(device)
            labels = labels.to(device)
            
            outputs = model(images)
            _, preds = torch.max(outputs, 1)

            val_labels.extend(labels.cpu().numpy())
            val_preds.extend(preds.cpu().numpy())

    val_accuracy = accuracy_score(val_labels, val_preds)
    print('Validation accuracy:', val_accuracy)

model.eval()
test_labels = []
test_preds = []

with torch.no_grad():
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
            
        outputs = model(images)
        _, preds = torch.max(outputs, 1)

        test_labels.extend(labels.cpu().numpy())
        test_preds.extend(preds.cpu().numpy())

test_accuracy = accuracy_score(test_labels, test_preds)
print('Test accuracy:', test_accuracy)

torch.save(model.state_dict(), 'pneumonia_classifier.pth')

Epoch 1 / 10, Loss: 0.12229617671938213
Validation accuracy: 0.5625
Epoch 2 / 10, Loss: 0.06267665189325215
Validation accuracy: 1.0
Epoch 3 / 10, Loss: 0.05042446671379444
Validation accuracy: 0.6875
Epoch 4 / 10, Loss: 0.050440363019584836
Validation accuracy: 0.625
Epoch 5 / 10, Loss: 0.03320916366289895
Validation accuracy: 1.0
Epoch 6 / 10, Loss: 0.02501620096074163
Validation accuracy: 0.625
Epoch 7 / 10, Loss: 0.0233052359818748
Validation accuracy: 0.8125
Epoch 8 / 10, Loss: 0.026901992052265705
Validation accuracy: 0.9375
Epoch 9 / 10, Loss: 0.007623186606543847
Validation accuracy: 0.9375
Epoch 10 / 10, Loss: 0.027557481184299674
Validation accuracy: 1.0
Test accuracy: 0.8108974358974359
