In [8]:
import os

import torch
import torch.nn as nn
import torch.optim as optim
from PIL import Image
from torch.utils.data import DataLoader,Dataset
from torchvision import transforms,models
from sklearn.metrics import accuracy_score

In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class PneumoniaDataset(Dataset):
    def __init__(self,root_dir,transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = []
        self.labels = []

        for label in ['NORMAL','PNEUMONIA']:
            class_dir = os.path.join(root_dir,label)
            for img_name in os.listdir(class_dir):
                self.image_paths.append(os.path.join(class_dir,img_name))
                self.labels.append(0 if label == 'Normal' else 1)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, item):
        img_path = self.image_paths[item]
        image = Image.open(img_path).convert('RGB')
        label = self.labels[item]

        if self.transform:
            image = self.transform(image)
        return image,label

In [11]:
transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])
])

train_dataset = PneumoniaDataset(root_dir='C:\\Projects\\python\\data_sets\\chest_xray\\train',transform=transform)
test_dataset = PneumoniaDataset(root_dir='C:\\Projects\\python\\data_sets\\chest_xray\\test',transform=transform)
val_dataset = PneumoniaDataset(root_dir='C:\\Projects\\python\\data_sets\\chest_xray\\val',transform=transform)

train_loader = DataLoader(train_dataset,batch_size=32,shuffle=True)
test_loader = DataLoader(test_dataset,batch_size=32,shuffle=False)
val_loader = DataLoader(val_dataset,batch_size=32,shuffle=False)

model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
model.fc = nn.Linear(model.fc.in_features,2)#NORMAL,PNEUMONIA
model = model.to(device)

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to C:\Users\jorda/.cache\torch\hub\checkpoints\resnet18-f37072fd.pth


100.0%


In [12]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(),lr=0.001)

In [13]:
num_epochs = 10

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    for images,labels in train_loader:
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs,labels)
        loss.backward()

        optimizer.step()
        running_loss += loss

    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(train_loader)}')

Epoch 1/10, Loss: 0.012368422001600266
Epoch 2/10, Loss: 9.618944204703439e-06
Epoch 3/10, Loss: 7.542238108726451e-06
Epoch 4/10, Loss: 6.22243169345893e-06
Epoch 5/10, Loss: 5.223051630309783e-06
Epoch 6/10, Loss: 4.452813755051466e-06
Epoch 7/10, Loss: 3.840632416540757e-06
Epoch 8/10, Loss: 3.343297294122749e-06
Epoch 9/10, Loss: 2.931276412709849e-06
Epoch 10/10, Loss: 2.587246854091063e-06


In [14]:
model.eval()
val_labels = []
val_preds = []

with torch.no_grad():
    for images,labels in val_loader:
        images = images.to(device)
        labels = labels.to(device)

        outputs = model(images)
        _,preds = torch.max(outputs,1)

        val_labels.extend(labels.cpu().numpy())
        val_preds.extend(preds.cpu().numpy())

    val_accuracy = accuracy_score(val_labels,val_preds)
    print('Validation accuracy:',val_accuracy)

Validation accuracy: 1.0


In [15]:
model.eval()
test_labels = []
test_preds = []


with torch.no_grad():
    for images,labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)

        outputs = model(images)
        _,preds = torch.max(outputs,1)

        test_labels.extend(labels.cpu().numpy())
        test_preds.extend(preds.cpu().numpy())

    test_accuracy = accuracy_score(test_labels,test_preds)
    print('Test accuracy:',test_accuracy)

Test accuracy: 1.0


In [16]:
torch.save(model.state_dict(),'pneumonia_classifier.pth')