In [1]:
import math
import cv2
import torch
import pandas as pd
import numpy as np
import seaborn as sns
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch.utils.data import DataLoader, Dataset
from torch.cuda.amp import autocast, GradScaler
from torchvision import models, transforms
from PIL import Image
from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix

root = '/mnt/c/Users/Thanasak/Downloads/archive/'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
seed = 4912
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

In [110]:
train_csv = pd.read_csv(root + 'train.csv')

##### get only the first set of pantient

In [14]:
FIRST_SET_PANTIENT = 21514

def is_patient_id_in_first_set(row):
  patient_folder = row['Path'].split('/')[2]
  return int(patient_folder.replace('patient', '')) < FIRST_SET_PANTIENT

train_csv_set1 = train_csv.copy()
train_csv_set1['is_first_set'] = train_csv_set1.apply(is_patient_id_in_first_set, axis=1)
train_csv_set1 = train_csv_set1[train_csv_set1['is_first_set'] == True]
train_csv_set1.drop(columns=['is_first_set'], inplace=True)

##### replace image path

In [15]:
train_csv_set1['Path'] = train_csv_set1['Path'].str.replace('CheXpert-v1.0-small/', root)

##### get only necessary columns

In [13]:
non_disease_cols = list(train_csv_set1.columns[:5])
interested_disease = ['Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 'Pleural Effusion']
disease_cols = [col for col in train_csv_set1.columns[5:] if col in interested_disease]
train_csv_set1 = train_csv_set1.drop(columns=[col for col in train_csv_set1.columns if col not in non_disease_cols + disease_cols])

##### replace -1.0 with 1.0 (U-One) and Nan with 0.0

In [18]:
train_csv_set1[disease_cols] = train_csv_set1[disease_cols].fillna(0.0).replace(-1.0, 1.0)

##### drop images don't have balance ratio (width and height differ than 200)

In [19]:
def width_height_difference(image_path):
  img = Image.open(image_path)
  width, height = img.size
  return abs(width - height), width, height

weird_ratio_image = []

for index, row in train_csv_set1.iterrows():
  diff, width, height = width_height_difference(row['Path'])
  if diff > 200:
    diseases = [row[disease] for disease in interested_disease]
    weird_ratio_image.append((row['Path'], width, height, diff, max(width / height, height / width), *diseases))

dropping_image = pd.DataFrame(weird_ratio_image, columns=['Path', 'Width', 'Height', 'Difference', 'Ratio'] + interested_disease)
print(dropping_image)

image_diff_paths = dropping_image['Path'].tolist()
train_csv_set1 = train_csv_set1[~train_csv_set1['Path'].isin(image_diff_paths)]

##### Save clean csv

In [23]:
train_csv_set1.to_csv('train_set1.csv', index=False)

##### Load train and validation set

In [2]:
disease_cols = ['Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 'Pleural Effusion']
train_csv_set1 = pd.read_csv('train_set1.csv')
valid_csv = pd.read_csv(root + 'valid.csv')
valid_csv['Path'] = valid_csv['Path'].str.replace('CheXpert-v1.0-small/', root)

In [None]:
num_cols = 3
num_rows = (len(disease_cols) + num_cols - 1) // num_cols  

fig, axes = plt.subplots(num_rows, num_cols, figsize=(15, 10)) 
axes = axes.flatten()

for i, col in enumerate(disease_cols):
    value_counts = train_csv_set1[col].value_counts(dropna=False)
    value_counts = value_counts.reindex([np.nan, -1.0, 0.0, 1.0])

    value_counts.plot(kind='bar', ax=axes[i], color='b', alpha=0.7)
    
    axes[i].set_title(f"Value counts for {col}")
    axes[i].set_ylabel('Count')

for j in range(i+1, num_rows * num_cols):
    fig.delaxes(axes[j])

plt.tight_layout()
plt.show()

In [5]:
def imshow_grid_tensor(images, title):
  grid_size = int(math.ceil(math.sqrt(images.shape[0])))
  plt.figure(figsize=(grid_size * 2, grid_size * 2))
  for i, image in enumerate(images):
    plt.subplot(grid_size, grid_size, i + 1)
    image = image.permute(1, 2, 0)
    plt.imshow(image, cmap='gray')
  plt.suptitle(title, fontsize=16)
  plt.tight_layout()
  plt.show()

##### Create Custom Dataset

In [4]:
class CustomDataset(Dataset):
  def __init__(self, data, disease_list, transforms):
    self.data = data
    self.disease_list = disease_list
    self.transforms = transforms
  
  def __len__(self):
    return self.data.shape[0]

  def __getitem__(self, idx):
    data = self.data.iloc[idx]
    image = cv2.imread(data["Path"])
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image = self.transforms(image)
    labels = data[self.disease_list].values.astype(float)
    angle = data["Frontal/Lateral"]

    return image, labels, angle

In [99]:
class GradCAM:
  def __init__(self, model, target_layers):
    self.model = model
    self.target_layers = target_layers
    self.gradients = {}
    self.activations = {}

    self._register_hooks()

  def _register_hooks(self):
    def forward_hook(module, input, output, layer_name):
      self.activations[layer_name] = output
    def backward_hook(module, grad_input, grad_output, layer_name):
      self.gradients[layer_name] = grad_output[0]

    for name, module in self.model.named_modules():
      if name in self.target_layers:
        module.register_forward_hook(lambda mod, inp, out, name=name: forward_hook(mod, inp, out, name))
        module.register_backward_hook(lambda mod, grad_inp, grad_out, name=name: backward_hook(mod, grad_inp, grad_out, name))

  def generate_cam(self, input_image, class_idx=None):
    output = self.model(input_image)
    if class_idx is None:
      class_idx = torch.argmax(output)

    self.model.zero_grad()
    target = output[0, class_idx]
    target.backward()

    cams = {}

    for layer in self.target_layers:
      gradients = self.gradients[layer][0].cpu().detach().numpy()
      activations = self.activations[layer][0].cpu().detach().numpy()

      weights = np.mean(gradients, axis=(1, 2))

      cam = np.zeros(activations.shape[1:], dtype=np.float32)
      for i, w in enumerate(weights):
        cam += w * activations[i]

      cam = np.maximum(cam, 0)

      cam = cam - np.min(cam)
      cam = cam / np.max(cam)
      cam = cv2.resize(cam, (input_image.shape[2], input_image.shape[3]))
      cams[layer] = cam

    return cams

  def visualize(self, input_image_path, cams, output_path):
    img = Image.open(input_image_path).convert('RGB')

    preprocess = transforms.Compose([transforms.Resize(320), transforms.CenterCrop(320)])
    img = preprocess(img)
    
    img = np.array(img)
    img = np.float32(img) / 255

    for layer in self.target_layers:
      heatmap = cv2.applyColorMap(np.uint8(255 * cams[layer]), cv2.COLORMAP_JET)
      heatmap = np.float32(heatmap) / 255

      superimposed_img = heatmap + img
      superimposed_img = superimposed_img / np.max(superimposed_img)

      plt.imshow(superimposed_img)
      plt.axis('off')
      plt.savefig(f"{output_path}/{layer}.jpg", bbox_inches='tight', pad_inches=0)

In [7]:
def min_max_norm(tensor):
  return (tensor - tensor.min()) / (tensor.max() - tensor.min())

##### Tranform with mean and std of ImageNet

In [8]:
transform = transforms.Compose([
  transforms.ToTensor(),
  transforms.Resize(320, interpolation=transforms.InterpolationMode.BILINEAR),
  transforms.CenterCrop(320),
  transforms.Lambda(lambda image: min_max_norm(image)),
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [9]:
dataset = CustomDataset(data=train_csv_set1, disease_list=disease_cols, transforms=transform)
dataloader = DataLoader(dataset=dataset, batch_size=16, num_workers=8, prefetch_factor=4, pin_memory=True, shuffle=True)

In [None]:
images, labels, angles = next(iter(dataloader))
imshow_grid_tensor(images, 'test first batch')
print(images.shape)
print(labels)
print(angles)

##### Custom Models

In [14]:
class CustomDenseNet(nn.Module):
  def __init__(self, num_classes=5):
    super(CustomDenseNet, self).__init__()
    self.net = models.densenet121(weights=models.DenseNet121_Weights.DEFAULT)

    for param in self.net.features.parameters():
      param.requires_grad = False
    
    for param in self.net.classifier.parameters():
      param.requires_grad = True

    self.net.classifier = nn.Linear(in_features=self.net.classifier.in_features, out_features=num_classes)

  def forward(self, x):
    x = self.net(x)
    return x

In [None]:
densenet = CustomDenseNet()
densenet

In [124]:
class CustomResNet(nn.Module):
  def __init__(self, num_classes=5):
    super(CustomResNet, self).__init__()
    self.net = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)

    for param in self.net.parameters():
      param.requires_grad = False
    
    for param in self.net.fc.parameters():
      param.requires_grad = True

    self.net.fc = nn.Linear(in_features=self.net.fc.in_features, out_features=num_classes)

  def forward(self, x):
    x = self.net(x)
    return x

In [None]:
resnet = CustomResNet()
resnet

##### Focal Loss with Weighted image angle

In [141]:
class FocalLossWithAngle(nn.Module):
  def __init__(self, alpha=.25, gamma=2, angle_weights=None, reduction='mean'):
    super(FocalLossWithAngle, self).__init__()
    self.alpha = alpha
    self.gamma = gamma
    self.reduction = reduction
    self.angle_weights = angle_weights

  def forward(self, inputs, targets, angle):
    loss = nn.BCEWithLogitsLoss(reduction='none')(inputs, targets)
    pt = torch.exp(loss)
    focal_loss = self.alpha * (1 - pt) ** self.gamma * loss

    for i, a in enumerate(angle):
      angle_weight = self.angle_weights[a] if self.angle_weights else 1.0
      focal_loss[i] *= angle_weight

    if self.reduction == 'mean':
      return focal_loss.mean()
    else:
      return focal_loss

##### Train Function

In [115]:
def train(model, opt, loss_fn, train_loader, val_loader, epochs=10, checkpoint_path=None, device='cpu'):
  model = model.to(device)
  scaler = GradScaler()

  for epoch in range(epochs):

    avg_train_loss = 0.0
    avg_test_loss = 0.0

    model.train()
    train_bar = tqdm(train_loader, desc=f'🚀Training Epoch [{epoch+1}/{epochs}]', unit='batch')
    for images, labels, angles in train_bar:
      images, labels = images.to(device), labels.to(device)
      opt.zero_grad()

      with autocast():
        outputs = model(images)
        train_loss = loss_fn(outputs, labels, angles)

      scaler.scale(train_loss).backward()
      scaler.step(opt)
      scaler.update()

      avg_train_loss += train_loss.item()
      train_bar.set_postfix(train_loss=train_loss.item())
    
    model.eval()
    test_bar = tqdm(val_loader, desc='📄Testing', unit='batch')
    for images, labels, angles in test_bar:
      images, labels = images.to(device), labels.to(device)

      with autocast():
        outputs = model(images)
        test_loss = loss_fn(outputs, labels, angles)

      avg_test_loss += test_loss.item()

      test_bar.set_postfix(test_loss=test_loss.item())
    
    avg_train_loss /= len(train_loader)
    avg_test_loss /= len(val_loader)

    print(f"Loss epoch: {epoch + 1}")
    print(f"Train Loss {avg_train_loss}")
    print(f"Test Loss {avg_test_loss}")

    if checkpoint_path is not None:
      torch.save(model.state_dict(), checkpoint_path + f'-{epoch + 1}.pth')
      print(f"Model saved to {checkpoint_path}")

#### Train DenseNet

In [None]:
loss_fn = FocalLossWithAngle(angle_weights={"Frontal": 1.0, "Lateral": 2.0})
optimizer = torch.optim.Adam(resnet.parameters(), lr=1e-4, betas=(0.9, 0.999))

train_dataset = CustomDataset(data=train_csv_set1, disease_list=disease_cols, transforms=transform)
valid_dataset = CustomDataset(data=valid_csv, disease_list=disease_cols, transforms=transform)
train_loader = DataLoader(dataset=train_dataset, batch_size=16, num_workers=8, prefetch_factor=4, pin_memory=True, shuffle=True)
valid_loader = DataLoader(dataset=train_dataset, batch_size=16, num_workers=8, prefetch_factor=4, pin_memory=True, shuffle=True)

train(
  model=resnet,
  opt=optimizer,  
  loss_fn=loss_fn,
  train_loader=train_loader,
  val_loader=valid_loader,
  epochs=10,
  checkpoint_path='checkpoint/resnet-ver1',
  device=device
)

### Evaluate models

In [None]:
densenet.load_state_dict(torch.load('checkpoints/densenet-ver1-10.pth', weights_only=True))
resnet.load_state_dict(torch.load('checkpoints/resnet-ver1-10.pth', weights_only=True))

In [142]:
disease_cols = ['Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 'Pleural Effusion']
test_csv = pd.read_csv('test_labels.csv')

def get_view_type(path):
  if 'frontal' in path.lower():
    return 'Frontal'
  elif 'lateral' in path.lower():
    return 'Lateral'
  else:
    return 'Unknown'

test_csv['Frontal/Lateral'] = test_csv['Path'].apply(get_view_type)

test_set = CustomDataset(data=test_csv, disease_list=disease_cols, transforms=transform)
test_loader = DataLoader(dataset=test_set, batch_size=16, num_workers=8, prefetch_factor=4, pin_memory=True, shuffle=True)
loss_fn = FocalLossWithAngle(angle_weights={"Frontal": 1.0, "Lateral": 2.0})

##### Try to plot test set labels

In [None]:
num_cols = 3
num_rows = (len(disease_cols) + num_cols - 1) // num_cols  

fig, axes = plt.subplots(num_rows, num_cols, figsize=(15, 10)) 
axes = axes.flatten()

for i, col in enumerate(disease_cols):
  value_counts = test_csv[col].value_counts(dropna=False)
  value_counts = value_counts.reindex([np.nan, -1.0, 0.0, 1.0])

  value_counts.plot(kind='bar', ax=axes[i], color='b', alpha=0.7)
  
  axes[i].set_title(f"Value counts for {col}")
  axes[i].set_ylabel('Count')

for j in range(i+1, num_rows * num_cols):
  fig.delaxes(axes[j])

plt.tight_layout()
plt.show()

In [206]:
def plot_confusion_matrices(outputs, labels, disease_classes=['Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 'Pleural Effusion']):
  preds = (torch.sigmoid(torch.from_numpy(outputs)).numpy() > 0.5)

  if len(labels.shape) == 1:
    labels = labels.reshape(-1, len(disease_classes))
  if len(preds.shape) == 1:
    preds = preds.reshape(-1, len(disease_classes))

  num_classes = labels.shape[1]
  
  for i in range(num_classes):
    precision = precision_score(labels[:, i], preds[:, i], zero_division=1)
    recall = recall_score(labels[:, i], preds[:, i], zero_division=1)
    f1 = f1_score(labels[:, i], preds[:, i], zero_division=1)

    print(f"{disease_classes[i]} - Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1:.4f}")
    
    cm = confusion_matrix(labels[:, i], preds[:, i])
    
    plt.figure(figsize=(4, 4))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", cbar=False)
    plt.title(f"Confusion Matrix for {disease_classes[i]}")
    plt.xlabel('Predicted')
    plt.ylabel('Actual')
    plt.show()

In [196]:
def plot_correct_incorrect(correct_samples, incorrect_samples, title, n_samples=6):

  plt.figure(figsize=(8, 4))
  plt.suptitle(title, fontsize=16)

  correct_samples = correct_samples[:n_samples]
  incorrect_samples = incorrect_samples[:n_samples]

  for idx, img in enumerate(correct_samples):
    img = np.clip(img, 0, 1)
    plt.subplot(2, n_samples, idx + 1)
    plt.imshow(np.transpose(img, (1, 2, 0)))
    plt.title('Correct')
    plt.axis('off')

  for idx, img in enumerate(incorrect_samples):
    img = np.clip(img, 0, 1)
    plt.subplot(2, n_samples, n_samples + idx + 1)
    plt.imshow(np.transpose(img, (1, 2, 0)))
    plt.title('Incorrect')
    plt.axis('off')

  plt.tight_layout()
  plt.show()

##### Test DenseNet121

In [None]:
disease_classes = ['Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 'Pleural Effusion']
correct_images_per_class = {class_name: [] for class_name in disease_classes}
incorrect_images_per_class = {class_name: [] for class_name in disease_classes}

densenet.eval()
densenet.to(device)
avg_loss = 0.0
correct_preds = 0
total_samples = 0
output_list = []
label_list = []
for images, labels, angles in test_loader:
  images, labels = images.to(device), labels.to(device)
  outputs = densenet(images)
  loss = loss_fn(outputs, labels, angles)
  avg_loss += loss.item()

  preds = (torch.sigmoid(outputs).cpu().detach().numpy() > 0.5)
  
  correct_preds += (preds == labels.cpu().detach().numpy()).sum()
  total_samples += labels.size(0) * labels.size(1)

  output_list.extend(outputs.cpu().detach().numpy())
  label_list.extend(labels.cpu().detach().numpy())
  
  for i in range(len(preds)):
    for class_idx, class_name in enumerate(disease_classes):
      if preds[i][class_idx] == labels[i][class_idx].cpu().numpy():
        if len(correct_images_per_class[class_name]) < 6:
          correct_images_per_class[class_name].append(images[i].cpu().numpy())
        else:
          if len(incorrect_images_per_class[class_name]) < 6:
            incorrect_images_per_class[class_name].append(images[i].cpu().numpy())

print("Test Loss:", avg_loss / len(test_loader))
print("Accuracy:", correct_preds / total_samples)

all_outputs = np.concatenate(output_list, axis=0)
all_labels = np.concatenate(label_list, axis=0)

print()
plot_confusion_matrices(all_outputs, all_labels)

##### Plot samples of correct and incorrect predict in each classes

In [None]:
for class_name in disease_classes:
  plot_correct_incorrect(correct_images_per_class[class_name], incorrect_images_per_class[class_name], f"sample of {class_name}")

##### Test ResNet50

In [None]:
disease_classes = ['Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 'Pleural Effusion']
correct_images_per_class = {class_name: [] for class_name in disease_classes}
incorrect_images_per_class = {class_name: [] for class_name in disease_classes}

resnet.eval()
resnet.to(device)
avg_loss = 0.0
correct_preds = 0
total_samples = 0
output_list = []
label_list = []
for images, labels, angles in test_loader:
  images, labels = images.to(device), labels.to(device)
  outputs = resnet(images)
  loss = loss_fn(outputs, labels, angles)
  avg_loss += loss.item()

  preds = (torch.sigmoid(outputs).cpu().detach().numpy() > 0.5)
  
  correct_preds += (preds == labels.cpu().detach().numpy()).sum()
  total_samples += labels.size(0) * labels.size(1)

  output_list.extend(outputs.cpu().detach().numpy())
  label_list.extend(labels.cpu().detach().numpy())
  
  for i in range(len(preds)):
    for class_idx, class_name in enumerate(disease_classes):
      if preds[i][class_idx] == labels[i][class_idx].cpu().numpy():
        if len(correct_images_per_class[class_name]) < 6:
          correct_images_per_class[class_name].append(images[i].cpu().numpy())
        else:
          if len(incorrect_images_per_class[class_name]) < 6:
            incorrect_images_per_class[class_name].append(images[i].cpu().numpy())

print("Test Loss:", avg_loss / len(test_loader))
print("Accuracy:", correct_preds / total_samples)

all_outputs = np.concatenate(output_list, axis=0)
all_labels = np.concatenate(label_list, axis=0)

print()
plot_confusion_matrices(all_outputs, all_labels)


In [None]:
for class_name in disease_classes:
  plot_correct_incorrect(correct_images_per_class[class_name], incorrect_images_per_class[class_name], f"sample of {class_name}")

##### Computation times

In [226]:
import time
densenet.eval()
resnet.eval()
images, labels, angles = next(iter(test_loader))
test_input = images[0].to(device).unsqueeze(0)

##### DenseNet121

In [None]:
start_time = time.time()
densenet(test_input)
end_time = time.time()
elapsed_time = end_time - start_time
print(f"Inference time: {elapsed_time:.6f} seconds")

##### ResNet50

In [None]:
start_time = time.time()
resnet(test_input)
end_time = time.time()
elapsed_time = end_time - start_time
print(f"Inference time: {elapsed_time:.6f} seconds")