# Meta-Net
### Code for MIDL 2024 Short Paper Meta-Learning for Segmentation of In Situ Hybridization Gene Expression Images (Brain Image Analysis Unit, RIKEN CBS)

## Import libraries

In [None]:
import logging
import os
import sys

import numpy as np
import torch

from datetime import datetime

import monai
from monai.utils import set_determinism
from monai.transforms import (
    Activations,
    AsDiscrete,
    Compose)

import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms.functional as tf
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter

import time
import re
import glob
import random
import numpy as np
import cv2
from PIL import Image
from skimage.exposure import match_histograms

from IPython.display import clear_output

import matplotlib.pyplot as plt
plt.rcParams["figure.figsize"] = plt.rcParamsDefault["figure.figsize"]

from numpy.lib import recfunctions as rfn
import math

import importlib

In [None]:
set_determinism(seed=0)
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'device: {device}')

## Prepare input directories for loading the image, gt labels; and segmentations

### Get segmentation input directories

In [None]:
"""
Assuming the data has the following strucutre:

├── seg_iroot
│   ├── segmentations_from_modelA
│   │   ├── 000.png
│   │   ├── 001.png
│   │   ├── 002.png
│   │   ├── .
│   │   ├── .
│   ├── segmentations_from_modelB
│   │   ├── 000.png
│   │   ├── 001.png
│   │   ├── 002.png
│   │   ├── .
│   │   ├── .
│   ├── segmentations_from_modelC
│   ├── .
│   ├── .

where /seg_iroot/segmentations_from_modelA/000.png
and   /seg_iroot/segmentations_from_modelB/000.png
are 2 segmentations of the same image.

"""

seg_iroot = '/path/to/your/directory/containing/segmentations/'

seg_roots = []
seg_paths_all = sorted([seg_iroot+x+'/' for x in os.listdir(seg_iroot) if 'net' in x])

num_models = len(seg_paths_all)

print(f'seg_paths: {len(seg_roots), seg_roots[2]}')
print(f'num_models: {num_models}, {seg_paths_all[10]}')

### Get image, label input directories

In [None]:
"""
Assuming the data has the following strucutre:

├── image_root
│   ├── geneA
│   │   ├── 000.png
│   │   ├── 001.png
│   │   ├── .
│   │   ├── .
│   ├── geneB
│   │   ├── 000.png
│   │   ├── 001.png
│   │   ├── .
│   │   ├── .
├── label_root
│   ├── geneA
│   │   ├── 000.png
│   │   ├── 001.png
│   │   ├── .
│   │   ├── .
│   ├── geneB
│   │   ├── 000.png
│   │   ├── 001.png
│   │   ├── .

where /image_root/geneA/005.png
and   /label_root/geneA/005.png
are an ISH image and its corresponding ground-truth label, respectively.

"""

# getting a list of paths for gene images
cmps_w_segs = sorted([x for x in os.listdir(genes_with_gt)])

# ensure the image paths exist
image_root = '/path/to/gene/images/'
image_paths_all = sorted([image_root+x+'/' for x in cmps_w_segs if os.path.exists(image_root+x)])

# paths to the corresponding ground truth directories
label_paths_all = [x.replace('image','label') for x in image_paths_all]

print(f'number of different genes with GT labels: {image_paths_all[5], label_paths_all[5]}')

### Split the training images and labels into 7:3 train:val subsets

In [None]:
import math
splitter = math.floor(len(image_paths_all)*.7)
#print(splitter)

train_images = image_paths_all[:splitter]
train_labels = label_paths_all[:splitter]
val_images = image_paths_all[splitter:]
val_labels = label_paths_all[splitter:]

assert len(train_images) == len(train_labels)
assert len(val_images) == len(val_labels)
print(len(train_images), len(val_images))

### ishDataset class

In [None]:
#%% Just returns the images, ground truth labels, and names
class ishDataset_flex_wnames(Dataset):
    
    def __init__(self, image_ls: list, label_ls: list, im_names_ls: list):
        self.image_ls = image_ls
        self.label_ls = label_ls
        self.name_ls = im_names_ls

    def __getitem__(self, idx):
        image, label, name = self.image_ls[idx], self.label_ls[idx], self.name_ls[idx]

        return image, label, name
    
                
    def __len__(self):
        return len(self.image_ls)
    
        
print("ok")

### Image, label transforms

In [None]:
tt = transforms.ToTensor()
augm_transforms = transforms.Compose([transforms.RandomApply([transforms.ColorJitter(brightness=0.2, contrast=0.3, saturation=0.2, hue=0.1)], p=0.8),
                                      transforms.GaussianBlur(kernel_size=3),
                                    ])

from skimage.transform import resize

# Load the images/labels as stacked arrays
def load_stack_names(image_dir, label_dir):
    label_f = sorted(glob.glob(label_dir+'/*.png'))
    image_f = sorted(glob.glob(image_dir+'/*.png'))
    
   
    #########################################
    assert(len(image_f)==len(label_f))
    
    print(f'num files:  images {len(image_f)}  labels {len(label_f)}')

    # load the images and GT labels from the same gene
    # i.e. do not mix images from geneA with images from geneB
    for idx in range(len(image_f)): 
        try:
            assert(image_f[idx].split('/')[-2] == image_f[0].split('/')[-2])
        except:
            print(f'assertion error: {image_f[idx]} and {image_f[0]}')
        try:
            assert(image_f[idx].split('/')[-2] == label_f[0].split('/')[-2])
        except:
            print(f'assertion error: {image_f[idx]} and {label_f[0]}')
        
        image = np.asarray(Image.open(image_f[idx]))  # (1440, 1680, 3)
        label = np.asarray(Image.open(label_f[idx]))


        # split images and labels in half, vertically, because GT labels only exist for half of each ISH image
        half = int(image.shape[1]//2)
        image = image[:,half:,:]
        if len(label.shape) > 2:
            label = label[:,half:,:]
        else:
            label = label[:,half:]

        image = resize(image, (256, 128, image.shape[2]))
        label = resize(label, (256, 128))  


        # ensure that gt labels are binary
        if True:  # if label has a channel channel
            if len(label.shape) > 2:
                newlabel = np.zeros((label[:,:,1].shape))
                newlabel[label[:,:,1] > 0.1] = 1
                label = newlabel
                assert(len(np.unique(label)) <= 3)
            else:  # if label does not have a channel channel
                if len(np.unique(label)) > 2:
                    newlabel = np.zeros((label.shape))
                    newlabel[label>0.1] = 1
                    label = newlabel
                assert(len(np.unique(label)) <= 2)


        # transform images and labels to tensors
        image = tt(image) 
        label = tt(label)
        
       # initialize image and label stack arrays
        if idx == 0:
            image_np = np.zeros((len(image_f), image.shape[0],image.shape[1],image.shape[2]), dtype='float32')
            label_np = torch.zeros((len(image_f), label.shape[0],label.shape[1],label.shape[2]), dtype=torch.float32)
            

        # load images and labels into the stack
        image_np[idx,...] = image
        label_np[idx,...] = label


        assert(len(image_np) == len(label_np))

    
    return image_np, label_np, image_f, label_f

    
# Histogram matching using a sliding window of 9    
def histogram_matching(image_stack, train_with_transforms_flag, ministack_size=9):  
    buffer = math.floor(ministack_size/2)

    counter=0

    # histogram matched images will be loaded into this array
    matched_np = np.copy(image_stack)

    if train_with_transforms_flag:
        for idx in range(image_stack.shape[0]):
            
            # if the section-of-interest in very anterior, then the sliding window is left-smushed
            if idx-buffer < 0:
                start = 0
                this = idx
            else:
                start = idx-buffer
                this = 4

            # if the section-of-interest is very posterior, then the sliding window is right-smushed
            if start+ministack_size > image_stack.shape[0]:
                counter = counter+1
                this = buffer+counter  # this is the center image in the ministack, i.e. the section-of-interest
                start = image_stack.shape[0] - ministack_size
    

            ministack = image_stack[start:start+ministack_size,...]
                   
            
            try:
                assert(ministack.shape[0] == ministack_size)
            except:
                print(f'ministack shape: {ministack.shape}  {ministack_size}')
    
            ministack_median = np.median(ministack, axis=0)

            matched = match_histograms(ministack[this,...], ministack_median, channel_axis=1)  # -1

            # place all histogram-matched images into the matched_np array
            matched_np[idx,...] = matched

            # visualization
            if False:
                print(f'in histogram matching, start: {start}, this: {this}')
                for z in range(ministack.shape[0]):
                    plt.subplot(1,ministack.shape[0]+2,z+1)
                    if z == 0:    
                        plt.imshow(np.transpose(image_stack[idx,:,:,:],(1,2,0)))
                        plt.title('image '+str(idx))
                        plt.axis('off')
                    else:
                        plt.imshow(np.transpose(ministack[z-1,...], (1,2,0)))
                        plt.title('ministack '+str(z-1))
                        plt.axis('off')
                plt.show()
                
            

    matched_np = np.transpose(matched_np, (0,2,3,1))  # has to be in  (H x W x C) 
    print(f'in histogram_matching, after transpose matched_np: {matched_np.shape}')   #  (10, 720, 420, 3)

    # numpy array to torch array
    matched_torch = torch.zeros((matched_np.shape[0], matched_np.shape[-1], matched_np.shape[1], matched_np.shape[2]))
    for i in range(matched_np.shape[0]):
        matched_torch[i,...] = tt(matched_np[i,...])
    
    return matched_torch  # torch.Size([60, 3, 720, 420]), (60, 3, 1440, 840)


# send image and label stacks to augm_transforms    
def trafo(image_stack, label_stack, train_with_transforms_flag): 
    if train_with_transforms_flag:
        image_stack = augm_transforms(image_stack)

    return image_stack, label_stack




In [None]:
plt.rcParams["figure.figsize"] = (15,5)

# load images and labels for training and validation
if True:
    train_images_list_all = []
    train_labels_list_all = []
    train_images_names_all = []
    train_labels_names_all = []
    
    train_with_transforms_flag = True
    
    for i in range(len(train_images)):  # list of list of cmps
        train_images_1fold=[]
        train_labels_1fold=[]
        image_np, label_np, train_image_names, train_label_names = load_stack_names(train_images[i], train_labels[i]) #image: (60, 1440, 840, 3), label: torch.Size([60, 3, 1440, 840])
        
        if train_with_transforms_flag:
            new_image_np = histogram_matching(image_np, train_with_transforms_flag=True)  # torch.Size([60, 3, 1440, 840])
        else:
            new_image_np = histogram_matching(image_np, train_with_transforms_flag=False)
                             
        assert(len(new_image_np) == len(label_np))

        train_images_list_all.extend(new_image_np)
        train_labels_list_all.extend(label_np)
        train_images_names_all.extend(train_image_names)
        train_labels_names_all.extend(train_label_names)
        assert(len(train_images_list_all) == len(train_labels_list_all))
        assert(len(train_images_list_all) == len(train_images_names_all))

        del new_image_np, image_np, label_np, train_image_names, train_label_names
    
    print(f'images: {len(train_images_list_all)}, labels: {len(train_labels_list_all)}')

if True:
    val_images_list_all = []
    val_labels_list_all = []
    val_images_names_all = []
    val_labels_names_all = []
        
    train_with_transforms_flag = True
    
    for i in range(len(val_images)):  # list of list of cmps
        val_images_1fold=[]
        val_labels_1fold=[]
        image_np, label_np, val_image_names, val_label_names = load_stack_names(val_images[i], val_labels[i]) #image: (60, 1440, 840, 3), label: torch.Size([60, 3, 1440, 840])
    
        if train_with_transforms_flag:
            new_image_np = histogram_matching(image_np, train_with_transforms_flag=True)  # torch.Size([60, 3, 1440, 840])
        else:
            new_image_np = histogram_matching(image_np, train_with_transforms_flag=False)
            
        assert(len(new_image_np) == len(label_np))
      
        val_images_list_all.extend(new_image_np)
        val_labels_list_all.extend(label_np)
        val_images_names_all.extend(val_image_names)
        val_labels_names_all.extend(val_label_names)
        assert(len(val_images_list_all) == len(val_labels_list_all))
        assert(len(val_images_list_all) == len(val_labels_names_all))
    
    print(f'images: {len(val_images_list_all)}, labels: {len(val_labels_list_all)}')

## Dataset and Dataloaders

In [None]:
train_dataset = ishDataset_flex_wnames(train_images_list_all, train_labels_list_all, train_images_names_all)
train_loader = DataLoader(train_dataset, batch_size=1, num_workers=10)

val_dataset = ishDataset_flex_wnames(val_images_list_all, val_labels_list_all, val_images_names_all)
val_loader = DataLoader(val_dataset, batch_size=1, num_workers=10)

## Helper functions - find and load the relevant segmentations for a given image

In [None]:
ttt = transforms.ToTensor()

# find the gene name and slice name of a given ISH image, so that the corresponding segmentations can be found
def find_segs(name, seg_paths_all):
    """
    name is from n[0]    
    """

    # 1 find the P0 cmp and corresponding z
    cmp = [x for x in name.split('/') if 'P0_' in x][0]
    z = [x for x in name.split('/') if '.png' in x][0]
    print(f'cmp: {cmp}, z: {z}')
    
    # 2 find the p0 cmp's matching segmentation dirs
    if cmp in [x for x in os.listdir(seg_paths_all[0])]:  # does not matter what the index of seg_paths_all should be
        seg_idirs_4_cmp_z = []
        for i in seg_paths_all: 
            if os.path.exists(i+cmp):
                #seg_idir = i+p0_id+'/'
                seg_dir = i+cmp+'/'+z
                seg_idirs_4_cmp_z.append(seg_dir)
            else:
                print(f'no matching seg idir for {i} and {cmp}')
    else:
        print(f'no {cmp} in segmentations {os.listdir(seg_paths_all[0])}')
        

    return seg_idirs_4_cmp_z
    
# load the relevant segmentations for a given ISH image, found by find_segs function
def load_segs(seg_idirs_4_cmp_z):
    # 3 load the relevant SEGS #
    for k in range(len(seg_idirs_4_cmp_z)):
        
        seg = np.asarray(Image.open(seg_idirs_4_cmp_z[k]))
        half = int(seg.shape[1]/2)
        seg = seg[:,half:]
    
        # downsize the segmentations as the images and labels were downsized
        seg = resize(seg, (256, 128))
        
        seg = seg.astype('float32')
        seg /= 255
    
        seg = ttt(seg)
        
        if k==0:
            seg_all = torch.zeros((len(seg_idirs_4_cmp_z), seg.shape[1], seg.shape[2]))

        seg_all[k] = seg[0]
        
    if seg_all.max() < 0.999:
        seg_all = seg_all*1/seg_all.max()
    seg_all = torch.unsqueeze(seg_all, dim=0)
    seg_all = seg_all.to(device)

    return seg_all
    
    ####################################

## Helper functions - plotting

In [None]:
# plot an ISH image, its GT label, and all the segmentations
def plot_x_y_yhat_preds(x,y,yhat,pred_all_stack, epoch, vis_out):

    # for naming
    if epoch < 10:
        str_epoch = '_epoch00'+str(epoch)
    elif epoch < 100:
        str_epoch = '_epoch0'+str(epoch)
    else:
        str_epoch = '_epoch'+str(epoch)        

    num_rows = 1
    num_cols = pred_all_stack.shape[1]+3
    
    fig, axs = plt.subplots(num_rows,num_cols, figsize=(20,6))
    
    for i in range(pred_all_stack.shape[1]+3):    
        if i==0:
            axs[i].imshow(torch.permute(x[0,...],(1,2,0)).clone().cpu().detach().numpy())
            axs[i].title.set_text('x')
        elif i==1:
            axs[i].imshow(y[0,0,...].clone().cpu().detach().numpy())
            axs[i].title.set_text('y')
        elif i==2:
            axs[i].imshow(yhat[0,0].clone().cpu().detach().numpy())
            axs[i].title.set_text('yhat')
        else:
            axs[i].imshow(pred_all_stack[0,i-3,...].clone().cpu().detach().numpy())
            axs[i].title.set_text('pred'+str(i))
        axs[i].axis('off')

    plt.tight_layout()
    plt.savefig(vis_out+str_epoch+'_xyyhatpreds.png')
    

## Meta-Net

In [None]:
def convbnrelu_im_33(in_channels, out_channels):
    return nn.Sequential(
        #nn.Conv3d(in_channels, out_channels, kernel_size=(3,3,3), stride=(1,1,1), padding='same'),
        nn.Conv2d(in_channels, out_channels, kernel_size=(3,3), stride=(1,1), padding=(1,1)),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
        #nn.Tanh(),
    )

In [None]:
import torch.nn.functional as F

class Meta_Unet(nn.Module):
    def __init__(self, im_shape, num_models, init_features=16):  # in_channels=11, 
        super(Meta_Unet, self).__init__()

        self.im_shape = im_shape
        self.num_models = num_models
        
        if len(self.im_shape) == 2:
            self.im_size = self.im_shape[0] * self.im_shape[1]
        elif len(self.im_shape) == 3:
            self.im_size = self.im_shape[0] * self.im_shape[1] * self.im_shape[2]
        elif len(self.im_shape) == 4:
            self.im_size = self.im_shape[0] * self.im_shape[1] * self.im_shape[2] * self.im_shape[3]
        else:
            print('error')
    
        self.in_channels = num_models + 1  # add 1 channel for the image
        self.hidden_size = 128

        self.device = device

        self.conv2d_im = nn.Conv2d(3, 1, kernel_size=(3,3), stride=(1,1), padding=(1,1))  # reduce 3 ch in image to 1 ch


        ## ENCODER
        self.maxpool = nn.MaxPool2d(kernel_size=(2,2), stride=2)
        self.conv_down1 = convbnrelu_im_33(self.in_channels, init_features)
        self.conv_down2 = convbnrelu_im_33(init_features, init_features*2)
        self.conv_down3 = convbnrelu_im_33(init_features*2, init_features*4)
        self.conv_down4 = convbnrelu_im_33(init_features*4, init_features*8)
        self.bottleneck = convbnrelu_im_33(init_features*8, init_features*16)

        ## DECODER
        self.convT4 = nn.ConvTranspose2d(init_features*16, init_features*8, kernel_size=(2,2), stride=(2,2))
        self.conv_up4 = convbnrelu_im_33((init_features*8)*2, init_features*8)
        self.convT3 = nn.ConvTranspose2d(init_features*8, init_features*4, kernel_size=(2,2), stride=(2,2))
        self.conv_up3 = convbnrelu_im_33((init_features*4)*2, init_features*4)
        self.convT2 = nn.ConvTranspose2d(init_features*4, init_features*2, kernel_size=(2,2), stride=(2,2))
        self.conv_up2 = convbnrelu_im_33((init_features*2)*2, init_features*2)
        self.convT1 = nn.ConvTranspose2d(init_features*2, init_features, kernel_size=(2,2), stride=(2,2))
        self.conv_up1 = convbnrelu_im_33(init_features*2, init_features)
        self.conv_up0 = convbnrelu_im_33(init_features, self.in_channels)

        self.conv2d_out = nn.Conv2d(self.in_channels, self.in_channels, kernel_size=(1,1), stride=(1,1), padding=(0,0))
        


    def forward(self, x, yhat_all):

        ## PREPARE IMAGE AND SEGS
        # reduce x from 3,512,512 to 1,512,512 which matches preds
        x = self.conv2d_im(x)

        # concat x and preds stack
        x_preds = torch.cat((x, yhat_all), dim=1)

        
        ## ENCODE    
        # values for 3,512,512 im input and 1,512,512 preds                                                   
        enc1 = self.conv_down1(x_preds)               # torch.Size([1, 16, 512, 512])
        enc2 = self.conv_down2(self.maxpool(enc1))    # torch.Size([1, 32, 256, 256])
        enc3 = self.conv_down3(self.maxpool(enc2))    # torch.Size([1, 64, 128, 128])
        enc4 = self.conv_down4(self.maxpool(enc3))    # torch.Size([1, 128, 64, 64])
        
        bottle = self.bottleneck(self.maxpool(enc4))  # torch.Size([1, 256, 32, 32])
        

        ## DECODE
        dec4 = self.convT4(bottle)  
        dec4 = torch.cat([dec4, enc4], dim=1)
        dec4 = self.conv_up4(dec4)       
        
        dec3 = self.convT3(dec4)
        dec3 = torch.cat([dec3, enc3], dim=1)
        dec3 = self.conv_up3(dec3)                     # torch.Size([1, 64, 128, 128])
        
        dec2 = self.convT2(dec3)
        dec2 = torch.cat([dec2, enc2], dim=1)
        dec2 = self.conv_up2(dec2)                     # torch.Size([1, 32, 256, 256])
        
        dec1 = self.convT1(dec2)
        dec1 = torch.cat([dec1, enc1], dim=1)
        dec1 = self.conv_up1(dec1)                     # torch.Size([1, 16, 512, 512])
        
        out = self.conv_up0(dec1)                      # torch.Size([1, 11, 512, 512])
        

        original_segmentation_masks = yhat_all
        network_outputs = self.conv2d_out(out)
        meta_weights = torch.nn.functional.softmax(network_outputs,dim=1)   #meta_weights: torch.Size([1, 26, 256, 128])
        meta_weights = meta_weights[:,1:,:,:]   # remove image from calculating the meta-segmentation

        
        meta_out = torch.sum(meta_weights*original_segmentation_masks,dim=1,keepdim=True)        
        
        
        return meta_out, meta_weights




## Train_one_epoch and early stopper

In [None]:
def train_one_epoch(epoch_index, loss_function, vis_opath, tb_writer):
    running_loss = 0.
    last_loss = 0.

    # use enumerate(train_loader) instead of iter(train_loader) to keep track of the batch index

    for i, data in enumerate(train_loader):
        x, y, n = data
        x = x.to(device)
        y = y.to(device)
        
        ########### find the segs ###########
        seg_idirs = find_segs(n[0], seg_paths_all)
        seg_all = load_segs(seg_idirs)
        #####################################

        optimizer.zero_grad()

        yhat, yhat_weights = meta_model(x, seg_all)

        loss = loss_function(yhat, y)
        loss.backward()

        optimizer.step()

        running_loss += loss.item()

        if i % 100 == 99:
            last_loss = running_loss / 100
            tb_x = epoch_index + len(train_loader) + i + 1
            tb_writer.add_scalar('Loss/train', last_loss, tb_x)
            running_loss = 0.

    return last_loss

In [None]:
# https://stackoverflow.com/questions/71998978/early-stopping-in-pytorch
class EarlyStopper:
    def __init__(self, patience=1, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.min_validation_loss = float('inf')

    def early_stop(self, validation_loss):
        if validation_loss < self.min_validation_loss:
            self.min_validation_loss = validation_loss
            self.counter = 0
        elif validation_loss > (self.min_validation_loss + self.min_delta):
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False


## Training

In [None]:
loss_function = DiceLoss(sigmoid=False, include_background=True)  # x: ff ; o: ft
ttt = transforms.ToTensor()

# detaches the gradient for some reason
post_pred = Compose([Activations(sigmoid=False), AsDiscrete(threshold=0.5005)])

dice_metric = DiceMetric(include_background=False, reduction="mean", get_not_nans=False)
metric_values = []
loss_values = []

max_epochs = 100
best_metric = -1
val_interval = 1


weights_ls = []
val_weights_ls = []
train_loss_ls = []
val_loss_ls = []
val_metric_ls = []

# set up naming for the saved models
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
print("timestamp:", timestamp)

workdir = '/path/to/where/you/want/to/save/things/'
last_model_path = workdir+timestamp+'_256128_meta_unet_lastmodel.pth'


##### create vis_odir #####
vis_odir = workdir.replace('/runs/','/vis/')
os.makedirs(vis_odir, exist_ok=True)
vis_opath = vis_odir+timestamp


##### initialize the model #####
x,y,n = next(iter(train_loader))
meta_model = Meta_Unet(x.shape, num_models)
meta_model = meta_model.to(device)

optimizer = torch.optim.Adam(meta_model.parameters(), 1e-3)


# initialize early stopper
early_stopper = EarlyStopper(patience=3, min_delta=.04)

writer = SummaryWriter(workdir+'test_{}'.format(timestamp))




for epoch in range(max_epochs):

        print(f'{epoch}/{max_epochs}')

        meta_model.train(True)
        avg_loss = train_one_epoch(epoch, loss_function, vis_opath, writer)
        print(f'{epoch}/{max_epochs}: {avg_loss}')
        train_loss_ls.append([epoch, avg_loss])
        #clear_output(wait=True)

        
        if (epoch + 1) % val_interval == 0:
            running_vloss = 0.
            
            # Set the model to evaluation mode, disabling dropout and using population
            # statistics for batch normalization.
            meta_model.eval()
            
            with torch.no_grad():
                count = 0
                ### for validation data to test on unseen data ###
                val_x = None
                val_y = None
                
                #for val_data in val_loader:
                for i, val_data in enumerate(val_loader):

                    # forward pass
                    val_x, val_y, val_n = val_data[0].to(device), val_data[1].to(device), val_data[2]
                    # x: torch.Size([1, 3, 1024, 512]
                    # y: torch.Size([1, 1, 1024, 512]
                    
                    
                    ########### find the segs ###########
                    val_seg_idirs = find_segs(val_n[0], seg_paths_all)
                    val_seg_all = load_segs(val_seg_idirs)
                    ####################################

                        
                    val_yhat, val_yhat_weights = meta_model(val_x, val_seg_all)
                    vloss = loss_function(val_yhat, val_y)
                    running_vloss += vloss.item()

                        
                    val_yhat = post_pred(val_yhat)

                    dice_metric(y_pred=val_yhat, y=val_y)   

                    
                # aggregate the final mean dice result
                metric = dice_metric.aggregate().item()

                # reset the status for next validation round
                val_metric_ls.append([epoch, metric])
                dice_metric.reset()

                metric_values.append([epoch, metric])
                print(f'epoch: {epoch}, metric: {metric}')

                
                if (epoch+1) % 10 == 0:
                    plot_x_y_yhat_preds(val_x, val_y, val_yhat, val_seg_all, epoch, vis_opath)
                
                count += 1
                
                 # plot
                if True:
                    if epoch < 10:
                        str_epoch = '00' + str(epoch)
                    elif epoch < 100:
                        str_epoch = '0' + str(epoch)
                    plt.subplot(1,3,1)
                    plt.imshow(val_x[0,0].cpu().detach().numpy())
                    plt.title('x')
                    plt.subplot(1,3,2)
                    plt.imshow(val_yhat[0,0].cpu().detach().numpy())
                    plt.title('yhat')
                    plt.subplot(1,3,3)
                    plt.imshow(val_y[0,0].cpu().detach().numpy())
                    plt.title('y')
                    plt.savefig(vis_opath+str_epoch+'.png')
                    plt.show()

                del val_seg_all

                avg_vloss = running_vloss / (i+1)
                print('LOSS train {} val {}'.format(avg_loss, avg_vloss))
                val_loss_ls.append([epoch, avg_vloss])

                # early stopper
                if early_stopper.early_stop(avg_vloss):             
                    break

                # log the running loss averaged per batch for both training and validation
                writer.add_scalars('train vs validation loss',
                                  {'train': avg_loss, 'val': avg_vloss},
                                   epoch +1)
                writer.flush()
                
                if metric > best_metric:
                    best_metric = metric
                    best_metric_epoch = epoch + 1
                    best_model_path = '{}_{}_256128_bestmodel.pth'.format(workdir,timestamp)
                    #torch.save(meta_model.state_dict(), best_model_path)
                    state = {'epoch' : epoch+1, 'state_dict' : meta_model.state_dict(), 'optimizer' : optimizer.state_dict()}
                    torch.save(state, best_model_path)
                    
                    print("current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}".format(epoch + 1, metric, best_metric, best_metric_epoch))
                    print(f"model saved at {best_model_path}")

            #clear_output(wait=True)

        # just in case, write the train loss, val loss, and val metric
        with open(workdir+timestamp+'0_train_loss.txt','w') as f:
            for line in train_loss_ls:
                f.write(f'{line}\n')   

        with open(workdir+timestamp+'0_val_loss.txt','w') as g:
            for line in val_loss_ls:
                g.write(f'{line}\n')   

        with open(workdir+timestamp+'0_val_metric.txt','w') as h:
            for line in metric_values:
                h.write(f'{line}\n')   

#torch.save(meta_model.state_dict(), last_model_path)
state = {'epoch' : epoch+1, 'state_dict' : meta_model.state_dict(), 'optimizer' : optimizer.state_dict()}
torch.save(state, last_model_path)
print(f'------ done at epoch {epoch}, saved last model as {last_model_path}')

## Inference

### Prepare test images and labels

In [None]:
test_images = ['/directory/with/some/test/ISH/images/A/',
'/directory/with/some/test/ISH/images/B/',
'/directory/with/some/test/ISH/images/C/',]

test_labels = [x.replace('/image/','/label/') for x in test_images]

In [None]:
if True:
    test_images_list_all = []
    test_labels_list_all = []
    test_images_names_all = []
    test_labels_names_all = []
        
    test_with_transforms_flag = True
    
    for i in range(len(test_images)):  # list of list of cmps
        test_images_1fold=[]
        test_labels_1fold=[]
        image_np, label_np, test_image_names, test_label_names = load_stack_names(test_images[i], test_labels[i]) #image: (60, 1440, 840, 3), label: torch.Size([60, 3, 1440, 840])
    
        if test_with_transforms_flag:
            new_image_np = histogram_matching(image_np, train_with_transforms_flag=True)  # torch.Size([60, 3, 1440, 840])
        else:
            new_image_np = histogram_matching(image_np, train_with_transforms_flag=False)
            
        assert(len(new_image_np) == len(label_np))

        
        test_images_list_all.extend(new_image_np)
        test_labels_list_all.extend(label_np)
        test_images_names_all.extend(test_image_names)
        test_labels_names_all.extend(test_label_names)
        assert(len(test_images_list_all) == len(test_labels_list_all))
        assert(len(test_images_list_all) == len(test_labels_names_all))
    
    print(f'images: {len(test_images_list_all)}, labels: {len(test_labels_list_all)}')

### Visualize images, labels, predictions from test_loader

In [None]:
test_dataset = ishDataset_flex_wnames(test_images_list_all, test_labels_list_all, test_images_names_all)
test_loader = DataLoader(test_dataset, batch_size=1, num_workers=10)

In [None]:
x,y,n = next(iter(test_loader))

print(x.shape, y.shape, n)
plt.subplot(1,2,1)
plt.imshow(torch.permute(x[0],(1,2,0)))
plt.subplot(1,2,2)
plt.imshow(torch.permute(y[0],(1,2,0)))
plt.title(n[0])
plt.tight_layout()
plt.show()

### Get and save predictions

In [None]:
x,y,n = next(iter(test_loader))

meta_model = Meta_Unet(x.shape, num_models)  # im_shape, num_models, device
meta_model = meta_model.to(device)

# if loading a trained model
best_model_path = '/path/to/your/best/model/bestmodel.pth'
check = torch.load(best_model_path, map_location=torch.device('cpu'))

meta_model.load_state_dict(check['state_dict'])

In [None]:
from skimage.filters import threshold_otsu
from sklearn.metrics import precision_recall_fscore_support

plt.rcParams["figure.figsize"] = (12,8)
#plt.rcParams["figure.figsize"] = plt.rcParamsDefault["figure.figsize"]



post_pred = Compose([Activations(sigmoid=False), AsDiscrete(threshold=0.5005)])


for d, data in enumerate(test_loader):

    x,y,n = data[0].to(device), data[1].to(device), data[2]
    
    ########### find the segs ###########
    seg_idirs = find_segs(n[0], seg_paths_all)
    seg_all = load_segs(seg_idirs)              # (torch.Size([1, 25, 256, 128]), torch.float32)
    ####################################
    
    ##### run meta_unet #####
    yhat, yhat_weights = meta_model(x, seg_all)
    
    ##### make numpy versions, ensure shapes, dtypes, are the same #####
    y_np = torch.squeeze(y.detach().cpu()).numpy()
    yhat_np = torch.squeeze(yhat.detach().cpu()).numpy()
    
    
    # threshold, done on numpys
    thresh_meta = threshold_otsu(yhat_np)
    yhat_np = (yhat_np > thresh_meta)*1
    yhat_np = yhat_np.astype(np.float32)
    
    
    ##### make sum of all preds #####
    # all_preds should be tensor of form e.g. torch.Size([1, 25, 512, 512])
    preds_sum = torch.zeros((seg_all[0,0].shape))
    for i in range(seg_all.shape[1]):
        preds_sum += seg_all.detach().cpu()[0,i]
    preds_sum /= seg_all.shape[1]
    # threshold, done on numpys
    thresh_sum = threshold_otsu(preds_sum.numpy())
    preds_sum_01 = (preds_sum > thresh_sum)*1
    preds_sum_01 = preds_sum_01.type(torch.float32)  # has to be kept a Tensor for dice calculation
    preds_sum_01_np = preds_sum_01.detach().cpu().numpy()
    
    # save preds
    yhat_np_int = yhat_np.astype(np.uint8)*255
    pred_meta = Image.fromarray(yhat_np_int)
    pred_meta.save('/directory/to/saving/metanet/outs/'+n[0].split('/')[-2]+'_'+n[0].split('/')[-1])
    
    preds_sum_01_np_int = preds_sum_01_np.astype(np.uint8)*255
    pred_ensemble = Image.fromarray(preds_sum_01_np_int)
    pred_ensemble.save('/directory/to/saving/emsemble/outs/'+n[0].split('/')[-2]+'_'+n[0].split('/')[-1])
    
 