-----------------

## In this notebook we repeat the same feature extraction process
#### `only difference (WSI image width and heights divided by 2)`

In [1]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import os
import math
import pickle

from sklearn.model_selection import train_test_split

In [2]:
training_data = pd.read_csv('/root/ubc/train.csv')

In [3]:
# Select WSI images
non_tma_data_simple = training_data[training_data['is_tma'] == False]

# Sampling 40 images from each cancer type
balanced_data = non_tma_data_simple.groupby('label').sample(n=40, random_state=33)

# Adding a new column 'path' to the balanced dataset
balanced_data['path'] = balanced_data['image_id'].astype(str) + '.png'

balanced_data.reset_index(drop=True, inplace=True)

# Displaying the updated DataFrame
balanced_data

Unnamed: 0,image_id,label,image_width,image_height,is_tma,path
0,5970,CC,27265,22900,False,5970.png
1,64824,CC,46589,19365,False,64824.png
2,1952,CC,33685,38053,False,1952.png
3,59515,CC,64700,36387,False,59515.png
4,54928,CC,36166,31487,False,54928.png
...,...,...,...,...,...,...
195,48550,MC,32431,25393,False,48550.png
196,39252,MC,48980,40700,False,39252.png
197,47431,MC,67495,46563,False,47431.png
198,65094,MC,55042,45080,False,65094.png


## Feature Extraction
additional: `img = img.resize((img.width // 2, img.height // 2))`

In [6]:
import cv2
import timm
import torch
from torchvision import transforms
from PIL import Image, ImageFile
import zipfile
import os

# Assuming the ZIP file path and the root directory inside the ZIP
zip_path = '/root/UBC-OCEAN.zip'
zip_root_dir = 'train_images/'

# Increase the maximum number of pixels PIL can process
Image.MAX_IMAGE_PIXELS = None
ImageFile.LOAD_TRUNCATED_IMAGES = True

# Load a pre-trained model for feature extraction
model = timm.create_model('resnet101', pretrained=True, num_classes=0)
model.eval()

# Define transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Function to check if the tile has tissue present
def is_tissue_present(tile, area_threshold=0.55, low_saturation_threshold=20):   # 55% tissue
    hsv = cv2.cvtColor(tile, cv2.COLOR_RGB2HSV)
    h, s, v = cv2.split(hsv)
    _, high_sat = cv2.threshold(s, low_saturation_threshold, 255, cv2.THRESH_BINARY)
    kernel = np.ones((5, 5), np.uint8)
    tissue_mask = cv2.dilate(high_sat, kernel, iterations=2)
    tissue_mask = cv2.erode(tissue_mask, kernel, iterations=2)
    tissue_ratio = np.sum(tissue_mask > 0) / (tile_size * tile_size)
    return tissue_ratio > area_threshold

# Function to extract features from a tile
def extract_features(tile, model, transform):
    tile = Image.fromarray(tile)
    tile = transform(tile).unsqueeze(0)
    with torch.no_grad():
        features = model(tile)
    return features.squeeze(0).numpy()

# Function to process a patch of the image
def process_patch(patch, model, transform):
    if is_tissue_present(patch):
        features = extract_features(patch, model, transform)
        return features
    return None

# Define the size for the tiles
tile_size = 512

#slide_features = {}

# Process each image, extract tiles, extract features, and store them
total_images = len(balanced_data)
for index, row in balanced_data.iterrows():
    #if index >= 173:
        # Stop after processing the first 50 images
    #   continue

    tile_features = [] # List to hold the features for the current image

    # Extracting image from ZIP
    image_name = row['path']  # Adjust based on your DataFrame structure
    image_path = os.path.join(zip_root_dir, image_name)

    # Print the current status
    print(f"Processing image {index + 1} of {total_images}: {image_path}")
    
    try:
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extract(image_path, '/root/ubc_ocean/temp')
            extracted_image_path = os.path.join('/root/ubc_ocean/temp', image_path)

            with Image.open(extracted_image_path) as img:
                img = img.resize((img.width // 2, img.height // 2))

                for y in range(0, img.height, tile_size):
                    for x in range(0, img.width, tile_size):
                        # Read the patch
                        patch = img.crop((x, y, x + tile_size, y + tile_size))
                        patch = np.array(patch)  # Convert PIL Image to NumPy array

                        # Process the patch
                        features = process_patch(patch, model, transform)
                        if features is not None:
                            tile_features.append(features)

            # Delete the extracted image to save space
            os.remove(extracted_image_path)

        # Store the extracted features and the label in the slide_features_part1 dictionary
        slide_features[image_name] = {
            'features': tile_features,
            'label': row['label']
        }
    except Exception as e:
        print(f"Error processing image {image_name}: {e}")
    
    import pickle

    with open('/root/ubc_ocean/anar/extracted-features/half_512px_resnet101_200.pkl', 'wb') as f:
        pickle.dump(slide_features, f)

Processing image 175 of 200: train_images/62476.png
Processing image 176 of 200: train_images/51893.png
Processing image 177 of 200: train_images/39872.png
Processing image 178 of 200: train_images/14532.png
Processing image 179 of 200: train_images/56799.png
Processing image 180 of 200: train_images/34508.png
Processing image 181 of 200: train_images/23523.png
Processing image 182 of 200: train_images/28562.png
Processing image 183 of 200: train_images/9254.png
Processing image 184 of 200: train_images/35792.png
Processing image 185 of 200: train_images/56993.png
Processing image 186 of 200: train_images/20329.png
Processing image 187 of 200: train_images/5456.png
Processing image 188 of 200: train_images/38019.png
Processing image 189 of 200: train_images/36678.png
Processing image 190 of 200: train_images/25331.png
Processing image 191 of 200: train_images/49587.png
Processing image 192 of 200: train_images/37190.png
Processing image 193 of 200: train_images/30986.png
Processing ima

In [None]:
with open('/root/ubc_ocean/anar/extracted-features/half_512px_resnet101_200.pkl', 'wb') as f:
    pickle.dump(slide_features, f)

len(slide_features)

---------------------

## Model Training

In [8]:
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from collections import defaultdict
import random

# Set a fixed random state for reproducibility
random_state = 33

# Convert slide_features to a suitable format
data = [(features['features'], features['label']) for path, features in slide_features.items()]

# Organize data by labels
data_by_label = defaultdict(list)
for features, label in data:
    data_by_label[label].append((features, label))

# Split data for each label into train, validation, and test
train_data = []
val_data = []
test_data = []

for label, label_data in data_by_label.items():
    # Split data for this label into train and test with a fixed random state
    train_val_label_data, test_label_data = train_test_split(label_data, test_size=0.2, random_state=random_state)
    
    # Split train data into train and validation with a fixed random state
    train_label_data, val_label_data = train_test_split(train_val_label_data, test_size=0.125, random_state=random_state)  # 0.25 x 0.8 = 0.2 of original
    
    # Append split data to respective sets
    train_data.extend(train_label_data)
    val_data.extend(val_label_data)
    test_data.extend(test_label_data)

# Shuffle the datasets
random.seed(random_state)
random.shuffle(train_data)
random.shuffle(val_data)
random.shuffle(test_data)

# Function to check balance in each set
def check_balance(dataset):
    label_counts = defaultdict(int)
    for _, label in dataset:
        label_counts[label] += 1
    return dict(label_counts)

# Display balance of each set
print("Train balance:", check_balance(train_data))
print("Validation balance:", check_balance(val_data))
print("Test balance:", check_balance(test_data))

Train balance: {'CC': 28, 'EC': 28, 'LGSC': 28, 'HGSC': 28, 'MC': 28}
Validation balance: {'HGSC': 4, 'CC': 4, 'LGSC': 4, 'MC': 4, 'EC': 4}
Test balance: {'MC': 8, 'HGSC': 8, 'LGSC': 8, 'EC': 8, 'CC': 8}


In [9]:
# Create a mapping from label strings to integers
unique_labels = sorted(set(label for _, label in data))
label_to_idx = {label: idx for idx, label in enumerate(unique_labels)}

class MILDataset(Dataset):
    def __init__(self, data, label_to_idx):
        self.data = data
        self.label_to_idx = label_to_idx

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        feature_vectors, label = self.data[idx]
        label_idx = self.label_to_idx[label]  # Convert label to integer
        return torch.tensor(feature_vectors), torch.tensor(label_idx, dtype=torch.float32)

# Create Datasets for train, validation, and test
train_dataset = MILDataset(train_data, label_to_idx)
val_dataset = MILDataset(val_data, label_to_idx)
test_dataset = MILDataset(test_data, label_to_idx)

# Create DataLoaders for each set
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

In [10]:
import torch.nn as nn
from sklearn.metrics import accuracy_score
from torch.optim.lr_scheduler import ReduceLROnPlateau

class AttentionMIL(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_classes):
        super(AttentionMIL, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.attention = nn.Sequential(
            nn.Linear(hidden_dim, 1),
            nn.Softmax(dim=0)
        )
        self.classifier = nn.Linear(hidden_dim, num_classes)  # num_classes instead of 1

    def forward(self, bag):
        h = torch.relu(self.fc1(bag))
        a = self.attention(h)
        v = torch.sum(a * h, dim=0)
        y = self.classifier(v)  # Remove softmax here; output raw scores
        return y, a

# Number of unique classes
num_classes = len(unique_labels)

model = AttentionMIL(input_dim=2048, hidden_dim=256, num_classes=num_classes)
loss_function = nn.CrossEntropyLoss()  # CrossEntropyLoss for multiclass
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=2, verbose=True)

# Early Stopping Parameters
best_val_loss = float('inf')
patience = 4
patience_counter = 0

# Model Training with Validation
num_epochs = 15
for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0
    train_correct = 0
    train_total = 0

    # Training loop
    for bags, labels in train_loader:
        optimizer.zero_grad()
        bags = bags.squeeze(0)  # Remove the extra dimension from bags
        labels = labels.squeeze(0).long()  # Remove extra dimension and ensure long type for labels
        output, _ = model(bags)
        loss = loss_function(output, labels)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        _, predicted = torch.max(output.data, 0)
        train_total += 1
        train_correct += (predicted == labels).sum().item()

    train_accuracy = 100 * train_correct / train_total
    train_loss /= len(train_loader)

    # Validation loop
    model.eval()
    val_loss = 0.0
    val_correct = 0
    val_total = 0
    with torch.no_grad():
        for bags, labels in val_loader:
            bags = bags.squeeze(0)
            labels = labels.squeeze(0).long()
            output, _ = model(bags)
            loss = loss_function(output, labels)
            val_loss += loss.item()
            _, predicted = torch.max(output.data, 0)
            val_total += 1
            val_correct += (predicted == labels).sum().item()

    val_accuracy = 100 * val_correct / val_total
    val_loss /= len(val_loader)
    
    # Step the scheduler
    scheduler.step(val_loss)

    # Early stopping check
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        patience_counter = 0
        torch.save(model.state_dict(), 'best_model.pth')
    else:
        patience_counter += 1

    if patience_counter >= patience:
        print("Stopping early due to no improvement in validation loss.")
        break

    print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Train Acc: {train_accuracy:.2f}%, Validation Loss: {val_loss:.4f}, Val Acc: {val_accuracy:.2f}%')

  return torch.tensor(feature_vectors), torch.tensor(label_idx, dtype=torch.float32)
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


Epoch 1/15, Train Loss: 1.6396, Train Acc: 20.71%, Validation Loss: 1.6049, Val Acc: 20.00%
Epoch 2/15, Train Loss: 1.5628, Train Acc: 25.00%, Validation Loss: 1.5783, Val Acc: 25.00%
Epoch 3/15, Train Loss: 1.4579, Train Acc: 37.86%, Validation Loss: 1.5405, Val Acc: 30.00%
Epoch 4/15, Train Loss: 1.2339, Train Acc: 52.14%, Validation Loss: 1.5252, Val Acc: 35.00%
Epoch 5/15, Train Loss: 1.1075, Train Acc: 54.29%, Validation Loss: 1.3934, Val Acc: 45.00%
Epoch 6/15, Train Loss: 0.9177, Train Acc: 62.86%, Validation Loss: 1.3636, Val Acc: 40.00%
Epoch 7/15, Train Loss: 0.7558, Train Acc: 73.57%, Validation Loss: 1.3521, Val Acc: 35.00%
Epoch 8/15, Train Loss: 0.6362, Train Acc: 75.00%, Validation Loss: 1.2380, Val Acc: 55.00%
Epoch 9/15, Train Loss: 0.5334, Train Acc: 80.71%, Validation Loss: 1.4628, Val Acc: 45.00%
Epoch 10/15, Train Loss: 0.4399, Train Acc: 85.00%, Validation Loss: 1.4080, Val Acc: 50.00%
Epoch 00011: reducing learning rate of group 0 to 1.0000e-04.
Epoch 11/15, Trai

In [11]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

model.eval()
predictions = []
true_labels = []

with torch.no_grad():
    for bags, labels in test_loader:
        output, _ = model(bags.squeeze(0))
        _, predicted_labels = torch.max(output, 0)  # Get the index of the max log-probability
        predictions.append(predicted_labels.item())  # Append scalar value
        true_labels.append(labels.squeeze(0).item())  # Append scalar value

# Convert lists to arrays for metric calculation
predictions = np.array(predictions)
true_labels = np.array(true_labels)

# Calculate metrics
accuracy = accuracy_score(true_labels, predictions)
precision = precision_score(true_labels, predictions, average='macro', zero_division=1)
recall = recall_score(true_labels, predictions, average='macro')
f1 = f1_score(true_labels, predictions, average='macro')

# Print the metrics
print(f'Accuracy: {accuracy:.4f}')
print(f'Precision: {precision:.4f}')
print(f'Recall: {recall:.4f}')
print(f'F1 Score: {f1:.4f}')

Accuracy: 0.7000
Precision: 0.7062
Recall: 0.7000
F1 Score: 0.6948


In [12]:
idx_to_label = {idx: label for label, idx in label_to_idx.items()}

# Use idx_to_label to map numeric predictions back to label names
predicted_labels = [idx_to_label[int(idx)] for idx in predictions]
true_label_names = [idx_to_label[int(idx)] for idx in true_labels]

# Now predicted_labels and true_label_names contain the label names
print(predicted_labels)
print(true_label_names)

['MC', 'LGSC', 'EC', 'LGSC', 'LGSC', 'HGSC', 'HGSC', 'LGSC', 'HGSC', 'HGSC', 'EC', 'EC', 'LGSC', 'LGSC', 'MC', 'LGSC', 'CC', 'HGSC', 'EC', 'CC', 'CC', 'HGSC', 'CC', 'LGSC', 'MC', 'LGSC', 'MC', 'EC', 'HGSC', 'MC', 'MC', 'CC', 'CC', 'LGSC', 'MC', 'EC', 'HGSC', 'MC', 'LGSC', 'EC']
['MC', 'HGSC', 'HGSC', 'LGSC', 'LGSC', 'HGSC', 'EC', 'LGSC', 'HGSC', 'HGSC', 'EC', 'MC', 'LGSC', 'LGSC', 'EC', 'EC', 'CC', 'EC', 'EC', 'CC', 'CC', 'HGSC', 'CC', 'HGSC', 'MC', 'LGSC', 'MC', 'EC', 'CC', 'MC', 'MC', 'CC', 'CC', 'LGSC', 'EC', 'CC', 'HGSC', 'MC', 'LGSC', 'MC']
