In [1]:
import os
import json
import torch
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch.nn.functional as F
from mxifpublic.community.graph_utils import (count_mask_pixels_in_radius,
                                              threshold_graph_edges_by_distance,
                                              generate_graph_adj_matrix,
                                              calculate_edge_length_statistic)
from mxifpublic.plotting.plot import cell_typing_plot
from sklearn.cluster import MiniBatchKMeans
from sklearn.preprocessing import MinMaxScaler, OneHotEncoder
from torch_geometric.utils import from_scipy_sparse_matrix
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import ARGVA, GCNConv
from tqdm.notebook import tqdm
from skimage.io import imread

# Constants specification

In order to conduct community analysis, one must have cell-type-to-segment annotations in the following format (denoted below as cell_type_segment_annotation.csv):

 index | sample_id | center_x | center_y | cell_type 
 --- | --- | --- | --- | --- 
 1 | HP9849 | 1200 | 2300 | B-cells 
 2 | HP9849 | 1212 | 2567 | B-cells 
 3 | HP9849 | 1232 | 3411 | T-helpers 
 ... | ... | ... | ... | ...
 3450022 | HP9976 | 3476 | 1012 | Macrophages 
 3450023 | HP9976 | 3490 | 2489 | B-cells 


Raw image data in Imaris (.ims) format can be found in the IDR repository (https://idr.openmicroscopy.org) under accession number idr0158. To use the code provided below, first convert the .ims images to TIFF format, ensuring that each channel is saved as a separate TIFF image. Place these images in a designated folder, such as IBEX_raw_images. Ensure that the IBEX_raw_images directory contains the raw images with the following naming convention: {sample_id}_{marker}.tif

In [3]:
samples = ['HP9849', 'HP9976']

path_cells = 'cell_type_segment_annotation.csv'
path_masks = 'IBEX_raw_images/'

random_state = 42

markers = [
 'CD21',
 'CD31',
 'CD68',
# Specify all markers of interest  
# ...
}
    
radius = 100
node_fetures_to_add = ['center_x', 'center_y', 'cell_type', 'index']

distance_threshold = 200
center_x_col_name, center_y_col_name = 'center_x', 'center_y'

features_columns = [
    f'nb_num_pixels_{radius}_{marker}' for marker in markers
]

# Data preparation

## Mask percentages calculation

In [None]:
cells = pd.read_csv(path_cells)

# Specify here other samples if needed
cells = cells[cells['sample_id'].isin(samples)].reset_index(drop=True)
rois = []

for sample_id, sample_data in tqdm(cells.groupby('sample_id')):
    for marker in tqdm(markers, leave=False):
        image = imread(f'{path_masks}/{sample_id}_{marker}.tif')

        # It is recommended to verify whether the Otsu method provides an appropriate threshold value; if it does not, manually set the threshold.
        mask = (image > threshold_otsu(image)).astype('uint8') 
        
        neighbour_pixels = count_mask_pixels_in_radius(
            mask, sample_data[['center_x', 'center_y']].values, radius=radius
        )

        feature_name = f'nb_num_pixels_{radius}_{mask_name}'
        sample_data[feature_name] = neighbour_pixels
        node_fetures_to_add.append(feature_name)
    rois.append(sample_data)
    
cells = pd.concat(rois).reset_index(drop=True)
cells['cell_index'] = cells['index']

## Assemble graph

In [None]:
patient_graphs = dict()
cell_distances = []

for sample_id, roi_data in tqdm(cells.groupby('sample_id')):
    adjacency_matrix = generate_graph_adj_matrix(roi_data)
    adjacency_matrix = threshold_graph_edges_by_distance(adjacency_matrix, distance_threshold)
    median_edge_distance = calculate_edge_length_statistic(adjacency_matrix, distance_threshold)
    
    cell_distances.append(median_edge_distance)
    patient_graphs[sample_id] = (adjacency_matrix, median_edge_distance)

cell_distances = np.concatenate(cell_distances)

cell_distance_scaler = MinMaxScaler()
cell_distance_scaler.fit(cell_distances.reshape(-1, 1))

patient_graphs = {
    key: (adj, cell_distance_scaler.transform(dst.reshape(-1, 1)))
    for key, (adj, dst) in patient_graphs.items()
}

# Dataset generation

In [None]:
cell_type_encoder = OneHotEncoder(sparse=False)
cell_type_encoder.fit(cells['cell_type'].values.reshape(-1, 1))

cohort_graph_dataset = []

for sample_id, roi_data in tqdm(cells.groupby('sample_id')):
    adjacency_matrix, median_edge_distance = patient_graphs[sample_id]
    edge_indices, _ = from_scipy_sparse_matrix(adjacency_matrix)
    cell_types = cell_type_encoder.transform(roi_data['cell_type'].values.reshape(-1, 1))
    
    mask_percentages = roi_data[features_columns]

    node_data = np.concatenate([cell_types, median_edge_distance, mask_percentages], axis=1)

    sample_data_object = Data(edge_index=edge_indices,
                              x=torch.Tensor(node_data).float(),
                              contour_index=cells['cell_index'],
                              sample_id=sample_id)

    cohort_graph_dataset.append(sample_data_object)

# Training

## Configuration

In [10]:
# +1 is required for median distance
in_channels = len(cells['cell_type'].unique()) + len(mask_paths) + 1
model_encoder_configuration = [in_channels, 32, 32]
model_decoder_configuration = [32, 64, 32]
device = 'cpu'
num_epochs = 10
num_repeat_discriminator = 5
discriminator_lr = 0.001
encoder_lr = 0.005
save_models_dir = 'models'
features_num = model_encoder_configuration[2]
n_communities = 15

## Model

In [11]:
class Encoder(torch.nn.Module):
    def __init__(self, in_channels: int, hidden_channels: int, out_channels: int):
        super(Encoder, self).__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels, cached=False)
        self.conv_mu = GCNConv(hidden_channels, out_channels, cached=False)
        self.conv_logstd = GCNConv(hidden_channels, out_channels, cached=False)

    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        return self.conv_mu(x, edge_index), self.conv_logstd(x, edge_index)


class Discriminator(torch.nn.Module):
    def __init__(self, in_channels: int, hidden_channels: int, out_channels: int):
        super(Discriminator, self).__init__()
        self.lin1 = torch.nn.Linear(in_channels, hidden_channels)
        self.lin2 = torch.nn.Linear(hidden_channels, hidden_channels)
        self.lin3 = torch.nn.Linear(hidden_channels, out_channels)

    def forward(self, x):
        x = F.relu(self.lin1(x))
        x = F.relu(self.lin2(x))
        x = self.lin3(x)
        return x


In [12]:
encoder = Encoder(*model_encoder_configuration)
discriminator = Discriminator(*model_decoder_configuration)
model = ARGVA(encoder, discriminator)

device = torch.device(device if torch.cuda.is_available() else 'cpu')
model.train()
model.to(device);

## Train loop

In [None]:
os.makedirs(save_models_dir, exist_ok=True)
discriminator_optimizer = torch.optim.Adam(model.discriminator.parameters(), lr=discriminator_lr)
encoder_optimizer = torch.optim.Adam(model.encoder.parameters(), lr=encoder_lr)
best_loss = 10**9

dataloader = DataLoader(cohort_graph_dataset, batch_size=1, shuffle=True)
loss_history = []
for epoch in tqdm(range(num_epochs)): 
    mean_epoch_loss = []
    for data_train in dataloader:
        encoder_optimizer.zero_grad()
        data_train = data_train.to(device)
        
        z = model.encode(data_train.x, data_train.edge_index)

        for i in range(num_repeat_discriminator):
            discriminator_optimizer.zero_grad()
            discriminator_loss = model.discriminator_loss(z)
            discriminator_loss.backward()
            discriminator_optimizer.step()

        loss = model.recon_loss(z, data_train.edge_index)
        loss = loss + model.reg_loss(z)
        loss = loss + (1 / data_train.num_nodes) * model.kl_loss()
        
        loss.backward()
        encoder_optimizer.step()
        
        mean_epoch_loss.append(loss.item())
    
    mean_epoch_loss = np.mean(mean_epoch_loss)
    loss_history.append(mean_epoch_loss)
    print(f'Epoch: {epoch} Mean epoch loss: {mean_epoch_loss}')
    
    if best_loss > mean_epoch_loss:
        best_loss = mean_epoch_loss
        torch.save(model.state_dict(), f'{save_models_dir}/community_model_{mean_epoch_loss:.4f}.pth')

## Loss history

In [None]:
plt.figure(figsize=(9, 9))
plt.plot(loss_history, color='black')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Loss history')

## Inference

In [15]:
model = model.to(device)
model.load_state_dict(torch.load(f'{save_models_dir}/community_model_{best_loss:.4f}.pth',
                                 map_location=device))
cohort_embeddings = []
for sample_data_object in tqdm(cohort_graph_dataset):
    sample_id = sample_data_object.sample_id
    with torch.no_grad():
        model.eval()
        x, edge_index = (sample_data_object.x.to(device), 
                         sample_data_object.edge_index.to(device))

        z = model.encode(x, edge_index)
        z = z.squeeze()
        z = z.detach().cpu().numpy()

    node_embeddings = pd.DataFrame(z)
    node_embeddings['sample_id'] = sample_id
    node_embeddings['cell_index'] = sample_data_object.contour_index
    cohort_embeddings.append(node_embeddings)
    
cohort_embeddings = pd.concat(cohort_embeddings)
cohort_embeddings = cohort_embeddings.reset_index(drop=True)

  0%|          | 0/2 [00:00<?, ?it/s]

## Cluster embedding

In [16]:
clusterer = MiniBatchKMeans(n_clusters=n_communities, random_state=random_state)
embedding_clusters = clusterer.fit_predict(cohort_embeddings[list(range(features_num))].values)
embedding_clusters = pd.Series(embedding_clusters, name='graph_cluster')
embedding_clusters = 'cluster_' + embedding_clusters.apply(str)
cohort_embeddings['graph_cluster'] = embedding_clusters

## Relative content

In [None]:
colors = [
 '#11112e',
 '#0ca7ef',
 '#1b1bc0',
 '#041e72',
 '#5A00FF',
 '#CBD855',
 '#3BA31B',
 '#005222',
 '#1E82AF',
 '#00A38B',
 '#8BF16A',
 '#cc751f',
 '#f6bd60',
 '#5C493E',
 '#b196b3'
]
# Palette is generated only if there are less than 16 community clusters
assert n_communities <= len(colors)
colors = random.sample(colors, n_communities)
palette = {cluster: color for cluster, color in zip([f'cluster_{i}' for i in range(n_communities)], 
                                                     colors)}
pivot = pd.pivot_table(cohort_embeddings, 
                       index='sample_id',
                       columns='graph_cluster',
                       values='cell_index',
                       aggfunc=len,
                       fill_value=0)
pivot = pivot.div(pivot.sum(axis=1), axis=0) * 100
fig, ax = plt.subplots(1, 1, figsize=(4, 8))
pivot.plot.bar(stacked=True, legend=False, 
               ax=ax, color=palette)
plt.xlabel('Sample')
plt.ylabel('Community relative content')