In [1]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC
from sklearn.linear_model import LogisticRegression
import cv2
from tqdm import tqdm
import random
from scipy.ndimage import gaussian_filter
import pydicom
import numpy as np
from skimage import exposure
from skimage.transform import resize


In [2]:
class NSCLCPatchDataset(Dataset):
    def __init__(self, root_dir, patient_ids, magnification='20X', transform=None, patch_size=500, 
                 sub_patch_size=400, augment=True):
        self.root_dir = root_dir
        self.patient_ids = patient_ids
        self.magnification = magnification
        self.transform = transform
        self.patch_size = patch_size
        self.sub_patch_size = sub_patch_size
        self.augment = augment
        
        # For NSCLC dataset, we have two classes: ADC (0) and SCC (1)
        # We'll need to determine this from the metadata or a separate file
        self.class_mapping = {
            'ADC': 0,
            'SCC': 1
        }
        
        # Load patient data and labels
        self.patches = []
        self.labels = []
        self.image_indices = []  # To track which image each patch belongs to
        
        self._load_patches()
    
    def _load_patches(self):
        print(f"Loading patches for {len(self.patient_ids)} patients...")
        
        for i, patient_id in enumerate(tqdm(self.patient_ids)):
            patient_dir = os.path.join(self.root_dir, patient_id)
            
            # Find the study folder (there's usually one per patient)
            study_folders = [f for f in os.listdir(patient_dir) if os.path.isdir(os.path.join(patient_dir, f))]
            
            if not study_folders:
                continue
                
            study_dir = os.path.join(patient_dir, study_folders[0])
            
            # Find the folder with CT images (usually the one with many .dcm files)
            ct_folders = [f for f in os.listdir(study_dir) if os.path.isdir(os.path.join(study_dir, f))]
            
            if not ct_folders:
                continue
                
            # Typically the first folder contains the CT images
            ct_dir = os.path.join(study_dir, ct_folders[0])
            
            # Get all DICOM files
            dcm_files = [f for f in os.listdir(ct_dir) if f.endswith('.dcm')]
            
            if not dcm_files:
                continue
                
            # Load DICOM files and extract patches
            # For simplicity, we'll use a subset of slices
            slice_step = max(1, len(dcm_files) // 20)  # Take about 20 slices per patient
            
            # Determine patient label (ADC or SCC)
            # This would typically come from clinical data
            # For now, we'll use a placeholder approach
            if 'ADC' in patient_id:
                label = self.class_mapping['ADC']
            elif 'SCC' in patient_id:
                label = self.class_mapping['SCC']
            else:
                # If label can't be determined, assign randomly for now
                # In a real implementation, you'd want to read this from metadata
                label = random.randint(0, 1)
            
            for j, dcm_file in enumerate(sorted(dcm_files)[::slice_step]):
                try:
                    # Load DICOM file
                    ds = pydicom.dcmread(os.path.join(ct_dir, dcm_file))
                    if 'PixelData' not in ds:
                        continue 
                    
                    # Convert to numpy array
                    img = ds.pixel_array
                    
                    # Handle different image dimensions
                    if len(img.shape) > 2:
                        # For 3D or 4D images, handle differently
                        if len(img.shape) == 3:
                            img = img[img.shape[0]//2]  # Take middle slice
                        # If it's RGB or has channels
                        elif len(img.shape) == 3 and img.shape[2] <= 4:
                            img = exposure.rescale_intensity(img, out_range=(0, 255)).astype(np.uint8)
                        else:
                            # Skip images with unexpected dimensions
                            continue
                    else:
                        # For 2D grayscale images
                        img = exposure.rescale_intensity(img, out_range=(0, 255)).astype(np.uint8)
                    
                    # Now check if image is large enough for patch extraction
                    if img.shape[0] < self.patch_size or img.shape[1] < self.patch_size:
                        # Resize if needed - make sure to specify correct output shape
                        new_size = (self.patch_size, self.patch_size, 3)  # Include channel dimension
                        img = resize(img, new_size, preserve_range=True).astype(np.uint8)
                    
                    # Extract patches from this slice
                    num_patches_per_slice = 5
                    
                    for k in range(num_patches_per_slice):
                        # Check if image is large enough for patch extraction
                        if img.shape[0] < self.patch_size or img.shape[1] < self.patch_size:
                            # Resize if needed
                            scale = max(self.patch_size / img.shape[0], self.patch_size / img.shape[1])
                            new_size = (int(img.shape[0] * scale), int(img.shape[1] * scale), 3)  # Include channel dimension
                            img = resize(img, new_size, preserve_range=True).astype(np.uint8)
                        
                        # Random crop to extract patch
                        if img.shape[0] > self.patch_size and img.shape[1] > self.patch_size:
                            x = random.randint(0, img.shape[1] - self.patch_size)
                            y = random.randint(0, img.shape[0] - self.patch_size)
                            patch = img[y:y+self.patch_size, x:x+self.patch_size]
                        else:
                            # If image is too small, pad it
                            patch = np.zeros((self.patch_size, self.patch_size, 3), dtype=np.uint8)
                            patch[:min(img.shape[0], self.patch_size), :min(img.shape[1], self.patch_size)] = img[:min(img.shape[0], self.patch_size), :min(img.shape[1], self.patch_size)]
                        
                        # Check if patch has enough tissue (simplified)
                        # In a real implementation, you'd want a more sophisticated method
                        if np.mean(patch) > 10:  # Simple threshold to check if patch has content
                            self.patches.append(patch)
                            self.labels.append(label)
                            self.image_indices.append(i)
                
                except Exception as e:
                    print(f"Error processing {dcm_file}: {e}")
    
    def __len__(self):
        return len(self.patches)
    
    def __getitem__(self, idx):
        patch = self.patches[idx]
        label = self.labels[idx]
        image_idx = self.image_indices[idx]
        
        # Convert to PIL Image for transformations
        patch = Image.fromarray(patch)
        
        # Ensure image is in RGB mode
        if patch.mode != 'RGB':
            patch = patch.convert('RGB')
        
        if self.augment:
            # 1. Random crop (sub-patch selection)
            i, j = random.randint(0, self.patch_size - self.sub_patch_size), random.randint(0, self.patch_size - self.sub_patch_size)
            patch = transforms.functional.crop(patch, i, j, self.sub_patch_size, self.sub_patch_size)
            
            # 2. Random rotation and mirroring
            if random.random() > 0.5:
                patch = transforms.functional.hflip(patch)
            rotation_angle = random.choice([0, 90, 180, 270])
            patch = transforms.functional.rotate(patch, rotation_angle)
        
        if self.transform:
            patch = self.transform(patch)
            
        return patch, label, image_idx



In [3]:
class PatchCNN(nn.Module):
    def __init__(self, num_classes=2):
        super(PatchCNN, self).__init__()
        
        # First convolutional block
        self.conv1 = nn.Conv2d(3, 96, kernel_size=9, stride=3, padding=0)
        self.bn1 = nn.BatchNorm2d(96)  # Using BatchNorm instead of LRN
        self.relu1 = nn.ReLU(inplace=True)
        self.pool1 = nn.MaxPool2d(kernel_size=3, stride=2)
        
        # Second convolutional block
        self.conv2 = nn.Conv2d(96, 256, kernel_size=5, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(256)  # Using BatchNorm instead of LRN
        self.relu2 = nn.ReLU(inplace=True)
        self.pool2 = nn.MaxPool2d(kernel_size=3, stride=2)
        
        # Third convolutional block
        self.conv3 = nn.Conv2d(256, 384, kernel_size=3, stride=1, padding=1)
        self.relu3 = nn.ReLU(inplace=True)
        
        # Fourth convolutional block
        self.conv4 = nn.Conv2d(384, 384, kernel_size=3, stride=1, padding=1)
        self.relu4 = nn.ReLU(inplace=True)
        
        # Fifth convolutional block
        self.conv5 = nn.Conv2d(384, 256, kernel_size=3, stride=1, padding=1)
        self.relu5 = nn.ReLU(inplace=True)
        self.pool5 = nn.MaxPool2d(kernel_size=3, stride=2)

        self.adaptive_pool = nn.AdaptiveAvgPool2d((15, 15))
        
        # Calculate the output size of the last convolutional layer
        with torch.no_grad():
            dummy_input = torch.zeros(1, 3, 400, 400)
            x = self.pool1(self.relu1(self.bn1(self.conv1(dummy_input))))
            x = self.pool2(self.relu2(self.bn2(self.conv2(x))))
            x = self.relu3(self.conv3(x))
            x = self.relu4(self.conv4(x))
            x = self.pool5(self.relu5(self.conv5(x)))
            x = self.adaptive_pool(x)
            fc_input_size = x.view(1, -1).size(1)
            print(f"Feature map size after convolutions: {fc_input_size}")
        
        # Create classifier with dynamically calculated input size
        self.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(fc_input_size, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_classes)
        )
        
    def forward(self, x):
        # Feature extraction
        x = self.pool1(self.relu1(self.bn1(self.conv1(x))))
        x = self.pool2(self.relu2(self.bn2(self.conv2(x))))
        x = self.relu3(self.conv3(x))
        x = self.relu4(self.conv4(x))
        x = self.pool5(self.relu5(self.conv5(x)))

        x = self.adaptive_pool(x)
        
        # Flatten
        x = x.view(x.size(0), -1)
        
        # Classification
        x = self.classifier(x)
        
        return x


In [4]:
class EMBasedPatchSelection:
    def __init__(self, model_20x, model_5x, p1=30, p2=50):
        """
        Args:
            model_20x: CNN model trained on 20X magnification
            model_5x: CNN model trained on 5X magnification
            p1: P1-th percentile for image-level threshold (default: 30)
            p2: P2-th percentile for class-level threshold (default: 50)
        """
        self.model_20x = model_20x
        self.model_5x = model_5x
        self.p1 = p1
        self.p2 = p2
        
    def compute_patch_probabilities(self, dataloader):
        """
        Compute P(y_i|x_ij) for all patches by averaging predictions from two CNNs
        """
        self.model_20x.eval()
        self.model_5x.eval()
        
        patch_probs = []
        image_indices = []
        labels = []
        
        with torch.no_grad():
            for patches, batch_labels, batch_image_indices in dataloader:
                # Get predictions from both models
                patches = patches.cuda()
                outputs_20x = torch.softmax(self.model_20x(patches), dim=1)
                outputs_5x = torch.softmax(self.model_5x(patches), dim=1)
                
                # Average predictions
                avg_outputs = (outputs_20x + outputs_5x) / 2
                
                # Extract probability for the true class
                batch_probs = torch.gather(avg_outputs, 1, batch_labels.unsqueeze(1).cuda()).cpu().numpy()
                
                patch_probs.extend(batch_probs)
                image_indices.extend(batch_image_indices.numpy())
                labels.extend(batch_labels.numpy())
        
        return np.array(patch_probs).flatten(), np.array(image_indices), np.array(labels)
    
    def apply_gaussian_smoothing(self, patch_probs, image_indices, sigma=1.0):
        """
        Apply Gaussian smoothing to probability maps to compute P(H_ij|X)
        """
        # Group probabilities by image
        unique_images = np.unique(image_indices)
        smoothed_probs = np.copy(patch_probs)
        
        # In a real implementation, you would need spatial information about patches
        # For simplicity, we'll just apply Gaussian smoothing to the 1D array of probabilities per image
        for img_idx in unique_images:
            img_mask = (image_indices == img_idx)
            img_probs = patch_probs[img_mask]
            
            # Apply Gaussian smoothing
            smoothed_img_probs = gaussian_filter(img_probs, sigma=sigma)
            
            # Update the smoothed probabilities
            smoothed_probs[img_mask] = smoothed_img_probs
        
        return smoothed_probs
    
    def select_discriminative_patches(self, smoothed_probs, image_indices, labels):
        """
        Select discriminative patches based on thresholds
        """
        unique_images = np.unique(image_indices)
        unique_classes = np.unique(labels)
        
        # Initialize mask for discriminative patches
        is_discriminative = np.zeros_like(smoothed_probs, dtype=bool)
        
        # Compute thresholds and select discriminative patches
        for img_idx in unique_images:
            img_mask = (image_indices == img_idx)
            img_probs = smoothed_probs[img_mask]
            
            # Image-level threshold (Hi): P1-th percentile of Si
            img_threshold = np.percentile(img_probs, self.p1)
            
            # Get class of this image (assuming all patches from same image have same label)
            img_class = labels[img_mask][0]
            
            # Class-level threshold (Ri): P2-th percentile of Ec
            class_mask = (labels == img_class)
            class_probs = smoothed_probs[class_mask]
            class_threshold = np.percentile(class_probs, self.p2)
            
            # Final threshold (Tij): min(Hi, Ri)
            threshold = min(img_threshold, class_threshold)
            
            # Select discriminative patches
            is_discriminative[img_mask] = (img_probs >= threshold)
        
        return is_discriminative


In [5]:
class DecisionFusionModel:
    def __init__(self, classifier_type='logistic_regression'):
        """
        Args:
            classifier_type: Type of classifier to use ('logistic_regression' or 'svm')
        """
        self.classifier_type = classifier_type
        
        if classifier_type == 'logistic_regression':
            self.classifier = LogisticRegression(multi_class='multinomial', solver='lbfgs', max_iter=1000)
        elif classifier_type == 'svm':
            self.classifier = SVC(kernel='rbf', probability=True)
        else:
            raise ValueError("classifier_type must be 'logistic_regression' or 'svm'")
    
    def create_histograms(self, models, dataloader, device):
        """
        Create histograms of patch-level predictions for each image
        
        Args:
            models: List of CNN models
            dataloader: DataLoader for patches
            device: Device to run inference on
        
        Returns:
            histograms: Histograms of patch-level predictions for each image
            image_labels: Labels for each image
        """
        # Set all models to evaluation mode
        for model in models:
            model.eval()
        
        # Get number of classes from the first model
        num_classes = models[0].classifier[-1].out_features
        
        # Initialize dictionaries to store predictions and counts
        image_predictions = {}
        image_labels = {}
        
        with torch.no_grad():
            for patches, batch_labels, batch_image_indices in dataloader:
                patches = patches.to(device)
                
                # Get predictions from all models and concatenate
                all_model_probs = []
                
                for model in models:
                    outputs = model(patches)
                    probs = torch.softmax(outputs, dim=1).cpu().numpy()
                    all_model_probs.append(probs)
                
                # Concatenate probabilities from all models
                combined_probs = np.concatenate(all_model_probs, axis=1)
                
                # Update image predictions and labels
                for i, img_idx in enumerate(batch_image_indices.numpy()):
                    if img_idx not in image_predictions:
                        image_predictions[img_idx] = np.zeros(combined_probs.shape[1])
                        image_labels[img_idx] = batch_labels[i].item()
                    
                    # Sum up probabilities for histogram
                    image_predictions[img_idx] += combined_probs[i]
        
        # Convert dictionaries to arrays
        unique_images = sorted(image_predictions.keys())
        histograms = np.array([image_predictions[img_idx] for img_idx in unique_images])
        labels = np.array([image_labels[img_idx] for img_idx in unique_images])
        
        return histograms, labels


    
    def fit(self, histograms, labels):
        """
        Fit the decision fusion model
        """
        self.classifier.fit(histograms, labels)
    
    def predict(self, histograms):
        """
        Predict image-level labels
        """
        return self.classifier.predict(histograms)
    
    def predict_proba(self, histograms):
        """
        Predict image-level class probabilities
        """
        return self.classifier.predict_proba(histograms)


In [6]:
def train_patch_cnn(model, train_loader, val_loader, num_epochs=10):
    """
    Train a patch-level CNN model
    """
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
    
    model = model.cuda()
    best_val_acc = 0.0
    
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        for patches, labels, _ in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            patches, labels = patches.cuda(), labels.cuda()
            
            optimizer.zero_grad()
            
            outputs = model(patches)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
        
        train_loss = running_loss / len(train_loader)
        train_acc = 100. * correct / total
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        
        with torch.no_grad():
            for patches, labels, _ in val_loader:
                patches, labels = patches.cuda(), labels.cuda()
                
                outputs = model(patches)
                loss = criterion(outputs, labels)
                
                val_loss += loss.item()
                
                _, predicted = outputs.max(1)
                val_total += labels.size(0)
                val_correct += predicted.eq(labels).sum().item()
        
        val_loss = val_loss / len(val_loader)
        val_acc = 100. * val_correct / val_total
        
        print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%, "
              f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")
        
        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), 'best_patch_cnn.pth')
        
        scheduler.step()
    
    return model

def train_em_based_model(dataset, num_iterations=3):
    """
    Train the full EM-based model with discriminative patch selection
    """
    # Split dataset into train and validation
    train_indices, val_indices = train_test_split(range(len(dataset)), test_size=0.2, stratify=dataset.labels)
    
    train_dataset = torch.utils.data.Subset(dataset, train_indices)
    val_dataset = torch.utils.data.Subset(dataset, val_indices)
    
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)
    
    # Initialize models for 20X and 5X magnifications
    model_20x = PatchCNN(num_classes=2)  # Assuming binary classification (ADC vs SCC)
    model_5x = PatchCNN(num_classes=2)
    
    # Initial training with all patches
    print("Initial training with all patches...")
    model_20x = train_patch_cnn(model_20x, train_loader, val_loader, num_epochs=5)
    model_5x = train_patch_cnn(model_5x, train_loader, val_loader, num_epochs=5)
    
    # Initialize EM-based patch selection
    em_selector = EMBasedPatchSelection(model_20x, model_5x)
    
    # EM iterations
    for iteration in range(num_iterations):
        print(f"\nEM Iteration {iteration+1}/{num_iterations}")
        
        # E-step: Select discriminative patches
        print("E-step: Selecting discriminative patches...")
        patch_probs, image_indices, labels = em_selector.compute_patch_probabilities(val_loader)
        smoothed_probs = em_selector.apply_gaussian_smoothing(patch_probs, image_indices)
        is_discriminative = em_selector.select_discriminative_patches(smoothed_probs, image_indices, labels)
        
        # Create a new dataset with only discriminative patches
        discriminative_indices = [idx for idx, is_disc in enumerate(is_discriminative) if is_disc]
        discriminative_train_dataset = torch.utils.data.Subset(train_dataset, discriminative_indices)
        discriminative_train_loader = DataLoader(discriminative_train_dataset, batch_size=32, shuffle=True, num_workers=4)
        
        # M-step: Retrain models with discriminative patches
        print("M-step: Retraining models with discriminative patches...")
        model_20x = train_patch_cnn(model_20x, discriminative_train_loader, val_loader, num_epochs=5)
        model_5x = train_patch_cnn(model_5x, discriminative_train_loader, val_loader, num_epochs=5)
    
    # Train additional models with different numbers of iterations
    model_20x_early = PatchCNN(num_classes=2)
    model_5x_early = PatchCNN(num_classes=2)
    
    model_20x_early.load_state_dict(torch.load('best_patch_cnn.pth'))
    model_5x_early.load_state_dict(torch.load('best_patch_cnn.pth'))
    
    # Return all models for decision fusion
    return [model_20x, model_5x, model_20x_early, model_5x_early]

def train_decision_fusion(models, dataset, device, classifier_type='logistic_regression'):
    """
    Train the image-level decision fusion model
    """
    # Split dataset into train and validation
    train_indices, val_indices = train_test_split(range(len(dataset)), test_size=0.2, stratify=dataset.labels)
    
    train_dataset = torch.utils.data.Subset(dataset, train_indices)
    val_dataset = torch.utils.data.Subset(dataset, val_indices)
    
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=False, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)
    
    # Initialize decision fusion model
    fusion_model = DecisionFusionModel(classifier_type=classifier_type)
    
    # Make sure all models are on the same device
    for model in models:
        model.to(device)
    
    # Create histograms for training
    print("Creating histograms for training...")
    train_histograms, train_labels = fusion_model.create_histograms(models, train_loader, device)
    
    # Fit decision fusion model
    print("Fitting decision fusion model...")
    fusion_model.fit(train_histograms, train_labels)
    
    # Evaluate on validation set
    print("Evaluating on validation set...")
    val_histograms, val_labels = fusion_model.create_histograms(models, val_loader, device)
    val_predictions = fusion_model.predict(val_histograms)
    
    # Calculate accuracy
    accuracy = np.mean(val_predictions == val_labels)
    print(f"Validation accuracy: {accuracy:.4f}")
    
    return fusion_model



In [7]:
def main():
    # Set random seed for reproducibility
    torch.manual_seed(42)
    np.random.seed(42)
    random.seed(42)
    
    # Dataset path
    dataset_path = '/kaggle/input/nsclc-radiomics/NSCLC-Radiomics'
    
    # Get list of patient IDs
    patient_ids = [f for f in os.listdir(dataset_path) if os.path.isdir(os.path.join(dataset_path, f))]
    
    # Split patients into train and test sets (80% train, 20% test)
    train_patients, test_patients = train_test_split(patient_ids, test_size=0.2, random_state=42)
    
    # Define transformations
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    # Create datasets for 20X and 5X magnifications
    dataset_20x = NSCLCPatchDataset(
        root_dir=dataset_path,
        patient_ids=train_patients,
        magnification='20X',
        transform=transform,
        patch_size=500,
        sub_patch_size=400,
        augment=True
    )
    
    dataset_5x = NSCLCPatchDataset(
        root_dir=dataset_path,
        patient_ids=train_patients,
        magnification='5X',
        transform=transform,
        patch_size=500,
        sub_patch_size=400,
        augment=True
    )
    
    # Train EM-based models
    print("Training EM-based models for 20X magnification...")
    models_20x = train_em_based_model(dataset_20x, num_iterations=3)
    
    print("Training EM-based models for 5X magnification...")
    models_5x = train_em_based_model(dataset_5x, num_iterations=3)
    
    # Combine all models
    all_models = models_20x + models_5x
    
    # Train decision fusion model
    print("Training decision fusion model...")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    fusion_model = train_decision_fusion(all_models, dataset_20x, device, classifier_type='logistic_regression')
    
    # Create test dataset
    test_dataset = NSCLCPatchDataset(
        root_dir=dataset_path,
        patient_ids=test_patients,
        magnification='20X',
        transform=transform,
        patch_size=500,
        sub_patch_size=400,
        augment=False
    )
    
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)
    
    # Evaluate on test set
    print("Evaluating on test set...")
    test_histograms, test_labels = fusion_model.create_histograms(all_models, test_loader, device)
    test_predictions = fusion_model.predict(test_histograms)
    
    # Calculate accuracy
    accuracy = np.mean(test_predictions == test_labels)
    print(f"Test accuracy: {accuracy:.4f}")

if __name__ == "__main__":
    main()


Loading patches for 337 patients...


100%|██████████| 337/337 [01:50<00:00,  3.04it/s]


Loading patches for 337 patients...


100%|██████████| 337/337 [01:01<00:00,  5.51it/s]


Training EM-based models for 20X magnification...
Feature map size after convolutions: 57600
Feature map size after convolutions: 57600
Initial training with all patches...


Epoch 1/5: 100%|██████████| 292/292 [01:11<00:00,  4.09it/s]


Epoch 1/5, Train Loss: 0.6817, Train Acc: 57.91%, Val Loss: 0.6797, Val Acc: 57.91%


Epoch 2/5: 100%|██████████| 292/292 [01:13<00:00,  3.95it/s]


Epoch 2/5, Train Loss: 0.6791, Train Acc: 57.91%, Val Loss: 0.6787, Val Acc: 57.91%


Epoch 3/5: 100%|██████████| 292/292 [01:14<00:00,  3.92it/s]


Epoch 3/5, Train Loss: 0.6780, Train Acc: 57.91%, Val Loss: 0.6774, Val Acc: 57.91%


Epoch 4/5: 100%|██████████| 292/292 [01:14<00:00,  3.91it/s]


Epoch 4/5, Train Loss: 0.6750, Train Acc: 57.90%, Val Loss: 0.6772, Val Acc: 57.91%


Epoch 5/5: 100%|██████████| 292/292 [01:14<00:00,  3.92it/s]


Epoch 5/5, Train Loss: 0.6716, Train Acc: 57.97%, Val Loss: 0.6696, Val Acc: 57.31%


Epoch 1/5: 100%|██████████| 292/292 [01:14<00:00,  3.90it/s]


Epoch 1/5, Train Loss: 0.6818, Train Acc: 57.84%, Val Loss: 0.6806, Val Acc: 57.91%


Epoch 2/5: 100%|██████████| 292/292 [01:14<00:00,  3.90it/s]


Epoch 2/5, Train Loss: 0.6807, Train Acc: 57.91%, Val Loss: 0.6788, Val Acc: 57.91%


Epoch 3/5: 100%|██████████| 292/292 [01:14<00:00,  3.91it/s]


Epoch 3/5, Train Loss: 0.6779, Train Acc: 57.92%, Val Loss: 0.6795, Val Acc: 57.91%


Epoch 4/5: 100%|██████████| 292/292 [01:14<00:00,  3.91it/s]


Epoch 4/5, Train Loss: 0.6761, Train Acc: 57.80%, Val Loss: 0.6827, Val Acc: 56.88%


Epoch 5/5: 100%|██████████| 292/292 [01:14<00:00,  3.92it/s]


Epoch 5/5, Train Loss: 0.6739, Train Acc: 57.89%, Val Loss: 0.6749, Val Acc: 57.22%

EM Iteration 1/3
E-step: Selecting discriminative patches...
M-step: Retraining models with discriminative patches...


Epoch 1/5: 100%|██████████| 58/58 [00:15<00:00,  3.80it/s]


Epoch 1/5, Train Loss: 0.6654, Train Acc: 60.57%, Val Loss: 0.6746, Val Acc: 57.91%


Epoch 2/5: 100%|██████████| 58/58 [00:15<00:00,  3.84it/s]


Epoch 2/5, Train Loss: 0.6615, Train Acc: 60.84%, Val Loss: 0.6746, Val Acc: 57.91%


Epoch 3/5: 100%|██████████| 58/58 [00:15<00:00,  3.82it/s]


Epoch 3/5, Train Loss: 0.6598, Train Acc: 60.90%, Val Loss: 0.6745, Val Acc: 57.91%


Epoch 4/5: 100%|██████████| 58/58 [00:15<00:00,  3.83it/s]


Epoch 4/5, Train Loss: 0.6548, Train Acc: 60.79%, Val Loss: 0.6746, Val Acc: 57.91%


Epoch 5/5: 100%|██████████| 58/58 [00:15<00:00,  3.86it/s]


Epoch 5/5, Train Loss: 0.6621, Train Acc: 60.57%, Val Loss: 0.6720, Val Acc: 57.95%


Epoch 1/5: 100%|██████████| 58/58 [00:15<00:00,  3.82it/s]


Epoch 1/5, Train Loss: 0.6625, Train Acc: 61.00%, Val Loss: 0.6753, Val Acc: 57.91%


Epoch 2/5: 100%|██████████| 58/58 [00:15<00:00,  3.82it/s]


Epoch 2/5, Train Loss: 0.6634, Train Acc: 60.90%, Val Loss: 0.6741, Val Acc: 57.91%


Epoch 3/5: 100%|██████████| 58/58 [00:15<00:00,  3.80it/s]


Epoch 3/5, Train Loss: 0.6616, Train Acc: 60.90%, Val Loss: 0.6729, Val Acc: 59.58%


Epoch 4/5: 100%|██████████| 58/58 [00:15<00:00,  3.85it/s]


Epoch 4/5, Train Loss: 0.6626, Train Acc: 61.28%, Val Loss: 0.6672, Val Acc: 58.59%


Epoch 5/5: 100%|██████████| 58/58 [00:15<00:00,  3.80it/s]


Epoch 5/5, Train Loss: 0.6575, Train Acc: 61.00%, Val Loss: 0.6753, Val Acc: 57.99%

EM Iteration 2/3
E-step: Selecting discriminative patches...
M-step: Retraining models with discriminative patches...


Epoch 1/5: 100%|██████████| 57/57 [00:14<00:00,  3.83it/s]


Epoch 1/5, Train Loss: 0.6644, Train Acc: 60.00%, Val Loss: 0.6704, Val Acc: 60.52%


Epoch 2/5: 100%|██████████| 57/57 [00:14<00:00,  3.86it/s]


Epoch 2/5, Train Loss: 0.6615, Train Acc: 59.50%, Val Loss: 0.6736, Val Acc: 58.04%


Epoch 3/5: 100%|██████████| 57/57 [00:15<00:00,  3.79it/s]


Epoch 3/5, Train Loss: 0.6551, Train Acc: 59.22%, Val Loss: 0.6778, Val Acc: 57.69%


Epoch 4/5: 100%|██████████| 57/57 [00:14<00:00,  3.85it/s]


Epoch 4/5, Train Loss: 0.6585, Train Acc: 60.33%, Val Loss: 0.6667, Val Acc: 58.29%


Epoch 5/5: 100%|██████████| 57/57 [00:14<00:00,  3.86it/s]


Epoch 5/5, Train Loss: 0.6600, Train Acc: 60.17%, Val Loss: 0.6698, Val Acc: 57.22%


Epoch 1/5: 100%|██████████| 57/57 [00:14<00:00,  3.83it/s]


Epoch 1/5, Train Loss: 0.6577, Train Acc: 60.22%, Val Loss: 0.6660, Val Acc: 58.81%


Epoch 2/5: 100%|██████████| 57/57 [00:14<00:00,  3.84it/s]


Epoch 2/5, Train Loss: 0.6609, Train Acc: 59.55%, Val Loss: 0.6663, Val Acc: 57.44%


Epoch 3/5: 100%|██████████| 57/57 [00:15<00:00,  3.79it/s]


Epoch 3/5, Train Loss: 0.6575, Train Acc: 60.72%, Val Loss: 0.6692, Val Acc: 58.34%


Epoch 4/5: 100%|██████████| 57/57 [00:14<00:00,  3.88it/s]


Epoch 4/5, Train Loss: 0.6577, Train Acc: 60.33%, Val Loss: 0.6644, Val Acc: 59.37%


Epoch 5/5: 100%|██████████| 57/57 [00:14<00:00,  3.86it/s]


Epoch 5/5, Train Loss: 0.6616, Train Acc: 59.89%, Val Loss: 0.6736, Val Acc: 57.22%

EM Iteration 3/3
E-step: Selecting discriminative patches...
M-step: Retraining models with discriminative patches...


Epoch 1/5: 100%|██████████| 58/58 [00:15<00:00,  3.85it/s]


Epoch 1/5, Train Loss: 0.6553, Train Acc: 59.57%, Val Loss: 0.6624, Val Acc: 61.85%


Epoch 2/5: 100%|██████████| 58/58 [00:15<00:00,  3.85it/s]


Epoch 2/5, Train Loss: 0.6579, Train Acc: 59.25%, Val Loss: 0.6658, Val Acc: 58.21%


Epoch 3/5: 100%|██████████| 58/58 [00:15<00:00,  3.81it/s]


Epoch 3/5, Train Loss: 0.6599, Train Acc: 60.28%, Val Loss: 0.6652, Val Acc: 59.11%


Epoch 4/5: 100%|██████████| 58/58 [00:15<00:00,  3.85it/s]


Epoch 4/5, Train Loss: 0.6599, Train Acc: 59.30%, Val Loss: 0.6668, Val Acc: 57.91%


Epoch 5/5: 100%|██████████| 58/58 [00:14<00:00,  3.88it/s]


Epoch 5/5, Train Loss: 0.6553, Train Acc: 60.56%, Val Loss: 0.6694, Val Acc: 58.29%


Epoch 1/5: 100%|██████████| 58/58 [00:15<00:00,  3.84it/s]


Epoch 1/5, Train Loss: 0.6558, Train Acc: 59.79%, Val Loss: 0.6688, Val Acc: 58.85%


Epoch 2/5: 100%|██████████| 58/58 [00:15<00:00,  3.83it/s]


Epoch 2/5, Train Loss: 0.6577, Train Acc: 59.52%, Val Loss: 0.6743, Val Acc: 57.87%


Epoch 3/5: 100%|██████████| 58/58 [00:15<00:00,  3.84it/s]


Epoch 3/5, Train Loss: 0.6616, Train Acc: 60.50%, Val Loss: 0.6907, Val Acc: 56.54%


Epoch 4/5: 100%|██████████| 58/58 [00:15<00:00,  3.86it/s]


Epoch 4/5, Train Loss: 0.6504, Train Acc: 61.76%, Val Loss: 0.6607, Val Acc: 60.35%


Epoch 5/5: 100%|██████████| 58/58 [00:15<00:00,  3.85it/s]


Epoch 5/5, Train Loss: 0.6600, Train Acc: 60.01%, Val Loss: 0.6713, Val Acc: 57.69%
Feature map size after convolutions: 57600
Feature map size after convolutions: 57600


  model_20x_early.load_state_dict(torch.load('best_patch_cnn.pth'))
  model_5x_early.load_state_dict(torch.load('best_patch_cnn.pth'))


Training EM-based models for 5X magnification...
Feature map size after convolutions: 57600
Feature map size after convolutions: 57600
Initial training with all patches...


Epoch 1/5: 100%|██████████| 292/292 [01:15<00:00,  3.87it/s]


Epoch 1/5, Train Loss: 0.6844, Train Acc: 55.87%, Val Loss: 0.6747, Val Acc: 58.81%


Epoch 2/5: 100%|██████████| 292/292 [01:15<00:00,  3.88it/s]


Epoch 2/5, Train Loss: 0.6811, Train Acc: 57.36%, Val Loss: 0.6803, Val Acc: 56.32%


Epoch 3/5: 100%|██████████| 292/292 [01:14<00:00,  3.91it/s]


Epoch 3/5, Train Loss: 0.6753, Train Acc: 58.81%, Val Loss: 0.6696, Val Acc: 60.09%


Epoch 4/5: 100%|██████████| 292/292 [01:14<00:00,  3.90it/s]


Epoch 4/5, Train Loss: 0.6733, Train Acc: 58.82%, Val Loss: 0.6717, Val Acc: 59.41%


Epoch 5/5: 100%|██████████| 292/292 [01:14<00:00,  3.91it/s]


Epoch 5/5, Train Loss: 0.6729, Train Acc: 59.23%, Val Loss: 0.6772, Val Acc: 57.82%


Epoch 1/5: 100%|██████████| 292/292 [01:15<00:00,  3.88it/s]


Epoch 1/5, Train Loss: 0.6861, Train Acc: 55.56%, Val Loss: 0.6797, Val Acc: 58.81%


Epoch 2/5: 100%|██████████| 292/292 [01:14<00:00,  3.90it/s]


Epoch 2/5, Train Loss: 0.6807, Train Acc: 57.61%, Val Loss: 0.6707, Val Acc: 59.97%


Epoch 3/5: 100%|██████████| 292/292 [01:14<00:00,  3.90it/s]


Epoch 3/5, Train Loss: 0.6784, Train Acc: 58.12%, Val Loss: 0.6703, Val Acc: 59.62%


Epoch 4/5: 100%|██████████| 292/292 [01:14<00:00,  3.90it/s]


Epoch 4/5, Train Loss: 0.6733, Train Acc: 58.90%, Val Loss: 0.6729, Val Acc: 59.62%


Epoch 5/5: 100%|██████████| 292/292 [01:14<00:00,  3.91it/s]


Epoch 5/5, Train Loss: 0.6726, Train Acc: 59.16%, Val Loss: 0.6649, Val Acc: 60.01%

EM Iteration 1/3
E-step: Selecting discriminative patches...
M-step: Retraining models with discriminative patches...


Epoch 1/5: 100%|██████████| 56/56 [00:14<00:00,  3.77it/s]


Epoch 1/5, Train Loss: 0.6645, Train Acc: 60.14%, Val Loss: 0.6683, Val Acc: 59.62%


Epoch 2/5: 100%|██████████| 56/56 [00:14<00:00,  3.79it/s]


Epoch 2/5, Train Loss: 0.6659, Train Acc: 60.76%, Val Loss: 0.6619, Val Acc: 60.65%


Epoch 3/5: 100%|██████████| 56/56 [00:14<00:00,  3.79it/s]


Epoch 3/5, Train Loss: 0.6640, Train Acc: 60.53%, Val Loss: 0.6661, Val Acc: 60.27%


Epoch 4/5: 100%|██████████| 56/56 [00:14<00:00,  3.77it/s]


Epoch 4/5, Train Loss: 0.6601, Train Acc: 61.38%, Val Loss: 0.6764, Val Acc: 59.37%


Epoch 5/5: 100%|██████████| 56/56 [00:14<00:00,  3.82it/s]


Epoch 5/5, Train Loss: 0.6581, Train Acc: 61.21%, Val Loss: 0.6616, Val Acc: 60.52%


Epoch 1/5: 100%|██████████| 56/56 [00:14<00:00,  3.83it/s]


Epoch 1/5, Train Loss: 0.6658, Train Acc: 60.93%, Val Loss: 0.6677, Val Acc: 60.14%


Epoch 2/5: 100%|██████████| 56/56 [00:14<00:00,  3.75it/s]


Epoch 2/5, Train Loss: 0.6667, Train Acc: 59.17%, Val Loss: 0.6792, Val Acc: 57.35%


Epoch 3/5: 100%|██████████| 56/56 [00:15<00:00,  3.73it/s]


Epoch 3/5, Train Loss: 0.6697, Train Acc: 60.25%, Val Loss: 0.6746, Val Acc: 58.59%


Epoch 4/5: 100%|██████████| 56/56 [00:14<00:00,  3.80it/s]


Epoch 4/5, Train Loss: 0.6594, Train Acc: 61.72%, Val Loss: 0.6657, Val Acc: 60.61%


Epoch 5/5: 100%|██████████| 56/56 [00:14<00:00,  3.78it/s]


Epoch 5/5, Train Loss: 0.6671, Train Acc: 60.42%, Val Loss: 0.6643, Val Acc: 60.61%

EM Iteration 2/3
E-step: Selecting discriminative patches...
M-step: Retraining models with discriminative patches...


Epoch 1/5: 100%|██████████| 56/56 [00:14<00:00,  3.77it/s]


Epoch 1/5, Train Loss: 0.6574, Train Acc: 60.90%, Val Loss: 0.6670, Val Acc: 58.98%


Epoch 2/5: 100%|██████████| 56/56 [00:14<00:00,  3.79it/s]


Epoch 2/5, Train Loss: 0.6680, Train Acc: 60.22%, Val Loss: 0.6873, Val Acc: 57.27%


Epoch 3/5: 100%|██████████| 56/56 [00:14<00:00,  3.76it/s]


Epoch 3/5, Train Loss: 0.6649, Train Acc: 59.61%, Val Loss: 0.6659, Val Acc: 60.14%


Epoch 4/5: 100%|██████████| 56/56 [00:14<00:00,  3.77it/s]


Epoch 4/5, Train Loss: 0.6605, Train Acc: 61.29%, Val Loss: 0.6685, Val Acc: 59.75%


Epoch 5/5: 100%|██████████| 56/56 [00:14<00:00,  3.77it/s]


Epoch 5/5, Train Loss: 0.6571, Train Acc: 61.12%, Val Loss: 0.6755, Val Acc: 59.28%


Epoch 1/5: 100%|██████████| 56/56 [00:14<00:00,  3.76it/s]


Epoch 1/5, Train Loss: 0.6613, Train Acc: 61.46%, Val Loss: 0.6599, Val Acc: 61.29%


Epoch 2/5: 100%|██████████| 56/56 [00:14<00:00,  3.78it/s]


Epoch 2/5, Train Loss: 0.6640, Train Acc: 61.24%, Val Loss: 0.6672, Val Acc: 59.67%


Epoch 3/5: 100%|██████████| 56/56 [00:14<00:00,  3.74it/s]


Epoch 3/5, Train Loss: 0.6636, Train Acc: 61.07%, Val Loss: 0.6609, Val Acc: 60.39%


Epoch 4/5: 100%|██████████| 56/56 [00:14<00:00,  3.77it/s]


Epoch 4/5, Train Loss: 0.6625, Train Acc: 60.79%, Val Loss: 0.6674, Val Acc: 59.97%


Epoch 5/5: 100%|██████████| 56/56 [00:14<00:00,  3.78it/s]


Epoch 5/5, Train Loss: 0.6578, Train Acc: 62.08%, Val Loss: 0.6550, Val Acc: 61.34%

EM Iteration 3/3
E-step: Selecting discriminative patches...
M-step: Retraining models with discriminative patches...


Epoch 1/5: 100%|██████████| 56/56 [00:14<00:00,  3.80it/s]


Epoch 1/5, Train Loss: 0.6633, Train Acc: 60.36%, Val Loss: 0.6680, Val Acc: 59.84%


Epoch 2/5: 100%|██████████| 56/56 [00:14<00:00,  3.82it/s]


Epoch 2/5, Train Loss: 0.6626, Train Acc: 60.48%, Val Loss: 0.6668, Val Acc: 59.11%


Epoch 3/5: 100%|██████████| 56/56 [00:14<00:00,  3.78it/s]


Epoch 3/5, Train Loss: 0.6607, Train Acc: 60.65%, Val Loss: 0.6707, Val Acc: 59.45%


Epoch 4/5: 100%|██████████| 56/56 [00:14<00:00,  3.80it/s]


Epoch 4/5, Train Loss: 0.6661, Train Acc: 60.48%, Val Loss: 0.6720, Val Acc: 58.89%


Epoch 5/5: 100%|██████████| 56/56 [00:14<00:00,  3.82it/s]


Epoch 5/5, Train Loss: 0.6607, Train Acc: 61.16%, Val Loss: 0.6668, Val Acc: 59.84%


Epoch 1/5: 100%|██████████| 56/56 [00:14<00:00,  3.79it/s]


Epoch 1/5, Train Loss: 0.6603, Train Acc: 60.99%, Val Loss: 0.6603, Val Acc: 60.27%


Epoch 2/5: 100%|██████████| 56/56 [00:14<00:00,  3.78it/s]


Epoch 2/5, Train Loss: 0.6675, Train Acc: 60.36%, Val Loss: 0.6767, Val Acc: 57.61%


Epoch 3/5: 100%|██████████| 56/56 [00:14<00:00,  3.80it/s]


Epoch 3/5, Train Loss: 0.6660, Train Acc: 59.97%, Val Loss: 0.6648, Val Acc: 60.27%


Epoch 4/5: 100%|██████████| 56/56 [00:14<00:00,  3.82it/s]


Epoch 4/5, Train Loss: 0.6659, Train Acc: 61.38%, Val Loss: 0.6665, Val Acc: 59.41%


Epoch 5/5: 100%|██████████| 56/56 [00:14<00:00,  3.81it/s]


Epoch 5/5, Train Loss: 0.6641, Train Acc: 59.85%, Val Loss: 0.6610, Val Acc: 61.17%
Feature map size after convolutions: 57600
Feature map size after convolutions: 57600


  model_20x_early.load_state_dict(torch.load('best_patch_cnn.pth'))
  model_5x_early.load_state_dict(torch.load('best_patch_cnn.pth'))


Training decision fusion model...
Creating histograms for training...
Fitting decision fusion model...
Evaluating on validation set...
Validation accuracy: 0.5825
Loading patches for 85 patients...


100%|██████████| 85/85 [00:29<00:00,  2.90it/s]

Evaluating on test set...





Test accuracy: 0.5172
