# Part 1: Setup

In [1]:
# Imports

# Setup environment correctly first
import os 
# Go to location where UniverSeg data is
os.environ['DATAPATH'] = ':'.join((
       '/storage',
       '/storage/megamedical'
))
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# UniverSeg Imports
from universeg.experiment.experiment import UniversegExperiment

# Torch imports
import torch
from torch import nn

# IonPy imports
from pylot.analysis import ResultsLoader

# Misc imports
import numpy as np
import copy
import seaborn as sns
import matplotlib.pyplot as plt
from dataclasses import dataclass
from pydantic import validate_arguments

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [2]:
torch.cuda.is_available()

True

In [3]:
from universeg.experiment.results import load_configs

# Load configs
configs = load_configs()

  0%|          | 0/9 [00:00<?, ?it/s]

For now load the model with support size = 8 so we have a good idea of subsampling a larger dataset to get to a smaller one.

In [4]:
small_support_model = configs.select(support_size=8)
model = small_support_model.path.values[0]

In [5]:
model

PosixPath('/storage/jjgo/results/omni/2023-03-01_Universeg-SupportSet-Ablation_pt2/20230304_125802-NH10-26180efa296261c9502bf25a551378e5')

# Part 2: Base Random Evaluation

In [6]:
from universeg.experiment.analysis import compute_dice_breakdown

df = compute_dice_breakdown(
                model,
                datasets=["WBC"],
                support_size=8,
                split="val",
                checkpoint="max-val_od-dice_score",
                augmentations=None,
                slicing="midslice",
                n_predictions=1,
                preload_cuda=False
            )

  warn("Intel MKL extensions not available for SciPy")
  warn("libjpeg_turbo not enabled for Pillow")


Loaded checkpoint with tag:max-val_od-dice_score. Last epoch:4430


In [7]:
df.dice_score.mean()

0.85023385

# Part 3: DiSeg Hyper-Net (Sampler)

First, let's define a sampler which given a (image, label) pair which try to predict the indices of the best support set.

In [8]:
import torchvision.models as models
from torchvision.models import ResNet18_Weights, ResNet50_Weights, ResNet101_Weights

@validate_arguments
@dataclass(eq=False, repr=False)
class SupportSampler(nn.Module):

    dset_size : int
    model_type: str = "resnet18"
    pretrained: bool = True 
    freeze_backbone: bool = False
    use_label: bool = True

    def __post_init__(self):
        super().__init__()

        if self.model_type == 'resnet18':
            weights = ResNet18_Weights if self.pretrained else None
            self.model = models.resnet18(weights=weights)
        elif self.model_type == 'resnet50':
            weights = ResNet50_Weights if self.pretrained else None
            self.model = models.resnet50(weights=weights)
        elif self.model_type == 'resnet101':
            weights = ResNet101_Weights if self.pretrained else None
            self.model = models.resnet101(weights=weights)
        else:
            raise ValueError(f"Unsupported ResNet model type: {self.model_type}")
        
        if self.freeze_backbone:
            for param in self.model.parameters():
                param.requires_grad = False
        
        # Make an adaptation head to accept two channels
        self.head_conv = nn.Conv2d(2, 3, kernel_size=3, stride=1, padding=1, bias=False)
        self.relu = nn.LeakyReLU(0.2, inplace=True)

        # Make sure the output predicts over the size of the any dataset
        num_ftrs = self.model.fc.in_features
        self.model.fc = nn.Linear(num_ftrs, self.dset_size)

    def forward(self, x, y=None):
        assert not(self.use_label and y is None), "If using label, must provide label"

        if self.use_label:
            z = torch.cat([x, y], dim=1)
            z_in = self.head_conv(z)
            z_in = self.relu(z_in)
        else:
            z_in = x.repeat(1, 3, 1, 1)

        indices_logits = self.model(z_in)

        return indices_logits 

Now we need to be able to get some datasets which we can use to train this sampler.

In [9]:
from universeg.experiment.datasets import MultiBinarySegment2DIndex, Segment2D

def get_datasets(datasets, slicing="midslice", split="val"):

    index = MultiBinarySegment2DIndex()
    task_df = index.task_df(
        slicing=slicing,
        datasets=datasets,
        resolution=128,
        version="v4.2",
        expand_labels=True,
    )

    segment_2d_datasets = []

    for _, row in task_df.iterrows():
        copy_keys = ("task", "label", "resolution", "slicing", "version")
        segment2d_params = dict(
            split=split,
            min_label_density=0,
            preload=True,
            **{k: row[k] for k in copy_keys},
        )

        target_dataset = Segment2D(**segment2d_params)
        segment_2d_datasets.append(target_dataset)
    
    return segment_2d_datasets

In [10]:
import torch
from torch.utils.data import Dataset
from pylot.util import to_device


class ModifiedCUDACachedDataset(Dataset):
    def __init__(self, dataset: Dataset):
        assert torch.cuda.is_available()
        self._dataset = dataset
        # The difference is that we need to trea this is one object.
        self.image_block = []
        self.labels_block = [] 
        for (image, label) in self._dataset:
            self.image_block.append(image)
            self.labels_block.append(label)
        self.image_block = to_device(torch.stack(self.image_block), "cuda")
        self.labels_block = to_device(torch.stack(self.labels_block), "cuda")

    # Get a particular index if you want, but not the intended use.
    def __getitem__(self, idx):
        return self.image_block[idx, ...], self.labels_block[idx, ...]

    def __getattr__(self, key):
        # This works because __getattr__ is only called as last resort
        # https://stackoverflow.com/questions/2405590/how-do-i-override-getattr-without-breaking-the-default-behavior
        return getattr(self._dataset, key)

    def __len__(self):
        return len(self.image_block)

Getting closer, now we need to make a loaded version of the model, and then run the meta training procedure on it. First, initialize the model and sampler, Second, run that shit.

In [11]:
# load the in-context learning model and freeze the parameters
ic_experiment = UniversegExperiment(model)
ic_experiment.load('max-val_od-dice_score')
ic_experiment.to_device()
ic_model = ic_experiment.model
for _, p in ic_model.named_parameters():
   p.requires_grad = False

# load the dataset
datasets = get_datasets(["WBC"])
dset = ModifiedCUDACachedDataset(datasets[0]) # temporary set

print()

Loaded checkpoint with tag:max-val_od-dice_score. Last epoch:4430



  warn("Intel MKL extensions not available for SciPy")
  warn("libjpeg_turbo not enabled for Pillow")


One major issue with the formulation is how do we actually backprop through our sampling and selection? To this, we can provide two solutions that are very simple to start and provide us some flexiblility going forward. For sampling, we can use gumbel softmax to draw samples from our categorical distribution. For selection it is a bit trickier, but using some tactics with matrix multiplication it isn't terribly difficult either.


In [12]:
import neurite as ne
from torch.nn.functional import gumbel_softmax

def select_support(indice_probs, dataset, temperature=0.1, support_size=8):
    assert indice_probs.shape[0] == 1, "Only batch size of 1 is supported."
    indice_probs = indice_probs.squeeze(0) # Remove the batch dimension

    gumbel_sampled_indices = torch.stack([gumbel_softmax(indice_probs, tau=temperature, hard=True) for _ in range(support_size)])

    one_hot_transposed = gumbel_sampled_indices.T
    selection_indices = one_hot_transposed[..., None, None]
        
    x_sample = (dataset.image_block * selection_indices).sum(dim=0)[:, None, ...]
    y_sample = (dataset.labels_block * selection_indices).sum(dim=0)[:, None, ...]

    # Add the batch dimensions and return.
    return x_sample[None], y_sample[None]

KeyboardInterrupt: 

Tada! Now we have a fully differentiable way to both sample and select our desired support set. Now let's describe our meta training procedure.

In [None]:
from pylot.metrics.segmentation import soft_dice_score 
from tqdm.notebook import tqdm


def meta_train(IC_Model, SupportSampler, dataset, support_size, bsize=1, lr=1e-3, iterations=100, temperature=0.1):

    optimizer = torch.optim.Adam(SupportSampler.parameters(), lr=lr)

    # Go through multiple iterations.
    for i in tqdm(range(iterations), desc="Training", unit="epoch"):

        # Sample random datapoint and send datapoints to device
        x, y = dataset[np.random.choice(len(dataset))]
        # add the batch dimension to query image and label
        x, y = x[None], y[None]

        optimizer.zero_grad()
        # Get the indices of the optimal support from the sampler 
        indice_probs = SupportSampler(x, y)

        # Sample from the dataset multiple times using the indices to construct multiple batches
        # and concatenate them together.
        for j in range(bsize):

            # Use gumbel softmax to sample from the probability distribution. To do this, make 
            # several samples from gumbel and then multiply with the indice mesh to get the
            # indices of the support set.    
            x_sample, y_sample = select_support(indice_probs, dataset, temperature=temperature, support_size=support_size)

            if j == 0:
                support_image_sets = x_sample
                support_label_sets = y_sample
            else:
                support_image_sets = torch.cat([support_image_sets, x_sample], dim=0)
                support_label_sets = torch.cat([support_label_sets, y_sample], dim=0)

        # Copy the query image multiple times in the batch dimension
        query_images = x.repeat(bsize, 1, 1, 1)
        y = y.repeat(bsize, 1, 1, 1)

        # Make predictions for supports sampled with the frozen IC model
        y_hat = IC_Model(support_image_sets, support_label_sets, query_images)

        # Get the average loss amongst predictions
        dice_loss = soft_dice_score(y_hat, y, from_logits=True)

        # Backpropagate
        dice_loss.backward()

        optimizer.step()

Finally, we run our optimization procedure and train the model.

In [None]:
# set hyperparameters
support_size = 8
bsize = 2
lr = 1e-3
iterations = 1000
temperature = 0.1

# Initialize Sampler
support_sampler = SupportSampler(dset_size=len(dset), model_type="resnet18") 
support_sampler.to(device)

# meta train the model
meta_trained_model = meta_train(ic_model, support_sampler, dset, support_size, bsize, lr, iterations, temperature)

And just like that, we've trained a sampler that can choose what support examples should be selected for a particular query image and label pair.

# Part 4: Preliminary Evaluation

We need to build out this repo so that these experiments can scale (and thus out of this notebook), but let's look very quickly at some of the suppoert sets that get sampled. First, choose a random image and label pair from our dataset. 

In [None]:
# Sample random datapoint and send datapoints to device
x, y = dset[np.random.choice(len(dset))]
# add the batch dimension to query image and label
x, y = x[None], y[None]

ne.plot.slices(torch.cat([x, y], dim=0).squeeze().cpu())

Now lets visually inspect the kinds of examples our network will choose.

In [None]:
support_images, support_labels = select_support(support_sampler(x, y), dset, temperature, support_size)

ne.plot.slices(support_images.squeeze().detach().cpu())
ne.plot.slices(support_labels.squeeze().detach().cpu())