In [1]:
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from tqdm.notebook import tqdm
import sys
import os
import json
import torchvision


os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

def gpu_usage():
    print('gpu usage (current/max): {:.2f} / {:.2f} GB'.format(torch.cuda.memory_allocated()*1e-9, torch.cuda.max_memory_allocated()*1e-9))

In [2]:
from utils_nlst import AdamRegMIND,thin_plate_dense,MINDSSC





In [3]:
#load data
#from torch.utils.data import Dataset, DataLoader

data_path='./'#../../Learn2Reg_Dataset_release_v1.0/NLST/'

#write dataloader

import os
import json
import nibabel as nib
from utils_voxelmorph_plusplus import *


class NLST(torch.utils.data.Dataset):
    def __init__(self, root_dir, masked=True, downsampled=False, half=False, mind=False):
        """
        NLST_Dataset
        Provides FIXED_IMG, MOVING_IMG, FIXED_KEYPOINTS, MOVING_KEYPOINTS
        """
        self.root_dir = root_dir
        self.image_dir = os.path.join(root_dir,'imagesTr')
        self.keypoint_dir = os.path.join(root_dir,'keypointsTr')
        self.masked = masked
        with open(os.path.join(data_path,'NLST_dataset.json')) as f:
            self.dataset_json = json.load(f)
        self.shape = self.dataset_json['tensorImageShape']['0']
        self.H, self.W, self.D = self.shape
        self.downsampled = downsampled
        self.half = half
        self.mind = mind
        
    def __len__(self):
        return self.dataset_json['numPairedTraining']

    def get_shape(self):
        if self.downsampled:
            return [x//2 for x in self.shape]
        else:
            return self.shape
    
    def __getitem__(self, idx):
        fix_path=os.path.join(self.root_dir,self.dataset_json['training_paired_images'][idx]['fixed'])
        mov_path=os.path.join(self.root_dir,self.dataset_json['training_paired_images'][idx]['moving'])
    
        fixed_img=torch.from_numpy(nib.load(fix_path).get_fdata())
        moving_img=torch.from_numpy(nib.load(mov_path).get_fdata())
        
        fixed_kp=torch.from_numpy(np.genfromtxt(fix_path.replace('images','keypoints').replace('nii.gz','csv'),delimiter=','))
        moving_kp=torch.from_numpy(np.genfromtxt(mov_path.replace('images','keypoints').replace('nii.gz','csv'),delimiter=','))
        fixed_kp=(fixed_kp.flip(-1)/torch.tensor(self.shape))*2-1
        moving_kp=(moving_kp.flip(-1)/torch.tensor(self.shape))*2-1

        moving_mind = None
        fixed_mind = None
        
        if(self.mind):
            mask_fix = torch.from_numpy(nib.load(fix_path.replace('images', 'masks')).get_fdata()).float()
            mask_mov = torch.from_numpy(nib.load(mov_path.replace('images', 'masks')).get_fdata()).float()
            mind_fix_ = mask_fix*MINDSSC(fixed_img.cuda().float().unsqueeze(0).unsqueeze(0),1,2).cpu().squeeze()
            mind_mov_ = mask_mov*MINDSSC(moving_img.cuda().float().unsqueeze(0).unsqueeze(0),1,2).cpu().squeeze()
            fixed_mind = F.avg_pool3d(mind_fix_,2,stride=2)
            moving_mind = F.avg_pool3d(mind_mov_,2,stride=2)


            
        if self.masked and not self.downsampled:
            fixed_img=torch.from_numpy(nib.load(fix_path.replace('images', 'masks')).get_fdata())*fixed_img
            moving_img=torch.from_numpy(nib.load(mov_path.replace('images', 'masks')).get_fdata())*moving_img
        
        if self.downsampled:
            fixed_img=F.interpolate(fixed_img.view(1,1,self.H,self.W,self.D),size=(self.H//2,self.W//2,self.D//2),mode='trilinear').squeeze()
            moving_img=F.interpolate(moving_img.view(1,1,self.H,self.W,self.D), size=(self.H//2,self.W//2,self.D//2), mode='trilinear').squeeze()
            if self.masked:
                fixed_img*=F.interpolate(torch.from_numpy(nib.load(fix_path.replace('images', 'masks')).get_fdata()).view(1,1,self.H,self.W,self.D),size=(self.H//2,self.W//2,self.D//2),mode='nearest').squeeze()
                moving_img*=F.interpolate(torch.from_numpy(nib.load(mov_path.replace('images', 'masks')).get_fdata()).view(1,1,self.H,self.W,self.D),size=(self.H//2,self.W//2,self.D//2),mode='nearest').squeeze()

        return fixed_img, moving_img, fixed_kp, moving_kp, fixed_mind, moving_mind

NLST_dataset=NLST(data_path, downsampled=True, masked=True, half=False, mind=True)


In [4]:
##load validation data

H,W,D=NLST_dataset.get_shape()

img_fix=torch.zeros((10,1,H,W,D)).float()
img_mov=torch.zeros((10,1,H,W,D)).float()
mind_fix=torch.zeros((10,12,H,W,D)).float()
mind_mov=torch.zeros((10,12,H,W,D)).float()

kpts_fix=[]
kpts_mov=[]

for idx,value in enumerate(range(90,100)):
    img_fix[idx,0,...],img_mov[idx,0,...],tmp_kpts_fix, tmp_kpts_mov,mind_fix[idx,...],mind_mov[idx,...]  = NLST_dataset[value]
    kpts_fix.append(tmp_kpts_fix)
    kpts_mov.append(tmp_kpts_mov)
print('done')





done


In [5]:
print(1)

1


In [6]:

def get_layer(model, name):
    layer = model
    for attr in name.split("."):
        layer = getattr(layer, attr)
    return layer


def set_layer(model, name, layer):
    try:
        attrs, name = name.rsplit(".", 1)
        model = get_layer(model, attrs)
    except ValueError:
        pass
    setattr(model, name, layer)
import torchvision
resnet = torchvision.models.resnet18(pretrained=False)

###
resnet = torchvision.models.resnet18(pretrained=False)
resnet.layer4 = nn.Identity()
resnet.avgpool = nn.Identity()#nn.PixelShuffle(2)
resnet.maxpool = nn.MaxPool3d(2)

resnet.fc = nn.Sequential(nn.Unflatten(1,(8*32//2,28,24,28)),nn.Upsample(scale_factor=2,mode='trilinear'))
#,nn.Upsample(scale_factor=2,mode='trilinear'),nn.Conv3d(32,3,3,padding=1))
#print(resnet.conv1)
resnet.conv1 = nn.Conv2d(2,64,5,stride=1,padding=2)
resnet.layer2[0].conv1.stride = (1,1)
resnet.layer2[0].downsample[0].stride=1

count = 0; count2 = 0
for name, module in resnet.named_modules():
    if isinstance(module, nn.Conv2d):
        before = get_layer(resnet, name)
        after = nn.Conv3d(before.in_channels//2,before.out_channels//2,int(torch.tensor(before.kernel_size)[0]),stride=int(torch.tensor(before.stride).view(-1)[0]),padding=before.padding[0])
        set_layer(resnet, name, after); count += 1
    if isinstance(module, nn.BatchNorm2d):
        before = get_layer(resnet, name)
        after = nn.BatchNorm3d(before.num_features//2)
        set_layer(resnet, name, after); count2 += 1
print(count,'# Conv2d > Conv3d','and',count2,'#BatchNorms')
resnet.cuda()
print()


class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = nn.Sequential(nn.Conv3d(in_channels,out_channels,3,padding=1,bias=False),\
                                   nn.InstanceNorm3d(out_channels),nn.ReLU(inplace=True))
        self.conv2 = nn.Sequential(nn.Conv3d(out_channels,out_channels,1,bias=False),\
                                   nn.InstanceNorm3d(out_channels),nn.ReLU(inplace=True))

    def forward(self, x):
        x = self.conv1(x)
        return self.conv2(x)
    
class UNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.ModuleDict({'enc1':ConvBlock(256,32),'enc2':ConvBlock(32,48),'enc3':ConvBlock(48,48),\
                                      'enc4':ConvBlock(48,64)})
        self.decoder = nn.ModuleDict({'dec1':ConvBlock(64+48,48),\
                                      'dec2':ConvBlock(48+48,48),'dec3':ConvBlock(48+32,32)})
        self.conv1 = ConvBlock(32,64)
        self.conv2 = nn.Sequential(nn.Conv3d(64,32,1,bias=False),nn.InstanceNorm3d(32),nn.ReLU(inplace=True),\
                                 nn.Conv3d(32,32,1,bias=False),nn.InstanceNorm3d(32),nn.ReLU(inplace=True),\
                                 nn.Conv3d(32,3,1))
    def forward(self, x):
        y = []
        upsample = nn.Upsample(scale_factor=2,mode='trilinear')
        for i in range(4):
            x = self.encoder['enc'+str(i+1)](x)
            if(i<3):
                y.append(x)
                x = F.max_pool3d(x,2) 
        for i in range(3):
            #if(i<3):
            x = torch.cat((upsample(x),y.pop()),1)
            x = self.decoder['dec'+str(i+1)](x)
        x = self.conv1(x)
        return upsample(self.conv2(x))

unet = UNet()
unet.cuda()
print()


15 # Conv2d > Conv3d and 15 #BatchNorms




In [7]:
resnet.load_state_dict(torch.load('NLST_Example_resnet_trained.pth').state_dict())
unet.load_state_dict(torch.load('NLST_Example_unet_trained.pth').state_dict())

<All keys matched successfully>

In [8]:

t_inf = 0

tre_net = torch.zeros(3,10)
idx_test = range(91,101)

out_path='outputs/NLST'



resnet.eval()
unet.eval()
#with torch.inference_mode():
with torch.cuda.amp.autocast():
    for idx,val in enumerate(idx_test):


        keypts_fix = kpts_fix[idx].cuda().float()
        keypts_mov = kpts_mov[idx].cuda().float()
        disp_gt = keypts_mov-keypts_fix

        torch.cuda.synchronize()
        t0 = time.time()
        tre_net[0,int(idx)]= disp_gt.mul(torch.tensor([H,W,D]).view(1,3).cuda()).pow(2).sum(1).sqrt().mean().item()*1.5
        input = torch.cat((resnet(img_fix[idx:idx+1].cuda().half()),resnet(img_mov[idx:idx+1].cuda().half())),1).cuda()
        output = unet(input)
        pred_xyz = F.grid_sample(output,keypts_fix.cuda().half().view(1,-1,1,1,3),mode='bilinear').squeeze().t()

        dense_flow = thin_plate_dense(keypts_fix.unsqueeze(0), pred_xyz.unsqueeze(0), (H*2, W*2, D*2), 4, 0.1)

        disp_hr = AdamRegMIND(mind_fix[idx:idx+1].cuda().half(),mind_mov[idx:idx+1].cuda().half(),dense_flow)


        pred_xyz2 = F.grid_sample(disp_hr,keypts_fix.cuda().view(1,-1,1,1,3),mode='bilinear').squeeze().t()



        disp_field= disp_hr#F.interpolate(output,scale_factor=2,mode='trilinear')
        disp_field=((disp_field.permute(0,2,3,4,1) /2)*(torch.tensor([224,192,224]).cuda()-1)).flip(-1).float().squeeze().cpu()

        nib.save(nib.Nifti1Image(disp_field.numpy(), np.eye(4)), os.path.join(out_path, f'disp_{str(val).zfill(4)}_{str(val).zfill(4)}.nii.gz'))

        t_inf += time.time()-t0
        tre1 = (disp_gt-pred_xyz).cpu().mul(torch.tensor([H,W,D]).view(1,3)).pow(2).sum(1).sqrt()*1.5
        tre_net[1,int(idx)] = tre1.mean().item()

        tre2 = (disp_gt-pred_xyz2).cpu().mul(torch.tensor([H,W,D]).view(1,3)).pow(2).sum(1).sqrt()*1.5
        tre_net[2,int(idx)] = tre2.mean().item()



        print(tre_net[0,idx],'-->',tre1.mean().item(),'-->',tre2.mean().item())
print(tre_net)



tensor(5.5810) --> 2.4929215908050537 --> 0.6259491443634033
tensor(5.5583) --> 4.479621887207031 --> 1.5308703184127808
tensor(5.6216) --> 3.975072145462036 --> 1.4929213523864746
tensor(6.2135) --> 3.493563652038574 --> 0.6372571587562561
tensor(6.2629) --> 3.6341845989227295 --> 0.8057507276535034
tensor(10.7615) --> 5.707571983337402 --> 1.5308440923690796
tensor(5.8886) --> 5.357746124267578 --> 1.9684263467788696
tensor(5.9513) --> 4.727171421051025 --> 0.7388262152671814
tensor(11.9153) --> 7.6124982833862305 --> 4.951704025268555
tensor(5.9239) --> 5.435888767242432 --> 1.262049913406372
tensor([[ 5.5810,  5.5583,  5.6216,  6.2135,  6.2629, 10.7615,  5.8886,  5.9513,
         11.9153,  5.9239],
        [ 2.4929,  4.4796,  3.9751,  3.4936,  3.6342,  5.7076,  5.3577,  4.7272,
          7.6125,  5.4359],
        [ 0.6259,  1.5309,  1.4929,  0.6373,  0.8058,  1.5308,  1.9684,  0.7388,
          4.9517,  1.2620]])


In [None]:
import os
displacement_fields= './outputs/NLST'
data='./'
output_dir=displacement_fields
output_suffix='_CTCTexamp.json'
for task in ['NLST']:
    
    print('Staring', task)
    _i=os.path.join(displacement_fields)
    _d='./'#os.path.join(data,task)
    _o=os.path.join(output_dir,task+output_suffix)
    _c=os.path.join('.',task+"_90_99.json")
    !python evaluation.py -i {_i} -d {_d} -o{_o} -c{_c} -v
    print(2*'\n')