#### In this notebook we use a dictionary of extracted features from *main3.ipynb*, where width and height was divided by 5
* Here more data were passed from Test set to Train set, resulting in 10% increase in accuracy on WSI images (`80%`)

In [1]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import os
import math

import random
import pickle
from collections import defaultdict

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

In [12]:
with open('/root/ubc_ocean/anar/extracted-features/main3_512px_resnet101_200.pkl', 'rb') as f:
    slide_features = pickle.load(f)
    
len(slide_features)

---------------------

### Model Training

In [28]:
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from collections import defaultdict
import random

# Set a fixed random state for reproducibility
random_state = 33

# Convert slide_features to a suitable format
data = [(features['features'], features['label']) for path, features in slide_features.items()]

# Organize data by labels
data_by_label = defaultdict(list)
for features, label in data:
    data_by_label[label].append((features, label))

# Split data for each label into train, validation, and test
train_data = []
val_data = []
test_data = []

for label, label_data in data_by_label.items():
    # Split data for this label into train and test with a fixed random state
    train_val_label_data, test_label_data = train_test_split(label_data, test_size=0.10, random_state=random_state)
    
    # Split train data into train and validation with a fixed random state
    train_label_data, val_label_data = train_test_split(train_val_label_data, test_size=0.1, random_state=random_state)  # 0.25 x 0.8 = 0.2 of original
    
    # Append split data to respective sets
    train_data.extend(train_label_data)
    val_data.extend(val_label_data)
    test_data.extend(test_label_data)

# Shuffle the datasets
random.seed(random_state)
random.shuffle(train_data)
random.shuffle(val_data)
random.shuffle(test_data)

# Function to check balance in each set
def check_balance(dataset):
    label_counts = defaultdict(int)
    for _, label in dataset:
        label_counts[label] += 1
    return dict(label_counts)

# Display balance of each set
print("Train balance:", check_balance(train_data))
print("Validation balance:", check_balance(val_data))
print("Test balance:", check_balance(test_data))

Train balance: {'LGSC': 32, 'EC': 32, 'CC': 32, 'HGSC': 32, 'MC': 32}
Validation balance: {'HGSC': 4, 'MC': 4, 'EC': 4, 'LGSC': 4, 'CC': 4}
Test balance: {'MC': 4, 'EC': 4, 'LGSC': 4, 'CC': 4, 'HGSC': 4}


In [29]:
# Create a mapping from label strings to integers
unique_labels = sorted(set(label for _, label in data))
label_to_idx = {label: idx for idx, label in enumerate(unique_labels)}

class MILDataset(Dataset):
    def __init__(self, data, label_to_idx):
        self.data = data
        self.label_to_idx = label_to_idx

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        feature_vectors, label = self.data[idx]
        label_idx = self.label_to_idx[label]  # Convert label to integer
        return torch.tensor(feature_vectors), torch.tensor(label_idx, dtype=torch.float32)

# Create Datasets for train, validation, and test
train_dataset = MILDataset(train_data, label_to_idx)
val_dataset = MILDataset(val_data, label_to_idx)
test_dataset = MILDataset(test_data, label_to_idx)

# Create DataLoaders for each set
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

In [30]:
import torch.nn as nn
from sklearn.metrics import accuracy_score
from torch.optim.lr_scheduler import ReduceLROnPlateau

class AttentionMIL(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_classes):
        super(AttentionMIL, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.attention = nn.Sequential(
            nn.Linear(hidden_dim, 1),
            nn.Softmax(dim=0)
        )
        self.classifier = nn.Linear(hidden_dim, num_classes)

    def forward(self, bag):
        h = torch.relu(self.fc1(bag))
        a = self.attention(h)
        v = torch.sum(a * h, dim=0)
        y = self.classifier(v)
        return y, a

# Number of unique classes
num_classes = len(unique_labels)

model = AttentionMIL(input_dim=2048, hidden_dim=256, num_classes=num_classes)
loss_function = nn.CrossEntropyLoss()  # CrossEntropyLoss for multiclass
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=2, verbose=True)

# Early Stopping Parameters
best_val_loss = float('inf')
patience = 4
patience_counter = 0

# Model Training with Validation
num_epochs = 15
for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0
    train_correct = 0
    train_total = 0

    # Training loop
    for bags, labels in train_loader:
        optimizer.zero_grad()

        if bags.nelement() == 0:  # Check if the bag is empty
            continue

        #print("Bags shape:", bags.shape)  # Add this line for debugging
        #print("Labels shape:", labels.shape)  # Add this line for debugging

        # Select the first bag in the batch
        bags = bags[0]  # bags now has the shape [75, 2048]
        # For labels, the following line should work fine
        labels = labels.squeeze(0).long()
    
        output, _ = model(bags)
        loss = loss_function(output, labels)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        _, predicted = torch.max(output.data, 0)
        train_total += 1
        train_correct += (predicted == labels).sum().item()

    train_accuracy = 100 * train_correct / train_total
    train_loss /= len(train_loader)

    # Validation loop
    model.eval()
    val_loss = 0.0
    val_correct = 0
    val_total = 0
    with torch.no_grad():
        for bags, labels in val_loader:
            if bags.nelement() == 0:  # Check if the bag is empty
                continue

            bags = bags.squeeze(0)
            labels = labels.squeeze(0).long()
            output, _ = model(bags)
            loss = loss_function(output, labels)
            val_loss += loss.item()
            _, predicted = torch.max(output.data, 0)
            val_total += 1
            val_correct += (predicted == labels).sum().item()

    val_accuracy = 100 * val_correct / val_total
    val_loss /= len(val_loader)
    
    # Step the scheduler
    scheduler.step(val_loss)

    # Early stopping check
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        patience_counter = 0
    else:
        patience_counter += 1

    if patience_counter >= patience:
        print("Stopping early due to no improvement in validation loss.")
        break

    print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Train Acc: {train_accuracy:.2f}%, Validation Loss: {val_loss:.4f}, Val Acc: {val_accuracy:.2f}%')

Epoch 1/15, Train Loss: 1.6114, Train Acc: 18.24%, Validation Loss: 1.5640, Val Acc: 50.00%
Epoch 2/15, Train Loss: 1.5170, Train Acc: 33.33%, Validation Loss: 1.4260, Val Acc: 50.00%
Epoch 3/15, Train Loss: 1.3571, Train Acc: 40.25%, Validation Loss: 1.2833, Val Acc: 40.00%
Epoch 4/15, Train Loss: 1.1448, Train Acc: 61.01%, Validation Loss: 1.1208, Val Acc: 60.00%
Epoch 5/15, Train Loss: 0.9701, Train Acc: 65.41%, Validation Loss: 0.9738, Val Acc: 70.00%
Epoch 6/15, Train Loss: 0.7930, Train Acc: 73.58%, Validation Loss: 0.8247, Val Acc: 80.00%
Epoch 7/15, Train Loss: 0.6699, Train Acc: 77.99%, Validation Loss: 0.7731, Val Acc: 75.00%
Epoch 8/15, Train Loss: 0.5448, Train Acc: 82.39%, Validation Loss: 0.7732, Val Acc: 75.00%
Epoch 9/15, Train Loss: 0.4730, Train Acc: 87.42%, Validation Loss: 0.8031, Val Acc: 75.00%
Epoch 10/15, Train Loss: 0.4041, Train Acc: 88.05%, Validation Loss: 0.7058, Val Acc: 85.00%
Epoch 11/15, Train Loss: 0.3201, Train Acc: 94.34%, Validation Loss: 0.7301, Va

In [31]:
model.eval()
predictions = []
true_labels = []

with torch.no_grad():
    for bags, labels in test_loader:
        output, _ = model(bags.squeeze(0))
        _, predicted_labels = torch.max(output, 0)  # Get the index of the max log-probability
        predictions.append(predicted_labels.item())  # Append scalar value
        true_labels.append(labels.squeeze(0).item())  # Append scalar value

# Convert lists to arrays for metric calculation
predictions = np.array(predictions)
true_labels = np.array(true_labels)

# Calculate metrics
accuracy = accuracy_score(true_labels, predictions)
precision = precision_score(true_labels, predictions, average='macro', zero_division=1)
recall = recall_score(true_labels, predictions, average='macro')
f1 = f1_score(true_labels, predictions, average='macro')

# Print the metrics
print(f'Accuracy: {accuracy:.4f}')
print(f'Precision: {precision:.4f}')
print(f'Recall: {recall:.4f}')
print(f'F1 Score: {f1:.4f}')

Accuracy: 0.8000
Precision: 0.8476
Recall: 0.8000
F1 Score: 0.8026


In [32]:
idx_to_label = {idx: label for label, idx in label_to_idx.items()}

# Use idx_to_label to map numeric predictions back to label names
predicted_labels = [idx_to_label[int(idx)] for idx in predictions]
true_label_names = [idx_to_label[int(idx)] for idx in true_labels]

# Now predicted_labels and true_label_names contain the label names
print(predicted_labels)
print(true_label_names)

['MC', 'EC', 'MC', 'MC', 'LGSC', 'CC', 'LGSC', 'HGSC', 'HGSC', 'LGSC', 'CC', 'EC', 'HGSC', 'CC', 'HGSC', 'EC', 'HGSC', 'HGSC', 'LGSC', 'HGSC']
['MC', 'EC', 'MC', 'MC', 'LGSC', 'CC', 'LGSC', 'EC', 'EC', 'LGSC', 'CC', 'EC', 'HGSC', 'CC', 'HGSC', 'MC', 'HGSC', 'HGSC', 'LGSC', 'CC']


-----

### save the model and load in kaggle submission notebook

In [36]:
torch.save(model.state_dict(), 'models/main4_model.pth')

In [41]:
model2 = AttentionMIL(input_dim=2048, hidden_dim=256, num_classes=num_classes)
model2.load_state_dict(torch.load('/root/ubc_ocean/anar/models/main4_model.pth'))

<All keys matched successfully>