In [1]:
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from tqdm.notebook import tqdm
import sys
import os
import json
import torchvision


In [2]:
#load data
import os
import json
import nibabel as nib
data_path='../../Learn2Reg_Dataset_release_v1.1/NLST/'
with open(os.path.join(data_path,'NLST_dataset.json')) as f:
    dataset_json = json.load(f)
list_val=dataset_json['registration_val']
ori_shape=dataset_json['tensorImageShape']['0']

In [3]:
##load validation data

H=ori_shape[0]//2
W=ori_shape[1]//2
D=ori_shape[2]//2

img_fix=torch.zeros((len(list_val),1,H,W,D)).float()
img_mov=torch.zeros((len(list_val),1,H,W,D)).float()
mask_fix=torch.zeros((len(list_val),1,H,W,D)).int()
mask_mov=torch.zeros((len(list_val),1,H,W,D)).int()


for idx,value in enumerate(list_val):
    img_fix[idx,0,...] = F.interpolate(torch.from_numpy(nib.load(os.path.join(data_path,value['fixed'])).get_fdata()).float().unsqueeze(0).unsqueeze(0),scale_factor=.5,mode='trilinear').squeeze()
    img_mov[idx,0,...] = F.interpolate(torch.from_numpy(nib.load(os.path.join(data_path,value['moving'])).get_fdata()).float().unsqueeze(0).unsqueeze(0),scale_factor=.5,mode='trilinear').squeeze()
    mask_fix[idx,0,...] = F.interpolate(torch.from_numpy(nib.load(os.path.join(data_path,value['fixed'].replace('image','mask'))).get_fdata()).float().unsqueeze(0).unsqueeze(0),scale_factor=.5,mode='nearest').squeeze()
    mask_mov[idx,0,...] = F.interpolate(torch.from_numpy(nib.load(os.path.join(data_path,value['moving'].replace('image','mask'))).get_fdata()).float().unsqueeze(0).unsqueeze(0),scale_factor=.5,mode='nearest').squeeze()
print('done')
img_fix*=mask_fix
img_mov*=mask_mov


done


In [4]:

def get_layer(model, name):
    layer = model
    for attr in name.split("."):
        layer = getattr(layer, attr)
    return layer


def set_layer(model, name, layer):
    try:
        attrs, name = name.rsplit(".", 1)
        model = get_layer(model, attrs)
    except ValueError:
        pass
    setattr(model, name, layer)
import torchvision
resnet = torchvision.models.resnet18(pretrained=False)

###
resnet = torchvision.models.resnet18(pretrained=False)
resnet.layer4 = nn.Identity()
resnet.avgpool = nn.Identity()#nn.PixelShuffle(2)
resnet.maxpool = nn.MaxPool3d(2)

resnet.fc = nn.Sequential(nn.Unflatten(1,(8*32//2,28,24,28)),nn.Upsample(scale_factor=2,mode='trilinear'))
#,nn.Upsample(scale_factor=2,mode='trilinear'),nn.Conv3d(32,3,3,padding=1))
#print(resnet.conv1)
resnet.conv1 = nn.Conv2d(2,64,5,stride=1,padding=2)
resnet.layer2[0].conv1.stride = (1,1)
resnet.layer2[0].downsample[0].stride=1

count = 0; count2 = 0
for name, module in resnet.named_modules():
    if isinstance(module, nn.Conv2d):
        before = get_layer(resnet, name)
        after = nn.Conv3d(before.in_channels//2,before.out_channels//2,int(torch.tensor(before.kernel_size)[0]),stride=int(torch.tensor(before.stride).view(-1)[0]),padding=before.padding[0])
        set_layer(resnet, name, after); count += 1
    if isinstance(module, nn.BatchNorm2d):
        before = get_layer(resnet, name)
        after = nn.BatchNorm3d(before.num_features//2)
        set_layer(resnet, name, after); count2 += 1
print(count,'# Conv2d > Conv3d','and',count2,'#BatchNorms')
resnet.cuda()
print()


class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = nn.Sequential(nn.Conv3d(in_channels,out_channels,3,padding=1,bias=False),\
                                   nn.InstanceNorm3d(out_channels),nn.ReLU(inplace=True))
        self.conv2 = nn.Sequential(nn.Conv3d(out_channels,out_channels,1,bias=False),\
                                   nn.InstanceNorm3d(out_channels),nn.ReLU(inplace=True))

    def forward(self, x):
        x = self.conv1(x)
        return self.conv2(x)
    
class UNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.ModuleDict({'enc1':ConvBlock(256,32),'enc2':ConvBlock(32,48),'enc3':ConvBlock(48,48),\
                                      'enc4':ConvBlock(48,64)})
        self.decoder = nn.ModuleDict({'dec1':ConvBlock(64+48,48),\
                                      'dec2':ConvBlock(48+48,48),'dec3':ConvBlock(48+32,32)})
        self.conv1 = ConvBlock(32,64)
        self.conv2 = nn.Sequential(nn.Conv3d(64,32,1,bias=False),nn.InstanceNorm3d(32),nn.ReLU(inplace=True),\
                                 nn.Conv3d(32,32,1,bias=False),nn.InstanceNorm3d(32),nn.ReLU(inplace=True),\
                                 nn.Conv3d(32,3,1))
    def forward(self, x):
        y = []
        upsample = nn.Upsample(scale_factor=2,mode='trilinear')
        for i in range(4):
            x = self.encoder['enc'+str(i+1)](x)
            if(i<3):
                y.append(x)
                x = F.max_pool3d(x,2) 
        for i in range(3):
            #if(i<3):
            x = torch.cat((upsample(x),y.pop()),1)
            x = self.decoder['dec'+str(i+1)](x)
        x = self.conv1(x)
        return upsample(self.conv2(x))

unet = UNet()
unet.cuda()
print()

15 # Conv2d > Conv3d and 15 #BatchNorms




In [8]:
resnet.load_state_dict(torch.load('NLST_Example_resnet_trained_github.pth').state_dict())
unet.load_state_dict(torch.load('NLST_Example_unet_trained_github.pth').state_dict())

<All keys matched successfully>

In [9]:

t_inf = 0

idx_test = range(101,111)

out_path='outputs/'

resnet.eval()
unet.eval()
with torch.inference_mode():
    with torch.cuda.amp.autocast():
        for idx,val in enumerate(idx_test):
            torch.cuda.synchronize()
            t0 = time.time()
            input = torch.cat((resnet(img_fix[idx:idx+1].cuda().half()),resnet(img_mov[idx:idx+1].cuda().half())),1).cuda()
            output = unet(input)
            
            disp_field= F.interpolate(output,scale_factor=2,mode='trilinear')
            disp_field=((disp_field.permute(0,2,3,4,1))*(torch.tensor([D,W,H]).cuda()-1)).flip(-1).float().squeeze().cpu()
            nib.save(nib.Nifti1Image(disp_field.numpy(), np.eye(4)), os.path.join(out_path, f'disp_{str(val).zfill(4)}_{str(val).zfill(4)}.nii.gz'))
            
            t_inf += time.time()-t0
print(t_inf,'sec')

41.21005630493164


In [10]:
displacement_fields= 'outputs/'
data='../../Learn2Reg_Dataset_v11/' #Secret Validation Data for now
output_dir=displacement_fields
output_suffix='_NLSTexamp.json'
for task in ['NLST']:
    print('Staring', task)
    _i=os.path.join(displacement_fields)
    _d=os.path.join(data,task)
    _o=os.path.join(output_dir,task+output_suffix)
    _c=os.path.join('../../L2R/evaluation/evaluation_configs/NLST_VAL_evaluation_config.json')
    !python ../../L2R/evaluation/evaluation.py -i {_i} -d {_d} -o{_o} -c{_c} -v
    print(2*'\n')

Staring NLST
Evaluate 10 cases for: ['LogJacDetStd', 'TRE_kp']
Will use masks for evaluation.
case_results [0]: {'LogJacDetStd': 0.05709194370830525, 'TRE_kp': 3.5281835027863107}
case_results [1]: {'LogJacDetStd': 0.04133854798122079, 'TRE_kp': 4.762279556245802}
case_results [2]: {'LogJacDetStd': 0.042281379128080124, 'TRE_kp': 5.8877467888313255}
case_results [3]: {'LogJacDetStd': 0.03195278304994823, 'TRE_kp': 7.240518992697461}
case_results [4]: {'LogJacDetStd': 0.04000418510951741, 'TRE_kp': 6.029647681055934}
case_results [5]: {'LogJacDetStd': 0.04428664051736225, 'TRE_kp': 9.883984352557025}
case_results [6]: {'LogJacDetStd': 0.03627716745890451, 'TRE_kp': 4.438268441937357}
case_results [7]: {'LogJacDetStd': 0.07834462087820306, 'TRE_kp': 4.054353200010246}
case_results [8]: {'LogJacDetStd': 0.04351126247084735, 'TRE_kp': 7.671051105658131}
case_results [9]: {'LogJacDetStd': 0.03881089500435125, 'TRE_kp': 4.035733856128705}
{
    "LogJacDetStd": {
        "30": 0.0396461980779