In [None]:
import numpy as np
import pandas as pd
import os
from PIL import Image, ImageOps
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
from sklearn.model_selection import train_test_split
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score, precision_score, roc_curve, f1_score, auc, recall_score, accuracy_score, classification_report, multilabel_confusion_matrix
from transformers import ViTModel, ViTFeatureExtractor
from torch.cuda.amp import GradScaler, autocast
import json
import seaborn as sns
import cv2
from skimage import exposure
from skimage.metrics import structural_similarity as ssim

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')

## Implementation (CheXclusion)

In [None]:
BASE_PATH = "path_to_directory"
dataset = pd.read_csv("path_to_directory")
len(dataset)
# # Get the first 100 records
# dataset = dataset.head(100)

# # Display the length of the subset to confirm
# print(len(dataset))

In [None]:
main_dirs = ['p10', 'p11', 'p12', 'p13', 'p14', 'p15', 'p16', 'p17', 'p18', 'p19']

# Function to check if the path format is correct
def check_path_format(row, image_column):
    image_path = os.path.normpath(row[image_column])
    
    # Split the path components
    parts = image_path.split(os.sep)  # os.sep is platform-specific separator
    
    if len(parts) < 4:
        return False
    
    # Check the main directory is valid
    main_dir = parts[0]
    if main_dir not in main_dirs:
        return False
    
    subject_id = 'p' + str(row['subject_id'])
    if parts[1] != subject_id:
        return False
    
    study_id = 's' + str(row['study_id'])
    if parts[2] != study_id:
        return False
    
    # Check if file exists at the expected location
    full_path = os.path.join(BASE_PATH, *parts)
    if not os.path.exists(full_path):
        return False
        
    return True

dataset['frontal_image_valid'] = dataset.apply(lambda row: check_path_format(row, 'frontal_image'), axis=1)

# Filter out rows with mismatches
invalid_rows = dataset[(~dataset['frontal_image_valid'])]

# Save invalid rows to a CSV file for inspection
# invalid_rows.to_csv('invalid_image_paths.csv', index=False)
print(f"Number of invalid rows: {len(invalid_rows)}")
# print("Invalid rows saved to 'invalid_image_paths.csv'")

In [None]:
labels = ['Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 'Enlarged Cardiomediastinum', 'Fracture', 'Lung Lesion', 'Lung Opacity', 'No Finding', 'Pleural Effusion', 'Pleural Other', 'Pneumonia', 'Pneumothorax', 'Support Devices']

In [None]:
train_df, temp_data = train_test_split(dataset, test_size=0.2, random_state=42)  # 80% training
valid_df, test_df = train_test_split(temp_data, test_size=0.5, random_state=42)  # 10% validation, 10% testing

# Display the lengths of the datasets
print(f'Number of samples in training set: {len(train_df)}')
print(f'Number of samples in validation set: {len(valid_df)}')
print(f'Number of samples in test set: {len(test_df)}')

In [None]:
def check_for_leakage(df1, df2, study_col):
    df1_studies_unique = set(df1[study_col].unique().tolist())
    df2_studies_unique = set(df2[study_col].unique().tolist())
    
    # Check for any common studies in both datasets
    studies_in_both_groups = df1_studies_unique.intersection(df2_studies_unique)
    leakage = len(studies_in_both_groups) >= 1
    return leakage

# Check for leakage using the 'study_id' column
print("Leakage between train and test: {}".format(check_for_leakage(train_df, test_df, 'study_id')))
print("Leakage between valid and test: {}".format(check_for_leakage(valid_df, test_df, 'study_id')))

In [None]:
def extract_image_paths(df):
    image_paths = {
        "frontal_images": df['frontal_image'].dropna().tolist(),
        "lateral_images": df['lateral_image'].dropna().tolist(),
    }
    return image_paths

# Extract image paths for each dataset
train_image_paths = extract_image_paths(train_df)
valid_image_paths = extract_image_paths(valid_df)
test_image_paths = extract_image_paths(test_df)

# Save each split to a JSON file
def save_to_json(data, filename):
    with open(filename, 'w') as f:
        json.dump(data, f, indent=4)

# Save the image paths
save_to_json(train_image_paths, 'train_image_paths.json')
save_to_json(valid_image_paths, 'valid_image_paths.json')
save_to_json(test_image_paths, 'test_image_paths.json')

print("Image paths saved successfully in JSON format!")

In [None]:
pil_img = Image.open("path_to_directory")
img = np.asarray(pil_img).astype('uint8')
print(img.max())
print(img.shape)

In [None]:
class CLAHETransform:
    def __init__(self, clip_limit=0.10, tile_grid_size=(8, 8)):
        self.clip_limit = clip_limit
        self.tile_grid_size = tile_grid_size
        self.clahe = cv2.createCLAHE(clipLimit=clip_limit, tileGridSize=tile_grid_size)

    def __call__(self, img):
        # Convert PIL image to numpy array if necessary
        if isinstance(img, Image.Image):
            img = np.array(img)
        
        # If the image is RGB (3 channels), convert it to LAB color space
        if img.ndim == 3:
            lab_img = cv2.cvtColor(img, cv2.COLOR_RGB2LAB)
            l_channel, a_channel, b_channel = cv2.split(lab_img)

            # Apply CLAHE only to the L (lightness) channel
            l_channel = self.clahe.apply(l_channel)

            # Merge back and convert to RGB
            lab_img = cv2.merge((l_channel, a_channel, b_channel))
            img = cv2.cvtColor(lab_img, cv2.COLOR_LAB2RGB)
        else:
            # If the image is grayscale, apply CLAHE directly
            img = self.clahe.apply(img)

        # Convert back to PIL Image before returning
        return Image.fromarray(img.astype('uint8'))

In [None]:
class MimicCXR_Dataset(Dataset):
    def __init__(self, img_data, img_path, labels, transform=None):
        self.img_data = img_data
        self.img_path = img_path
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.img_data)

    def __getitem__(self, index):
        # Get frontal image path
        frontal_img_name = self.img_data.iloc[index]['frontal_image']
        frontal_img_name = os.path.join(self.img_path, str(frontal_img_name)) if pd.notna(frontal_img_name) else None
        
        # Get lateral image path
        lateral_img_name = self.img_data.iloc[index]['lateral_image']  # Assuming you have this column in the CSV
        lateral_img_name = os.path.join(self.img_path, str(lateral_img_name)) if pd.notna(lateral_img_name) else None
        
        # Check if the paths are valid
        if frontal_img_name is None or not os.path.exists(frontal_img_name):
            raise FileNotFoundError(f'Frontal image not found at path: {frontal_img_name}')
        if lateral_img_name is None or not os.path.exists(lateral_img_name):
            raise FileNotFoundError(f'Lateral image not found at path: {lateral_img_name}')
        
        # Open both images
        frontal_image = Image.open(frontal_img_name).convert('RGB')  # Convert to 3-channel RGB
        lateral_image = Image.open(lateral_img_name).convert('RGB')  # Convert to 3-channel RGB
        
        # Fetch label, handle NaNs before converting to tensor
        label = self.img_data.iloc[index][self.labels].fillna(0).values  # Replace NaNs with 0
        label = torch.tensor(label, dtype=torch.float32)
        
        # Apply transformations if specified
        if self.transform:
            frontal_image = self.transform(frontal_image)
            lateral_image = self.transform(lateral_image)
        
        # Return both images (frontal and lateral) and the label
        return (frontal_image, lateral_image), label

        
def calculate_mean_std(self):
    # Temporary transform to convert images to tensors
    # Initialize variables to store sums
    mean = 0.0
    std = 0.0
    num_images = 0
    
    # Iterate over all images
    for img_file in self.img_data['frontal_image'].to_list():
        img_path = os.path.join(self.img_path, img_file)
        pil_img = Image.open(img_path)
        img = np.asarray(pil_img).astype('uint8')
        mean += np.mean(img, axis=(0, 1))
        std += np.std(img, axis=(0, 1))
        num_images += 1
    
    # Calculate the mean and std across the dataset
    mean /= num_images
    std /= num_images
    # print("HELLOOOOOOOOOOOOO")
    print(f"Calculated Mean: {mean}")
    print(f"Calculated Std: {std}")

    return mean, std

# temp_dataset = MimicCXR_Dataset(dataset, BASE_PATH, labels)
# mean, std = temp_dataset.calculate_mean_std()

In [None]:
# mean=0.47339121
# std= 0.30462474
# mean,std

In [None]:
train_transform = transforms.Compose([
    transforms.Resize(256),
    CLAHETransform(clip_limit=0.34, tile_grid_size=(8, 8)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.CenterCrop(256),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

validation_transform = transforms.Compose([
    transforms.Resize(256),
    CLAHETransform(clip_limit=0.35, tile_grid_size=(8, 8)),
    transforms.CenterCrop(256),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

## Display

In [None]:
# Load the frontal and lateral images (3-channel RGB)
frontal_image_path = "path_to_directory"
lateral_image_path = "path_to_directory"

# Open and resize both images
frontal_image = Image.open(frontal_image_path).convert("RGB")
frontal_image = frontal_image.resize((320, 320), Image.LANCZOS)

lateral_image = Image.open(lateral_image_path).convert("RGB")
lateral_image = lateral_image.resize((320, 320), Image.LANCZOS)

# Save both original images for visualization
frontal_image.save('frontal_image_original.png')
lateral_image.save('lateral_image_original.png')

# Apply the transform directly to both PIL images
frontal_transformed_image = validation_transform(frontal_image)
lateral_transformed_image = validation_transform(lateral_image)

# Convert back to numpy for visualization
frontal_image_numpy = frontal_transformed_image.numpy().transpose((1, 2, 0))
lateral_image_numpy = lateral_transformed_image.numpy().transpose((1, 2, 0))

# Unnormalize the images (reverse the normalization)
mean = np.array([0.485, 0.456, 0.406])  # ImageNet mean for RGB
std = np.array([0.229, 0.224, 0.225])   # ImageNet std for RGB

frontal_unnormalized_image = frontal_image_numpy * std + mean
frontal_unnormalized_image = np.clip(frontal_unnormalized_image, 0, 1)
frontal_transformed_image_pil = Image.fromarray((frontal_unnormalized_image * 255).astype(np.uint8))

lateral_unnormalized_image = lateral_image_numpy * std + mean
lateral_unnormalized_image = np.clip(lateral_unnormalized_image, 0, 1)
lateral_transformed_image_pil = Image.fromarray((lateral_unnormalized_image * 255).astype(np.uint8))

# Save enhanced images
frontal_transformed_image_pil.save('frontal_enhanced_image.png')
lateral_transformed_image_pil.save('lateral_enhanced_image.png')

# Plot the images
fig, axes = plt.subplots(2, 2, figsize=(12, 12))

# Display the original frontal and lateral images
axes[0, 0].imshow(frontal_image)
axes[0, 0].set_title("Original Frontal Image")

axes[0, 1].imshow(lateral_image)
axes[0, 1].set_title("Original Lateral Image")

# Display the enhanced frontal and lateral images
axes[1, 0].imshow(frontal_transformed_image_pil)
axes[1, 0].set_title("Enhanced Frontal Image")

axes[1, 1].imshow(lateral_transformed_image_pil)
axes[1, 1].set_title("Enhanced Lateral Image")

# Save the comparison plot
plt.savefig('frontal_lateral_comparison.png')
plt.show()

# Dataloader

In [None]:
# batch_size = 50

# # Create dataset instances
# train_dataset = MimicCXR_Dataset(train_df, BASE_PATH, transform)
# valid_dataset = MimicCXR_Dataset(valid_df, BASE_PATH, transform)
# test_dataset = MimicCXR_Dataset(test_df, BASE_PATH, transform)

# # Create data loaders
# train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)
# valid_loader = DataLoader(valid_dataset, batch_size=batch_size, pin_memory=True)
# test_loader = DataLoader(test_dataset, batch_size=batch_size, pin_memory=True)

# # Optional: Print data loader lengths
# print(f'Train Loader Size: {len(train_loader)}')
# print(f'Validation Loader Size: {len(valid_loader)}')
# print(f'Test Loader Size: {len(test_loader)}')

# Parameters
batch_size = 50
N_LABELS = 14
start_epoch = 0
num_epochs = 64  # Number of epochs to train for

train_dataset = MimicCXR_Dataset(train_df, BASE_PATH, labels, transform=train_transform)
valid_dataset = MimicCXR_Dataset(valid_df, BASE_PATH, labels, transform=validation_transform)
test_dataset = MimicCXR_Dataset(test_df, BASE_PATH,labels, transform=validation_transform)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Optional: Print data loader lengths
print(f'Train Loader Size: {len(train_loader)}')
print(f'Validation Loader Size: {len(valid_loader)}')
print(f'Test Loader Size: {len(test_loader)}')

## Model

In [None]:
class DenseNet121(nn.Module):
    def __init__(self):
        super(DenseNet121, self).__init__()
        self.densenet121 = models.densenet121(weights=models.DenseNet121_Weights.IMAGENET1K_V1)
        num_ftrs = self.densenet121.classifier.in_features
        self.densenet121.classifier = nn.Identity()  # No classifier yet, only features

    def forward(self, x):
        return self.densenet121(x)

class ResNet50(nn.Module):
    def __init__(self):
        super(ResNet50, self).__init__()
        self.resnet50 = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
        num_ftrs = self.resnet50.fc.in_features
        self.resnet50.fc = nn.Identity()  # No classifier yet, only features

    def forward(self, x):
        return self.resnet50(x)

class CombinedModel(nn.Module):
    def __init__(self, out_size):
        super(CombinedModel, self).__init__()
        # Instantiate DenseNet121 and ResNet50 for frontal and lateral views
        self.densenet_frontal = DenseNet121()
        self.resnet_frontal = ResNet50()
        
        self.densenet_lateral = DenseNet121()
        self.resnet_lateral = ResNet50()
        
        # The combined feature size
        frontal_feature_size = 1024 + 2048  # Assuming DenseNet121 outputs 1024 and ResNet50 outputs 2048
        lateral_feature_size = 1024 + 2048
        
        combined_feature_size = frontal_feature_size + lateral_feature_size
        
        # Final classifier layer
        self.classifier = nn.Sequential(
            nn.Linear(combined_feature_size, out_size),
            nn.Sigmoid()  # Assuming binary classification for multi-label
        )
        
    def forward(self, x_frontal, x_lateral):
        # Extract features from DenseNet and ResNet for both views
        frontal_features = torch.cat([self.densenet_frontal(x_frontal), self.resnet_frontal(x_frontal)], dim=1)
        lateral_features = torch.cat([self.densenet_lateral(x_lateral), self.resnet_lateral(x_lateral)], dim=1)
        
        # Combine frontal and lateral features
        combined_features = torch.cat([frontal_features, lateral_features], dim=1)
        
        # Final output through classifier
        out = self.classifier(combined_features)
        
        return out

In [None]:
N_LABELS = 14
model = CombinedModel(N_LABELS).to(device)

# Define loss and optimizer
criterion = nn.BCELoss().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0005)

## Training

In [None]:
class EarlyStopping:
    def __init__(self, patience=15, verbose=False):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif val_loss > self.best_loss:
            self.counter += 1
            if self.verbose:
                print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = val_loss
            self.counter = 0
            
def train_model(model, train_loader, valid_loader, criterion, optimizer, num_epochs=50, threshold=0.5):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)
    
    train_losses, valid_losses = [], []
    train_f1s, valid_f1s = [], []
    train_precisions, valid_precisions = [], []
    train_recalls, valid_recalls = [], []
    train_accuracies, valid_accuracies = [], []

    epoch_data = []

    early_stopping = EarlyStopping(patience=15, verbose=True)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, 
        mode='min', 
        factor=0.5,
        patience=3,
        min_lr=1e-6,
        threshold=0.0001,
        threshold_mode='abs'
    )

    best_loss = float('inf')

    for epoch in range(num_epochs):
        # Training phase
        model.train()
        running_loss = 0.0
        all_labels, all_predictions = [], []

        train_progress_bar = tqdm(train_loader, desc=f'Epoch {epoch + 1}/{num_epochs} [Train]', leave=True)
        for (frontal_images, lateral_images), labels in train_progress_bar:
            frontal_images, lateral_images, labels = frontal_images.to(device), lateral_images.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(frontal_images, lateral_images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * frontal_images.size(0)
            predictions = (outputs > threshold).float()
            all_labels.extend(labels.cpu().numpy())
            all_predictions.extend(predictions.cpu().numpy())

        epoch_loss = running_loss / len(train_loader.dataset)
        train_losses.append(epoch_loss)

        train_f1 = f1_score(all_labels, all_predictions, average='weighted', zero_division=0)
        train_precision = precision_score(all_labels, all_predictions, average='weighted', zero_division=0)
        train_recall = recall_score(all_labels, all_predictions, average='weighted', zero_division=0)
        train_accuracy = accuracy_score(all_labels, all_predictions)

        train_f1s.append(train_f1)
        train_precisions.append(train_precision)
        train_recalls.append(train_recall)
        train_accuracies.append(train_accuracy)

        # Validation phase
        model.eval()
        valid_running_loss = 0.0
        all_labels, all_predictions = [], []

        valid_progress_bar = tqdm(valid_loader, desc=f'Epoch {epoch + 1}/{num_epochs} [Valid]', leave=True)
        with torch.no_grad():
            for (frontal_images, lateral_images), labels in valid_progress_bar:
                frontal_images, lateral_images, labels = frontal_images.to(device), lateral_images.to(device), labels.to(device)
                outputs = model(frontal_images, lateral_images)
                loss = criterion(outputs, labels)
                valid_running_loss += loss.item() * frontal_images.size(0)
                predictions = (outputs > threshold).float()
                all_labels.extend(labels.cpu().numpy())
                all_predictions.extend(predictions.cpu().numpy())

        valid_epoch_loss = valid_running_loss / len(valid_loader.dataset)
        valid_losses.append(valid_epoch_loss)

        valid_f1 = f1_score(all_labels, all_predictions, average='weighted', zero_division=0)
        valid_precision = precision_score(all_labels, all_predictions, average='weighted', zero_division=0)
        valid_recall = recall_score(all_labels, all_predictions, average='weighted', zero_division=0)
        valid_accuracy = accuracy_score(all_labels, all_predictions)

        valid_f1s.append(valid_f1)
        valid_precisions.append(valid_precision)
        valid_recalls.append(valid_recall)
        valid_accuracies.append(valid_accuracy)

        print(f'Epoch {epoch + 1}/{num_epochs}')
        print(f'Train - Loss: {epoch_loss:.4f}, F1: {train_f1:.4f}, Precision: {train_precision:.4f}, Recall: {train_recall:.4f}, Accuracy: {train_accuracy:.4f}')
        print(f'Valid - Loss: {valid_epoch_loss:.4f}, F1: {valid_f1:.4f}, Precision: {valid_precision:.4f}, Recall: {valid_recall:.4f}, Accuracy: {valid_accuracy:.4f}')

        epoch_data.append([epoch + 1, epoch_loss, valid_epoch_loss, train_f1, valid_f1, train_precision, valid_precision, train_recall, valid_recall, train_accuracy, valid_accuracy])

        if valid_epoch_loss < best_loss:
            best_loss = valid_epoch_loss
            torch.save(model.state_dict(), 'best_model.pth')
            print("Model saved!")

        early_stopping(valid_epoch_loss)
        scheduler.step(valid_epoch_loss)
        
        if early_stopping.early_stop:
            print("Early stopping")
            break
        
        if (epoch + 1) % 3 == 0:
            plot_metrics(train_losses, valid_losses, train_f1s, valid_f1s, train_precisions, valid_precisions, 
                         train_recalls, valid_recalls, train_accuracies, valid_accuracies)

    # Final evaluation
    model.eval()
    all_labels, all_predictions = [], []
    with torch.no_grad():
        for (frontal_images, lateral_images), labels in valid_loader:
            frontal_images, lateral_images, labels = frontal_images.to(device), lateral_images.to(device), labels.to(device)
            outputs = model(frontal_images, lateral_images)
            predictions = (outputs > threshold).float()
            all_labels.extend(labels.cpu().numpy())
            all_predictions.extend(predictions.cpu().numpy())

    print("\nClassification Report:")
    print(classification_report(all_labels, all_predictions, zero_division=0))

    metrics_df = pd.DataFrame(epoch_data, columns=[
        'Epoch', 'Train Loss', 'Valid Loss', 'Train F1', 'Valid F1', 
        'Train Precision', 'Valid Precision', 'Train Recall', 'Valid Recall', 
        'Train Accuracy', 'Valid Accuracy'
    ])
    metrics_df.to_csv('epoch_metrics.csv', index=False)
    print("Epoch data saved to epoch_metrics.csv")

    return train_losses, valid_losses, train_f1s, valid_f1s, train_precisions, valid_precisions, train_recalls, valid_recalls, train_accuracies, valid_accuracies

def plot_metrics(train_losses, valid_losses, train_f1s, valid_f1s, train_precisions, valid_precisions, 
                 train_recalls, valid_recalls, train_accuracies, valid_accuracies):

    plt.figure(figsize=(10, 5))
    plt.plot(train_losses, label='Train Loss', color='blue')
    plt.plot(valid_losses, label='Validation Loss', color='orange')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss')
    plt.legend()
    plt.grid(True)
    plt.show()

    plt.tight_layout()
    plt.show()

# Example usage
train_losses, valid_losses, train_f1s, valid_f1s, train_precisions, valid_precisions, train_recalls, valid_recalls, train_accuracies, valid_accuracies = train_model(model, train_loader, valid_loader, criterion, optimizer)