In [11]:
import os
import numpy as np
import pandas as pd
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from torchvision import models
from torchvision.transforms import transforms
from torch import nn, optim
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
from PIL import Image

from sklearn.model_selection import train_test_split


In [6]:
# Load metadata
metadata = pd.read_csv('./datasets/MetaData.csv')

# Display the first few rows to understand the structure
print(metadata.head())

     id gender age    county  ptb                                      remarks
0  1000   male  31  Shenxhen    0                                       normal
1  1001   male  64  Shenxhen    0                                       normal
2  1002   male  35  Shenxhen    0                                       normal
3  1003   male  32  Shenxhen    1               STB,ATB,tuberculosis pleuritis
4  1004   male   2  Shenxhen    1  secondary PTB  in the bilateral upper field


In [7]:
# Define paths for images and masks
image_dir = './datasets/image/'
mask_dir = './datasets/mask/'

In [8]:
# Add paths for images and masks in the metadata
metadata['image_path'] = metadata['id'].apply(lambda x: os.path.join(image_dir, f'{x}.png'))
metadata['mask_path'] = metadata['id'].apply(lambda x: os.path.join(mask_dir, f'{x}.png'))

In [9]:
# Verify file existence
metadata = metadata[metadata['image_path'].apply(os.path.exists) & metadata['mask_path'].apply(os.path.exists)]


In [10]:
# Check the prepared metadata
print(metadata.head())

     id gender age    county  ptb  \
0  1000   male  31  Shenxhen    0   
1  1001   male  64  Shenxhen    0   
2  1002   male  35  Shenxhen    0   
3  1003   male  32  Shenxhen    1   
4  1004   male   2  Shenxhen    1   

                                       remarks                 image_path  \
0                                       normal  ./datasets/image/1000.png   
1                                       normal  ./datasets/image/1001.png   
2                                       normal  ./datasets/image/1002.png   
3               STB,ATB,tuberculosis pleuritis  ./datasets/image/1003.png   
4  secondary PTB  in the bilateral upper field  ./datasets/image/1004.png   

                  mask_path  
0  ./datasets/mask/1000.png  
1  ./datasets/mask/1001.png  
2  ./datasets/mask/1002.png  
3  ./datasets/mask/1003.png  
4  ./datasets/mask/1004.png  


In [12]:
class ChestXrayDataset(Dataset):
    def __init__(self, metadata, transform=None, mask_transform=None):
        self.metadata = metadata
        self.transform = transform
        self.mask_transform = mask_transform

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):
        row = self.metadata.iloc[idx]
        
        # Load image and mask
        image = Image.open(row['image_path']).convert('RGB')
        mask = Image.open(row['mask_path']).convert('L')  # Grayscale
        
        # Apply transformations
        if self.transform:
            image = self.transform(image)
        if self.mask_transform:
            mask = self.mask_transform(mask)
        
        # Get label
        label = torch.tensor(row['ptb'], dtype=torch.float32)
        
        return image, mask, label

In [13]:
# Split metadata
train_meta, test_meta = train_test_split(metadata, test_size=0.2, stratify=metadata['ptb'], random_state=42)
val_meta, test_meta = train_test_split(test_meta, test_size=0.5, stratify=test_meta['ptb'], random_state=42)


In [14]:
# Define transformations
image_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

mask_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

In [15]:
# Create datasets
train_dataset = ChestXrayDataset(train_meta, transform=image_transform, mask_transform=mask_transform)
val_dataset = ChestXrayDataset(val_meta, transform=image_transform, mask_transform=mask_transform)
test_dataset = ChestXrayDataset(test_meta, transform=image_transform, mask_transform=mask_transform)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [35]:
from torchvision import models
from torch import nn

# Load a pre-trained ResNet50 model
model = models.resnet50(pretrained=True)

# Modify the final layer for binary classification
model.fc = nn.Linear(model.fc.in_features, 1)

# Move the model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)



In [41]:
import torch.optim as optim

# Binary cross-entropy loss
class_weights = torch.tensor([2.0, 2.0]).to(device)  # Adjust weights as needed
criterion = nn.BCEWithLogitsLoss(pos_weight=class_weights[1])

# criterion = nn.BCEWithLogitsLoss()

# Adam optimizer
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [42]:
def train_model(model, train_loader, val_loader, criterion, optimizer, epochs=10):
    model.train()
    patience = 5
    best_val_loss = float('inf')
    counter = 0
    for epoch in range(epochs):
        train_loss = 0
        model.train()
        for images, masks, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            
            # Forward pass
            outputs = model(images).squeeze(1)
            loss = criterion(outputs, labels)
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
        
        # Validation step
        val_loss = 0
        model.eval()
        with torch.no_grad():
            for images, masks, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images).squeeze(1)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            counter = 0
            torch.save(model.state_dict(), 'best_model.pth')  # Save best model
            print(f"Saved Best Model at Epoch {epoch + 1}")
        else:
            counter += 1
            if counter >= patience:
                print("Early stopping triggered")
                break

        print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss/len(train_loader):.4f}, Val Loss: {val_loss/len(val_loader):.4f}")


In [None]:
train_model(
    model=model,
    train_loader=train_loader,  # DataLoader for the training set
    val_loader=val_loader,      # DataLoader for the validation set
    criterion=criterion,
    optimizer=optimizer,
    epochs=10  # Number of epochs to train
)


In [19]:
# Save the model
torch.save(model.state_dict(), 'tb_detection_model.pth')

# Load the model
model.load_state_dict(torch.load('tb_detection_model.pth'))
model.eval()


  model.load_state_dict(torch.load('tb_detection_model.pth'))


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [39]:
def evaluate_with_threshold(model, loader, threshold=0.5):
    model.eval()
    y_true = []
    y_pred = []

    with torch.no_grad():
        for images, masks, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images).squeeze(1)
            probabilities = torch.sigmoid(outputs)
            predictions = (probabilities > threshold).float()
            
            y_true.extend(labels.cpu().numpy())
            y_pred.extend(predictions.cpu().numpy())

    # Classification report and confusion matrix
    from sklearn.metrics import classification_report, confusion_matrix
    print("Classification Report:\n", classification_report(y_true, y_pred))
    print("Confusion Matrix:\n", confusion_matrix(y_true, y_pred))

In [None]:
evaluate_with_threshold(model, test_loader, threshold=0.6)

Classification Report:
               precision    recall  f1-score   support

         0.0       0.61      0.97      0.75        36
         1.0       0.93      0.37      0.53        35

    accuracy                           0.68        71
   macro avg       0.77      0.67      0.64        71
weighted avg       0.77      0.68      0.64        71

Confusion Matrix:
 [[35  1]
 [22 13]]
