In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import os
import json
import nibabel as nib
import time

import sys
from tqdm.notebook import tqdm

import os

out_path='outputs/'

In [2]:
##preload data

data_path='../../../Learn2Reg_Dataset_release_v1.0/AbdomenCTCT/'
with open(os.path.join(data_path,'AbdomenCTCT_dataset.json')) as f:
    dataset_info=json.load(f)

val_list=sorted(list(set([x['fixed'] for x in dataset_info['registration_val']] 
              + [x['moving'] for x in dataset_info['registration_val']])))
validation_ = dataset_info['registration_val']
training_ = [x for x in dataset_info['training'] if x['image'] not in val_list]
H,W,D = dataset_info['tensorImageShape']['0']
num_val=len(val_list); num_train=len(training_)
print('Training:',len(training_),'; Validation',len(val_list))


##validation
seg_val = torch.zeros(num_val,H,W,D).long().pin_memory()
img_val = torch.zeros(num_val,1,H//2,W//2,D//2).pin_memory()
t0 = time.time()
for ii,i in enumerate(val_list):
    seg_val[ii] = torch.from_numpy(nib.load(os.path.join(data_path,i.replace('image','label'))).get_fdata()).long()
    img_val[ii] = F.avg_pool3d(torch.from_numpy(nib.load(os.path.join(data_path,i)).get_fdata()).float().cuda().unsqueeze(0).unsqueeze(0)/500,2).cpu()
t1 = time.time()
print('validaion data loaded in %.2f s' % (t1-t0))

Training: 20 ; Validation 10
validaion data loaded in 3.43 s


In [3]:
### functions 
def jacobian_determinant_3d(dense_flow):
    B,_,H,W,D = dense_flow.size()
    
    dense_pix = dense_flow*(torch.Tensor([H-1,W-1,D-1])/2).view(1,3,1,1,1).to(dense_flow.device)
    gradz = nn.Conv3d(3,3,(3,1,1),padding=(1,0,0),bias=False,groups=3)
    gradz.weight.data[:,0,:,0,0] = torch.tensor([-0.5,0,0.5]).view(1,3).repeat(3,1)
    gradz.to(dense_flow.device)
    grady = nn.Conv3d(3,3,(1,3,1),padding=(0,1,0),bias=False,groups=3)
    grady.weight.data[:,0,0,:,0] = torch.tensor([-0.5,0,0.5]).view(1,3).repeat(3,1)
    grady.to(dense_flow.device)
    gradx = nn.Conv3d(3,3,(1,1,3),padding=(0,0,1),bias=False,groups=3)
    gradx.weight.data[:,0,0,0,:] = torch.tensor([-0.5,0,0.5]).view(1,3).repeat(3,1)
    gradx.to(dense_flow.device)
    #with torch.no_grad():
    jacobian = torch.cat((gradz(dense_pix),grady(dense_pix),gradx(dense_pix)),0)+torch.eye(3,3).view(3,3,1,1,1).to(dense_flow.device)
    jacobian = jacobian[:,:,2:-2,2:-2,2:-2]
    jac_det = jacobian[0,0,:,:,:]*(jacobian[1,1,:,:,:]*jacobian[2,2,:,:,:]-jacobian[1,2,:,:,:]*jacobian[2,1,:,:,:])-\
    jacobian[1,0,:,:,:]*(jacobian[0,1,:,:,:]*jacobian[2,2,:,:,:]-jacobian[0,2,:,:,:]*jacobian[2,1,:,:,:])+\
    jacobian[2,0,:,:,:]*(jacobian[0,1,:,:,:]*jacobian[1,2,:,:,:]-jacobian[0,2,:,:,:]*jacobian[1,1,:,:,:])

    return jac_det

def dice_coeff(outputs, labels, max_label):
    dice = torch.FloatTensor(max_label-1).fill_(0)
    for label_num in range(1, max_label):
        iflat = (outputs==label_num).view(-1).float()
        tflat = (labels==label_num).view(-1).float()
        intersection = torch.mean(iflat * tflat)
        dice[label_num-1] = (2. * intersection) / (1e-8 + torch.mean(iflat) + torch.mean(tflat))
    return dice

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = nn.Sequential(nn.Conv3d(in_channels,out_channels,3,padding=1,bias=False),\
                                   nn.BatchNorm3d(out_channels),nn.PReLU())
        self.conv2 = nn.Sequential(nn.Conv3d(out_channels,out_channels,1,bias=False),\
                                   nn.BatchNorm3d(out_channels),nn.PReLU())

    def forward(self, x):
        x = self.conv1(x)
        return self.conv2(x)
    
base_ch = 16
class UNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.ModuleDict({'enc1':ConvBlock(64,base_ch*2),'enc2':ConvBlock(base_ch*2,base_ch*3),\
                                      'enc3':ConvBlock(base_ch*3,base_ch*3),'enc4':ConvBlock(base_ch*3,base_ch*4)})
        self.decoder = nn.ModuleDict({'dec1':ConvBlock(base_ch*7,base_ch*3),\
                                      'dec2':ConvBlock(base_ch*6,base_ch*3),'dec3':ConvBlock(base_ch*5,base_ch*2)})
        self.conv1 = ConvBlock(base_ch*2,base_ch*4)
        self.conv2 = nn.Sequential(nn.Conv3d(base_ch*4,base_ch*2,1,bias=False),nn.BatchNorm3d(base_ch*2),nn.PReLU(),\
                                 nn.Conv3d(base_ch*2,base_ch*2,1,bias=False),nn.BatchNorm3d(base_ch*2),nn.PReLU(),\
                                 nn.Conv3d(base_ch*2,3,1))
    def forward(self, x):
        y = []
        upsample = nn.Upsample(scale_factor=2,mode='trilinear')
        for i in range(4):
            x = self.encoder['enc'+str(i+1)](x)
            if(i<3):
                y.append(x)
                x = F.max_pool3d(x,2) 
        for i in range(3):
            x = torch.cat((upsample(x),y.pop()),1)
            x = self.decoder['dec'+str(i+1)](x)
        x = self.conv1(x)
        return F.avg_pool3d(F.avg_pool3d(upsample(self.conv2(x)),5,stride=1,padding=2),5,stride=1,padding=2)
    
class UNet2(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.ModuleDict({'enc1':ConvBlock(1,base_ch*2),'enc2':ConvBlock(base_ch*2,base_ch*3),\
                                      'enc3':ConvBlock(base_ch*3,base_ch*3),'enc4':ConvBlock(base_ch*3,base_ch*4),\
                                      'enc5':ConvBlock(base_ch*4,base_ch*4)})
        self.decoder = nn.ModuleDict({'dec1':ConvBlock(base_ch*8,base_ch*3),'dec2':ConvBlock(base_ch*6,base_ch*3),\
                                      'dec3':ConvBlock(base_ch*6,base_ch*3),'dec4':ConvBlock(base_ch*3,base_ch*2)})
        self.conv1 = ConvBlock(base_ch*2,base_ch*4)
        self.conv2 = nn.Sequential(nn.Conv3d(base_ch*4,base_ch*2,1,bias=False),nn.BatchNorm3d(base_ch*2),nn.PReLU(),\
                                 nn.Conv3d(base_ch*2,base_ch*2,1,bias=False),nn.BatchNorm3d(base_ch*2),nn.PReLU(),\
                                 nn.Conv3d(base_ch*2,32,1))
        self.final = nn.Identity()
    def forward(self, x):
        y = []
        upsample = nn.Upsample(scale_factor=2,mode='trilinear')
        for i in range(5):
            x = self.encoder['enc'+str(i+1)](x)
            if(i<4):
                y.append(x)
                x = F.max_pool3d(x,2) 
        for i in range(4):
            if(i<3):
                x = torch.cat((upsample(x),y.pop()),1)
            x = self.decoder['dec'+str(i+1)](x)
        x = self.conv1(x)
        x = self.conv2(x)
        return self.final(x)

resnet = UNet2()

resnet.cuda()
print()




In [4]:
models = torch.load('AbdomenCTCT_example_complete.pth')
unet=UNet()
unet.load_state_dict(models['unet'])
resnet.load_state_dict(models['resnet'])

<All keys matched successfully>

In [5]:
unet.cuda()
resnet.cuda()
unet.eval()
resnet.eval()
dice45 = torch.zeros(4,45)
t_all = 0
count = 0

val_list_nr=[int(x[-16:-12]) for x in val_list]

for i,val1 in enumerate(val_list_nr):
    for j, val2 in enumerate(val_list_nr):
        if(i>=j):
            continue
        
        torch.cuda.synchronize()
        t0 = time.time()

        
        with torch.no_grad():
            with torch.cuda.amp.autocast():
                feat = resnet(torch.cat((img_val[i:i+1].cuda(),img_val[j:j+1].cuda()),0))
                input = torch.cat((feat[:1],feat[1:2]),1).cuda()

                output = F.interpolate(unet(input),scale_factor=2,mode='trilinear')



        torch.cuda.synchronize()
        t1 = time.time()   
        t_all += (t1-t0)
        with torch.no_grad():
            seg_warped = F.grid_sample(seg_val[j:j+1].unsqueeze(1).float().contiguous().cuda(),output.permute(0,2,3,4,1)+F.affine_grid(torch.eye(3,4).unsqueeze(0).cuda(),(1,1,H,W,D)),mode='nearest')
            jac_det = jacobian_determinant_3d(output[:1])
        d0 = dice_coeff(seg_val[i].cuda().contiguous(),seg_val[j].contiguous().cuda().long(),14).cpu()

        nib_disp_field=((output.permute(0,2,3,4,1) /2)*(torch.tensor([D,W,H]).cuda()-1)).flip(-1).float().squeeze().cpu()
        nib.save(nib.Nifti1Image(nib_disp_field.numpy(), np.eye(4)), os.path.join(out_path, f'disp_{str(val1).zfill(4)}_{str(val2).zfill(4)}.nii.gz'))
        print(f'Saved disp_{str(val1).zfill(4)}_{str(val2).zfill(4)}.nii.gz')

        d1 = dice_coeff(seg_val[i].cuda().contiguous(),seg_warped.contiguous().squeeze().long(),14).cpu()
        #print(d1.mean(),d1)
        dice45[0,count] = d0.mean()
        dice45[1,count] = d1.mean()
        dice45[2,count] = jac_det.std()
        dice45[3,count] = (jac_det<0).float().mean()



        count += 1
    #print(d0.mean(),'after',d1.mean(),jac_det.std(),(jac_det<0).float().mean())

print('%0.4f'%((t_all)/45),'sec/im','Dice before (%)','%0.2f'%(dice45.mean(1)[0].item()*100),\
      'Dice after (%)','%0.2f'%(dice45.mean(1)[1].item()*100),'std(Jac)','%0.4f'%(dice45.mean(1)[2].item()),\
     'neg(Jac)','%0.6f'%(dice45.mean(1)[3].item()))



Saved disp_0001_0004.nii.gz
Saved disp_0001_0007.nii.gz
Saved disp_0001_0010.nii.gz
Saved disp_0001_0013.nii.gz
Saved disp_0001_0016.nii.gz
Saved disp_0001_0019.nii.gz
Saved disp_0001_0022.nii.gz
Saved disp_0001_0025.nii.gz
Saved disp_0001_0028.nii.gz
Saved disp_0004_0007.nii.gz
Saved disp_0004_0010.nii.gz
Saved disp_0004_0013.nii.gz
Saved disp_0004_0016.nii.gz
Saved disp_0004_0019.nii.gz
Saved disp_0004_0022.nii.gz
Saved disp_0004_0025.nii.gz
Saved disp_0004_0028.nii.gz
Saved disp_0007_0010.nii.gz
Saved disp_0007_0013.nii.gz
Saved disp_0007_0016.nii.gz
Saved disp_0007_0019.nii.gz
Saved disp_0007_0022.nii.gz
Saved disp_0007_0025.nii.gz
Saved disp_0007_0028.nii.gz
Saved disp_0010_0013.nii.gz
Saved disp_0010_0016.nii.gz
Saved disp_0010_0019.nii.gz
Saved disp_0010_0022.nii.gz
Saved disp_0010_0025.nii.gz
Saved disp_0010_0028.nii.gz
Saved disp_0013_0016.nii.gz
Saved disp_0013_0019.nii.gz
Saved disp_0013_0022.nii.gz
Saved disp_0013_0025.nii.gz
Saved disp_0013_0028.nii.gz
Saved disp_0016_0019

In [7]:
##Check Evaluation

config_files_dir='../../evaluation/evaluation_configs'
output_suffix='_example.json'
for task in ['AbdomenCTCT']:
    print('Staring', task)
    _i=os.path.join(out_path)
    _d=os.path.join(data_path)
    _o=os.path.join('.',task+output_suffix)
    _c=os.path.join(config_files_dir,task+"_VAL_evaluation_config.json")
    !python /share/data_zoe3/grossbroehmer/Learn2Reg2022/L2R/evaluation/evaluation.py -i {_i} -d {_d} -o{_o} -c{_c} -v
    print(2*'\n')

Staring AbdomenCTCT
Evaluate 45 cases for: ['LogJacDetStd', 'DSC', 'HD95']
case_results [0]: {'LogJacDetStd': 0.19564424713621664, 'DSC': 0.42361463170105124, 'HD95': 16.336788892209622}
case_results [1]: {'LogJacDetStd': 0.40910771799212176, 'DSC': 0.41988416299836623, 'HD95': 13.8310298266488}
case_results [2]: {'LogJacDetStd': 0.43348170926996576, 'DSC': 0.49952440400488946, 'HD95': 10.104615792306994}
case_results [3]: {'LogJacDetStd': 0.1389824331355023, 'DSC': 0.5017108769899589, 'HD95': 10.710867758990679}
case_results [4]: {'LogJacDetStd': 0.19372547490566755, 'DSC': 0.5161475037448059, 'HD95': 12.99845110980096}
case_results [5]: {'LogJacDetStd': 0.31455199202966005, 'DSC': 0.46862124618244805, 'HD95': 13.294307036197644}
case_results [6]: {'LogJacDetStd': 0.14431763526804345, 'DSC': 0.5135910774243343, 'HD95': 8.945338469473548}
case_results [7]: {'LogJacDetStd': 0.1280621911388051, 'DSC': 0.551504405650682, 'HD95': 8.282858758622256}
case_results [8]: {'LogJacDetStd': 0.2121