## Importing modules

In [1]:
import h5py
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
from torchvision.models import resnet50, vit_b_16
import numpy as np

import skimage
from skimage.color import rgb2hed, hed2rgb
import pywt
from PIL import Image
from sklearn.decomposition import PCA
import cv2
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torchstain

from matplotlib.colors import ListedColormap
from histomicstk.preprocessing.color_normalization.deconvolution_based_normalization import deconvolution_based_normalization


In [2]:
class H5Dataset(Dataset):
    def __init__(self, image_file, label_file, transform=None):
        self.transform = transform
        
        # Load data from the H5 file
        with h5py.File(image_file, 'r') as f:
            self.images = f['x'][:]
        with h5py.File(label_file, 'r') as f:
            self.labels = f['y'][:].reshape(-1,)
        
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]
        
        
        if self.transform:
            image = self.transform(image)
        
        return image, label

In [3]:
class RGB2HED(torch.nn.Module):
    def __init__(self, mode=None):
        super(RGB2HED, self).__init__()
        self.mode = mode
    def forward(self, img):
        img = img.astype(np.float32) / 255.
        hed_img = rgb2hed(img) * 255.
        hed_img = np.tile(hed_img[:, :, -2:-1], reps=(1,1,3))
        return hed_img
    
        
class WaveletTransform(nn.Module):
    def __init__(self, wavelet='haar', threshold=20):
        super(WaveletTransform, self).__init__()
        self.wavelet = wavelet
        self.threshold = threshold
        
    def forward(self, img):
        grayscale_image = np.dot(img.astype(np.uint8), [0.299, 0.587, 0.114])
        
        # Step 2: Perform 2D wavelet decomposition
        coeffs = pywt.wavedec2(grayscale_image, wavelet=self.wavelet, level=2)
        cA, details = coeffs[0], coeffs[1:]
        
        # Step 3: Apply thresholding to detail coefficients
        def threshold_coeffs(coeffs, threshold):
            return [pywt.threshold(c, threshold, mode='soft') for c in coeffs] 
        
        
        details_thresh = [threshold_coeffs(detail, self.threshold) for detail in details]
        coeffs_thresh = [cA] + details_thresh
        
        # Step 4: Reconstruct the image
        compressed_image = pywt.waverec2(coeffs_thresh, wavelet=self.wavelet)
        compressed_image = np.clip(compressed_image, 0, 255).astype(np.uint8)
        compressed_image = np.tile(np.expand_dims(compressed_image, -1), (1,1,3))
        
        return compressed_image
    

class CLAHE(nn.Module):
    def __init__(self, mode=None):
        super(CLAHE, self).__init__()
        self.mode = mode
    def forward(self, image):
        # Convert to LAB color space
        lab_image = cv2.cvtColor(image.astype(np.uint8), cv2.COLOR_RGB2LAB)
        l_channel, a, b = cv2.split(lab_image)

        # Apply CLAHE to the L channel
        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
        l_channel = clahe.apply(l_channel)

        # Merge and convert back to RGB
        lab_image = cv2.merge((l_channel, a, b))
        return cv2.cvtColor(lab_image, cv2.COLOR_LAB2RGB)


class Opening(nn.Module):
    def __init__(self):
        super(Opening, self).__init__()
        
    def forward(self, image):
        return skimage.morphology.opening(image)
    
    
class Macenko(nn.Module):
    def __init__(self, reference_image, target_W=None, alpha=1, beta=0.01):
        super(Macenko, self).__init__()
        self.target_W = target_W
        self.alpha = alpha
        self.beta = beta
        self.reference_image = reference_image.astype(np.uint8)
    
    def forward(self, image):
        """
        Apply Macenko normalization to a single image with error handling.
        
        Parameters:
            image (np.ndarray): The image to normalize, shape (H, W, C) in RGB format.
            reference_image (np.ndarray): The reference image for normalization, shape (H, W, C) in RGB format.
        
        Returns:
            np.ndarray: The normalized image, shape (C, H, W) in normalized format.
            None: If normalization fails for any reason.
        """
        try:
            # # Set up the transformation
            # T = transforms.Compose([
            #     transforms.ToTensor(),
            # ])

            # Initialize the MacenkoNormalizer
            normalizer = torchstain.normalizers.MacenkoNormalizer(backend='torch')

            # Fit the normalizer with the reference image
            normalizer.fit(self.reference_image)

            # Transform the image and apply normalization
            t_to_transform = image
            norm_img, _, _ = normalizer.normalize(I=t_to_transform, stains=True)

            # Return the normalized image
            return norm_img

        except torch.linalg.LinAlgError as e:
            # print(f"LinAlgError during normalization: {e}")
            pass
        except Exception as e:
            # print(f"Unexpected error during normalization: {e}")
            pass

        # Return None if normalization fails
        return image

    
class ReinhardNormalization(nn.Module):
    def __init__(self, reference_image):
        super(ReinhardNormalization, self).__init__()
        self.reference_image = reference_image.astype(np.uint8)
        
    def forward(self, image):
        """
        Apply Reinhard normalization to a single image with error handling.

        Parameters:
            image (np.ndarray): The image to normalize, shape (H, W, C) in RGB format.
            reference_image (np.ndarray): The reference image for normalization, shape (H, W, C) in RGB format.

        Returns:
            np.ndarray: The normalized image, shape (H, W, C) in normalized format.
            None: If normalization fails for any reason.
        """
        try:
            # Initialize the ReinhardNormalizer
            normalizer = torchstain.normalizers.ReinhardNormalizer()

            # Fit the normalizer with the reference image
            normalizer.fit(self.reference_image)

            # Normalize the image
            normalized_image = normalizer.normalize(image)

            # Return the normalized image
            return normalized_image

        except Exception as e:
            # print(f"Unexpected error during Reinhard normalization: {e}")
            pass

        # Return None if normalization fails
        return image


train_data = H5Dataset(image_file='../../../pcam/training_split.h5', 
                        label_file='../../../Labels/Labels/camelyonpatch_level_2_split_train_y.h5')
reference_image = train_data.images[176298]

train_transform = transforms.Compose(
    [
        # RGB2HED(),
        # WaveletTransform(),
        # CLAHE(),
        # Opening(),
        transforms.ToPILImage(),
        transforms.ColorJitter(brightness=.5, saturation=.25,
                            hue=.1, contrast=.5),
        transforms.RandomAffine(10, (0.05, 0.05), fill=255),
        transforms.RandomHorizontalFlip(.5),
        transforms.RandomVerticalFlip(.5),
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        # Macenko(reference_image=reference_image),
        # ReinhardNormalization(reference_image=reference_image),
        transforms.Normalize([0.6716241, 0.48636872, 0.60884315],
                            [0.27210504, 0.31001145, 0.2918652]),
        
])

val_transform = transforms.Compose(
    [
        transforms.ToPILImage(),
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.6716241, 0.48636872, 0.60884315],
                            [0.27210504, 0.31001145, 0.2918652])
        
    ]
)

In [4]:
# Load datasets
train_dataset = H5Dataset(image_file='../../../pcam/training_split.h5', 
                          label_file='../../../Labels/Labels/camelyonpatch_level_2_split_train_y.h5', 
                          transform=train_transform)
val_dataset = H5Dataset(image_file='../../../pcam/validation_split.h5', 
                        label_file='../../../Labels/Labels/camelyonpatch_level_2_split_valid_y.h5',
                        transform=val_transform)

test_dataset = H5Dataset(image_file='../../../pcam/test_split.h5', 
                        label_file='../../../Labels/Labels/camelyonpatch_level_2_split_test_y.h5',
                        transform=val_transform)

# Create dataloaders
bs = 128
train_loader = DataLoader(train_dataset, batch_size=bs, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=bs, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, shuffle=False, num_workers=4)

In [5]:
# Initialize model, loss function, and optimizer
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# Load a pre-trained ViT model from torchvision and modify the classifier head
model = vit_b_16(pretrained=True)
model.heads = nn.Linear(model.hidden_dim, 2)  # Assuming binary classification (2 classes)
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=1, verbose=True, threshold=0.01)




In [6]:
# Training and validation loops
def train_and_validate(model, train_loader, val_loader, criterion, optimizer, epochs=10):
    for epoch in range(epochs):
        # Training phase
        model.train()
        train_loss, train_correct, train_total = 0, 0, 0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            
            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            # Metrics
            train_loss += loss.item()
            _, predicted = outputs.max(1)
            train_total += labels.size(0)
            train_correct += predicted.eq(labels).sum().item()
        
        # Validation phase
        model.eval()
        val_loss, val_correct, val_total = 0, 0, 0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                
                # Forward pass
                outputs = model(images)
                loss = criterion(outputs, labels)
                
                # Metrics
                val_loss += loss.item()
                _, predicted = outputs.max(1)
                val_total += labels.size(0)
                val_correct += predicted.eq(labels).sum().item()
                
        avg_val_loss = val_loss / len(val_loader)
        scheduler.step(avg_val_loss)
                
        # Test phase
        model.eval()
        test_loss, test_correct, test_total = 0, 0, 0
        with torch.no_grad():
            for images, labels in test_loader:
                images, labels = images.to(device), labels.to(device)
                
                # Forward pass
                outputs = model(images)
                loss = criterion(outputs, labels)
                
                # Metrics
                test_loss += loss.item()
                _, predicted = outputs.max(1)
                test_total += labels.size(0)
                test_correct += predicted.eq(labels).sum().item()
        
        # Print epoch results
        print(f"Epoch {epoch+1}/{epochs}")
        print(f"Train Loss: {train_loss/len(train_loader):.4f}, Train Acc: {100 * train_correct/train_total:.2f}%")
        print(f"Val Loss: {val_loss/len(val_loader):.4f}, Val Acc: {100 * val_correct/val_total:.2f}%")
        print(f"Test Loss: {test_loss/len(test_loader):.4f}, Test Acc: {100 * test_correct/test_total:.2f}%\n\n")

# Train and validate the model
train_and_validate(model, train_loader, val_loader, criterion, optimizer, epochs=10)


# Epoch 9/10
# Train Loss: 0.0016, Train Acc: 99.96%
# Val Loss: 0.8319, Val Acc: 89.04%
# Test Loss: 1.1633, Test Acc: 85.61%

Epoch 1/10
Train Loss: 0.2013, Train Acc: 92.02%
Val Loss: 0.2464, Val Acc: 90.47%
Test Loss: 0.2224, Test Acc: 90.86%


Epoch 2/10
Train Loss: 0.1457, Train Acc: 94.55%
Val Loss: 0.3991, Val Acc: 85.88%
Test Loss: 0.4939, Test Acc: 83.26%


Epoch 3/10
Train Loss: 0.1271, Train Acc: 95.29%
Val Loss: 0.3157, Val Acc: 89.23%
Test Loss: 0.2654, Test Acc: 90.16%


Epoch 4/10
Train Loss: 0.0908, Train Acc: 96.78%
Val Loss: 0.3297, Val Acc: 90.11%
Test Loss: 0.4090, Test Acc: 87.60%


Epoch 5/10
Train Loss: 0.0821, Train Acc: 97.09%
Val Loss: 0.3135, Val Acc: 91.05%
Test Loss: 0.4162, Test Acc: 88.20%


Epoch 6/10
Train Loss: 0.0622, Train Acc: 97.84%
Val Loss: 0.3825, Val Acc: 89.19%
Test Loss: 0.4879, Test Acc: 86.37%


Epoch 7/10
Train Loss: 0.0560, Train Acc: 98.08%
Val Loss: 0.4696, Val Acc: 88.21%
Test Loss: 0.6007, Test Acc: 84.84%


Epoch 8/10
Train Loss: 0.0461, Train Acc: 98.40%
Val Loss: 0.4812, Val Acc: 89.03%
Test Loss: 0.6082, Test Acc: 87.03%


Epoch 9/10
Train Loss: 0.0418, T