# Using Resnet for chest X-ray Tuberculosis classification

In [1]:
pip install opendatasets

Defaulting to user installation because normal site-packages is not writeable
Collecting opendatasets
  Downloading opendatasets-0.1.22-py3-none-any.whl.metadata (9.2 kB)
Collecting kaggle (from opendatasets)
  Downloading kaggle-1.7.4.5-py3-none-any.whl.metadata (16 kB)
Collecting python-slugify (from kaggle->opendatasets)
  Downloading python_slugify-8.0.4-py2.py3-none-any.whl.metadata (8.5 kB)
Collecting text-unidecode (from kaggle->opendatasets)
  Downloading text_unidecode-1.3-py2.py3-none-any.whl.metadata (2.4 kB)
Downloading opendatasets-0.1.22-py3-none-any.whl (15 kB)
Downloading kaggle-1.7.4.5-py3-none-any.whl (181 kB)
Downloading python_slugify-8.0.4-py2.py3-none-any.whl (10 kB)
Downloading text_unidecode-1.3-py2.py3-none-any.whl (78 kB)
Installing collected packages: text-unidecode, python-slugify, kaggle, opendatasets

   ---------- ----------------------------- 1/4 [python-slugify]
   ---------- ----------------------------- 1/4 [python-slugify]
   ---------- -------------



In [2]:
import os
import opendatasets as od
import os
import random
import argparse
from pathlib import Path
from sklearn.model_selection import train_test_split
import numpy as np
from PIL import Image
from sklearn.metrics import accuracy_score, roc_auc_score, confusion_matrix, classification_report
import torch
from torch import nn
from torchvision import transforms, datasets, models
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm

# Download the dataset
dataset_url = 'https://www.kaggle.com/datasets/tawsifurrahman/tuberculosis-tb-chest-xray-dataset'
od.download(dataset_url)

Please provide your Kaggle credentials to download this dataset. Learn more: http://bit.ly/kaggle-creds
Your Kaggle username:Your Kaggle username:Your Kaggle username:Your Kaggle username:Your Kaggle username:Your Kaggle username:Your Kaggle username:Your Kaggle username:Your Kaggle username:Your Kaggle username:Your Kaggle username:Your Kaggle username:Your Kaggle username:Your Kaggle username:Your Kaggle username:Your Kaggle username:Your Kaggle username:Your Kaggle username:Your Kaggle username:Your Kaggle username:Your Kaggle username:Your Kaggle username:Your Kaggle username:Your Kaggle username:Your Kaggle username:Your Kaggle username:Your Kaggle username:Your Kaggle username:Your Kaggle username:Your Kaggle username:Your Kaggle username:Your Kaggle username:Your Kaggle username:Your Kaggle username:Your Kaggle username:Your Kaggle username:Your Kaggle username:Your Kaggle username:Your Kaggle username:Your Kaggle username:Your Kaggle Key:Dataset URL: https://www.kaggle.com/data

100%|██████████| 663M/663M [00:00<00:00, 736MB/s] 





In [7]:
# Define the data directory
data_dir = r'C:\Users\Admin\Downloads\DeepLearning\Practice\w3.1\tuberculosis-tb-chest-xray-dataset\TB_Chest_Radiography_Database'
# Define model checkpoints directory
save_dir = "checkpoints"

In [8]:
# ---------------------------
# Reproducibility & device
# ---------------------------
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ---------------------------
# Utility: metrics
# ---------------------------
def evaluate_model(model, loader):
    model.eval()
    y_true, y_probs, y_pred = [], [], []
    with torch.no_grad():
        for xb, yb in loader:
            xb = xb.to(device)
            logits = model(xb).squeeze(-1).cpu()  # [batch]
            probs = torch.sigmoid(logits).numpy()
            preds = (probs >= 0.5).astype(int)
            y_probs.extend(probs.tolist())
            y_pred.extend(preds.tolist())
            y_true.extend(yb.numpy().tolist())
    acc = accuracy_score(y_true, y_pred)
    try:
        auc = roc_auc_score(y_true, y_probs)
    except Exception:
        auc = float("nan")
    cls_report = classification_report(y_true, y_pred, digits=4)
    return acc, auc, cls_report

In [9]:
set_seed(42)

# transforms
tf = transforms.Compose([
      transforms.RandomResizedCrop((224,224)),
      transforms.RandomHorizontalFlip(),
      transforms.ToTensor(),
      transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
  ])

full_ds = datasets.ImageFolder(os.path.join(data_dir), transform=tf)
# Store class_to_idx before splitting
class_to_idx = full_ds.class_to_idx

train_len = int(len(full_ds)*0.8)
val_len = len(full_ds) - train_len
train_ds, val_ds = random_split(full_ds, [train_len, val_len], generator=torch.Generator().manual_seed(42))

train_loader = DataLoader(train_ds, batch_size=16, shuffle=True, num_workers=1)
val_loader = DataLoader(val_ds, batch_size=16, shuffle=False, num_workers=1)

In [10]:
# model:resnet50 -> single logit output
# using pretrained with weights='DEFAULT'
# training from scratch with weights=None
model = models.resnet50(weights=None)

for param in model.parameters():
      param.requires_grad = True  # fine-tune all (or set False to freeze)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 1)  # single logit
model = model.to(device)

criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=2)

best_auc = 0.0
os.makedirs(save_dir, exist_ok=True)

for epoch in range(2):
    model.train()
    running_loss = 0.0
    for xb, yb in tqdm(train_loader):
        xb, yb = xb.to(device), yb.float().to(device)
        logits = model(xb).squeeze(-1)  # [batch]
        loss = criterion(logits, yb)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * xb.size(0)

    train_loss = running_loss / len(train_loader.dataset)
    val_acc, val_auc, val_report = evaluate_model(model, val_loader)
    scheduler.step(val_loss := train_loss)  # or use val_auc etc
    print(f"[Epoch {epoch}] train_loss={train_loss:.4f} val_acc={val_acc:.4f} val_auc={val_auc:.4f} \n {val_report}")

    if val_auc > best_auc:
        best_auc = val_auc
        ckpt = os.path.join(save_dir, "tb_resnet50_best.pt")
        torch.save({"model_state": model.state_dict(), "class_to_idx": class_to_idx}, ckpt)
        print(f"  Saved best checkpoint to {ckpt}")

  # final eval
test_acc, test_auc, test_report = evaluate_model(model, val_loader)
print(f"Final val acc={test_acc:.4f}, auc={test_auc:.4f} \n {test_report}")

100%|██████████| 210/210 [23:03<00:00,  6.59s/it]


[Epoch 0] train_loss=0.2791 val_acc=0.8107 val_auc=0.9476 
               precision    recall  f1-score   support

           0     0.9837    0.7835    0.8723       693
           1     0.4792    0.9388    0.6345       147

    accuracy                         0.8107       840
   macro avg     0.7314    0.8612    0.7534       840
weighted avg     0.8954    0.8107    0.8307       840

  Saved best checkpoint to checkpoints\tb_resnet50_best.pt


100%|██████████| 210/210 [2:33:10<00:00, 43.76s/it]     


[Epoch 1] train_loss=0.2133 val_acc=0.8702 val_auc=0.9460 
               precision    recall  f1-score   support

           0     0.9725    0.8672    0.9169       693
           1     0.5856    0.8844    0.7046       147

    accuracy                         0.8702       840
   macro avg     0.7790    0.8758    0.8107       840
weighted avg     0.9048    0.8702    0.8797       840

Final val acc=0.8655, auc=0.9452 
               precision    recall  f1-score   support

           0     0.9739    0.8600    0.9134       693
           1     0.5746    0.8912    0.6987       147

    accuracy                         0.8655       840
   macro avg     0.7742    0.8756    0.8060       840
weighted avg     0.9040    0.8655    0.8758       840

