In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from tqdm import tqdm
from PIL import Image
import os

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
IMG_SIZE = 256
BATCH_SIZE = 64
EPOCHS = 100
LR = 0.001


train_transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.RandomResizedCrop(IMG_SIZE, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

val_transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])


In [4]:
train_dir = "/content/drive/MyDrive/archive/chest_xray/train"
val_dir   = "/content/drive/MyDrive/archive/chest_xray/val"
test_dir  = "/content/drive/MyDrive/archive/chest_xray/test"

train_dataset = datasets.ImageFolder(root=train_dir, transform=train_transform)
val_dataset = datasets.ImageFolder(root=val_dir, transform=val_transform)

#train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
#val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

from torch.utils.data import DataLoader, WeightedRandomSampler
import numpy as np

targets = [sample[1] for sample in train_dataset.samples]  # class indices
class_sample_counts = np.bincount(targets)
print("Samples per class:", class_sample_counts)


# Compute weights: inverse of class frequency
class_weights = 1.0 / class_sample_counts
print("Class weights:", class_weights)

# Assign weight for each sample in the dataset
sample_weights = [class_weights[t] for t in targets]

# Create sampler
sampler = WeightedRandomSampler(
    weights=sample_weights,
    num_samples=len(sample_weights),  # draw same number as dataset size
    replacement=True
)

# create DataLoader with sampler instead of shuffle
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    sampler=sampler
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False
)


class_names = train_dataset.classes
print("Classes:", class_names)

Samples per class: [1183 3873]
Class weights: [0.00084531 0.0002582 ]
Classes: ['NORMAL', 'PNEUMONIA']


In [None]:
'''
class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2)
        )
    def forward(self, x):
        return self.block(x)

class SmallCNN64(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.features = nn.Sequential(
            ConvBlock(1, 32),
            ConvBlock(32, 64),
            ConvBlock(64, 128),
            ConvBlock(128, 256),
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(256 * (IMG_SIZE // 16) * (IMG_SIZE // 16), 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        return self.classifier(x)
'''

class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch, dropout=0.3):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.SiLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.SiLU(inplace=True),
            nn.Dropout2d(dropout),
            nn.MaxPool2d(2)
        )

    def forward(self, x):
        return self.block(x)

class SmallCNN64(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.features = nn.Sequential(
            ConvBlock(1, 32),
            ConvBlock(32, 64),
            ConvBlock(64, 128),
            ConvBlock(128, 256),
        )
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d((1,1)),   # GAP
            nn.Flatten(),
            nn.Linear(256, 512),
            nn.SiLU(),
            nn.Dropout(0.5),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        return self.classifier(x)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SmallCNN64(num_classes=len(class_names)).to(device)


criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LR)
scaler = torch.amp.GradScaler("cuda", enabled=torch.cuda.is_available())


def train_model(model, train_loader, val_loader, epochs=EPOCHS):
    train_losses, val_losses, val_accs = [], [], []

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0

        for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()

            with torch.amp.autocast("cuda", enabled=torch.cuda.is_available()):
                outputs = model(images)
                loss = criterion(outputs, labels)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            running_loss += loss.item()

        avg_train_loss = running_loss / len(train_loader)
        train_losses.append(avg_train_loss)

        # Validation
        model.eval()
        correct, total, val_loss = 0, 0, 0.0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()

        avg_val_loss = val_loss / len(val_loader)
        val_losses.append(avg_val_loss)
        val_acc = 100 * correct / total
        val_accs.append(val_acc)

        print(f"Epoch {epoch+1}: Train Loss={avg_train_loss:.4f}, "
              f"Val Loss={avg_val_loss:.4f}, Val Acc={val_acc:.2f}%")

    return train_losses, val_losses, val_accs


import matplotlib.pyplot as plt

def plot_training(train_losses, val_losses, val_accs):
    epochs = range(1, len(train_losses) + 1)

    plt.figure(figsize=(14,5))

    # Plot Loss
    plt.subplot(1,2,1)
    plt.plot(epochs, train_losses, 'b-', label="Training Loss")
    plt.plot(epochs, val_losses, 'r-', label="Validation Loss")
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.title("Training vs Validation Loss")
    plt.legend()
    plt.grid(True)

    # Plot Accuracy
    plt.subplot(1,2,2)
    plt.plot(epochs, val_accs, 'g-', label="Validation Accuracy")
    plt.xlabel("Epochs")
    plt.ylabel("Accuracy (%)")
    plt.title("Validation Accuracy")
    plt.legend()
    plt.grid(True)

    plt.show()

# Call it after training



train_losses, val_losses, val_accs = train_model(model, train_loader, val_loader, epochs=EPOCHS)
plot_training(train_losses, val_losses, val_accs)

def predict_image(image_path, model, class_names):
    model.eval()
    img = Image.open(image_path).convert("L")
    transform = val_transform
    img_t = transform(img).unsqueeze(0).to(device)

    with torch.no_grad():
        outputs = model(img_t)
        _, predicted = outputs.max(1)

    return class_names[predicted.item()]

TP = 0
FP = 0
TN = 0
FN = 0
'''
for filename in os.listdir(test_dir):
    test_img = os.path.join(test_dir, filename)
    print("Test Image:", test_img)
    predicted = predict_image(test_img, model, class_names)
    print("Predicted:", predicted)
    if "N" in filename:
      if predicted == "NORMAL":
        TN += 1
      else:
        FP += 1
    if "P" in filename:
      if predicted == "PNEUMONIA":
        TP += 1
      else:
        FN += 1
'''
count = 1
for filename in os.listdir(test_dir):
    test_img = os.path.join(test_dir, filename)
    print("Test Image:", test_img)
    predicted = predict_image(test_img, model, class_names)
    print("Predicted:", predicted)
    if count < 51:
      if predicted == "NORMAL":
        TN += 1
      else:
        FP += 1
    if count >= 51:
      if predicted == "PNEUMONIA":
        TP += 1
      else:
        FN += 1
    count += 1

print("TP:", TP)
print("FP:", FP)
print("TN:", TN)
print("FN:", FN)



Epoch 1/100: 100%|██████████| 79/79 [18:17<00:00, 13.90s/it]


Epoch 1: Train Loss=0.4804, Val Loss=0.4651, Val Acc=79.00%


Epoch 2/100: 100%|██████████| 79/79 [09:54<00:00,  7.53s/it]


Epoch 2: Train Loss=0.3965, Val Loss=0.4188, Val Acc=80.67%


Epoch 3/100: 100%|██████████| 79/79 [06:00<00:00,  4.57s/it]


Epoch 3: Train Loss=0.3738, Val Loss=0.3436, Val Acc=86.83%


Epoch 4/100: 100%|██████████| 79/79 [04:15<00:00,  3.23s/it]


Epoch 4: Train Loss=0.3677, Val Loss=0.3687, Val Acc=84.83%


Epoch 5/100: 100%|██████████| 79/79 [03:15<00:00,  2.47s/it]


Epoch 5: Train Loss=0.3204, Val Loss=0.3927, Val Acc=83.00%


Epoch 6/100: 100%|██████████| 79/79 [02:35<00:00,  1.97s/it]


Epoch 6: Train Loss=0.3049, Val Loss=0.3005, Val Acc=88.33%


Epoch 7/100: 100%|██████████| 79/79 [02:20<00:00,  1.78s/it]


Epoch 7: Train Loss=0.3129, Val Loss=0.3211, Val Acc=86.50%


Epoch 8/100: 100%|██████████| 79/79 [02:17<00:00,  1.74s/it]


Epoch 8: Train Loss=0.2864, Val Loss=0.2857, Val Acc=88.33%


Epoch 9/100: 100%|██████████| 79/79 [02:13<00:00,  1.68s/it]


Epoch 9: Train Loss=0.3029, Val Loss=0.2738, Val Acc=90.33%


Epoch 10/100: 100%|██████████| 79/79 [02:11<00:00,  1.66s/it]


Epoch 10: Train Loss=0.2803, Val Loss=0.2393, Val Acc=89.83%


Epoch 11/100: 100%|██████████| 79/79 [02:11<00:00,  1.67s/it]


Epoch 11: Train Loss=0.2533, Val Loss=0.2704, Val Acc=90.50%


Epoch 12/100: 100%|██████████| 79/79 [02:08<00:00,  1.63s/it]


Epoch 12: Train Loss=0.2370, Val Loss=0.2140, Val Acc=91.50%


Epoch 13/100: 100%|██████████| 79/79 [02:12<00:00,  1.67s/it]


Epoch 13: Train Loss=0.2506, Val Loss=0.2335, Val Acc=89.83%


Epoch 14/100: 100%|██████████| 79/79 [02:08<00:00,  1.63s/it]


Epoch 14: Train Loss=0.2347, Val Loss=0.3162, Val Acc=86.17%


Epoch 15/100: 100%|██████████| 79/79 [02:09<00:00,  1.64s/it]


Epoch 15: Train Loss=0.2393, Val Loss=0.1965, Val Acc=91.17%


Epoch 16/100: 100%|██████████| 79/79 [02:09<00:00,  1.64s/it]


Epoch 16: Train Loss=0.2179, Val Loss=0.2028, Val Acc=91.83%


Epoch 17/100: 100%|██████████| 79/79 [02:07<00:00,  1.61s/it]


Epoch 17: Train Loss=0.2092, Val Loss=0.1860, Val Acc=92.67%


Epoch 18/100: 100%|██████████| 79/79 [02:09<00:00,  1.64s/it]


Epoch 18: Train Loss=0.2083, Val Loss=0.2123, Val Acc=91.17%


Epoch 19/100: 100%|██████████| 79/79 [02:10<00:00,  1.65s/it]


Epoch 19: Train Loss=0.1940, Val Loss=0.2144, Val Acc=90.83%


Epoch 20/100: 100%|██████████| 79/79 [02:06<00:00,  1.61s/it]


Epoch 20: Train Loss=0.2067, Val Loss=0.2006, Val Acc=91.00%


Epoch 21/100: 100%|██████████| 79/79 [02:07<00:00,  1.62s/it]


Epoch 21: Train Loss=0.2023, Val Loss=0.2449, Val Acc=90.00%


Epoch 22/100: 100%|██████████| 79/79 [02:05<00:00,  1.59s/it]


Epoch 22: Train Loss=0.1956, Val Loss=0.1817, Val Acc=91.67%


Epoch 23/100: 100%|██████████| 79/79 [02:07<00:00,  1.61s/it]


Epoch 23: Train Loss=0.1994, Val Loss=0.2237, Val Acc=89.00%


Epoch 24/100: 100%|██████████| 79/79 [02:07<00:00,  1.61s/it]


Epoch 24: Train Loss=0.1806, Val Loss=0.2079, Val Acc=92.17%


Epoch 25/100: 100%|██████████| 79/79 [02:05<00:00,  1.58s/it]


Epoch 25: Train Loss=0.1670, Val Loss=0.2694, Val Acc=88.50%


Epoch 26/100: 100%|██████████| 79/79 [02:08<00:00,  1.63s/it]


Epoch 26: Train Loss=0.1922, Val Loss=0.2321, Val Acc=89.50%


Epoch 27/100: 100%|██████████| 79/79 [02:07<00:00,  1.62s/it]


Epoch 27: Train Loss=0.1863, Val Loss=0.1814, Val Acc=91.67%


Epoch 28/100: 100%|██████████| 79/79 [02:09<00:00,  1.64s/it]


Epoch 28: Train Loss=0.1710, Val Loss=0.2087, Val Acc=91.33%


Epoch 29/100: 100%|██████████| 79/79 [02:07<00:00,  1.61s/it]


Epoch 29: Train Loss=0.1852, Val Loss=0.1941, Val Acc=91.50%


Epoch 30/100: 100%|██████████| 79/79 [02:05<00:00,  1.58s/it]


Epoch 30: Train Loss=0.1597, Val Loss=0.2041, Val Acc=91.33%


Epoch 31/100: 100%|██████████| 79/79 [02:08<00:00,  1.63s/it]


Epoch 31: Train Loss=0.1824, Val Loss=0.1940, Val Acc=90.50%


Epoch 32/100: 100%|██████████| 79/79 [02:05<00:00,  1.59s/it]


Epoch 32: Train Loss=0.1610, Val Loss=0.1959, Val Acc=91.83%


Epoch 33/100: 100%|██████████| 79/79 [02:06<00:00,  1.61s/it]


Epoch 33: Train Loss=0.1686, Val Loss=0.1928, Val Acc=91.83%


Epoch 34/100: 100%|██████████| 79/79 [02:06<00:00,  1.60s/it]


Epoch 34: Train Loss=0.1646, Val Loss=0.2705, Val Acc=90.33%


Epoch 35/100: 100%|██████████| 79/79 [02:06<00:00,  1.61s/it]


Epoch 35: Train Loss=0.1686, Val Loss=0.2018, Val Acc=91.83%


Epoch 36/100: 100%|██████████| 79/79 [02:05<00:00,  1.59s/it]


Epoch 36: Train Loss=0.1632, Val Loss=0.1885, Val Acc=91.83%


Epoch 37/100: 100%|██████████| 79/79 [02:06<00:00,  1.60s/it]


Epoch 37: Train Loss=0.1541, Val Loss=0.1632, Val Acc=92.67%


Epoch 38/100: 100%|██████████| 79/79 [02:07<00:00,  1.62s/it]


Epoch 38: Train Loss=0.1446, Val Loss=0.2344, Val Acc=90.50%


Epoch 39/100:  58%|█████▊    | 46/79 [01:15<00:55,  1.67s/it]

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

conf_matrix = np.array([[TN, FP],
                        [FN, TP]])

# Labels
labels = ["Negative", "Positive"]
classes = ["NORMAL","PNEUMONIA"]
# Plot with seaborn
plt.figure(figsize=(6,5))
sns.heatmap(conf_matrix, annot=True, fmt="d", cmap="Blues",
            xticklabels=classes, yticklabels=labels)


plt.title("Confusion Matrix")
plt.show()


In [None]:
import math

accuracy  = (TP + TN) / (TP + TN + FP + FN)
precision = TP / (TP + FP) if (TP + FP) else 0
recall    = TP / (TP + FN) if (TP + FN) else 0
specificity = TN / (TN + FP) if (TN + FP) else 0
fpr       = FP / (FP + TN) if (FP + TN) else 0
fnr       = FN / (FN + TP) if (FN + TP) else 0
npv       = TN / (TN + FN) if (TN + FN) else 0
f1        = 2 * precision * recall / (precision + recall) if (precision + recall) else 0
balanced_accuracy = (recall + specificity) / 2
mcc_num   = (TP*TN - FP*FN)
mcc_den   = math.sqrt((TP+FP)*(TP+FN)*(TN+FP)*(TN+FN))
mcc       = mcc_num / mcc_den if mcc_den else 0

print("Accuracy:", accuracy)
print("Precision:", precision)
print("Recall:", recall)
print("Specificity:", specificity)
print("F1 Score:", f1)
print("Balanced Accuracy:", balanced_accuracy)
print("MCC:", mcc)
