In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from matplotlib import pyplot as plt
import os

In [2]:
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as transforms

In [None]:
class XRay(Dataset):
  def __init__(self, path, dim):
    self.path = path
    self.normal = os.listdir(path + '/normal')
    self.pneumonia = os.listdir(path + '/pneumonia')
    self.tuberculosis = os.listdir(path + '/tuberculosis')
    self.t = transforms.Compose([transforms.ToTensor(),
                                 transforms.Grayscale(),
                                 transforms.Resize([dim, dim]),
                                 transforms.Lambda(lambda x: x.view(-1))])

  def __len__(self):
    return len(self.normal) + len(self.pneumonia) + len(self.tuberculosis)

  def __getitem__(self, idx):
    if idx < len(self.normal):
      img = plt.imread(self.path + '/normal/' + self.normal[idx])
      label = 0
    elif idx < len(self.normal) + len(self.pneumonia):
      img = plt.imread(self.path + '/pneumonia/' + self.pneumonia[idx - len(self.normal)])
      label = 1
    else:
      img = plt.imread(self.path + '/tuberculosis/' + self.tuberculosis[idx - len(self.normal) - len(self.pneumonia)])
      label = 2
    img = self.t(img)
    return img, label

In [None]:
ds_train = XRay('xray/train', 64)
ds_val = XRay('xray/val', 64)

In [None]:
dl_train = DataLoader(ds_train, batch_size=32, shuffle=True)
dl_val = DataLoader(ds_val, batch_size=32, shuffle=True)

In [None]:
models = {
    'Shallow MLP': nn.Sequential(
        nn.Linear(64*64, 256),
        nn.ReLU(),
        nn.Linear(256, 3)
    ),
    'Deep MLP': nn.Sequential(
        nn.Linear(64*64, 256),
        nn.ReLU(),
        nn.Linear(256, 128),
        nn.ReLU(),
        nn.Linear(128, 64),
        nn.ReLU(),
        nn.Linear(64, 3)
    ),
    'Shallow CNN': nn.Sequential(
        nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2),
        nn.Flatten(),
        nn.Linear(32*32*32, 3)
    ),
    'Deep CNN': nn.Sequential(
        nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2),
        nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2),
        nn.Flatten(),
        nn.Linear(64*32*32, 128),
        nn.ReLU(),
        nn.Linear(128, 3)
    )

}

In [None]:
def train_model(model, train_dl, val_dl, epochs=1000, lr=0.1):
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    train_loss_history = []
    train_acc_history = []
    val_loss_history = []
    val_acc_history = []

    for epoch in range(epochs):
        model.train()
        train_loss = 0.0
        train_correct = 0

        for batch_x, batch_y in train_dl:
            batch_x, batch_y = batch_x.to(device), batch_y.to(device)
            optimizer.zero_grad()
            output = model(batch_x)
            loss = criterion(output, batch_y)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            _, predicted = torch.max(output.data, 1)
            train_correct += (predicted == batch_y).sum().item()

        train_loss_history.append(train_loss / len(train_dl))
        train_acc_history.append(train_correct / len(train_dl.dataset))

        model.eval()
        val_loss = 0.0
        val_correct = 0
        with torch.no_grad():
            for batch_x, batch_y in val_dl:
                batch_x, batch_y = batch_x.to(device), batch_y.to(device)
                output = model(batch_x)
                loss = criterion(output, batch_y)
                val_loss += loss.item()
                _, predicted = torch.max(output.data, 1)
                val_correct += (predicted == batch_y).sum().item()

        val_loss_history.append(val_loss / len(val_dl))
        val_acc_history.append(val_correct / len(val_dl.dataset))

        if epoch % 10 == 0:
            print(f"Epoch {epoch}, Train Loss: {train_loss_history[-1]:.4f}, Train Acc: {train_acc_history[-1]:.4f}, Val Loss: {val_loss_history[-1]:.4f}, Val Acc: {val_acc_history[-1]:.4f}")



    return train_loss_history, train_acc_history, val_loss_history, val_acc_history
