# X Ray Classifier

## Importing Dependencies

In [2]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os
from pathlib import Path

## Dataset Downloading

In [8]:
# import kagglehub
# import zipfile
# from kaggle.api.kaggle_api_extended import KaggleApi

In [9]:
# api = KaggleApi()
# api.authenticate()

In [10]:
# dataset_identifier = "tolgadincer/labeled-chest-xray-images"
# download_dir = "../data/x-ray-images"

# os.makedirs(download_dir, exist_ok=True)

In [11]:
# Download latest version
# api.dataset_download_files(dataset_identifier, path=download_dir, unzip=True)

# print(f"Dataset downloaded to {download_dir}")

## Dataset Loading and Preprocessing

In [4]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.transforms import v2
from torchvision import models
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

In [5]:
print(torch.__version__)

2.7.1+cpu


In [6]:
train_dir = "../data/chest_xray/train"
val_dir = "../data/chest_xray/test"

In [40]:
train_transforms = v2.Compose([
    v2.Grayscale(num_output_channels=3),
    v2.Resize(size=(256, 256)), 
    # v2.RandomResizedCrop(size=(256, 256)),
    v2.RandomHorizontalFlip(p=0.5),
    v2.RandomRotation(degrees=15),
    v2.GaussianBlur(kernel_size=3, sigma=(0.1, 1.0)),
    # v2.GaussianNoise(),
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transforms = v2.Compose([
    v2.Grayscale(num_output_channels=3),
    v2.Resize(size=(256, 256)),
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [41]:
train_dataset = ImageFolder(root=train_dir, transform=train_transforms)
val_dataset = ImageFolder(root=val_dir, transform=val_transforms)

In [42]:
batch_size = 16

In [43]:
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, 
                          num_workers=4, pin_memory=False, prefetch_factor=1, 
                          persistent_workers=False, in_order=False)

val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=False, 
                        num_workers=4, pin_memory=False, prefetch_factor=1, 
                        persistent_workers=False, in_order=True)

In [11]:
# for idx, (x, y) in enumerate(val_loader):
#     print(model.forward(x)[0][0])
#     print(f"Index: {idx} | Shape: {np.shape(x)} | Target length: {len(y)}")

## Model Architecture

In [12]:
# class XRayNet(nn.Module):
#     def __init__(self):
#         super().__init__()
#         self.conv_relu_stack = nn.Sequential(
#             nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=0, bias=True),
#             nn.ReLU(),
#             nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=True),
#             nn.ReLU(),
#             nn.MaxPool2d(kernel_size=2, stride=2),
            
#             nn.Conv2d(128, 256, kernel_size=5, stride=1, padding=1, bias=True),
#             nn.ReLU(),
#             nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=0, bias=True),
#             nn.ReLU(),
#             nn.MaxPool2d(kernel_size=2, stride=2),
            
#             nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=0, bias=True),
#             nn.ReLU(),
#             nn.MaxPool2d(kernel_size=2, stride=2)
#         )

#         self.classifier = nn.Sequential(     
#             nn.Flatten(),
#             nn.Linear(430592, 128),
#             nn.ReLU(),
#             nn.Dropout(0.5),
#             nn.Linear(128, 1)
#         )

#     def forward(self, x):
#         x = self.conv_relu_stack(x)
#         x = self.classifier(x)

#         return x

In [44]:
def initialize_resnet18(num_classes=1):
    weights = models.ResNet18_Weights
    model = models.resnet18(weights=weights)

    # Freeze initial layers
    for param in model.parameters():
        param.requires_grad = False

    in_features = model.fc.in_features
    model.fc = nn.Sequential(
        nn.Linear(in_features, 256),
        nn.ReLU(),
        nn.Linear(256, 128),
        nn.ReLU(),
        nn.Linear(128, num_classes)
    )

    return model

## Model Training

In [45]:
# model training script
def train_model(train_loader, val_loader, batch_size=16, epochs=50, learning_rate=1e-3, log_dir='../reports/exp1'):
    device = torch.device('cpu')
    
    model = initialize_resnet18(1).to(device)
    
    loss_fn = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    
    writer = SummaryWriter(log_dir=log_dir)
    
    for epoch in range(epochs):
        print(f"Epoch {epoch+1}\n------------------------------------------------------------------------------------")
        
        # Training
        
        model.train()
        
        total_train_loss = 0.0
    
        train_progress_bar = tqdm(train_loader, desc='Training', leave=True)
        
        for batch, (x, y) in enumerate(train_progress_bar):
            optimizer.zero_grad()
            
            y_pred = model(x)
            
            train_loss = loss_fn(y_pred, y.unsqueeze(1).float())
            train_loss.backward()
            
            optimizer.step()
            
            total_train_loss += train_loss.item()
            
            train_progress_bar.set_postfix({'Batch Loss': f"{train_loss.item():.3f}"})

        
        avg_train_loss = total_train_loss / len(train_loader)
        writer.add_scalar('Loss/Train', avg_train_loss, epoch)
        
        
        
        # Validation
        model.eval()
    
        total_val_loss, correct = 0.0, 0
        total = 0
    
        
        with torch.no_grad():
            val_progress_bar = tqdm(val_loader, desc='Validation', leave=True)
            
            for x, y in val_progress_bar:
                y_pred = model(x)
                
                val_loss = loss_fn(y_pred, y.unsqueeze(1).float())
                total_val_loss += val_loss.item()
                
                val_progress_bar.set_postfix({'Val Loss': f"{val_loss.item():.3f}"})
                
                y_pred_labels = (torch.sigmoid(y_pred) > 0.5).int()
                
                correct += (y_pred_labels == y.unsqueeze(1).int()).sum().item()
                total += y.size(0)
                
    
        avg_val_loss = total_val_loss / len(val_loader)
        accuracy = correct / total

        writer.add_scalar('Loss/Validation', avg_val_loss, epoch)
        writer.add_scalar('Accuracy/Validation', accuracy, epoch)

        print(f"Train Loss: {avg_train_loss:.3f} | Val Loss: {avg_val_loss:.3f} | Acccuracy: {(accuracy*100):.2f} % \n")
        
        writer.close()

    return model

In [None]:
train_model(train_loader, val_loader)

print("Training completed!")

Epoch 1
------------------------------------------------------------------------------------


Training: 100%|████████████████████████████████████████████████████████████████████████████████████████| 327/327 [05:34<00:00,  1.02s/it, Batch Loss=0.075]
Validation: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 39/39 [00:46<00:00,  1.19s/it, Val Loss=0.074]


Train Loss: 0.253 | Val Loss: 0.174 | Acccuracy: 93.59 % 

Epoch 2
------------------------------------------------------------------------------------


Training: 100%|████████████████████████████████████████████████████████████████████████████████████████| 327/327 [09:43<00:00,  1.78s/it, Batch Loss=0.142]
Validation: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 39/39 [01:26<00:00,  2.23s/it, Val Loss=0.030]


Train Loss: 0.167 | Val Loss: 0.219 | Acccuracy: 92.15 % 

Epoch 3
------------------------------------------------------------------------------------


Training: 100%|████████████████████████████████████████████████████████████████████████████████████████| 327/327 [09:41<00:00,  1.78s/it, Batch Loss=0.026]
Validation: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 39/39 [01:14<00:00,  1.92s/it, Val Loss=0.060]


Train Loss: 0.175 | Val Loss: 0.184 | Acccuracy: 93.11 % 

Epoch 4
------------------------------------------------------------------------------------


Training: 100%|████████████████████████████████████████████████████████████████████████████████████████| 327/327 [10:14<00:00,  1.88s/it, Batch Loss=1.092]
Validation: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 39/39 [01:07<00:00,  1.74s/it, Val Loss=0.060]


Train Loss: 0.164 | Val Loss: 0.207 | Acccuracy: 93.27 % 

Epoch 5
------------------------------------------------------------------------------------


Training: 100%|████████████████████████████████████████████████████████████████████████████████████████| 327/327 [12:43<00:00,  2.34s/it, Batch Loss=0.183]
Validation: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 39/39 [01:15<00:00,  1.92s/it, Val Loss=0.057]


Train Loss: 0.158 | Val Loss: 0.220 | Acccuracy: 91.67 % 

Epoch 6
------------------------------------------------------------------------------------


Training: 100%|████████████████████████████████████████████████████████████████████████████████████████| 327/327 [08:51<00:00,  1.63s/it, Batch Loss=0.188]
Validation: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 39/39 [01:13<00:00,  1.89s/it, Val Loss=0.285]


Train Loss: 0.163 | Val Loss: 0.167 | Acccuracy: 92.95 % 

Epoch 7
------------------------------------------------------------------------------------


Training: 100%|████████████████████████████████████████████████████████████████████████████████████████| 327/327 [09:18<00:00,  1.71s/it, Batch Loss=0.084]
Validation: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 39/39 [01:29<00:00,  2.29s/it, Val Loss=0.080]


Train Loss: 0.153 | Val Loss: 0.209 | Acccuracy: 91.99 % 

Epoch 8
------------------------------------------------------------------------------------


Training: 100%|████████████████████████████████████████████████████████████████████████████████████████| 327/327 [10:45<00:00,  1.97s/it, Batch Loss=0.050]
Validation: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 39/39 [01:17<00:00,  1.98s/it, Val Loss=0.119]


Train Loss: 0.146 | Val Loss: 0.186 | Acccuracy: 92.79 % 

Epoch 9
------------------------------------------------------------------------------------


Training: 100%|████████████████████████████████████████████████████████████████████████████████████████| 327/327 [10:44<00:00,  1.97s/it, Batch Loss=0.024]
Validation: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 39/39 [01:08<00:00,  1.75s/it, Val Loss=0.161]


Train Loss: 0.148 | Val Loss: 0.184 | Acccuracy: 92.63 % 

Epoch 10
------------------------------------------------------------------------------------


Training: 100%|████████████████████████████████████████████████████████████████████████████████████████| 327/327 [10:36<00:00,  1.95s/it, Batch Loss=0.225]
Validation: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 39/39 [01:15<00:00,  1.93s/it, Val Loss=0.102]


Train Loss: 0.150 | Val Loss: 0.181 | Acccuracy: 92.15 % 

Epoch 11
------------------------------------------------------------------------------------


Training: 100%|████████████████████████████████████████████████████████████████████████████████████████| 327/327 [09:10<00:00,  1.68s/it, Batch Loss=0.025]
Validation: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 39/39 [01:08<00:00,  1.75s/it, Val Loss=0.090]


Train Loss: 0.138 | Val Loss: 0.208 | Acccuracy: 91.99 % 

Epoch 12
------------------------------------------------------------------------------------


Training:   4%|███▌                                                                                     | 13/327 [00:56<07:59,  1.53s/it, Batch Loss=0.109]

## Model Testing