In [None]:
import torch
from torch.nn import CrossEntropyLoss
from torch.optim import Adam

def train_network(net, trainloader, criterion=CrossEntropyLoss(), optimizer_class=Adam, n_epochs=50):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("Device: ", device)

    net.to(device)

    # Instantiate the criterion if it's a class.
    if isinstance(criterion, type):
        criterion = criterion()

    # Instantiate the optimizer with the network parameters.
    optimizer = optimizer_class(net.parameters())

    for epoch in range(n_epochs):  # loop over the dataset multiple times
        print("starting epoch ", epoch)
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data
            print("got data")
            inputs, labels = inputs.to(device), labels.to(device)
            print("inputs and labels to device")

            # zero the parameter gradients
            optimizer.zero_grad()
            print("zero grad")
            # forward + backward + optimize
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            # print statistics
            if i % 1000 == 0 and i > 0:
                print(f'Epoch={epoch + 1} Iter={i + 1:5d} Loss={running_loss / 1000:.3f}')
                running_loss = 0.0

    print('Finished Training')
    return net

In [None]:
from torch import nn
import torch.nn.functional as F

class ResNetBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.skip_connection = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.skip_connection = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        print("RESNET BLOCK FORWARD: ", x.shape)
        identity = self.skip_connection(x)

        out = self.conv1(x)
        out = self.bn1(out)
        out = F.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out += identity
        out = F.relu(out)
        return out

class ResNet(nn.Module):
    def __init__(self, input_channels=3, num_classes=2):
        super().__init__()
        self.conv1 = nn.Conv2d(input_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        # Define ResNet layers
        self.layer1 = ResNetBlock(64, 64)
        self.layer2 = ResNetBlock(64, 128, stride=2)
        self.layer3 = ResNetBlock(128, 256, stride=2)
        self.layer4 = ResNetBlock(256, 512, stride=2)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))  # This makes the network input size agnostic
        self.fc = nn.Linear(512, num_classes)

    def forward(self, x):
        print("FORWARD: ", x.shape)
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)  # Adaptive pooling
        x = x.reshape(x.size(0), -1)
        x = self.fc(x)
        return x

In [None]:
import numpy as np
import h5py
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torch
from PIL import Image

class HEDataset(Dataset):
    def __init__(self, images_file_path, labels_file_path, transform=None, crop = None):
        self.images_file_path = images_file_path
        self.labels_file_path = labels_file_path
        self.transform = transform
        self.crop = crop

        # Load the file to check the shape of the dataset
        self._print_dataset_shapes()

    def _print_dataset_shapes(self):
        with h5py.File(self.images_file_path, 'r') as images_file:
            # Assuming you want to know the shape of the first image dataset

            first_image_key = list(images_file.keys())[0]
            image_shape = images_file[first_image_key].shape
            self.length = image_shape[0]
            if self.crop:
                print(f"Image shape {image_shape} will crop to ", [self.length, 3, self.crop, self.crop])
            else:
                print(f"Image shape {[self.length, 3, image_shape[2], image_shape[3]]}")
        with h5py.File(self.labels_file_path, 'r') as labels_file:
            # Assuming you want to know the shape of the first label dataset
            first_label_key = list(labels_file.keys())[0]
            label_shape = labels_file[first_label_key].shape

            if label_shape[0] != self.length:
                print("BAD: make error. Dataset x and y sizes do not match")

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        with h5py.File(self.images_file_path, 'r') as images_file:
            first_image_key = list(images_file.keys())[0]
            image_data = list(images_file[first_image_key])[idx]
            if image_data.shape[-1] == 3:
                image_data = image_data.transpose((2, 0, 1))  # RGB filter in 0 position

        with h5py.File(self.labels_file_path, 'r') as labels_file:
            first_label_key = list(labels_file.keys())[0]
            label_data = list(labels_file[first_label_key])[idx]
        if self.crop:
            w, h = image_data.shape[1], image_data.shape[2]
            startx = w // 2 - self.crop // 2
            starty = h // 2 - self.crop // 2
            image_data = image_data[:, starty:starty + self.crop, startx:startx + self.crop]

        # Convert to PIL image using PIL Image module
        image_data = Image.fromarray(np.uint8(image_data.transpose(1, 2, 0)))

        if self.transform:
            image_data = self.transform(image_data)

        return image_data, torch.tensor(label_data, dtype=torch.float32)

def load_datasets(test_x,test_y, train_x, train_y, valid_x,valid_y, transform=None, crop = None):

    test_dataset = HEDataset(test_x, test_y, transform, crop)
    train_dataset = HEDataset(train_x, train_y, transform, crop)
    validation_dataset = HEDataset(valid_x, valid_y, transform, crop)

    return test_dataset, train_dataset, validation_dataset

def load_dataloaders(test_dataset, train_dataset, validation_dataset, batch_size = 64, shuffle = True, num_workers = 4):

    test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
    validation_dataloader = DataLoader(validation_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)

    return test_dataloader, train_dataloader, validation_dataloader

In [None]:
!gdown --fuzzy https://drive.google.com/file/d/1krtqpLh9Wi9QijMU7EcMOmlOB6ezwPI4/view?usp=sharing

Downloading...
From (original): https://drive.google.com/uc?id=1krtqpLh9Wi9QijMU7EcMOmlOB6ezwPI4
From (redirected): https://drive.google.com/uc?id=1krtqpLh9Wi9QijMU7EcMOmlOB6ezwPI4&confirm=t&uuid=155d21ea-f2f1-432c-a664-c70869fcb813
To: /content/pcam.zip
100% 8.02G/8.02G [01:53<00:00, 70.9MB/s]


In [None]:
!unzip pcam.zip

Archive:  pcam.zip
   creating: pcam/
  inflating: pcam/camelyonpatch_level_2_split_train_meta.csv  
  inflating: __MACOSX/pcam/._camelyonpatch_level_2_split_train_meta.csv  
  inflating: pcam/camelyonpatch_level_2_split_test_y.h5.gz  
  inflating: __MACOSX/pcam/._camelyonpatch_level_2_split_test_y.h5.gz  
  inflating: pcam/camelyonpatch_level_2_split_train_y.h5.gz  
  inflating: __MACOSX/pcam/._camelyonpatch_level_2_split_train_y.h5.gz  
  inflating: pcam/camelyonpatch_level_2_split_test_x.h5.gz  
  inflating: __MACOSX/pcam/._camelyonpatch_level_2_split_test_x.h5.gz  
  inflating: pcam/camelyonpatch_level_2_split_test_meta.csv  
  inflating: __MACOSX/pcam/._camelyonpatch_level_2_split_test_meta.csv  
  inflating: pcam/camelyonpatch_level_2_split_train_x.h5.zip  
  inflating: __MACOSX/pcam/._camelyonpatch_level_2_split_train_x.h5.zip  
  inflating: pcam/camelyonpatch_level_2_split_valid_x.h5.gz  
  inflating: __MACOSX/pcam/._camelyonpatch_level_2_split_valid_x.h5.gz  
  inflating: pcam

In [None]:
!unzip /content/pcam/camelyonpatch_level_2_split_train_x.h5.zip

Archive:  /content/pcam/camelyonpatch_level_2_split_train_x.h5.zip
  inflating: camelyonpatch_level_2_split_train_x.h5  
  inflating: __MACOSX/._camelyonpatch_level_2_split_train_x.h5  


In [None]:
!gunzip /content/pcam/camelyonpatch_level_2_split_train_y.h5.gz
!gunzip /content/pcam/camelyonpatch_level_2_split_test_y.h5.gz
!gunzip /content/pcam/camelyonpatch_level_2_split_test_x.h5.gz
!gunzip /content/pcam/camelyonpatch_level_2_split_valid_x.h5.gz
!gunzip /content/pcam/camelyonpatch_level_2_split_valid_y.h5.gz

In [None]:
print(torch.cuda.is_available())

True


In [None]:
import pandas as pd

base = '/content/pcam/'
test_x = base+'camelyonpatch_level_2_split_test_x.h5'
test_y = base+'camelyonpatch_level_2_split_test_y.h5'
train_mask = base+'camelyonpatch_level_2_split_train_mask.h5'
train_x = 'camelyonpatch_level_2_split_train_x.h5'
train_y = base+'camelyonpatch_level_2_split_train_y.h5'
valid_x = base+'camelyonpatch_level_2_split_valid_x.h5'
valid_y = base+'camelyonpatch_level_2_split_valid_y.h5'

#Load in metadata
train_csv = base+'camelyonpatch_level_2_split_train_meta.csv'
test_csv = base+'camelyonpatch_level_2_split_test_meta.csv'
valid_csv = base+'camelyonpatch_level_2_split_valid_meta.csv'

train_meta = pd.read_csv(train_csv)
test_meta = pd.read_csv(test_csv)
val_meta = pd.read_csv(valid_csv)

def convert_to_int(df):
    df["tumor_patch"] = df["tumor_patch"].astype(int)
    df["center_tumor_patch"] = df["center_tumor_patch"].astype(int)
    return df

train_meta = convert_to_int(train_meta)
test_meta = convert_to_int(test_meta)
val_meta = convert_to_int(val_meta)

transform = transforms.Compose([
    transforms.ToTensor(),  # Converts to tensor and preserves channels
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize each channel of RGB
])

test_dataset, train_dataset, validation_dataset = load_datasets(test_x,test_y, train_x, train_y, valid_x,valid_y, transform= transform, crop=32)
test_dataloader, train_dataloader, validation_dataloader = load_dataloaders(test_dataset, train_dataset, validation_dataset, batch_size = 128, shuffle = True)

model = ResNet()


net = train_network(model, train_dataloader) #using default CrossEntropyLoss and Adam Optimizer with 50 epochs

In [None]:
import cv2
import numpy as np
import os
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import h5py
from PIL import Image
import matplotlib.pyplot as plt
import torch
from torch.utils.data import Dataset
import random

class HEDataset(Dataset):
    def __init__(self, images_file_path, labels_file_path, transform=None, crop=None):
        self.images_file_path = images_file_path
        self.labels_file_path = labels_file_path
        self.transform = transform
        self.crop = crop
        self.images, self.labels = self.load_data()

    def load_data(self):
        images, labels = [], []
        with h5py.File(self.images_file_path, 'r') as images_file:
            first_image_key = list(images_file.keys())[0]
            images = np.array(images_file[first_image_key])
        with h5py.File(self.labels_file_path, 'r') as labels_file:
            first_label_key = list(labels_file.keys())[0]
            labels = np.array(labels_file[first_label_key])
        return images, labels

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image_data = self.images[idx]
        label_data = self.labels[idx]
        if self.crop:
            image_data = self.crop_image(image_data)
        image_data = Image.fromarray(image_data)
        if self.transform:
            image_data = self.transform(image_data)
        return image_data, torch.tensor(label_data, dtype=torch.float32)

    def crop_image(self, image_data):
        w, h = image_data.shape[1], image_data.shape[0]
        startx = w // 2 - self.crop // 2
        starty = h // 2 - self.crop // 2
        return image_data[starty:starty + self.crop, startx:startx + self.crop]

    def update_exclusion_list(self, exclude_indices):
        """ Update the list of indices to exclude from the dataset. """
        self.exclude_indices = set(exclude_indices)

    def degrade_all_images(self):
        return [self.degrade_image(image) for image in self.images]

    def degrade_image(self, image, focus_area_size=32, sigma=20):
        x = random.randint(0, image.shape[0] - focus_area_size)
        y = random.randint(0, image.shape[1] - focus_area_size)
        patch = image[x:x + focus_area_size, y:y + focus_area_size]
        blurred_patch = cv2.GaussianBlur(patch, (5, 5), sigma)
        noise = np.random.normal(0, 10, blurred_patch.shape)
        degraded_patch = blurred_patch + noise
        image[x:x + focus_area_size, y:y + focus_area_size] = degraded_patch.astype(np.uint8)
        return image

def show_dataset_images(dataset, indices, ncols=3):
    plt.figure(figsize=(15, 5))  # Adjust the size as needed
    for i, idx in enumerate(indices):
        image, _ = dataset[idx]
        if isinstance(image, torch.Tensor):  # Check if the image needs to be converted from a tensor
            image = image.permute(1, 2, 0).numpy()  # Adjust dimensions for Matplotlib
        plt.subplot(1, ncols, i + 1)
        plt.imshow(image)
        plt.title(f"Index: {idx}")
        plt.axis('off')
    plt.show()

# Example usage
dataset = HEDataset('camelyonpatch_level_2_split_train_x.h5', '/content/pcam/camelyonpatch_level_2_split_train_y.h5', transform=None, crop=128)
degraded_images = dataset.degrade_all_images()
show_dataset_images(degraded_images, range(12), ncols=3)


In [None]:
#code below based off of slideflow implementation of DeepFocus algorithm (https://github.com/jamesdolezal/slideflow/blob/master/slideflow/slide/qc/deepfocus.py)

import torch
import torch.nn as nn
import torch.nn.functional as F

class DeepFocusV3(nn.Module):
    def __init__(self, filters=(32, 32, 64, 128, 128), kernel_sizes=(5, 3, 3, 3, 3), fc=(128, 64)):
        super(DeepFocusV3, self).__init__()
        self.filters = filters
        self.kernel_sizes = kernel_sizes

        # Assuming the input images are 64x64 RGB images
        # Convolutional layers
        self.conv1 = nn.Conv2d(3, filters[0], kernel_size=kernel_sizes[0], padding='same')
        self.bn1 = nn.BatchNorm2d(filters[0])

        self.conv2 = nn.Conv2d(filters[0], filters[1], kernel_size=kernel_sizes[1], padding='same')
        self.bn2 = nn.BatchNorm2d(filters[1])

        self.conv3 = nn.Conv2d(filters[1], filters[2], kernel_size=kernel_sizes[2], padding='same')
        self.bn3 = nn.BatchNorm2d(filters[2])
        self.pool1 = nn.MaxPool2d(2, padding='same')

        self.conv4 = nn.Conv2d(filters[2], filters[3], kernel_size=kernel_sizes[3], padding='same')
        self.bn4 = nn.BatchNorm2d(filters[3])
        self.pool2 = nn.MaxPool2d(2, padding='same')

        self.conv5 = nn.Conv2d(filters[3], filters[4], kernel_size=kernel_sizes[4], padding='same')
        self.bn5 = nn.BatchNorm2d(filters[4])
        self.pool3 = nn.MaxPool2d(2, padding='same')

        # Fully connected layers
        self.fc1 = nn.Linear(filters[4] * 8 * 8, fc[0])  # Adjust the sizing calculation as necessary
        self.bn6 = nn.BatchNorm1d(fc[0])
        self.dropout1 = nn.Dropout(0.2)

        self.fc2 = nn.Linear(fc[0], fc[1])
        self.bn7 = nn.BatchNorm1d(fc[1])
        self.dropout2 = nn.Dropout(0.2)

        self.fc3 = nn.Linear(fc[1], 2)  # Output layer for binary classification

    def forward(self, x):
        # Subtract mean
        x = x - torch.mean(x, dim=(2, 3), keepdim=True)

        # Convolutional blocks
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = self.pool1(x)

        x = F.relu(self.bn4(self.conv4(x)))
        x = self.pool2(x)

        x = F.relu(self.bn5(self.conv5(x)))
        x = self.pool3(x)

        # Flatten the output for the fully connected layers
        x = torch.flatten(x, 1)

        # Fully connected layers
        x = F.relu(self.bn6(self.fc1(x)))
        x = self.dropout1(x)

        x = F.relu(self.bn7(self.fc2(x)))
        x = self.dropout2(x)

        x = F.softmax(self.fc3(x), dim=1)
        return x

# Instantiate the model and transfer it to the device
model = DeepFocusV3()
model.to(device)
model.eval()  # Set the model to evaluation mode

# Function to predict clarity
def predict_clarity(dataloader, model, threshold = 0.75):
    with torch.no_grad():  # No need to track gradients
        for i, (images, labels) in enumerate(dataloader):
            images = images.to(device)
            outputs = model(images)
            probabilities = torch.nn.functional.softmax(outputs, dim=1)
            clear_probs = probabilities[:, 1]  # Index 1 for 'clear'

            # Decide which images to exclude based on the threshold
            for j, prob in enumerate(clear_probs):
                if prob.item() < threshold:  # Less than 75% probability of being 'clear'
                    exclude_indices.append(i * dataloader.batch_size + j)

# Predict clarity of PCAM images
predict_clarity(dataloader, model)
