## Importing modules

In [1]:
import h5py
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
from torchvision.models import resnet50
import numpy as np
from functools import partial

import skimage
from skimage.color import rgb2hed
import pywt
from PIL import Image
import cv2
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torchstain

from histomicstk.preprocessing.color_normalization.deconvolution_based_normalization import deconvolution_based_normalization


In [2]:
class H5Dataset(Dataset):
    def __init__(self, image_file, label_file, transform=None):
        self.transform = transform
        
        # Load data from the H5 file
        with h5py.File(image_file, 'r') as f:
            self.images = f['x'][:]
        with h5py.File(label_file, 'r') as f:
            self.labels = f['y'][:].reshape(-1,)
        
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]
        
        
        if self.transform:
            image = self.transform(image)
        
        return image, label

In [3]:
class RGB2HED(torch.nn.Module):
    def __init__(self, mode=None):
        super(RGB2HED, self).__init__()
        self.mode = mode
    def forward(self, img):
        img = img.astype(np.float32) / 255.
        hed_img = rgb2hed(img) * 255.
        hed_img = np.tile(hed_img[:, :, -2:-1], reps=(1,1,3))
        return hed_img
    
        
class WaveletTransform(nn.Module):
    def __init__(self, wavelet='haar', threshold=20):
        super(WaveletTransform, self).__init__()
        self.wavelet = wavelet
        self.threshold = threshold
        
    def forward(self, img):
        grayscale_image = np.dot(img.astype(np.uint8), [0.299, 0.587, 0.114])
        
        # Step 2: Perform 2D wavelet decomposition
        coeffs = pywt.wavedec2(grayscale_image, wavelet=self.wavelet, level=2)
        cA, details = coeffs[0], coeffs[1:]
        
        # Step 3: Apply thresholding to detail coefficients
        def threshold_coeffs(coeffs, threshold):
            return [pywt.threshold(c, threshold, mode='soft') for c in coeffs] 
        
        
        details_thresh = [threshold_coeffs(detail, self.threshold) for detail in details]
        coeffs_thresh = [cA] + details_thresh
        
        # Step 4: Reconstruct the image
        compressed_image = pywt.waverec2(coeffs_thresh, wavelet=self.wavelet)
        compressed_image = np.clip(compressed_image, 0, 255).astype(np.uint8)
        compressed_image = np.tile(np.expand_dims(compressed_image, -1), (1,1,3))
        
        return compressed_image
    

class CLAHE(nn.Module):
    def __init__(self, mode=None):
        super(CLAHE, self).__init__()
        self.mode = mode
    def forward(self, image):
        # Convert to LAB color space
        lab_image = cv2.cvtColor(image.astype(np.uint8), cv2.COLOR_RGB2LAB)
        l_channel, a, b = cv2.split(lab_image)

        # Apply CLAHE to the L channel
        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
        l_channel = clahe.apply(l_channel)

        # Merge and convert back to RGB
        lab_image = cv2.merge((l_channel, a, b))
        return cv2.cvtColor(lab_image, cv2.COLOR_LAB2RGB)


class Opening(nn.Module):
    def __init__(self):
        super(Opening, self).__init__()
        
    def forward(self, image):
        return skimage.morphology.opening(image)
    
    
class Macenko(nn.Module):
    def __init__(self, reference_image, target_W=None, alpha=1, beta=0.01):
        super(Macenko, self).__init__()
        self.target_W = target_W
        self.alpha = alpha
        self.beta = beta
        self.reference_image = reference_image.astype(np.uint8)
    
    def forward(self, image):
        """
        Apply Macenko normalization to a single image with error handling.
        
        Parameters:
            image (np.ndarray): The image to normalize, shape (H, W, C) in RGB format.
            reference_image (np.ndarray): The reference image for normalization, shape (H, W, C) in RGB format.
        
        Returns:
            np.ndarray: The normalized image, shape (C, H, W) in normalized format.
            None: If normalization fails for any reason.
        """
        try:
            # # Set up the transformation
            # T = transforms.Compose([
            #     transforms.ToTensor(),
            # ])

            # Initialize the MacenkoNormalizer
            normalizer = torchstain.normalizers.MacenkoNormalizer(backend='torch')

            # Fit the normalizer with the reference image
            normalizer.fit(self.reference_image)

            # Transform the image and apply normalization
            t_to_transform = image
            norm_img, _, _ = normalizer.normalize(I=t_to_transform, stains=True)

            # Return the normalized image
            return norm_img

        except torch.linalg.LinAlgError as e:
            # print(f"LinAlgError during normalization: {e}")
            pass
        except Exception as e:
            # print(f"Unexpected error during normalization: {e}")
            pass

        # Return None if normalization fails
        return image

    
class ReinhardNormalization(nn.Module):
    def __init__(self, reference_image):
        super(ReinhardNormalization, self).__init__()
        self.reference_image = reference_image.astype(np.uint8)
        
    def forward(self, image):
        """
        Apply Reinhard normalization to a single image with error handling.

        Parameters:
            image (np.ndarray): The image to normalize, shape (H, W, C) in RGB format.
            reference_image (np.ndarray): The reference image for normalization, shape (H, W, C) in RGB format.

        Returns:
            np.ndarray: The normalized image, shape (H, W, C) in normalized format.
            None: If normalization fails for any reason.
        """
        try:
            # Initialize the ReinhardNormalizer
            normalizer = torchstain.normalizers.ReinhardNormalizer()

            # Fit the normalizer with the reference image
            normalizer.fit(self.reference_image)

            # Normalize the image
            normalized_image = normalizer.normalize(image)

            # Return the normalized image
            return normalized_image

        except Exception as e:
            # print(f"Unexpected error during Reinhard normalization: {e}")
            pass

        # Return None if normalization fails
        return image


train_data = H5Dataset(image_file='../../../pcam/training_split.h5', 
                        label_file='../../../Labels/Labels/camelyonpatch_level_2_split_train_y.h5')
reference_image = train_data.images[176298]

train_transform = transforms.Compose(
    [
        # RGB2HED(),
        # WaveletTransform(),
        # CLAHE(),
        # Opening(),
        transforms.ToPILImage(),
        transforms.ColorJitter(brightness=.5, saturation=.25,
                            hue=.1, contrast=.5),
        transforms.RandomAffine(10, (0.05, 0.05), fill=255),
        transforms.RandomHorizontalFlip(.5),
        transforms.RandomVerticalFlip(.5),
        transforms.ToTensor(),
        # Macenko(reference_image=reference_image),
        # ReinhardNormalization(reference_image=reference_image),
        transforms.Normalize([0.6716241, 0.48636872, 0.60884315],
                            [0.27210504, 0.31001145, 0.2918652]),
        
])

val_transform = transforms.Compose(
    [
        transforms.ToPILImage(),
        transforms.ToTensor(),
        transforms.Normalize([0.6716241, 0.48636872, 0.60884315],
                            [0.27210504, 0.31001145, 0.2918652])
        
    ]
)

In [4]:
# Load datasets
train_dataset = H5Dataset(image_file='../../../pcam/training_split.h5', 
                          label_file='../../../Labels/Labels/camelyonpatch_level_2_split_train_y.h5', 
                          transform=train_transform)
val_dataset = H5Dataset(image_file='../../../pcam/validation_split.h5', 
                        label_file='../../../Labels/Labels/camelyonpatch_level_2_split_valid_y.h5',
                        transform=val_transform)

test_dataset = H5Dataset(image_file='../../../pcam/test_split.h5', 
                        label_file='../../../Labels/Labels/camelyonpatch_level_2_split_test_y.h5',
                        transform=val_transform)

# Create dataloaders
bs = 128
train_loader = DataLoader(train_dataset, batch_size=bs, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=bs, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, shuffle=False, num_workers=4)

In [5]:
# U-Net model with a classification head
class UNet(nn.Module):
    def __init__(self, num_classes=2):
        super(UNet, self).__init__()

        # Encoder (Downsampling path)
        self.enc1 = self.conv_block(3, 64)
        self.enc2 = self.conv_block(64, 128)
        self.enc3 = self.conv_block(128, 256)
        self.enc4 = self.conv_block(256, 512)

        # Bottleneck
        self.bottleneck = self.conv_block(512, 1024)

        # Decoder (Upsampling path)
        self.up4 = self.upconv(1024, 512)
        self.dec4 = self.conv_block(1024, 512)
        self.up3 = self.upconv(512, 256)
        self.dec3 = self.conv_block(512, 256)
        self.up2 = self.upconv(256, 128)
        self.dec2 = self.conv_block(256, 128)
        self.up1 = self.upconv(128, 64)
        self.dec1 = self.conv_block(128, 64)

        # Classification head
        self.gap = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(64, num_classes)

        # Pooling layer
        self.pool = nn.MaxPool2d(2, 2)

    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def upconv(self, in_channels, out_channels):
        return nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)

    def forward(self, x):
        # Encoder
        enc1 = self.enc1(x)
        enc2 = self.enc2(self.pool(enc1))
        enc3 = self.enc3(self.pool(enc2))
        enc4 = self.enc4(self.pool(enc3))

        # Bottleneck
        bottleneck = self.bottleneck(self.pool(enc4))

        # Decoder with skip connections
        dec4 = self.dec4(torch.cat([self.up4(bottleneck), enc4], dim=1))
        dec3 = self.dec3(torch.cat([self.up3(dec4), enc3], dim=1))
        dec2 = self.dec2(torch.cat([self.up2(dec3), enc2], dim=1))
        dec1 = self.dec1(torch.cat([self.up1(dec2), enc1], dim=1))

        # Classification head
        gap = self.gap(dec1)
        out = self.fc(gap.view(gap.size(0), -1))

        return out

In [6]:
# Initialize model, loss function, and optimizer
device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')
model = UNet(num_classes=2).to(device)  # Binary classification
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.1)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=1, verbose=True, threshold=0.01)




In [7]:
# Training and validation loops
def train_and_validate(model, train_loader, val_loader, criterion, optimizer, epochs=10):
    for epoch in range(epochs):
        # Training phase
        model.train()
        train_loss, train_correct, train_total = 0, 0, 0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            
            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            # Metrics
            train_loss += loss.item()
            _, predicted = outputs.max(1)
            train_total += labels.size(0)
            train_correct += predicted.eq(labels).sum().item()
        
        # Validation phase
        model.eval()
        val_loss, val_correct, val_total = 0, 0, 0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                
                # Forward pass
                outputs = model(images)
                loss = criterion(outputs, labels)
                
                # Metrics
                val_loss += loss.item()
                _, predicted = outputs.max(1)
                val_total += labels.size(0)
                val_correct += predicted.eq(labels).sum().item()
                
        avg_val_loss = val_loss / len(val_loader)
        scheduler.step(avg_val_loss)
                
        # Test phase
        model.eval()
        test_loss, test_correct, test_total = 0, 0, 0
        with torch.no_grad():
            for images, labels in test_loader:
                images, labels = images.to(device), labels.to(device)
                
                # Forward pass
                outputs = model(images)
                loss = criterion(outputs, labels)
                
                # Metrics
                test_loss += loss.item()
                _, predicted = outputs.max(1)
                test_total += labels.size(0)
                test_correct += predicted.eq(labels).sum().item()
        
        # Print epoch results
        print(f"Epoch {epoch+1}/{epochs}")
        print(f"Train Loss: {train_loss/len(train_loader):.4f}, Train Acc: {100 * train_correct/train_total:.2f}%")
        print(f"Val Loss: {val_loss/len(val_loader):.4f}, Val Acc: {100 * val_correct/val_total:.2f}%")
        print(f"Test Loss: {test_loss/len(test_loader):.4f}, Test Acc: {100 * test_correct/test_total:.2f}%\n\n")

# Train and validate the model
train_and_validate(model, train_loader, val_loader, criterion, optimizer, epochs=20)

Epoch 1/20
Train Loss: 0.4497, Train Acc: 79.13%
Val Loss: 0.3595, Val Acc: 84.66%
Test Loss: 0.3503, Test Acc: 85.14%


Epoch 2/20
Train Loss: 0.3045, Train Acc: 87.17%
Val Loss: 0.3606, Val Acc: 86.10%
Test Loss: 0.3518, Test Acc: 86.54%


Epoch 3/20
Train Loss: 0.2535, Train Acc: 89.67%
Val Loss: 0.2718, Val Acc: 88.83%
Test Loss: 0.3054, Test Acc: 87.71%


Epoch 4/20
Train Loss: 0.2299, Train Acc: 90.80%
Val Loss: 0.2946, Val Acc: 87.87%
Test Loss: 0.2955, Test Acc: 88.04%


Epoch 5/20
Train Loss: 0.2138, Train Acc: 91.63%
Val Loss: 0.2451, Val Acc: 89.90%
Test Loss: 0.2849, Test Acc: 88.64%


Epoch 6/20
Train Loss: 0.2005, Train Acc: 92.14%
Val Loss: 0.2593, Val Acc: 89.86%
Test Loss: 0.3394, Test Acc: 87.03%


Epoch 7/20
Train Loss: 0.1904, Train Acc: 92.63%
Val Loss: 0.2866, Val Acc: 89.32%
Test Loss: 0.3368, Test Acc: 86.41%


Epoch 8/20
Train Loss: 0.1608, Train Acc: 93.97%
Val Loss: 0.3161, Val Acc: 89.17%
Test Loss: 0.2815, Test Acc: 88.83%


Epoch 9/20
Train Loss: 0.1545, T