# COVID-19 Chest X-Ray Database - Experiment

## CNN Model Implementation

In [2]:
import torch
import wandb
from torchvision import models
from torchvision.models import ResNet50_Weights

In [3]:
NUMBER_OF_CLASSES = 4
IMAGE_SIZE = 299

In [4]:
#TODO: Check project name and other values
# Initialize wandb run | 'wandb login' terminal
wandb_run = wandb.init(
    # set the wandb project where this run will be logged
    project="my-awesome-project",

    # track hyperparameters and run metadata
    config={
        "learning_rate": 0.001,
        "pretrained_model": "RestNet-50",
        "architecture": "CNN",
        "optimizer": "Adam",
        "criterion": "Cross entropy loss",
        "dataset": "COVID-19 Chest X-Ray Database",
        "epochs": 20
    },
)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mdwayne-taylor-ucr[0m ([33muniversity-of-costa-rica[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [5]:
# Define device (CPU or GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [6]:
# Load pre-trained ResNet50, freeze early layers
model = models.resnet50(weights=ResNet50_Weights.DEFAULT)
for param in model.parameters():
    param.requires_grad = False

# Get number of input features from the original FC layer
num_features = model.fc.in_features

# Define new classifier head
model.fc = torch.nn.Sequential(
    torch.nn.Linear(num_features, 128),
    torch.nn.ReLU(inplace=True),
    torch.nn.Linear(128, NUMBER_OF_CLASSES),
)

# Define loss function
criterion = torch.nn.CrossEntropyLoss()

# Define optimizer (replace with your learning rate if needed)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Move model to chosen device
model.to(device)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [7]:
class EarlyStopping(object):
    def __init__(self, patience=5):
        self.patience = patience
        self.best_val_loss = float('inf')
        self.counter = 0

    def __call__(self, epoch, logs):
        val_loss = logs.get('val_loss')
        if val_loss < self.best_val_loss:
            self.best_val_loss = val_loss
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                print(f"Early stopping triggered after {self.patience} epochs with no improvement.")
                return False
        return True

In [8]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
from torch.utils.data.sampler import SubsetRandomSampler, SubsetRandomSampler, SubsetRandomSampler
import numpy as np

# Define the transformations
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),  # Convert grayscale to 3 channels
    transforms.Resize((248, 248)),  # Resize to 299x299
    transforms.ToTensor(),  # Convert PIL image to tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Define the root directory
data_dir = './data/raw'

# Create the dataset
full_dataset = datasets.ImageFolder(root=data_dir, transform=transform)

# Get the labels
labels = np.array([label for _, label in full_dataset])

# Split the dataset into training, validation, and testing sets with stratification
from sklearn.model_selection import train_test_split

train_indices, temp_indices = train_test_split(np.arange(len(full_dataset)), test_size=0.2, stratify=labels)
val_indices, test_indices = train_test_split(temp_indices, test_size=0.5, stratify=labels[temp_indices])

# Create Samplers
train_sampler = SubsetRandomSampler(train_indices)
val_sampler = SubsetRandomSampler(val_indices)
test_sampler = SubsetRandomSampler(test_indices)

# Create DataLoaders
train_loader = DataLoader(full_dataset, batch_size=32, sampler=train_sampler)
val_loader = DataLoader(full_dataset, batch_size=32, sampler=val_sampler)
test_loader = DataLoader(full_dataset, batch_size=32, sampler=test_sampler)



In [None]:

from sklearn.metrics import accuracy_score, precision_score, recall_score, confusion_matrix, roc_curve, roc_auc_score
from sklearn.preprocessing import label_binarize
import torch.nn.functional as F

early_stopping = EarlyStopping(patience=5)

# Training loop
for epoch in range(1):
    # Initialize empty numpy arrays for predictions and targets
    all_predictions = np.array([])
    all_targets = np.array([])
    
    # Training phase
    for data, target in train_loader:
        # Move data and target to device
        data, target = data.to(device), target.to(device)

        # Forward pass, calculate loss
        output = model(data)
        loss = criterion(output, target)

        # Backpropagation, update weights
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # Validation phase
    val_loss = 0.0
    all_outputs_proba = np.empty((0, NUMBER_OF_CLASSES))

    with torch.no_grad():
        for data, target in val_loader:
            # Move data and target to device
            data, target = data.to(device), target.to(device)

            # Forward pass
            output = model(data)

            # Calculate validation loss
            val_loss += criterion(output, target).item()
            # Apply softmax to get probabilities
            probabilities = F.softmax(output, dim=1)
            probabilities = probabilities.unsqueeze(0) if probabilities.ndim == 1 else probabilities

            # Concatenate predictions and targets for the current batch
            all_predictions = np.concatenate((all_predictions, output.argmax(dim=1).cpu().numpy()))
            all_targets = np.concatenate((all_targets, target.cpu().numpy()))
            all_outputs_proba = np.concatenate((all_outputs_proba, probabilities.cpu().numpy()), axis=0)

        # Calculate average validation loss
        val_loss /= len(val_loader)

        # Calculate metrics for the current epoch
        accuracy = accuracy_score(all_targets, all_predictions)
        precision = precision_score(all_targets, all_predictions, average='weighted')
        recall = recall_score(all_targets, all_predictions, average='weighted')
        # confusion = confusion_matrix(all_targets, all_predictions, labels=range(NUMBER_OF_CLASSES))
        confusion = wandb.plot.confusion_matrix(y_true=all_targets, preds=all_predictions, class_names=["x", "y", "z", "w"])
        print(confusion)
        # Binarize the targets for ROC curve calculation
        all_targets_binary = label_binarize(all_targets, classes=range(NUMBER_OF_CLASSES))

        # Initialize dictionaries to hold the TPR and FPR for each class
        fpr = dict()
        tpr = dict()
        roc_auc = dict()

        # Calculate ROC curve and ROC area for each class
        for i in range(NUMBER_OF_CLASSES):
            fpr[i], tpr[i], _ = roc_curve(all_targets_binary[:, i], all_outputs_proba[:, i])
            roc_auc[i] = roc_auc_score(all_targets_binary[:, i], all_outputs_proba[:, i])


        # Log validation and metrics for the current epoch
        if wandb_run is not None:
            wandb_run.log({
                "val_loss": val_loss,
                "accuracy": accuracy,
                "precision": precision,
                "recall": recall,
                # "confusion_matrix": confusion,
                "fpr": fpr,
                "tpr": tpr,
                "roc_auc": roc_auc
            })

    if not early_stopping(epoch, logs={'val_loss': val_loss}): break

state_dict_path = "./models/cnn-raw.pth"
torch.save(model.state_dict(), state_dict_path)
# Finish Wandb run
if wandb_run is not None:
    wandb.finish()

: 

### Raw Images

### Bilateral Filtered Images

## Multilayer Perceptron