# An automated pipeline to create an atlas of in-situ hybridization gene expression data in the adult marmoset brain (ISBI 2023, Poon, C., et al.)

## UNet based semantic segmentation model with contrastive loss, this is the pretraining model which is only trained using the contrastive loss

In [None]:
# based on 20221027

### Import libraries

In [None]:
import os
import re
import glob
import random
import numpy as np
import time
import cv2

from PIL import Image

import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset, random_split
import torchvision.transforms.functional as tf
import torchvision.transforms as transforms
import torchvision

import matplotlib.pyplot as plt

import pytorch_lightning as pl
from pytorch_lightning.loggers import CSVLogger

torch.manual_seed(42)

In [None]:
!echo $CUDA_VISIBLE_DEVICES
print(torch.__version__)
print(torch.cuda.is_available())
print(torch.cuda.current_device())
print(torch.cuda.get_device_name(0))
print(torch.cuda.device_count())
print(torch.backends.cudnn.version())
print(torch.version.cuda)

#print(torch.cuda.)

### Encoder and decoder blocks of UNet

In [None]:
#%%  encoder and decoder from UNet

def convbnrelu2(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding='same', bias=True),
        nn.BatchNorm2d(out_channels),
        nn.LeakyReLU(), 
    )

def convbnrelu2_T(in_channels, out_channels):
    return nn.Sequential(
        nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2, bias=True),
        nn.BatchNorm2d(out_channels),
        nn.LeakyReLU(), 
    )

class unet_encoder(nn.Module):
    def __init__(self, in_channels=3, out_channels=1, init_features=16): 
        super(unet_encoder, self).__init__()
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.init_features = init_features
        
        self.conv_down1 = convbnrelu2(in_channels, init_features)
        self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2) 
        self.conv_down2 = convbnrelu2(init_features, init_features*2)
        self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)    
        
        self.bottleneck = convbnrelu2(init_features*2, init_features*4)  
    
    def forward(self, x):
        enc1 = self.conv_down1(x)
        
        enc2 = self.conv_down2(self.maxpool1(enc1))
        bottleneck = self.bottleneck(self.maxpool2(enc2))
        
        return bottleneck#, enc_list

class unet_decoder(nn.Module):
    def __init__(self, init_features=16, out_channels=1):  #64
        super(unet_decoder, self).__init__()
    
        self.convT2 = convbnrelu2_T(init_features*4, init_features*2)
        self.conv_up2 = convbnrelu2(init_features*2, init_features*2)
        
        self.convT1 = convbnrelu2_T(init_features*2, init_features)
        self.conv_up1 = convbnrelu2(init_features, init_features)   
        
        self.final_layer = nn.Conv2d(init_features, out_channels, kernel_size=1)

    def forward(self, bottleneck):
        
        dec2 = self.convT2(bottleneck)
        dec2 = self.conv_up2(dec2)
        
        dec1 = self.convT1(dec2)
        dec1 = self.conv_up1(dec1)
        
        out = self.final_layer(dec1)
        out = torch.sigmoid(out)
        
        return(out)

### Define contrastive transformations (see SimCLR paper for details)

In [None]:
class ContrastiveTransformations:
    def __init__(self, flag, n_views=2):  
            self.base_transforms = augm_transforms
        elif flag == 'noaugm':
            self.base_transforms = noaugm_transforms
        self.n_views = n_views

    def __call__(self, x):
        return [self.base_transforms(x) for i in range(self.n_views)]

    
augm_transforms = transforms.Compose(
    [  # only apply to image
        transforms.RandomApply([transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1)], p=0.8),
        transforms.RandomGrayscale(p=0.2),
        transforms.GaussianBlur(kernel_size=9),
    ]
)

noaugm_transforms = transforms.Compose(
    [
    ]
)    
print("ok")

### Dataset used for contrastive learning

In [None]:
class ishDataset_c(Dataset):
    def __init__(self, root_dir: str, img_type: str):  
        self.img_path = root_dir
        self.img_fn_list = sorted(glob.glob(self.img_path+'*.'+self.img_type))  
        self.img_list = []

        self.fast = True
        self.tt = transforms.ToTensor()
        
        if self.fast:
            for idx in range(len(self.img_fn_list)):
                im = Image.open(self.img_fn_list[idx])
                im = self.transform(im)
                im = self.tt(im)
                self.img_list += [im]
        
    def __len__(self):
        return len(self.img_fn_list)
    
    def transform(self, image): 
        i, j, h, w = transforms.RandomResizedCrop.get_params(image,scale=(0.9, 1.5), ratio=(0.9, 1.33))  
        output_size=(400,400)
        image = tf.resized_crop(image, i, j, h, w, output_size)

        if random.random() > 0.5:
            image = tf.hflip(image)

        return image
    
    def __getitem__(self, idx):
        if self.fast:  
            image = self.img_list[idx]

        return image
        
        
print("ok")

### Load datasets and visualize some samples from dataloaders

In [None]:
"""
Data is assumed to be organized in the following structure:

├── root directory                        , ie root_dir      (used in pretraining)
│   ├── image directory (training)        , ie img_suffix    (not used in pretraining)
│   ├── label directory (training)        , ie label_suffix  (not used in pretraining)
│   ├── image directory (testing)         , ie img_suffix    (not used in pretraining)
│   ├── label directory (testing)         , ie label_suffix  (not used in pretraining)
│   ├── image only directory (training)   , used in trainc_data  (used in pretraining)


where the image directory has the path:     (not used in pretraining)
/root_dir/img_suffix/

and the label directory has the path:       (not used in pretraining)
/root_dir/label_suffix/

and the images only directory has the path: (used in pretraining)
/root_dir/img_only

"""

train_dataset = ishDataset_c(root_dir='/home/username/projectA/img_suffix/')
train_c = ishDataset_c(root_dir='home/username/projectA/img_only/')
train_all = train_dataset + train_c

print('train_dataset size:',len(train_all))
training_size = int(0.7 * len(train_all))
val_size = len(train_all) - training_size
print('training_size:', training_size, 'val_size:', val_size)
train_data, val_data = random_split(train_all, [training_size, val_size])

train_dataloader = DataLoader(train_data, batch_size=8, shuffle=True, drop_last=True, num_workers=1)
val_dataloader = DataLoader(val_data, batch_size=8, shuffle=False, drop_last=True, num_workers=1)

                                
if True:
    x = next(iter(train_dataloader))
    print(x[0].shape)
    for s in range(4):
        plt.imshow(np.transpose(x[s], (1,2,0)), cmap='gray')
        plt.title('train images')
        plt.show()

    print("ok")
    print(train_dataloader.__len__())

### Segmentation model

In [None]:
from IPython.display import clear_output

class SegModel(pl.LightningModule):
    def __init__(self, hidden_dim, bottleneck_ch, bottleneck_h, bottleneck_w, lr, temperature, weight_decay, p_conloss, out_name):
        super(SegModel, self).__init__()
        self.save_hyperparameters()
        
        self.automatic_optimization = False
        self.hidden_dim=hidden_dim
        self.lr = lr
        self.temperature = temperature
        self.weight_decay = weight_decay
        self.batch_size = 8
        self.bottleneck_ch = bottleneck_ch
        self.bottleneck_w = bottleneck_w
        self.bottleneck_h = bottleneck_h
        self.counter = 0
        self.out_name = out_name
        self.p_conloss = p_conloss
    
        self.encoder = unet_encoder(in_channels=3, out_channels=1, init_features=16)  #64
        self.decoder = unet_decoder(init_features=16, out_channels=1)  #64
        
        self.trainset = train_data
        self.valset = val_data
        
        self.ct_null = ContrastiveTransformations(flag='noaugm', n_views=2)
        self.ct = ContrastiveTransformations(flag='augm', n_views=2)
        
        # conv2d for the MLP (below)
        self.conv2d = torch.nn.Conv2d(
            in_channels=64,  #1024
            out_channels=hidden_dim*4,
            kernel_size=[bottleneck_w, bottleneck_h],
            stride = [bottleneck_h, bottleneck_w]
        )
        # MLP used for the contrastive loss
        self.mlp = nn.Sequential( 
            self.conv2d,
            nn.Flatten(),
            nn.ReLU(),
            nn.Linear(64*hidden_dim, hidden_dim),
            nn.Linear(hidden_dim,4*2)
        )
        
    
    def train_dataloader(self):
        return DataLoader(self.trainset, batch_size=self.batch_size, shuffle=True, drop_last=True, num_workers=1)
        
    def val_dataloader(self):
        return DataLoader(self.valset, batch_size=self.batch_size, shuffle=False, drop_last=True, num_workers=1)
    
    def forward(self, x):        
        bottleneck = self.encoder(x)
        preds = self.decoder(bottleneck)
        return (bottleneck, preds)
    
    def training_step(self, batch, batch_idx) :  
        
        if not self.automatic_optimization:
            opt = self.optimizers()
            opt.zero_grad()
            
        img_train = batch 
        print('img_train len:', len(img_train))

        
        # INFO NCE LOSS #         
               
        img_train_c = self.ct(img_train)
        imgcat_train_c = torch.cat(img_train_c, dim=0)
        bottleneck_train_c, _ = self.forward(imgcat_train_c)
        
        loss_nce_train = self.info_nce_loss(bottleneck_train_c, mode='train')
        self.log("train_contr_loss", loss_nce_train, on_epoch=True)        
        
       
        self.counter += 1

        if not self.automatic_optimization:
            print('if not self.automatic_optimization:')
            self.manual_backward(loss_nce_train)
            opt.step()
            
        else:
            return {'loss' :loss_nce_train}
    
        
    
    def validation_step(self, batch, batch_idx):  # 
        
        if self.counter % 5 == 0:
            
            clear_output(wait=True)
            
            x = batch         
            
            # INFO NCE LOSS #  
            x_c = self.ct(x)
            x_cat_c = torch.cat(x_c, dim=0)

            b_c, _ = self.forward(x_cat_c)
            loss_c = self.info_nce_loss(b_c, mode='train')
            self.log("val_cont_loss", loss_c, on_epoch=True)

         
            # visualize validation outputs
            if True:
                cx1 = torch.cat(
                            (x_c[0].detach().cpu()[0,0,...], x_c[0].detach().cpu()[1,0,...],
                              x_c[0].detach().cpu()[2,0,...], x_c[0].detach().cpu()[3,0,...]),dim=1   
                        )
                plt.figure(figsize = (20,5))
                plt.imshow(cx1)
                plt.title('contrastive: x')
                plt.show()


                sx = torch.cat(
                            (x.detach().cpu()[0,0,...], x.detach().cpu()[1,0,...],
                             x.detach().cpu()[2,0,...], x.detach().cpu()[3,0,...]),dim=1)
                plt.figure(figsize = (20,5))
                plt.imshow(sx)
                plt.title('supervised: x')
                plt.show()

        if self.counter % 20 == 0:
            # naming
            out_dir = mid_outpath+self.out_name+'/'
            if not os.path.exists(out_dir):
                os.mkdir(out_dir)            
            
            out_name_root = out_dir + str(self.counter) + '_'
            
            cv2.imwrite((out_name_root+'con_x.png'), np.asarray(cx1*255).astype(np.uint8))

        self.counter +=1

        
    def info_nce_loss(self, bottleneck, mode):
        # from https://theaisummer.com/simclr/
        
        feats = self.mlp(bottleneck)
        mask = (~torch.eye(self.batch_size * 2, self.batch_size * 2, dtype=bool, device=self.device)).float()

        ai_cos = F.cosine_similarity(feats.unsqueeze(1), feats.unsqueeze(0), dim=2)

        sim_ij = torch.diag(ai_cos, self.batch_size) 
        sim_ji = torch.diag(ai_cos, -self.batch_size)  

        positives = torch.cat([sim_ij, sim_ji], dim=0)
        # nominator has positive pairs only
        nominator = torch.exp(positives / self.temperature)
        # denominator has both positive and negative pairs, but mask each element from itself (inverse identity)
        denominator = mask * torch.exp(ai_cos / self.temperature)
       
        all_losses = -torch.log(nominator / torch.sum(denominator, dim=1))
        nce_loss = torch.sum(all_losses) / (2 * self.batch_size)
        print('nce_loss:', nce_loss)  
        
        self.log(mode + "_nce_loss", nce_loss)
        
        return nce_loss
    
   
    def configure_optimizers(self):
        opt = torch.optim.Adam(self.parameters())
        return [opt]#, [sch]
    
    


### Run pretraining

In [None]:
# create name for saved model
date = 20221027
max_epochs = 100
p_conloss = 1
out_name=str(date)+'_pretraining_epochs'+str(max_epochs)+'_patch400'

segmodel = SegModel(hidden_dim=32, bottleneck_ch=64, bottleneck_h=24, bottleneck_w=24, lr=1e-3, temperature=.07, weight_decay=1e-4, p_conloss=p_conloss, out_name=out_name)

# define checkpoint directory
ckpt_dir = '/home/username/projectA/ckpt/'
os.makedirs(ckpt_dir,exist_ok=True)

if True:
    print('training with if True:')
    trainer = pl.Trainer(logger=CSVLogger(save_dir = 'csvlogs/', name=out_name),
                         accelerator="gpu",
                         gpus=1,
                         devices='gpus',
                         default_root_dir = ckpt_dir,
                         enable_progress_bar=True,
                         enable_model_summary=True,
                         max_epochs=max_epochs)  

    trainer.fit(segmodel)

else:  # pure pytorch
    x,y = next(iter(train_dataloader))
    optimizer = torch.optim.Adam(segmodel.parameters())
    loss = torch.nn.BCELoss()
    segmodel.cuda()
    for a in range(10000+1):
        patch = x.cuda()
        patch_gt = y.cuda()
        segmodel.zero_grad()
        tmp,out_ = segmodel(patch)
        prediction = torch.sigmoid(out_)
        l = loss(prediction,patch_gt)
        l.backward()
        optimizer.step()

        if True:
            if a % 100 == 0:
                print('a:',a)
                
                with torch.no_grad():
                    clear_output(wait=True)
                    print("iter ",a," loss : ",l)
                    segmodel.eval()
                    tmp,out_ = segmodel(patch)
                    prediction = torch.sigmoid(out_)
                    rimg = torch.cat(
                        (prediction.detach().cpu()[0,0,...],patch_gt.cpu()[0,0,...]),dim=1
                    )
                    plt.imshow(rimg)
                    plt.show()


### Visualize some metrics

In [None]:
import pandas as pd
import seaborn as sn

plt.rcParams['figure.figsize'] = [4.0, 4.0]
plt.rcParams['figure.dpi'] = 70

#metrics = pd.read_csv('/home/charissa/shimogori/shimogori_adult/contrastive_learning/lightning_code/csvlogs/lightning_logs/version_6/metrics.csv')
metrics = pd.read_csv(f"{trainer.logger.log_dir}/metrics.csv")
del metrics["step"]
del metrics["train_loss_step"]  
del metrics["train_loss_epoch"]
del metrics['train_nce_loss']
del metrics['train_contr_loss_step']
del metrics['train_sup_loss_step']

metrics.set_index("epoch", inplace=True)
display(metrics.dropna(axis=1, how="all").head())
sn.relplot(data=metrics, kind="line", height=15)

### Visualize some results from pretraining

In [None]:
# if you want to load a different, saved model
segmodel = SegModel(hidden_dim=32, bottleneck_ch=64, bottleneck_h=24, bottleneck_w=24, lr=1e-3, 
                    temperature=.07, weight_decay=1e-4, p_conloss=1, out_name='meow')  # preds_h,w = 6 for patch64,96, 3 for patch48
    
ckpt_path = '/home/charissa/shimogori/shimogori_adult/segmentation/contrastive_learning/lightning_code/csvlogs/202210124_pretraining_epochs500/version_0/checkpoints/epoch=499-step=177000_copy.ckpt'
checkpoint = torch.load(ckpt_path)
segmodel.load_state_dict(checkpoint['state_dict'])

In [None]:
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = [8.0, 8.0]
plt.rcParams['figure.dpi'] = 140

t_invert = transforms.RandomInvert(p=0.5)

segmodel.eval()
with torch.no_grad():
    x,y = next(iter(train_dataloader))
    x = t_invert(x)
    bottleneck, val_preds = segmodel(x) 
    print(x[0,0,...].shape)    
    print(y[0,0,...].shape)
    print(val_preds[0,0,...].shape)


    frame1 = torch.cat((x.detach().cpu()[0,0,...],val_preds.detach().cpu()[0,0,...],y.cpu()[0,0,...]),dim=1)
    plt.imshow(frame1, cmap='gray')
    plt.title('x, preds, y')
    plt.show()
    frame2 = torch.cat((x.detach().cpu()[1,0,...],val_preds.detach().cpu()[1,0,...],y.cpu()[1,0,...]),dim=1)
    plt.title('x, preds, y')
    plt.imshow(frame2, cmap='gray')
    plt.show()



    print('x[0,0,...]:',x[0,0,...].min(), x[0,0,...].max(), x[0,0,...].dtype)
    print('y.cpu()[0,0,...]:',y.cpu()[0,0,...].min(), y.cpu()[0,0,...].max(), y.cpu()[0,0,...].dtype)
    print('val_preds:',val_preds.detach().cpu()[0,0,...].min(), val_preds.detach().cpu()[0,0,...].max(), val_preds.detach().cpu()[0,0,...].dtype)
    assert x[0,0,...].min() >= 0   
    assert x[0,0,...].max() >= 0 
    assert x[0,0,...].max() <= 1 
    assert y[0,0,...].min() >= 0 
    assert y[0,0,...].max() <= 1 
    assert val_preds[0,0,...].min() >= 0 
    assert val_preds[0,0,...].max() <= 1 

### If you would like to count some parameters

In [None]:
PATH = '/home/charissa/shimogori/shimogori_adult/contrastive_learning/lightning_code/csvlogs/lsup_only/version_3/checkpoints/epoch=49-step=4950.ckpt'
segmodel2 = SegModel.load_from_checkpoint(PATH).cpu()
segmodel2.eval()

pytorch_total_params = sum(p.numel() for p in segmodel2.parameters() if p.requires_grad)
print(pytorch_total_params)