## Importing modules

In [1]:
import h5py
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
from torchvision.models import resnet50
import numpy as np

import skimage
from skimage.color import rgb2hed, hed2rgb
import pywt
from PIL import Image
from sklearn.decomposition import PCA
import cv2

In [2]:
class H5Dataset(Dataset):
    def __init__(self, image_file, label_file, transform=None):
        self.transform = transform
        
        # Load data from the H5 file
        with h5py.File(image_file, 'r') as f:
            self.images = f['x'][:]
        with h5py.File(label_file, 'r') as f:
            self.labels = f['y'][:].reshape(-1,)
        
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]
        
        
        if self.transform:
            image = self.transform(image)
        
        return image, label

In [3]:
class RGB2HED(torch.nn.Module):
    def __init__(self, mode=None):
        super(RGB2HED, self).__init__()
        self.mode = mode
    def forward(self, img):
        img = img.astype(np.float32) / 255.
        hed_img = rgb2hed(img) * 255.
        hed_img = np.tile(hed_img[:, :, -2:-1], reps=(1,1,3))
        return hed_img
    
        
class WaveletTransform(nn.Module):
    def __init__(self, wavelet='haar', threshold=20):
        super(WaveletTransform, self).__init__()
        self.wavelet = wavelet
        self.threshold = threshold
        
    def forward(self, img):
        grayscale_image = np.dot(img.astype(np.uint8), [0.299, 0.587, 0.114])
        
        # Step 2: Perform 2D wavelet decomposition
        coeffs = pywt.wavedec2(grayscale_image, wavelet=self.wavelet, level=2)
        cA, details = coeffs[0], coeffs[1:]
        
        # Step 3: Apply thresholding to detail coefficients
        def threshold_coeffs(coeffs, threshold):
            return [pywt.threshold(c, threshold, mode='soft') for c in coeffs] 
        
        
        details_thresh = [threshold_coeffs(detail, self.threshold) for detail in details]
        coeffs_thresh = [cA] + details_thresh
        
        # Step 4: Reconstruct the image
        compressed_image = pywt.waverec2(coeffs_thresh, wavelet=self.wavelet)
        compressed_image = np.clip(compressed_image, 0, 255).astype(np.uint8)
        compressed_image = np.tile(np.expand_dims(compressed_image, -1), (1,1,3))
        
        return compressed_image
    

class CLAHE(nn.Module):
    def __init__(self, mode=None):
        super(CLAHE, self).__init__()
        self.mode = mode
    def forward(self, image):
        # Convert to LAB color space
        lab_image = cv2.cvtColor(image.astype(np.uint8), cv2.COLOR_RGB2LAB)
        l_channel, a, b = cv2.split(lab_image)

        # Apply CLAHE to the L channel
        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
        l_channel = clahe.apply(l_channel)

        # Merge and convert back to RGB
        lab_image = cv2.merge((l_channel, a, b))
        return cv2.cvtColor(lab_image, cv2.COLOR_LAB2RGB)
    

class Macenko(nn.Module):
    def __init__(self):
        super(Macenko, self).__init__()
    
    def forward(self, image):
        """Normalize H&E stained images using Macenko method."""
        # Reshape image to 2D
        h, w, c = image.shape
        image_flat = image.reshape((-1, c))

        # PCA for stain separation
        pca = PCA(n_components=c)
        pca.fit(image_flat)
        stains = pca.components_

        # Normalize to intensity ranges
        norms = np.sqrt(np.sum(stains**2, axis=0))
        normalized_stains = stains / norms
        normalized_image = np.dot(image_flat, normalized_stains.T)
        
        # Scale back and reshape
        return normalized_image.reshape((h, w, c))
    
class Opening(nn.Module):
    def __init__(self):
        super(Opening, self).__init__()
        
    def forward(self, image):
        return skimage.morphology.opening(image)


transform = transforms.Compose([
    Opening(),
    CLAHE(),
    Macenko(),
    WaveletTransform(),
    transforms.ToPILImage(),
    # transforms.Resize((96, 96)),
    transforms.ToTensor(),
    # transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

In [4]:
# Load datasets
train_dataset = H5Dataset(image_file='../../pcam/training_split.h5', 
                          label_file='../../Labels/Labels/camelyonpatch_level_2_split_train_y.h5', 
                          transform=transform)
val_dataset = H5Dataset(image_file='../../pcam/validation_split.h5', 
                        label_file='../../Labels/Labels/camelyonpatch_level_2_split_valid_y.h5',
                        transform=transform)

test_dataset = H5Dataset(image_file='../../pcam/test_split.h5', 
                        label_file='../../Labels/Labels/camelyonpatch_level_2_split_test_y.h5',
                        transform=transform)

# Create dataloaders
bs = 128
train_loader = DataLoader(train_dataset, batch_size=bs, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=bs, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, shuffle=False, num_workers=4)

In [5]:
class ResNetModel(nn.Module):
    def __init__(self, num_classes=2):
        super(ResNetModel, self).__init__()
        self.resnet = resnet50(pretrained=False)
        # Replace the final fully connected layer
        self.resnet.fc = nn.Linear(self.resnet.fc.in_features, num_classes)
    
    def forward(self, x):
        return self.resnet(x)

In [6]:
# Initialize model, loss function, and optimizer
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = ResNetModel(num_classes=2).to(device)  # Binary classification
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)



In [None]:
# Training and validation loops
def train_and_validate(model, train_loader, val_loader, criterion, optimizer, epochs=10):
    for epoch in range(epochs):
        # Training phase
        model.train()
        train_loss, train_correct, train_total = 0, 0, 0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            
            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            # Metrics
            train_loss += loss.item()
            _, predicted = outputs.max(1)
            train_total += labels.size(0)
            train_correct += predicted.eq(labels).sum().item()
        
        # Validation phase
        model.eval()
        val_loss, val_correct, val_total = 0, 0, 0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                
                # Forward pass
                outputs = model(images)
                loss = criterion(outputs, labels)
                
                # Metrics
                val_loss += loss.item()
                _, predicted = outputs.max(1)
                val_total += labels.size(0)
                val_correct += predicted.eq(labels).sum().item()
                
        # Test phase
        model.eval()
        test_loss, test_correct, test_total = 0, 0, 0
        with torch.no_grad():
            for images, labels in test_loader:
                images, labels = images.to(device), labels.to(device)
                
                # Forward pass
                outputs = model(images)
                loss = criterion(outputs, labels)
                
                # Metrics
                test_loss += loss.item()
                _, predicted = outputs.max(1)
                test_total += labels.size(0)
                test_correct += predicted.eq(labels).sum().item()
        
        # Print epoch results
        print(f"Epoch {epoch+1}/{epochs}")
        print(f"Train Loss: {train_loss/len(train_loader):.4f}, Train Acc: {100 * train_correct/train_total:.2f}%")
        print(f"Val Loss: {val_loss/len(val_loader):.4f}, Val Acc: {100 * val_correct/val_total:.2f}%")
        print(f"Test Loss: {test_loss/len(test_loader):.4f}, Test Acc: {100 * test_correct/test_total:.2f}%\n\n")

# Train and validate the model
train_and_validate(model, train_loader, val_loader, criterion, optimizer, epochs=10)

  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
