In [1]:
from monai.utils import first, set_determinism
from monai.transforms import (
    AsDiscrete,
    Lambda,
    Resize,
    ScaleIntensityRange,
    SpatialCrop,
    ToTensor,
)
from monai.handlers.utils import from_engine
from monai.networks.nets import UNet
from monai.networks.layers import Norm
from monai.metrics import DiceMetric
from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference
from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch
from monai.config import print_config
from monai.apps import download_and_extract
import monai.utils as utils

import torch
import tempfile
import shutil
import os
import sys

import numpy as np
import skimage

from scipy.stats import multivariate_normal
from scipy.stats import mode

import ffmpeg
import av

import itk
from itk import TubeTK as ttk

from time import perf_counter
from contextlib import contextmanager

from ARGUSUtils import *
from ARGUSUtils_IO import *
from ARGUSUtils_Linearization import *
from ARGUSUtils_Transforms import *

In [2]:
filename = "../Data/Final15/BAMC-PTXNoSliding/Do_Not_Use/219ns_image_1895283541879_clean.mov"
#filename = "../Data/Final15/BAMC-PTXSliding/212s_image_128692595484031_CLEAN.mov"

In [3]:
height,width = shape_video(filename)

with time_this("Load Time:"):
    vid = load_video(filename,height,width)
with time_this("Linearization Time:"):
    vid_linear = linearize_video(vid).transpose([2,1,0])

Time for Load Time: is 0.47677268367260695
   Resampling with zoom = 1.2627627627627627
Time for Linearization Time: is 3.8460789481177926


In [4]:
device = torch.device("cpu")

num_classes = 3
class_pleura = 1
class_rib = 2

net_in_dims = 3
net_in_channels = 1
net_channels=(16, 32, 64, 128, 32)
net_strides=(2, 2, 2, 2)
        
num_slices = 48
size_x = 320
size_y = 320
roi_size = (size_x,size_y,num_slices)

vfold_num = 0
model_filename_base = "./Models/BAMC_PTX_ARUNET-3D-PR-Final15/"
model_type = "best"  #"best" or "last"
model_file = model_filename_base+model_type+'_model.vfold_'+str(vfold_num)+'.pth'

In [5]:
Scale = ScaleIntensityRange(
    a_min=0, a_max=255,
    b_min=0.0, b_max=1.0)
vid_linear_scaled = Scale.__call__(vid_linear)
Crop = ARGUS_RandSpatialCropSlices(
    num_slices=num_slices,
    center_slice=30,
    axis=2)
image_data = np.empty([1, 1, vid_linear.shape[0], vid_linear.shape[1], num_slices])
image_data[0, 0] = Crop.__call__(vid_linear_scaled)
image_data_t = ToTensor().__call__(image_data.astype(np.float32))
itk.imwrite(itk.GetImageFromArray(image_data[0,0].astype(np.float32)), "ARUNet_input_image.mha")

In [6]:
output_image = vid_linear

pleura_prior = 1
min_size_comp = 110000
max_size_comp = 160000

with time_this("CPU 3D Inference Time:"):
    model = UNet(
        dimensions=net_in_dims,
        in_channels=net_in_channels,
        out_channels=num_classes,
        channels=net_channels,
        strides=net_strides,
        num_res_units=2,
        norm=Norm.BATCH,
    ).to(device)    
    model.load_state_dict(torch.load(model_file))
    model.eval()
    with torch.no_grad():
        test_outputs = sliding_window_inference(
            image_data_t.to(device), roi_size, 1, model)
        prob_shape = test_outputs[0,:,:,:,:].shape
        prob = np.empty(prob_shape)
        for c in range(num_classes):
            itkProb = itk.GetImageFromArray(test_outputs[0,c,:,:,:].cpu())
            imMathProb = ttk.ImageMath.New(itkProb)
            imMathProb.Blur(5)
            itkProb = imMathProb.GetOutput()
            prob[c] = itk.GetArrayFromImage(itkProb)
        arrc1 = np.zeros(prob[0].shape)
        pmin = prob[0].min()
        pmax = prob[0].max()
        for c in range(1,num_classes):
            pmin = min(pmin, prob[c].min())
            pmax = max(pmax, prob[c].max())
        prange = pmax - pmin
        prob = (prob - pmin) / prange
        prob[class_pleura] = prob[class_pleura] * pleura_prior
        done = False
        while not done:
            done = True
            count = np.count_nonzero(arrc1>0)
            prior_factor = 1
            while count<min_size_comp:
                prior_factor *= 1.05
                prob[class_pleura] = prob[class_pleura] * 1.05
                prob[class_rib] = prob[class_rib] * 1.05
                arrc1 = np.argmax(prob,axis=0)
                count = np.count_nonzero(arrc1>0)
                done = False
            while count>max_size_comp:
                prior_factor *= 0.95
                prob[class_pleura] = prob[class_pleura] * 0.95
                prob[class_rib] = prob[class_rib] * 0.95
                arrc1 = np.argmax(prob,axis=0)
                count = np.count_nonzero(arrc1>0)
                done = False

        arrcF = np.where(arrc1==1,1,0)
        itkcF = itk.GetImageFromArray(arrcF.astype(np.float32))
        imMathCF = ttk.ImageMath.New(itkcF)
        imMathCF.Erode(5,class_pleura,0)
        imMathCF.Dilate(5,class_pleura,0)
        output_image = imMathCF.GetOutputUChar()
        
        itkSegmentConnectedComponents = itk.itkARGUS.SegmentConnectedComponents
        seg = itkSegmentConnectedComponents.New(Input=output_image)
        seg.SetKeepOnlyLargestComponent(True)
        seg.Update()

        output_image = seg.GetOutput()
        output_arr = itk.GetArrayFromImage(output_image)

itk.imwrite(output_image, "ARUNet_output_image.mha")

numObjects = 3
largest = 1
Time for CPU 3D Inference Time: is 7.290886126924306


In [7]:
with time_this("CPU ROI Extraction Time:"):
    ROI_min_x = 0
    ROI_max_x = output_arr.shape[0]-1
    while( np.count_nonzero(output_arr[ROI_min_x,:,:]==1)==0 and ROI_min_x<ROI_max_x ):
        ROI_min_x += 1
    while( np.count_nonzero(output_arr[ROI_max_x,:,:]==1)==0 and ROI_max_x>ROI_min_x):
        ROI_max_x -= 1
    ROI_mid_x = (ROI_min_x + ROI_max_x)//2
    ROI_min_x = max(ROI_mid_x-80,0)
    ROI_max_x = min(ROI_min_x+160,output_arr.shape[0]-1)
    ROI_min_x = ROI_max_x-160
    ROI_arr = output_arr[ROI_min_x:ROI_max_x,:,:].transpose([2,0,1])
    print(ROI_min_x, "-", ROI_max_x)

print(ROI_arr.shape)
ROI_image = itk.GetImageFromArray(ROI_arr)
itk.imwrite(ROI_image, "ROI_input_image.mha")

50 - 210
Time for CPU ROI Extraction Time: is 0.003480011597275734
(48, 160, 320)


In [8]:
ROI_num_classes = 3
ROI_class_not_sliding = 1
ROI_class_sliding = 2

ROI_class_prior = [1.3,1.0,0.85]

ROI_net_in_dims = 2
ROI_net_in_channels = 4
ROI_net_channels=(32, 64, 128)
ROI_net_strides=(2, 2)

ROI_num_slices = 32
ROI_size_x = 160
ROI_size_y = 320
ROI_roi_size = (ROI_size_x, ROI_size_y)

ROI_vfold_num = 0
ROI_model_filename_base = "./Models/BAMC_PTX_ROINet-StdDevExtended-ExtrudedNS-Final15/"
ROI_model_type = "best"  #"best" or "last"
ROI_model_file = model_filename_base+model_type+'_model.vfold_'+str(vfold_num)+'.pth'

In [9]:
Scale = ScaleIntensityRange(
    a_min=0, a_max=255,
    b_min=0.0, b_max=1.0)
ROI_arr_scaled = Scale.__call__(ROI_arr)

Crop = ARGUS_RandSpatialCropSlices(
    num_slices=ROI_num_slices,
    axis=0,
    reduce_to_statistics=True,
    extended=True)
ROI_arr_data = np.empty([1, ROI_net_in_channels, ROI_arr.shape[1], ROI_arr.shape[2]])
ROI_arr_data[0] = Crop.__call__(ROI_arr_scaled)
ROI_arr_data_t = ToTensor().__call__(ROI_arr_data.astype(np.float32))

In [10]:
with time_this("ROI Inference Time:"):
    ROI_model = UNet(
        dimensions=ROI_net_in_dims,
        in_channels=ROI_net_in_channels,
        out_channels=ROI_num_classes,
        channels=ROI_net_channels,
        strides=ROI_net_strides,
        num_res_units=2,
        norm=Norm.BATCH,
    ).to(device)    
    ROI_model.load_state_dict(torch.load(ROI_model_file))
    ROI_model.eval()
    with torch.no_grad():
        ROI_test_outputs = sliding_window_inference(
            ROI_arr_data_t.to(device), ROI_roi_size, 1, ROI_model)
        val_outputs = [post_pred(i) for i in decollate_batch(val_outputs)]
        ROI_prob_shape = ROI_test_outputs[0,:,:,:].shape
        ROI_prob = np.empty(ROI_prob_shape)
        for c in range(ROI_num_classes):
            ROI_itkProb = itk.GetImageFromArray(ROI_test_outputs[0,c,:,:].cpu())
            ROI_imMathProb = ttk.ImageMath.New(ROI_itkProb)
            ROI_imMathProb.Blur(5)
            ROI_itkProb = ROI_imMathProb.GetOutput()
            itk.imwrite(ROI_itkProb, "prob"+str(c)+".mha")
            ROI_prob[c] = itk.GetArrayFromImage(ROI_itkProb)
        ROI_arrc1 = np.zeros(ROI_prob[0].shape)
        ROI_pmin = ROI_prob[0].min()
        ROI_pmax = ROI_prob[0].max()
        for c in range(1,ROI_num_classes):
            ROI_pmin = min(ROI_pmin, ROI_prob[c].min())
            ROI_pmax = max(ROI_pmax, ROI_prob[c].max())
        ROI_prange = ROI_pmax - ROI_pmin
        ROI_prob = (ROI_prob - ROI_pmin) / ROI_prange
        for c in range(ROI_num_classes):
            ROI_prob[c] = ROI_prob[c] * ROI_class_prior[c]
        ROI_arrc1 = np.argmax(ROI_prob,axis=0)
        
        ROI_itkc1 = itk.GetImageFromArray(ROI_arrc1.astype(np.float32))
        ROI_imMathC1 = ttk.ImageMath.New(ROI_itkc1)
        for c in range(ROI_num_classes):
            ROI_imMathC1.Erode(5,c,0)
            ROI_imMathC1.Dilate(5,c,0)
        ROI_itkc1 = ROI_imMathC1.GetOutputUChar()
        ROI_arrc1 = itk.GetArrayFromImage(ROI_itkc1)
        ROI_class_not_sliding_count = np.count_nonzero(ROI_arrc1==ROI_class_not_sliding)
        ROI_class_sliding_count = np.count_nonzero(ROI_arrc1==ROI_class_sliding)
        if( ROI_class_not_sliding_count > ROI_class_sliding_count ):
            print("Not Sliding:", ROI_class_not_sliding_count, ">", ROI_class_sliding_count)
        else:
            print("Sliding:", ROI_class_not_sliding_count, "<", ROI_class_sliding_count)
        

RuntimeError: Error(s) in loading state_dict for UNet:
	Missing key(s) in state_dict: "model.1.submodule.1.submodule.conv.unit0.conv.weight", "model.1.submodule.1.submodule.conv.unit0.conv.bias", "model.1.submodule.1.submodule.conv.unit0.adn.N.weight", "model.1.submodule.1.submodule.conv.unit0.adn.N.bias", "model.1.submodule.1.submodule.conv.unit0.adn.N.running_mean", "model.1.submodule.1.submodule.conv.unit0.adn.N.running_var", "model.1.submodule.1.submodule.conv.unit0.adn.A.weight", "model.1.submodule.1.submodule.conv.unit1.conv.weight", "model.1.submodule.1.submodule.conv.unit1.conv.bias", "model.1.submodule.1.submodule.conv.unit1.adn.N.weight", "model.1.submodule.1.submodule.conv.unit1.adn.N.bias", "model.1.submodule.1.submodule.conv.unit1.adn.N.running_mean", "model.1.submodule.1.submodule.conv.unit1.adn.N.running_var", "model.1.submodule.1.submodule.conv.unit1.adn.A.weight", "model.1.submodule.1.submodule.residual.weight", "model.1.submodule.1.submodule.residual.bias". 
	Unexpected key(s) in state_dict: "model.1.submodule.1.submodule.0.conv.unit0.conv.weight", "model.1.submodule.1.submodule.0.conv.unit0.conv.bias", "model.1.submodule.1.submodule.0.conv.unit0.adn.N.weight", "model.1.submodule.1.submodule.0.conv.unit0.adn.N.bias", "model.1.submodule.1.submodule.0.conv.unit0.adn.N.running_mean", "model.1.submodule.1.submodule.0.conv.unit0.adn.N.running_var", "model.1.submodule.1.submodule.0.conv.unit0.adn.N.num_batches_tracked", "model.1.submodule.1.submodule.0.conv.unit0.adn.A.weight", "model.1.submodule.1.submodule.0.conv.unit1.conv.weight", "model.1.submodule.1.submodule.0.conv.unit1.conv.bias", "model.1.submodule.1.submodule.0.conv.unit1.adn.N.weight", "model.1.submodule.1.submodule.0.conv.unit1.adn.N.bias", "model.1.submodule.1.submodule.0.conv.unit1.adn.N.running_mean", "model.1.submodule.1.submodule.0.conv.unit1.adn.N.running_var", "model.1.submodule.1.submodule.0.conv.unit1.adn.N.num_batches_tracked", "model.1.submodule.1.submodule.0.conv.unit1.adn.A.weight", "model.1.submodule.1.submodule.0.residual.weight", "model.1.submodule.1.submodule.0.residual.bias", "model.1.submodule.1.submodule.1.submodule.0.conv.unit0.conv.weight", "model.1.submodule.1.submodule.1.submodule.0.conv.unit0.conv.bias", "model.1.submodule.1.submodule.1.submodule.0.conv.unit0.adn.N.weight", "model.1.submodule.1.submodule.1.submodule.0.conv.unit0.adn.N.bias", "model.1.submodule.1.submodule.1.submodule.0.conv.unit0.adn.N.running_mean", "model.1.submodule.1.submodule.1.submodule.0.conv.unit0.adn.N.running_var", "model.1.submodule.1.submodule.1.submodule.0.conv.unit0.adn.N.num_batches_tracked", "model.1.submodule.1.submodule.1.submodule.0.conv.unit0.adn.A.weight", "model.1.submodule.1.submodule.1.submodule.0.conv.unit1.conv.weight", "model.1.submodule.1.submodule.1.submodule.0.conv.unit1.conv.bias", "model.1.submodule.1.submodule.1.submodule.0.conv.unit1.adn.N.weight", "model.1.submodule.1.submodule.1.submodule.0.conv.unit1.adn.N.bias", "model.1.submodule.1.submodule.1.submodule.0.conv.unit1.adn.N.running_mean", "model.1.submodule.1.submodule.1.submodule.0.conv.unit1.adn.N.running_var", "model.1.submodule.1.submodule.1.submodule.0.conv.unit1.adn.N.num_batches_tracked", "model.1.submodule.1.submodule.1.submodule.0.conv.unit1.adn.A.weight", "model.1.submodule.1.submodule.1.submodule.0.residual.weight", "model.1.submodule.1.submodule.1.submodule.0.residual.bias", "model.1.submodule.1.submodule.1.submodule.1.submodule.conv.unit0.conv.weight", "model.1.submodule.1.submodule.1.submodule.1.submodule.conv.unit0.conv.bias", "model.1.submodule.1.submodule.1.submodule.1.submodule.conv.unit0.adn.N.weight", "model.1.submodule.1.submodule.1.submodule.1.submodule.conv.unit0.adn.N.bias", "model.1.submodule.1.submodule.1.submodule.1.submodule.conv.unit0.adn.N.running_mean", "model.1.submodule.1.submodule.1.submodule.1.submodule.conv.unit0.adn.N.running_var", "model.1.submodule.1.submodule.1.submodule.1.submodule.conv.unit0.adn.N.num_batches_tracked", "model.1.submodule.1.submodule.1.submodule.1.submodule.conv.unit0.adn.A.weight", "model.1.submodule.1.submodule.1.submodule.1.submodule.conv.unit1.conv.weight", "model.1.submodule.1.submodule.1.submodule.1.submodule.conv.unit1.conv.bias", "model.1.submodule.1.submodule.1.submodule.1.submodule.conv.unit1.adn.N.weight", "model.1.submodule.1.submodule.1.submodule.1.submodule.conv.unit1.adn.N.bias", "model.1.submodule.1.submodule.1.submodule.1.submodule.conv.unit1.adn.N.running_mean", "model.1.submodule.1.submodule.1.submodule.1.submodule.conv.unit1.adn.N.running_var", "model.1.submodule.1.submodule.1.submodule.1.submodule.conv.unit1.adn.N.num_batches_tracked", "model.1.submodule.1.submodule.1.submodule.1.submodule.conv.unit1.adn.A.weight", "model.1.submodule.1.submodule.1.submodule.1.submodule.residual.weight", "model.1.submodule.1.submodule.1.submodule.1.submodule.residual.bias", "model.1.submodule.1.submodule.1.submodule.2.0.conv.weight", "model.1.submodule.1.submodule.1.submodule.2.0.conv.bias", "model.1.submodule.1.submodule.1.submodule.2.0.adn.N.weight", "model.1.submodule.1.submodule.1.submodule.2.0.adn.N.bias", "model.1.submodule.1.submodule.1.submodule.2.0.adn.N.running_mean", "model.1.submodule.1.submodule.1.submodule.2.0.adn.N.running_var", "model.1.submodule.1.submodule.1.submodule.2.0.adn.N.num_batches_tracked", "model.1.submodule.1.submodule.1.submodule.2.0.adn.A.weight", "model.1.submodule.1.submodule.1.submodule.2.1.conv.unit0.conv.weight", "model.1.submodule.1.submodule.1.submodule.2.1.conv.unit0.conv.bias", "model.1.submodule.1.submodule.1.submodule.2.1.conv.unit0.adn.N.weight", "model.1.submodule.1.submodule.1.submodule.2.1.conv.unit0.adn.N.bias", "model.1.submodule.1.submodule.1.submodule.2.1.conv.unit0.adn.N.running_mean", "model.1.submodule.1.submodule.1.submodule.2.1.conv.unit0.adn.N.running_var", "model.1.submodule.1.submodule.1.submodule.2.1.conv.unit0.adn.N.num_batches_tracked", "model.1.submodule.1.submodule.1.submodule.2.1.conv.unit0.adn.A.weight", "model.1.submodule.1.submodule.2.0.conv.weight", "model.1.submodule.1.submodule.2.0.conv.bias", "model.1.submodule.1.submodule.2.0.adn.N.weight", "model.1.submodule.1.submodule.2.0.adn.N.bias", "model.1.submodule.1.submodule.2.0.adn.N.running_mean", "model.1.submodule.1.submodule.2.0.adn.N.running_var", "model.1.submodule.1.submodule.2.0.adn.N.num_batches_tracked", "model.1.submodule.1.submodule.2.0.adn.A.weight", "model.1.submodule.1.submodule.2.1.conv.unit0.conv.weight", "model.1.submodule.1.submodule.2.1.conv.unit0.conv.bias", "model.1.submodule.1.submodule.2.1.conv.unit0.adn.N.weight", "model.1.submodule.1.submodule.2.1.conv.unit0.adn.N.bias", "model.1.submodule.1.submodule.2.1.conv.unit0.adn.N.running_mean", "model.1.submodule.1.submodule.2.1.conv.unit0.adn.N.running_var", "model.1.submodule.1.submodule.2.1.conv.unit0.adn.N.num_batches_tracked", "model.1.submodule.1.submodule.2.1.conv.unit0.adn.A.weight". 
	size mismatch for model.0.conv.unit0.conv.weight: copying a param with shape torch.Size([16, 1, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([32, 4, 3, 3]).
	size mismatch for model.0.conv.unit0.conv.bias: copying a param with shape torch.Size([16]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for model.0.conv.unit0.adn.N.weight: copying a param with shape torch.Size([16]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for model.0.conv.unit0.adn.N.bias: copying a param with shape torch.Size([16]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for model.0.conv.unit0.adn.N.running_mean: copying a param with shape torch.Size([16]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for model.0.conv.unit0.adn.N.running_var: copying a param with shape torch.Size([16]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for model.0.conv.unit1.conv.weight: copying a param with shape torch.Size([16, 16, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([32, 32, 3, 3]).
	size mismatch for model.0.conv.unit1.conv.bias: copying a param with shape torch.Size([16]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for model.0.conv.unit1.adn.N.weight: copying a param with shape torch.Size([16]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for model.0.conv.unit1.adn.N.bias: copying a param with shape torch.Size([16]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for model.0.conv.unit1.adn.N.running_mean: copying a param with shape torch.Size([16]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for model.0.conv.unit1.adn.N.running_var: copying a param with shape torch.Size([16]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for model.0.residual.weight: copying a param with shape torch.Size([16, 1, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([32, 4, 3, 3]).
	size mismatch for model.0.residual.bias: copying a param with shape torch.Size([16]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for model.1.submodule.0.conv.unit0.conv.weight: copying a param with shape torch.Size([32, 16, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 32, 3, 3]).
	size mismatch for model.1.submodule.0.conv.unit0.conv.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for model.1.submodule.0.conv.unit0.adn.N.weight: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for model.1.submodule.0.conv.unit0.adn.N.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for model.1.submodule.0.conv.unit0.adn.N.running_mean: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for model.1.submodule.0.conv.unit0.adn.N.running_var: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for model.1.submodule.0.conv.unit1.conv.weight: copying a param with shape torch.Size([32, 32, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 64, 3, 3]).
	size mismatch for model.1.submodule.0.conv.unit1.conv.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for model.1.submodule.0.conv.unit1.adn.N.weight: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for model.1.submodule.0.conv.unit1.adn.N.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for model.1.submodule.0.conv.unit1.adn.N.running_mean: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for model.1.submodule.0.conv.unit1.adn.N.running_var: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for model.1.submodule.0.residual.weight: copying a param with shape torch.Size([32, 16, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 32, 3, 3]).
	size mismatch for model.1.submodule.0.residual.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for model.1.submodule.2.0.conv.weight: copying a param with shape torch.Size([64, 16, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([192, 32, 3, 3]).
	size mismatch for model.1.submodule.2.0.conv.bias: copying a param with shape torch.Size([16]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for model.1.submodule.2.0.adn.N.weight: copying a param with shape torch.Size([16]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for model.1.submodule.2.0.adn.N.bias: copying a param with shape torch.Size([16]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for model.1.submodule.2.0.adn.N.running_mean: copying a param with shape torch.Size([16]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for model.1.submodule.2.0.adn.N.running_var: copying a param with shape torch.Size([16]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for model.1.submodule.2.1.conv.unit0.conv.weight: copying a param with shape torch.Size([16, 16, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([32, 32, 3, 3]).
	size mismatch for model.1.submodule.2.1.conv.unit0.conv.bias: copying a param with shape torch.Size([16]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for model.1.submodule.2.1.conv.unit0.adn.N.weight: copying a param with shape torch.Size([16]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for model.1.submodule.2.1.conv.unit0.adn.N.bias: copying a param with shape torch.Size([16]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for model.1.submodule.2.1.conv.unit0.adn.N.running_mean: copying a param with shape torch.Size([16]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for model.1.submodule.2.1.conv.unit0.adn.N.running_var: copying a param with shape torch.Size([16]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for model.2.0.conv.weight: copying a param with shape torch.Size([32, 3, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 3, 3, 3]).
	size mismatch for model.2.1.conv.unit0.conv.weight: copying a param with shape torch.Size([3, 3, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([3, 3, 3, 3]).