## Importing modules

In [1]:
import h5py
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
from torchvision.models import resnet50
import numpy as np
from functools import partial

import skimage
from skimage.color import rgb2hed
import pywt
from PIL import Image
import cv2
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torchstain

from histomicstk.preprocessing.color_normalization.deconvolution_based_normalization import (
    deconvolution_based_normalization,
)

In [2]:
class H5Dataset(Dataset):
    def __init__(self, image_file, label_file, transform=None):
        self.transform = transform

        # Load data from the H5 file
        with h5py.File(image_file, "r") as f:
            self.images = f["x"][:]
        with h5py.File(label_file, "r") as f:
            self.labels = f["y"][:].reshape(
                -1,
            )

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label

In [None]:
class RGB2HED(torch.nn.Module):
    def __init__(self, mode=None):
        super(RGB2HED, self).__init__()
        self.mode = mode

    def forward(self, img):
        img = img.astype(np.float32) / 255.0
        hed_img = rgb2hed(img) * 255.0
        hed_img = np.tile(hed_img[:, :, -2:-1], reps=(1, 1, 3))
        return hed_img


class WaveletTransform(nn.Module):
    def __init__(self, wavelet="haar", threshold=20):
        super(WaveletTransform, self).__init__()
        self.wavelet = wavelet
        self.threshold = threshold

    def forward(self, img):
        grayscale_image = np.dot(img.astype(np.uint8), [0.299, 0.587, 0.114])

        # Step 2: Perform 2D wavelet decomposition
        coeffs = pywt.wavedec2(grayscale_image, wavelet=self.wavelet, level=2)
        cA, details = coeffs[0], coeffs[1:]

        # Step 3: Apply thresholding to detail coefficients
        def threshold_coeffs(coeffs, threshold):
            return [pywt.threshold(c, threshold, mode="soft") for c in coeffs]

        details_thresh = [
            threshold_coeffs(detail, self.threshold) for detail in details
        ]
        coeffs_thresh = [cA] + details_thresh

        # Step 4: Reconstruct the image
        compressed_image = pywt.waverec2(coeffs_thresh, wavelet=self.wavelet)
        compressed_image = np.clip(compressed_image, 0, 255).astype(np.uint8)
        compressed_image = np.tile(np.expand_dims(compressed_image, -1), (1, 1, 3))

        return compressed_image


class CLAHE(nn.Module):
    def __init__(self, mode=None):
        super(CLAHE, self).__init__()
        self.mode = mode

    def forward(self, image):
        # Convert to LAB color space
        lab_image = cv2.cvtColor(image.astype(np.uint8), cv2.COLOR_RGB2LAB)
        l_channel, a, b = cv2.split(lab_image)

        # Apply CLAHE to the L channel
        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
        l_channel = clahe.apply(l_channel)

        # Merge and convert back to RGB
        lab_image = cv2.merge((l_channel, a, b))
        return cv2.cvtColor(lab_image, cv2.COLOR_LAB2RGB)


class Opening(nn.Module):
    def __init__(self):
        super(Opening, self).__init__()

    def forward(self, image):
        return skimage.morphology.opening(image)


class Macenko(nn.Module):
    def __init__(self, reference_image, target_W=None, alpha=1, beta=0.01):
        super(Macenko, self).__init__()
        self.target_W = target_W
        self.alpha = alpha
        self.beta = beta
        self.reference_image = reference_image.astype(np.uint8)

    def forward(self, image):
        """
        Apply Macenko normalization to a single image with error handling.

        Parameters:
            image (np.ndarray): The image to normalize, shape (H, W, C) in RGB format.
            reference_image (np.ndarray): The reference image for normalization, shape (H, W, C) in RGB format.

        Returns:
            np.ndarray: The normalized image, shape (C, H, W) in normalized format.
            None: If normalization fails for any reason.
        """
        try:
            # # Set up the transformation
            # T = transforms.Compose([
            #     transforms.ToTensor(),
            # ])

            # Initialize the MacenkoNormalizer
            normalizer = torchstain.normalizers.MacenkoNormalizer(backend="torch")

            # Fit the normalizer with the reference image
            normalizer.fit(self.reference_image)

            # Transform the image and apply normalization
            t_to_transform = image
            norm_img, _, _ = normalizer.normalize(I=t_to_transform, stains=True)
            print('norm_img', norm_img.shape)
            final_img = torch.stack((image, norm_img), axis=-3)
            # print('final img', final_img.shape)

            # Return the normalized image
            return final_img

        except torch.linalg.LinAlgError as e:
            # print(f"LinAlgError during normalization: {e}")
            pass
        except Exception as e:
            # print(f"Unexpected error during normalization: {e}")
            pass

        # Return None if normalization fails
        return torch.stack((image, image), axis=-3)


class ReinhardNormalization(nn.Module):
    def __init__(self, reference_image):
        super(ReinhardNormalization, self).__init__()
        self.reference_image = reference_image.astype(np.uint8)

    def forward(self, image):
        """
        Apply Reinhard normalization to a single image with error handling.

        Parameters:
            image (np.ndarray): The image to normalize, shape (H, W, C) in RGB format.
            reference_image (np.ndarray): The reference image for normalization, shape (H, W, C) in RGB format.

        Returns:
            np.ndarray: The normalized image, shape (H, W, C) in normalized format.
            None: If normalization fails for any reason.
        """
        try:
            # Initialize the ReinhardNormalizer
            normalizer = torchstain.normalizers.ReinhardNormalizer()

            # Fit the normalizer with the reference image
            normalizer.fit(self.reference_image)

            # Normalize the image
            normalized_image = normalizer.normalize(image)
            final_img = torch.stack((image, normalized_image), axis=-3)
            print('final img', final_img.shape)

            # Return the normalized image
            return normalized_image

        except Exception as e:
            # print(f"Unexpected error during Reinhard normalization: {e}")
            pass

        # Return None if normalization fails
        return image #torch.stack((image, image), axis=-3)


train_data = H5Dataset(
    image_file="../../../pcam/training_split.h5",
    label_file="../../../Labels/Labels/camelyonpatch_level_2_split_train_y.h5",
)
reference_image = train_data.images[176298]

train_transform = transforms.Compose(
    [
        # RGB2HED(),
        # WaveletTransform(),
        # CLAHE(),
        # Opening(),
        transforms.ToPILImage(),
        transforms.ColorJitter(brightness=0.5, saturation=0.25, hue=0.1, contrast=0.5),
        transforms.RandomAffine(10, (0.05, 0.05), fill=255),
        transforms.RandomHorizontalFlip(0.5),
        transforms.RandomVerticalFlip(0.5),
        transforms.ToTensor(),
        Macenko(reference_image=reference_image),
        # ReinhardNormalization(reference_image=reference_image),
        # transforms.Normalize(
        #     [0.6716241, 0.48636872, 0.60884315, 0.6716241, 0.48636872, 0.60884315],
        #     [0.27210504, 0.31001145, 0.2918652, 0.6716241, 0.48636872, 0.60884315]
        # ),
    ]
)

val_transform = transforms.Compose(
    [
        transforms.ToPILImage(),
        transforms.ToTensor(),
        # transforms.Normalize(
        #     [0.6716241, 0.48636872, 0.60884315, 0.6716241, 0.48636872, 0.60884315], 
        #     [0.27210504, 0.31001145, 0.2918652, 0.6716241, 0.48636872, 0.60884315]
        # ),
    ]
)

In [4]:
# Load datasets
train_dataset = H5Dataset(
    image_file="../../../pcam/training_split.h5",
    label_file="../../../Labels/Labels/camelyonpatch_level_2_split_train_y.h5",
    transform=train_transform,
)
val_dataset = H5Dataset(
    image_file="../../../pcam/validation_split.h5",
    label_file="../../../Labels/Labels/camelyonpatch_level_2_split_valid_y.h5",
    transform=val_transform,
)

test_dataset = H5Dataset(
    image_file="../../../pcam/test_split.h5",
    label_file="../../../Labels/Labels/camelyonpatch_level_2_split_test_y.h5",
    transform=val_transform,
)

# Create dataloaders
bs = 128
train_loader = DataLoader(train_dataset, batch_size=bs, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=bs, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, shuffle=False, num_workers=4)

In [5]:
class ResNetModel(nn.Module):
    def __init__(self, num_classes=2):
        super(ResNetModel, self).__init__()
        self.resnet = resnet50(pretrained=False)
        # Replace the final fully connected layer
        self.resnet.fc = nn.Linear(self.resnet.fc.in_features, num_classes)

    def forward(self, x):
        return self.resnet(x)

In [6]:
# Initialize model, loss function, and optimizer
device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")
model = ResNetModel(num_classes=2).to(device)  # Binary classification
model.resnet.conv1.in_channels = 6
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.1)



In [7]:
# Training and validation loops
def train_and_validate(
    model, train_loader, val_loader, criterion, optimizer, epochs=10
):
    for epoch in range(epochs):
        # Training phase
        model.train()
        train_loss, train_correct, train_total = 0, 0, 0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Metrics
            train_loss += loss.item()
            _, predicted = outputs.max(1)
            train_total += labels.size(0)
            train_correct += predicted.eq(labels).sum().item()

        # Validation phase
        model.eval()
        val_loss, val_correct, val_total = 0, 0, 0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)

                # Forward pass
                outputs = model(images)
                loss = criterion(outputs, labels)

                # Metrics
                val_loss += loss.item()
                _, predicted = outputs.max(1)
                val_total += labels.size(0)
                val_correct += predicted.eq(labels).sum().item()

        # Test phase
        model.eval()
        test_loss, test_correct, test_total = 0, 0, 0
        with torch.no_grad():
            for images, labels in test_loader:
                images, labels = images.to(device), labels.to(device)

                # Forward pass
                outputs = model(images)
                loss = criterion(outputs, labels)

                # Metrics
                test_loss += loss.item()
                _, predicted = outputs.max(1)
                test_total += labels.size(0)
                test_correct += predicted.eq(labels).sum().item()

        # Print epoch results
        print(f"Epoch {epoch+1}/{epochs}")
        print(
            f"Train Loss: {train_loss/len(train_loader):.4f}, Train Acc: {100 * train_correct/train_total:.2f}%"
        )
        print(
            f"Val Loss: {val_loss/len(val_loader):.4f}, Val Acc: {100 * val_correct/val_total:.2f}%"
        )
        print(
            f"Test Loss: {test_loss/len(test_loader):.4f}, Test Acc: {100 * test_correct/test_total:.2f}%\n\n"
        )


# Train and validate the model
train_and_validate(model, train_loader, val_loader, criterion, optimizer, epochs=20)

RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/vaibhav/miniconda3/envs/debo/lib/python3.11/site-packages/torch/utils/data/_utils/worker.py", line 351, in _worker_loop
    data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/vaibhav/miniconda3/envs/debo/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py", line 52, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/vaibhav/miniconda3/envs/debo/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py", line 52, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
            ~~~~~~~~~~~~^^^^^
  File "/tmp/ipykernel_921154/1861484373.py", line 21, in __getitem__
    image = self.transform(image)
            ^^^^^^^^^^^^^^^^^^^^^
  File "/home/vaibhav/miniconda3/envs/debo/lib/python3.11/site-packages/torchvision/transforms/transforms.py", line 95, in __call__
    img = t(img)
          ^^^^^^
  File "/home/vaibhav/miniconda3/envs/debo/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/vaibhav/miniconda3/envs/debo/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/vaibhav/miniconda3/envs/debo/lib/python3.11/site-packages/torchvision/transforms/transforms.py", line 277, in forward
    return F.normalize(tensor, self.mean, self.std, self.inplace)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/vaibhav/miniconda3/envs/debo/lib/python3.11/site-packages/torchvision/transforms/functional.py", line 350, in normalize
    return F_t.normalize(tensor, mean=mean, std=std, inplace=inplace)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/vaibhav/miniconda3/envs/debo/lib/python3.11/site-packages/torchvision/transforms/_functional_tensor.py", line 928, in normalize
    return tensor.sub_(mean).div_(std)
           ^^^^^^^^^^^^^^^^^
RuntimeError: The size of tensor a (2) must match the size of tensor b (6) at non-singleton dimension 1


In [None]:
# Adamw lr=1e-3, wd=0.1

# Epoch 1/20
# Train Loss: 0.4875, Train Acc: 77.33%
# Val Loss: 0.5107, Val Acc: 76.40%
# Test Loss: 0.5035, Test Acc: 77.21%


# Epoch 2/20
# Train Loss: 0.3851, Train Acc: 83.07%
# Val Loss: 0.3695, Val Acc: 83.42%
# Test Loss: 0.3493, Test Acc: 84.66%


# Epoch 3/20
# Train Loss: 0.3143, Train Acc: 86.64%
# Val Loss: 0.3746, Val Acc: 83.56%
# Test Loss: 0.3669, Test Acc: 82.75%


# Epoch 4/20
# Train Loss: 0.2835, Train Acc: 88.26%
# Val Loss: 0.3393, Val Acc: 85.56%
# Test Loss: 0.3460, Test Acc: 84.50%


# Epoch 5/20
# Train Loss: 0.2653, Train Acc: 89.18%
# Val Loss: 0.4198, Val Acc: 84.33%
# Test Loss: 0.4474, Test Acc: 81.17%


# Epoch 6/20
# Train Loss: 0.2547, Train Acc: 89.61%
# Val Loss: 0.3080, Val Acc: 86.96%
# Test Loss: 0.3253, Test Acc: 85.78%


# Epoch 7/20
# Train Loss: 0.2447, Train Acc: 90.04%
# Val Loss: 0.2909, Val Acc: 87.81%
# Test Loss: 0.3318, Test Acc: 85.75%


# Epoch 8/20
# Train Loss: 0.2380, Train Acc: 90.41%
# Val Loss: 0.2811, Val Acc: 88.64%
# Test Loss: 0.3492, Test Acc: 85.32%


# Epoch 9/20
# Train Loss: 0.2333, Train Acc: 90.70%
# Val Loss: 0.2742, Val Acc: 88.77%
# Test Loss: 0.2985, Test Acc: 87.68%


# Epoch 10/20
# Train Loss: 0.2280, Train Acc: 90.88%
# Val Loss: 0.2770, Val Acc: 88.74%
# Test Loss: 0.2955, Test Acc: 87.54%


# Epoch 11/20
# Train Loss: 0.2250, Train Acc: 91.04%
# Val Loss: 0.3370, Val Acc: 87.11%
# Test Loss: 0.3484, Test Acc: 85.82%


# Epoch 12/20
# Train Loss: 0.2209, Train Acc: 91.23%
# Val Loss: 0.4230, Val Acc: 81.84%
# Test Loss: 0.4566, Test Acc: 80.81%


# Epoch 13/20
# Train Loss: 0.2188, Train Acc: 91.34%
# Val Loss: 0.3244, Val Acc: 86.85%
# Test Loss: 0.3417, Test Acc: 84.73%


# Epoch 14/20
# Train Loss: 0.2153, Train Acc: 91.44%
# Val Loss: 0.3606, Val Acc: 86.03%
# Test Loss: 0.4001, Test Acc: 82.55%


# Epoch 15/20
# Train Loss: 0.2147, Train Acc: 91.51%
# Val Loss: 0.2853, Val Acc: 88.75%
# Test Loss: 0.3511, Test Acc: 85.88%


# Epoch 16/20
# Train Loss: 0.2147, Train Acc: 91.51%
# Val Loss: 0.2635, Val Acc: 89.07%
# Test Loss: 0.3208, Test Acc: 86.80%


# Epoch 17/20
# Train Loss: 0.2120, Train Acc: 91.64%
# Val Loss: 0.3522, Val Acc: 86.07%
# Test Loss: 0.3377, Test Acc: 85.63%


# Epoch 18/20
# Train Loss: 0.2099, Train Acc: 91.74%
# Val Loss: 0.2562, Val Acc: 89.83%
# Test Loss: 0.2840, Test Acc: 88.69%


# Epoch 19/20
# Train Loss: 0.2096, Train Acc: 91.73%
# Val Loss: 0.3078, Val Acc: 88.31%
# Test Loss: 0.3197, Test Acc: 87.25%


# Epoch 20/20
# Train Loss: 0.2084, Train Acc: 91.76%
# Val Loss: 0.2884, Val Acc: 87.91%
# Test Loss: 0.2865, Test Acc: 87.74%


In [None]:
# Adamw lr=1e-4, wd=0.01

# Epoch 1/10
# Train Loss: 0.5258, Train Acc: 74.43%
# Val Loss: 0.4653, Val Acc: 78.62%
# Test Loss: 0.4727, Test Acc: 78.53%


# Epoch 2/10
# Train Loss: 0.4634, Train Acc: 78.67%
# Val Loss: 0.4820, Val Acc: 76.75%
# Test Loss: 0.4805, Test Acc: 77.37%


# Epoch 3/10
# Train Loss: 0.4302, Train Acc: 80.49%
# Val Loss: 0.4225, Val Acc: 79.57%
# Test Loss: 0.4196, Test Acc: 80.92%


# Epoch 4/10
# Train Loss: 0.3916, Train Acc: 82.53%
# Val Loss: 0.4054, Val Acc: 80.63%
# Test Loss: 0.3912, Test Acc: 81.98%


# Epoch 5/10
# Train Loss: 0.3528, Train Acc: 84.56%
# Val Loss: 0.4083, Val Acc: 80.78%
# Test Loss: 0.3981, Test Acc: 81.82%


# Epoch 6/10
# Train Loss: 0.3193, Train Acc: 86.38%
# Val Loss: 0.3495, Val Acc: 83.68%
# Test Loss: 0.3886, Test Acc: 82.54%


# Epoch 7/10
# Train Loss: 0.2909, Train Acc: 87.75%
# Val Loss: 0.3562, Val Acc: 84.77%
# Test Loss: 0.3520, Test Acc: 85.36%


# Epoch 8/10
# Train Loss: 0.2679, Train Acc: 88.99%
# Val Loss: 0.3755, Val Acc: 84.72%
# Test Loss: 0.4703, Test Acc: 80.82%


# Epoch 9/10
# Train Loss: 0.2513, Train Acc: 89.77%
# Val Loss: 0.2995, Val Acc: 87.69%
# Test Loss: 0.3098, Test Acc: 86.81%


# Epoch 10/10
# Train Loss: 0.2349, Train Acc: 90.56%
# Val Loss: 0.3082, Val Acc: 87.26%
# Test Loss: 0.3217, Test Acc: 86.59%