In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import numpy as np
from IPython.display import clear_output
import matplotlib.pyplot as plt
import ipywidgets as ipyw
from monai.losses import GeneralizedDiceLoss
import sys
from torch.utils.data import Subset
from numpy.random import choice
from scipy import ndimage as nd
from scipy.ndimage import binary_opening
import random
import nibabel as nib
import warnings
warnings.filterwarnings("ignore")


class DoubleConv3D(nn.Module):
    """Two consecutive 3D convolutional layers with ReLU activation."""
    def __init__(self, in_channels, out_channels):
        super(DoubleConv3D, self).__init__()
        self.double_conv = nn.Sequential(
            nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        return self.double_conv(x)

class UNet3D(nn.Module):
    def __init__(self, in_channels=1, out_channels=1, features=[32, 64, 128, 256]):
        super(UNet3D, self).__init__()
        self.encoder_layers = nn.ModuleList()
        self.pool = nn.MaxPool3d(kernel_size=2, stride=2)
        
        for feature in features:
            self.encoder_layers.append(DoubleConv3D(in_channels, feature))
            in_channels = feature
        
        self.bottleneck = DoubleConv3D(features[-1], features[-1]*2)
        
        self.up_transpose = nn.ModuleList()
        self.decoder_layers = nn.ModuleList()
        reversed_features = features[::-1]
        decoder_in_channels = features[-1]*2  
        for feature in reversed_features:
            self.up_transpose.append(
                nn.ConvTranspose3d(decoder_in_channels, feature, kernel_size=2, stride=2)
            )
            self.decoder_layers.append(DoubleConv3D(feature*2, feature))
            decoder_in_channels = feature  
        
        self.conv_final = nn.Conv3d(features[0], out_channels, kernel_size=1)
    
    def forward(self, x):
        skip_connections = []
        
        for encoder in self.encoder_layers:
            x = encoder(x)
            skip_connections.append(x)
            x = self.pool(x)
        
        x = self.bottleneck(x)

        skip_connections = skip_connections[::-1]
        
        for idx in range(len(self.up_transpose)):
            x = self.up_transpose[idx](x)
            skip_connection = skip_connections[idx]
            
            if x.shape != skip_connection.shape:
                x = F.interpolate(x, size=skip_connection.shape[2:])
            
            x = torch.cat((skip_connection, x), dim=1)
            x = self.decoder_layers[idx](x)
        
        x = self.conv_final(x)
        return torch.sigmoid(x)  

In [2]:
device = 'cuda:3'

In [3]:
# get list of healthy subjects
healthy_sub = ['00021', '00051', '00069', '00057', '00075', '00170', '00106', '00162',
 '00008', '00117', '00113', '00039', '00049', '00129', '00153', '00042', '00150',
 '00135', '00088', '00168', '00137', '00127', '00013', '00165', '00005', '00085',
 'C8', '00119', '00084', '00169', '00031', '00022', 'C1', '00152',
 '00079', '00046', '00114', '00067', '00041', '00143', '00054', '00036', '00037',
 '00149', '00028', '00025', '00148', '00096', '00017', 'C2', '00163',
 '00052', '00104', '00029', '00164', '00019', '00147', '00011', '00007',
 '00157', '00030', '00155', '00102', 'C10', '00110', '00154', '00012', '00124',
 '00118', 'C6', '00056', '00002', 'C11', '00070', '00093', '00035', '00111']

In [4]:
# load masks
lobe_masks = [nib.load('/workspace/Features/Features/templates/insula_lobe_mask.nii.gz').get_fdata(),
             nib.load('/workspace/Features/Features/templates/temporal_lobe_mask.nii.gz').get_fdata(),
             nib.load('/workspace/Features/Features/templates/parietal_lobe_mask.nii.gz').get_fdata(),
             nib.load('/workspace/Features/Features/templates/frontal_lobe_mask.nii.gz').get_fdata(),
             nib.load('/workspace/Features/Features/templates/occipital_lobe_mask.nii.gz').get_fdata()]
# probabilities to be in one of the lobes
prob_lobe = [0.029585798816568046,
             0.42011834319526625,
             0.16568047337278108,
             0.3609467455621302,
             0.023668639053254437]


Generation (subject, split, path_to_save, suffix)
1. Randomly select a brain lobe out of 5
2. Select the point at the intersection of the lobe mask and the gray matter mask (exclude the central stripe with a width of 9)
3. Cut out the patch 40x40x40
4. Send a patch to U-Net
5. Generate a threshold for the area and split
6. Take the result of the grid by threshold
7. Check the area, if less than 100 voxels, run it again
8. Remove small excess areas
9. Remove the areas in the exclusive mask (cerebellum, etc.)
10. Cross the mask with a mask of gray + white matter 
11. Insert into the mask according to the size of the brain
12. If the mask lies in 2 hemispheres, remove part of the mask from the hemisphere where the smaller part lies.
13. Save '{path_to_save}/sub-{subject}{suffix}.nii.gz '

In [5]:
def random_crop_mask(sub, size, probs, lobe_masks, cond=False):
    # choose one lobe 
    ind_lobe = choice([0,1,2,3,4], 1, p=probs)[0]
    lobe_mask = lobe_masks[ind_lobe].copy()
    # delete central part
    lobe_mask[94:103] = 0
    
    GM = nib.load(f'/workspace/Features/Features/prep_wf/sub-{sub}/c1sub-{sub}_space-MNI152NLin2009asym_T1w.nii').get_fdata()
    is_T1 = 1 # to check if there is the problem with GM mask
    
    points = np.array(np.where((GM>0.5)&(lobe_mask>50)))
    if points.shape[1] == 0:# if there is the problem with GM mask for T1, we choose mask for T2
        is_T1 = 0
        GM = nib.load(f'/workspace/Features/Features/prep_wf/sub-{sub}/c1sub-{sub}_space-MNI152NLin2009asym_T2w.nii').get_fdata()
        points = np.array(np.where((GM>0.5)&(lobe_mask>50)))
    # choose randomly the point in the lobe to be a center of the patch    
    ind = choice(np.arange(points.shape[1]), 1)[0]
    p1, p2, p3 = points[0, ind], points[1, ind], points[2, ind]
    s1, s2, s3 = size[0]//2, size[1]//2, size[2]//2
    
    mask = np.zeros_like(GM)
    mask[p1-s1:p1+s1, p2-s2:p2+s2, p3-s3:p3+s3] = 1
    return mask, is_T1, ind_lobe

In [6]:
# Function to delete the small areas that not connected with the main one
def postprocessing(input_scan):
    labels_scan = np.zeros_like(input_scan)
    output_scan = input_scan.copy()
    morphed = nd.binary_opening(output_scan!=0, iterations=1)
    # label connected components
    pred_labels, _ = nd.label(input_scan, structure=np.ones((3,3,3)))
    label_list = np.unique(pred_labels)
    num_elements_by_lesion = nd.labeled_comprehension(input_scan, pred_labels, label_list, np.sum, float, 0)
    max_elements_ind = np.array(num_elements_by_lesion).argmax()
    current_voxels = np.stack(np.where(pred_labels == max_elements_ind), axis=1)
    labels_scan[current_voxels[:,0], current_voxels[:,1], current_voxels[:,2]] = 1
    return labels_scan

In [7]:
def generate(subs, split, path_to_save, suffix='_t1_brain-final'):
    thrs1 = [0.05, 0.05, 0.05, 0.05, 0.05]
    thrs2 = [0.55, 0.55, 0.6, 0.75, 0.4]
    probs = [0.6, 0.4, 0.6, 0.7, 0.6]
    new_mask_volume = []
    new_masks = []
    no_mask = []
    ind_lobe = []

    exclusive_mask = nib.load('/workspace/Features/Features/templates/exclusive_mask_MNI1mm_resampled.nii.gz').get_fdata()
    
    model = UNet3D(in_channels=2, out_channels=1).to(device)
    model.load_state_dict(torch.load(f'unet3d_{split}.pth'))
    for sub in tqdm(subs):
        try:
            volume = 0
            while volume < 100:
                patch_mask, is_T1, ind = random_crop_mask(sub, [40,40,40], prob_lobe, lobe_masks)
                ind_lobe.append(ind)
                mri_img = nib.load(f'/workspace/Features/Features/prep_wf/sub-{sub}/sub-{sub}_t1_brain-final.nii.gz')
                mri = mri_img.get_fdata()
                if is_T1:
                    gm = nib.load(f'/workspace/Features/Features/prep_wf/sub-{sub}/c1sub-{sub}_space-MNI152NLin2009asym_T1w.nii').get_fdata()
                    wm = nib.load(f'/workspace/Features/Features/prep_wf/sub-{sub}/c2sub-{sub}_space-MNI152NLin2009asym_T1w.nii').get_fdata()
                else:
                    gm = nib.load(f'/workspace/Features/Features/prep_wf/sub-{sub}/c1sub-{sub}_space-MNI152NLin2009asym_T2w.nii').get_fdata()
                    wm = nib.load(f'/workspace/Features/Features/prep_wf/sub-{sub}/c2sub-{sub}_space-MNI152NLin2009asym_T2w.nii').get_fdata()
                    
                mri_patch = mri[patch_mask > 0].reshape(1, 40, 40, 40)
                gm_patch = gm[patch_mask > 0].reshape(1, 40, 40, 40)
                patch = np.stack([mri_patch, gm_patch], axis=0).reshape(2, 40, 40, 40)
                patch = torch.tensor(patch.astype(np.float32))
                gen_mask = model(patch.unsqueeze(0).to(device)).cpu().detach().numpy()
            
                thr1 = thrs1[ind]
                thr2 = thrs2[ind]
                prob = probs[ind]
                step = 0.05
                prob_ = np.linspace(prob, 1, num=round((thr2-thr1+step*1.9)/step))[1:]
                thresh = choice(np.arange(thr1, thr2+step*0.9, step), 1, p=prob_/prob_.sum())[0]
                gen_mask_th = (gen_mask > thresh).astype(np.float64)[0,0]
                gen_mask_th = postprocessing(gen_mask_th) * (exclusive_mask[patch_mask > 0]<0.5).reshape(40, 40, 40)
                volume = np.sum(gen_mask_th)
            new_mask = np.zeros_like(mri)
            new_mask[np.where(patch_mask > 0)] =  gen_mask_th.flatten()
            if (new_mask[:98].sum() * new_mask[99:].sum()) != 0:
                if new_mask[:98].sum() > new_mask[99:].sum():
                    new_mask[99:] = 0
                else:
                    new_mask[:98] = 0
            new_mask_volume.append(new_mask.sum())
            new_masks.append(new_mask)

            nib_mask = nib.Nifti1Image(new_mask*((gm>0.5)|(wm>0.5)), affine=mri_img.affine)
            maskname = f'sub-{sub}{suffix}.nii.gz'
            nib.save(nib_mask, f"/workspace/Features/Features/generated_mri/i2sb/3dunet_masks_new/split{split}/masks/{maskname}")
        except:
            no_mask.append(sub)
            continue
    return new_mask_volume, new_masks, no_mask, ind_lobe

In [8]:
for split in range(8):
    new_mask_volume, new_masks, no_mask, ind_lobe = generate(healthy_sub, split, f'/workspace/Features/Features/generated_mri/i2sb/3dunet_masks_new/split{split}/masks')
    print(f'Problems for {len(no_mask)} subs')

100%|██████████| 77/77 [01:18<00:00,  1.01s/it]


Problems for 0 subs


100%|██████████| 77/77 [01:23<00:00,  1.09s/it]


Problems for 0 subs


100%|██████████| 77/77 [01:14<00:00,  1.03it/s]


Problems for 0 subs


100%|██████████| 77/77 [01:11<00:00,  1.07it/s]


Problems for 0 subs


100%|██████████| 77/77 [01:11<00:00,  1.08it/s]


Problems for 0 subs


100%|██████████| 77/77 [01:12<00:00,  1.06it/s]


Problems for 0 subs


100%|██████████| 77/77 [01:11<00:00,  1.08it/s]


Problems for 0 subs


100%|██████████| 77/77 [01:11<00:00,  1.08it/s]

Problems for 0 subs



