In [None]:
import numpy as np
import pandas as pd
import os
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torchvision.models as models
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from sklearn.metrics import precision_score, recall_score, accuracy_score, classification_report
import wandb
import cv2
from torchvision.models import ViT_B_16_Weights

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')

In [None]:
BASE_PATH = 'path_to_directory'

In [None]:
dataset = pd.read_csv("path_to_directory")

labels = ['Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 
          'Enlarged Cardiomediastinum', 'Fracture', 'Lung Lesion', 'Lung Opacity', 
          'No Finding', 'Pleural Effusion', 'Pleural Other', 'Pneumonia', 'Pneumothorax', 'Support Devices']

In [None]:
train_df, temp_data = train_test_split(dataset, test_size=0.3, random_state=42)
valid_df, test_df = train_test_split(temp_data, test_size=(1/3), random_state=42)

In [None]:
class CLAHETransform:
    def __init__(self, clip_limit=0.34, tile_grid_size=(8, 8)):
        self.clip_limit = clip_limit
        self.tile_grid_size = tile_grid_size
        self.clahe = cv2.createCLAHE(clipLimit=clip_limit, tileGridSize=tile_grid_size)

    def __call__(self, img):
        if isinstance(img, Image.Image):
            img = np.array(img)
        if img.ndim == 3:
            lab_img = cv2.cvtColor(img, cv2.COLOR_RGB2LAB)
            l_channel, a_channel, b_channel = cv2.split(lab_img)
            l_channel = self.clahe.apply(l_channel)
            lab_img = cv2.merge((l_channel, a_channel, b_channel))
            img = cv2.cvtColor(lab_img, cv2.COLOR_LAB2RGB)
        else:
            img = self.clahe.apply(img)
        img = np.clip(img, 0, 255)
        return Image.fromarray(img.astype('uint8'))

In [None]:
transform = transforms.Compose([
    CLAHETransform(clip_limit=0.34, tile_grid_size=(8, 8)),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [None]:
class MimicCXR_Dataset(Dataset):
    def __init__(self, img_data, img_path, transform=None):
        self.img_path = img_path
        self.transform = transform
        self.img_data = img_data
        
    def __len__(self):
        return len(self.img_data)
        
    def __getitem__(self, index):
        img_name = self.img_data.iloc[index]['frontal_image']
        if pd.isna(img_name):
            raise ValueError(f'Missing image path for index {index}')
        img_name = os.path.join(self.img_path, str(img_name))
        try:
            image = Image.open(img_name).convert('RGB')
        except FileNotFoundError:
            raise FileNotFoundError(f'Image not found at path: {img_name}')
        label = torch.zeros(len(labels), dtype=torch.float32)
        for i, col in enumerate(labels):
            label[i] = self.img_data.iloc[index][col]
        if self.transform:
            image = self.transform(image)
        return image, label

In [None]:
class EarlyStopping:
    def __init__(self, patience=5, verbose=False, delta=0, path='path_to_model.pt'):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path

    def __call__(self, val_loss, model):
        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

In [None]:
class DenseNetModel(nn.Module):
    def __init__(self):
        super(DenseNetModel, self).__init__()
        self.base_model = models.densenet121(weights=models.DenseNet121_Weights.DEFAULT)
        num_ftrs = self.base_model.classifier.in_features
        self.base_model.classifier = nn.Linear(num_ftrs, len(labels))
    
    def forward(self, x):
        x = self.base_model(x)
        return torch.sigmoid(x)

class ResNetModel(nn.Module):
    def __init__(self):
        super(ResNetModel, self).__init__()
        self.base_model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
        num_ftrs = self.base_model.fc.in_features
        self.base_model.fc = nn.Linear(num_ftrs, len(labels))
    
    def forward(self, x):
        x = self.base_model(x)
        return torch.sigmoid(x)

class EfficientNetModel(nn.Module):
    def __init__(self):
        super(EfficientNetModel, self).__init__()
        self.base_model = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.DEFAULT)
        num_ftrs = self.base_model.classifier[1].in_features
        self.base_model.classifier[1] = nn.Linear(num_ftrs, len(labels))
    
    def forward(self, x):
        x = self.base_model(x)
        return torch.sigmoid(x)

class InceptionV3Model(nn.Module):
    def __init__(self):
        super(InceptionV3Model, self).__init__()
        self.base_model = models.inception_v3(weights=models.Inception_V3_Weights.DEFAULT)
        num_ftrs = self.base_model.fc.in_features
        self.base_model.fc = nn.Linear(num_ftrs, len(labels))
        self.base_model.aux_logits = False  # Disable auxiliary output
    
    def forward(self, x):
        # Inception v3 expects (299,299) sized images
        if x.size()[2:] != (299, 299):
            x = nn.functional.interpolate(x, size=(299, 299), mode='bilinear', align_corners=False)
        x = self.base_model(x)
        return torch.sigmoid(x)
    
    def forward(self, x):
        # Inception v3 expects (299,299) sized images
        if x.size()[2:] != (299, 299):
            x = nn.functional.interpolate(x, size=(299, 299), mode='bilinear', align_corners=False)
        x = self.base_model(x)
        if isinstance(x, tuple):
            x = x[0]  # In training, inception returns (output, aux_output)
        return torch.sigmoid(x)

class VGG16Model(nn.Module):
    def __init__(self):
        super(VGG16Model, self).__init__()
        self.base_model = models.vgg16(weights=models.VGG16_Weights.DEFAULT)
        num_ftrs = self.base_model.classifier[6].in_features
        self.base_model.classifier[6] = nn.Linear(num_ftrs, len(labels))
    
    def forward(self, x):
        x = self.base_model(x)
        return torch.sigmoid(x)

class DenseNet201Model(nn.Module):
    def __init__(self):
        super(DenseNet201Model, self).__init__()
        self.base_model = models.densenet201(weights=models.DenseNet201_Weights.DEFAULT)
        num_ftrs = self.base_model.classifier.in_features
        self.base_model.classifier = nn.Linear(num_ftrs, len(labels))
    
    def forward(self, x):
        x = self.base_model(x)
        return torch.sigmoid(x)

class ResNeXt101Model(nn.Module):
    def __init__(self):
        super(ResNeXt101Model, self).__init__()
        self.base_model = models.resnext101_32x8d(weights=models.ResNeXt101_32X8D_Weights.DEFAULT)
        num_ftrs = self.base_model.fc.in_features
        self.base_model.fc = nn.Linear(num_ftrs, len(labels))
    
    def forward(self, x):
        x = self.base_model(x)
        return torch.sigmoid(x)

class EfficientNetB4Model(nn.Module):
    def __init__(self):
        super(EfficientNetB4Model, self).__init__()
        self.base_model = models.efficientnet_b4(weights=models.EfficientNet_B4_Weights.DEFAULT)
        num_ftrs = self.base_model.classifier[1].in_features
        self.base_model.classifier[1] = nn.Linear(num_ftrs, len(labels))
    
    def forward(self, x):
        x = self.base_model(x)
        return torch.sigmoid(x)

class VisionTransformerModel(nn.Module):
    def __init__(self):
        super(VisionTransformerModel, self).__init__()
        self.base_model = models.vit_b_16(weights=ViT_B_16_Weights.DEFAULT)
        num_ftrs = self.base_model.heads.head.in_features
        self.base_model.heads.head = nn.Linear(num_ftrs, len(labels))
    
    def forward(self, x):
        # ViT expects images of size 224x224
        if x.size()[2:] != (224, 224):
            x = nn.functional.interpolate(x, size=(224, 224), mode='bilinear', align_corners=False)
        x = self.base_model(x)
        return torch.sigmoid(x)

In [None]:
def train_model():
    # Initialize wandb
    wandb.init()
    
    # Access sweep parameters
    config = wandb.config
    
    # Data loaders
    train_dataset = MimicCXR_Dataset(train_df, BASE_PATH, transform)
    valid_dataset = MimicCXR_Dataset(valid_df, BASE_PATH, transform)
    test_dataset = MimicCXR_Dataset(test_df, BASE_PATH, transform)

    train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, drop_last=True)
    valid_loader = DataLoader(valid_dataset, batch_size=config.batch_size, drop_last=True)
    test_loader = DataLoader(test_dataset, batch_size=config.batch_size, drop_last=True)

    # Model setup
    if config.architecture == 'densenet121':
        model = DenseNetModel().to(device)
    elif config.architecture == 'resnet50':
        model = ResNetModel().to(device)
    elif config.architecture == 'efficientnet':
        model = EfficientNetModel().to(device)
    elif config.architecture == 'inceptionv3':
        model = InceptionV3Model().to(device)
    elif config.architecture == 'vgg16':
        model = VGG16Model().to(device)
    elif config.architecture == 'densenet201':
        model = DenseNet201Model().to(device)
    elif config.architecture == 'resnext101':
        model = ResNeXt101Model().to(device)
    elif config.architecture == 'efficientnet_b4':
        model = EfficientNetB4Model().to(device)
    elif config.architecture == 'vit':
        model = VisionTransformerModel().to(device)
    
    criterion = nn.BCEWithLogitsLoss()
    
    # Optimizer setup
    if config.optimizer == 'adam':
        optimizer = optim.Adam(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)
    elif config.optimizer == 'sgd':
        optimizer = optim.SGD(model.parameters(), lr=config.learning_rate, momentum=0.9, weight_decay=config.weight_decay)
    elif config.optimizer == 'rmsprop':
        optimizer = optim.RMSprop(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)
    elif config.optimizer == 'adamw':
        optimizer = optim.AdamW(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)

    # Learning rate scheduler
    if config.lr_scheduler == 'step':
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=config.lr_step_size, gamma=config.lr_gamma)
    elif config.lr_scheduler == 'cosine':
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config.epochs)
    
    # Early stopping
    early_stopping = EarlyStopping(patience=5, verbose=True, path=f'hyperparameter_best_models/best_model_{wandb.run.id}.pt')
    
    # Log model architecture and hyperparameters
    wandb.watch(model, log="all")
    wandb.log({
        "batch_size": config.batch_size,
        "architecture": config.architecture,
        "optimizer": config.optimizer,
        "learning_rate": config.learning_rate,
        "weight_decay": config.weight_decay,
        "lr_scheduler": config.lr_scheduler,
    })
    
    # Training loop
    for epoch in range(1, config.epochs + 1):
        model.train()
        running_loss = 0.0
        for batch_idx, (data, target) in enumerate(train_loader):
            if data.size(0) < 2:  # Skip batches smaller than 2
                continue
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            outputs = model(data)
            loss = criterion(outputs, target)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        
        train_loss = running_loss / len(train_loader)
        
        # Validation
        model.eval()
        val_loss = 0.0
        val_outputs = []
        val_targets = []
        with torch.no_grad():
            for data, target in valid_loader:
                if data.size(0) < 2:  # Skip batches smaller than 2
                    continue
                data, target = data.to(device), target.to(device)
                outputs = model(data)
                loss = criterion(outputs, target)
                val_loss += loss.item()
                val_outputs.append(outputs.cpu().numpy())
                val_targets.append(target.cpu().numpy())
        
        val_loss = val_loss / len(valid_loader)
        
        # Calculate validation metrics
        val_outputs = np.concatenate(val_outputs)
        val_targets = np.concatenate(val_targets)
        val_preds = (val_outputs > 0.5).astype(int)
        val_precision = precision_score(val_targets, val_preds, average='macro', zero_division=1)
        val_recall = recall_score(val_targets, val_preds, average='macro', zero_division=1)
        val_accuracy = accuracy_score(val_targets, val_preds)
        
        # Update learning rate
        scheduler.step()
        
        # Log metrics to wandb
        wandb.log({
            "epoch": epoch,
            "train_loss": train_loss,
            "val_loss": val_loss,
            "val_precision": val_precision,
            "val_recall": val_recall,
            "val_accuracy": val_accuracy,
            "learning_rate": scheduler.get_last_lr()[0],
        })
        
        print(f'Epoch {epoch}, Training Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}, '
              f'Validation Precision: {val_precision:.4f}, Validation Recall: {val_recall:.4f}, '
              f'Validation Accuracy: {val_accuracy:.4f}')
        
        # Early stopping check
        early_stopping(val_loss, model)
        if early_stopping.early_stop:
            print("Early stopping triggered")
            break
    
    # Load the best model
    model.load_state_dict(torch.load(f'hyperparameter_best_models/best_model_{wandb.run.id}.pt'))
    
    # Test the model
    model.eval()
    all_outputs = []
    all_targets = []

    with torch.no_grad():
        for data, target in test_loader:
            if data.size(0) < 2:  # Skip batches smaller than 2
                continue
            data, target = data.to(device), target.to(device)
            outputs = model(data)
            all_outputs.append(outputs.cpu().numpy())
            all_targets.append(target.cpu().numpy())

    all_outputs = np.concatenate(all_outputs)
    all_targets = np.concatenate(all_targets)

    # Calculate and log test metrics
    test_preds = (all_outputs > 0.5).astype(int)
    test_precision = precision_score(all_targets, test_preds, average='macro', zero_division=1)
    test_recall = recall_score(all_targets, test_preds, average='macro', zero_division=1)
    test_accuracy = accuracy_score(all_targets, test_preds)

    wandb.log({
        "test_precision": test_precision,
        "test_recall": test_recall,
        "test_accuracy": test_accuracy
    })

    # Log classification report to wandb
    report = classification_report(all_targets, test_preds, target_names=labels, output_dict=True, zero_division=1)
    wandb.log({"classification_report": wandb.Table(dataframe=pd.DataFrame(report).transpose())})

    # Log confusion matrix
    wandb.log({"confusion_matrix": wandb.plot.confusion_matrix(
        probs=None,
        y_true=all_targets.argmax(axis=1),
        preds=test_preds.argmax(axis=1),
        class_names=labels
    )})

    wandb.finish()

# Sweep configuration
sweep_configuration = {
    'method': 'bayes',
    'name': 'mimic-cxr-sweep',
    'metric': {'goal': 'maximize', 'name': 'val_accuracy'},
    'parameters': 
    {
        'batch_size': {'values': [16, 32, 64]},
        'epochs': {'values': [15, 30, 50]},
        'learning_rate': {'max': 0.01, 'min': 0.0001, 'distribution': 'log_uniform'},
        'weight_decay': {'max': 0.1, 'min': 1e-5, 'distribution': 'log_uniform'},
        'optimizer': {'values': ['adam', 'sgd', 'rmsprop', 'adamw']},
        'architecture': {'values': ['densenet121', 'resnet50', 'efficientnet', 'inceptionv3', 'vgg16', 
                                    'densenet201', 'resnext101', 'efficientnet_b4', 'vit']},
        'lr_scheduler': {'values': ['step', 'cosine']},
        'lr_step_size': {'values': [5, 10, 15]},
        'lr_gamma': {'min': 0.1, 'max': 0.5, 'distribution': 'uniform'}
    }
}

# Initialize the sweep
sweep_id = wandb.sweep(sweep=sweep_configuration, project='name')

# Run the sweep
wandb.agent(sweep_id, function=train_model, count=20)