In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import os
import random
import re

from tqdm import tqdm
import time

import pydicom as dicom
import nibabel as nib
import SimpleITK as sitk
import monai

import torch
import torch.nn as nn
import torch.optim as optim

from monai.networks.nets import EfficientNetBN
from monai.networks.nets import ResNet
#from efficientnet_pytorch import EfficientNet

import wandb


In [2]:
SEED = 344
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True # Fix the network according to random seed
    print('Finish seeding with seed {}'.format(seed))
    
seed_everything(SEED)
print('Training on device {}'.format(device))

Finish seeding with seed 344
Training on device cuda


In [3]:
dicom_tag_columns = [
    'Columns',
    'ImageOrientationPatient',
    'ImagePositionPatient',
    'InstanceNumber',
    'PatientID',
    'PatientPosition',
    'PixelSpacing',
    'RescaleIntercept',
    'RescaleSlope',
    'Rows',
    'SeriesNumber',
    'SliceThickness',
    'path',
    'WindowCenter',
    'WindowWidth'
]

train_dicom_tags = pd.read_parquet('autodl-tmp/train_dicom_tags.parquet', columns=dicom_tag_columns)
test_dicom_tags = pd.read_parquet('autodl-tmp/test_dicom_tags.parquet', columns=dicom_tag_columns)

train_series_meta = pd.read_csv('autodl-tmp/train_series_meta.csv')
test_series_meta = pd.read_csv('autodl-tmp/test_series_meta.csv')

train_csv = pd.read_csv('autodl-tmp/train.csv')

mid_z_csv = pd.read_csv('middle_z.csv')

complete_series_meta = mid_z_csv[mid_z_csv.middle_z != -1].reset_index()
complete_series_meta

Unnamed: 0,index,patient_id,series_id,middle_z
0,0,10004,21057,128
1,1,10004,51033,131
2,2,10005,18667,120
3,3,10007,47578,127
4,4,10026,29700,110
...,...,...,...,...
4392,4393,9961,2003,133
4393,4394,9961,63032,129
4394,4395,9980,40214,94
4395,4396,9980,40466,151


In [4]:
def raw_path_gen(patient_id, series_id, train=True):
    if(train):
        path = 'autodl-tmp/train_images_resample/'
    else:
        path = 'autodl-tmp/train_images_resample/'
    
    path += str(patient_id) + '/' + str(series_id)
    
    return path

def create_3D_scans(folder, downsample_rate=1): 
    filenames = os.listdir(folder)
    filenames = [int(filename.split('.')[0]) for filename in filenames]
    filenames = sorted(filenames)
    filenames = [str(filename) + '.dcm' for filename in filenames]
        
    volume = []
    #for filename in tqdm(filenames[::downsample_rate], position=0): 
    for filename in filenames[::downsample_rate]: 
        filepath = os.path.join(folder, filename)
        ds = dicom.dcmread(filepath)
        image = ds.pixel_array
        
        if ds.PixelRepresentation == 1:
            bit_shift = ds.BitsAllocated - ds.BitsStored
            dtype = image.dtype 
            image = (image << bit_shift).astype(dtype) >>  bit_shift
        
        # find rescale params
        if ("RescaleIntercept" in ds) and ("RescaleSlope" in ds):
            intercept = float(ds.RescaleIntercept)
            slope = float(ds.RescaleSlope)
    
        # find clipping params
        center = int(ds.WindowCenter)
        width = int(ds.WindowWidth)
        low = center - width / 2
        high = center + width / 2    
        
        
        image = (image * slope) + intercept
        image = np.clip(image, low, high)

        image = (image / np.max(image) * 255).astype(np.int16)
        image = image[::downsample_rate, ::downsample_rate]
        volume.append( image )
    
    volume = np.stack(volume, axis=0)
    return volume

def plot_image_with_seg(volume, volume_seg=[], orientation='Coronal', num_subplots=20):
    # simply copy
    if len(volume_seg) == 0:
        plot_mask = 0
    else:
        plot_mask = 1
        
    if orientation == 'Coronal':
        slices = np.linspace(0, volume.shape[2]-1, num_subplots).astype(np.int16)
        volume = volume.transpose([1, 0, 2])
        if plot_mask:
            volume_seg = volume_seg.transpose([1, 0, 2])
        
    elif orientation == 'Sagittal':
        slices = np.linspace(0, volume.shape[2]-1, num_subplots).astype(np.int16)
        volume = volume.transpose([2, 0, 1])
        if plot_mask:
            volume_seg = volume_seg.transpose([2, 0, 1])

    elif orientation == 'Axial':
        slices = np.linspace(0, volume.shape[0]-1, num_subplots).astype(np.int16)
           
    rows = np.max( [np.floor(np.sqrt(num_subplots)).astype(int) - 2, 1])
    cols = np.ceil(num_subplots/rows).astype(int)
    
    fig, ax = plt.subplots(rows, cols, figsize=(cols * 2, rows * 4))
    fig.tight_layout(h_pad=0.01, w_pad=0)
    
    ax = ax.ravel()
    for this_ax in ax:
        this_ax.axis('off')

    for counter, this_slice in enumerate( slices ):
        plt.sca(ax[counter])
        
        image = volume[this_slice, :, :]
        plt.imshow(image, cmap='gray')
        
        if plot_mask:
            mask = np.where(volume_seg[this_slice, :, :], volume_seg[this_slice, :, :], np.nan)
            plt.imshow(mask, cmap='Set1', alpha=0.5)
            
def load_nii(patient_id, series_id, root='autodl-tmp/train_images_resample/'):
    path = root + str(patient_id) + '/' + str(series_id) + '.nii.gz'
    img = sitk.ReadImage(path)
    img = sitk.GetArrayFromImage(img)
    
    # img = nib.load(path)
    # img = img.get_fdata().transpose(2, 1, 0)
    
    return img

def load_kidney_left_nii(patient_id, series_id, root='autodl-tmp/train_images_resample/'):
    path = root + str(patient_id) + '/' + str(series_id) + '/' + 'kidney_left.nii.gz'
    img = sitk.ReadImage(path)
    img = sitk.GetArrayFromImage(img)
    
    # img = nib.load(path)
    # img = img.get_fdata().transpose(2, 1, 0)
    
    return img

def load_kidney_right_nii(patient_id, series_id, root='autodl-tmp/train_images_resample/'):
    path = root + str(patient_id) + '/' + str(series_id) + '/' + 'kidney_right.nii.gz'
    img = sitk.ReadImage(path)
    img = sitk.GetArrayFromImage(img)
    
    # img = nib.load(path)
    # img = img.get_fdata().transpose(2, 1, 0)
    
    return img

def load_liver_nii(patient_id, series_id, root='autodl-tmp/train_images_resample/'):
    path = root + str(patient_id) + '/' + str(series_id) + '/' + 'liver.nii.gz'
    img = sitk.ReadImage(path)
    img = sitk.GetArrayFromImage(img)
    
    # img = nib.load(path)
    # img = img.get_fdata().transpose(2, 1, 0)
    
    return img

def load_spleen_nii(patient_id, series_id, root='autodl-tmp/train_images_resample/'):
    path = root + str(patient_id) + '/' + str(series_id) + '/' + 'liver.nii.gz'
    img = sitk.ReadImage(path)
    img = sitk.GetArrayFromImage(img)
    
    # img = nib.load(path)
    # img = img.get_fdata().transpose(2, 1, 0)
    
    return img

In [5]:
# patient_id, series_id = train_series_meta.loc[10, ["patient_id", "series_id"]].astype('int')
# mask_all = load_nii(10721, 63796, root='autodl-tmp/train_mask/')
# img_a = load_nii(10721, 63796, root='autodl-tmp/train_images_resample/')

# mask_kidney_left = mask_all % 2 == 1
# mask_kidney_right = (mask_all // 2) % 2 == 1
# mask_liver = (mask_all // 4) % 2 == 1
# mask_spleen = (mask_all // 8) % 2 == 1

# mask_kls = mask_kidney_left + mask_kidney_right + mask_liver + mask_spleen
# mask_z = np.sum(np.sum(mask_kls, 1), 1)
# np.argwhere(mask_z)[len(np.argwhere(mask_z)) // 2]


In [6]:
# plt.plot(mask_z)

In [7]:
# plot_image_with_seg(mask_kls, orientation='Coronal', num_subplots=7)
# plot_image_with_seg(img_a, orientation='Coronal', num_subplots=7)

In [8]:
# midz_columns = ['patient_id', 'series_id', 'middle_z']
# midz_pd = pd.DataFrame(columns=midz_columns)

# for i in tqdm(range(0, len(complete_series_meta))):
#     patient_id, series_id = complete_series_meta.loc[i, ["patient_id", "series_id"]].astype('int')
    
#     mask_all = load_nii(patient_id, series_id, root='autodl-tmp/train_mask/')
#     mask_kidney_left = mask_all % 2 == 1
#     mask_kidney_right = (mask_all // 2) % 2 == 1
#     mask_liver = (mask_all // 4) % 2 == 1
#     mask_spleen = (mask_all // 8) % 2 == 1
#     mask_kls = mask_kidney_left + mask_kidney_right + mask_liver + mask_spleen
#     mask_z = np.sum(np.sum(mask_kls, 1), 1)
#     if np.sum(mask_z) == 0:
#         middle_z = -1
#     else:
#         middle_z = np.argwhere(mask_z)[len(np.argwhere(mask_z)) // 2][0]

#     data_out = [[patient_id, series_id, middle_z]]
#     pd_out = pd.DataFrame(data=data_out, columns=midz_columns)
#     midz_pd = pd.concat([midz_pd, pd_out], axis=0, ignore_index=True)
# midz_pd.to_csv('middle_z.csv', index=None)
# midz_pd

In [9]:
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import scipy
import random

ds = 2

midz_pd = pd.read_csv('middle_z.csv')

class CTDataset(Dataset):
    def __init__(self, root='autodl-tmp/train_images_resample/', augmentation=False, meta=train_series_meta, device='cpu'):
        self.device = device
        self.series_meta = meta
        self.root = root
        self.t = monai.transforms.Compose([monai.transforms.RandZoom(prob=0.5, min_zoom=0.9, max_zoom=1.1),
                                           monai.transforms.RandRotate(range_x=3.14 / 24, prob=0.5),
                                           monai.transforms.SpatialPad(spatial_size=(320//ds, 280//ds, 280//ds), mode="edge"),
                                           monai.transforms.RandSpatialCrop(roi_size=(320//ds, 256//ds, 256//ds), random_size=False),
                                           monai.transforms.NormalizeIntensity(divisor = 400)
                                ])
        self.t_val = monai.transforms.Compose([monai.transforms.NormalizeIntensity(divisor = 400)])
        
        self.aug = augmentation
        
    def __len__(self):
        #return 1100
        return len(self.series_meta)
    
    def __getitem__(self, idx):
        
        z_aug = random.randint(-12, 12)

        patient_id, series_id = self.series_meta.loc[idx, ["patient_id", "series_id"]].astype('int')
        img_a = load_nii(patient_id, series_id, self.root).astype('float32')
        
        middle_z = midz_pd.loc[midz_pd.series_id == series_id, "middle_z"].values[0] // ds
        
        #img_t = torch.from_numpy(img_a).unsqueeze(0)
        #img_t = torch.from_numpy(img_a[::2, ::2, ::2]).unsqueeze(0)
        if(self.aug):
            img_a = scipy.ndimage.shift(img_a, [z_aug, 0, 0], output=None, order=0, mode='constant', cval=img_a[0, 0, 0])
            img_a = np.expand_dims(img_a[::ds, ::ds, ::ds], 0)
            img_t = self.t(img_a)
            middle_z += z_aug
        else:
            #img_t = torch.from_numpy(img_a[::ds, ::ds, ::ds]).unsqueeze(0)
            img_a = np.expand_dims(img_a[::ds, ::ds, ::ds], 0)
            img_t = self.t_val(img_a)

        label_t = torch.tensor(middle_z.astype('float32'))
        
        return img_t, label_t


In [10]:
# _, labels = next(iter(train_dl))
# print(labels.shape)

In [11]:
train_meta = complete_series_meta[-3200:].reset_index()
val_meta = complete_series_meta[0:-3200].reset_index()

# train_meta = injury_series_meta[-800:].reset_index()
# val_meta = injury_series_meta[0:-800].reset_index()

train_ds = CTDataset(meta = train_meta, augmentation=True)
train_dl = DataLoader(train_ds, batch_size=4, shuffle=True, num_workers=8)

val_ds = CTDataset(meta = val_meta, augmentation=False)
val_dl = DataLoader(val_ds, batch_size=4, shuffle=False, num_workers=8)

In [12]:
class EffNet(nn.Module):
    def __init__(self, ch_out=9):
        super(EffNet, self).__init__()
        self.conv_in = nn.Conv3d(1, 3, kernel_size=5, padding=2, stride=2)
        self.net = EfficientNetBN("efficientnet-b0", pretrained=False, progress=False, spatial_dims=3, in_channels=3, num_classes=ch_out,)
    def forward(self, x):
        x = self.conv_in(x)
        #return torch.sigmoid(self.net(x))
        return self.net(x)

In [13]:
import copy

def TrainClassifer(model,trn_dl,val_dl,optimizer, scheduler=None,
                   n_eopchs=20, device='cpu'):
 
    loss_fn = nn.MSELoss()
    model.to(device)
    best_model = copy.deepcopy(model)
    best_val = 999.0
    PATH_MODEL = 'Slice/test_slice2.pt'
    wandb.init(name='resnet18_test_slice2', 
               project='ResNet18-slice')

    for epoch in range(1, n_eopchs + 1):
        loss_train = 0.0
        model.train()
        for imgs, labels in tqdm(trn_dl, position=0):
            imgs = imgs.to(device)
            labels = labels.to(device)

            outputs = model(imgs).squeeze(-1)
            #outputs = model(imgs.unsqueeze(1))
            loss = loss_fn(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            loss_train += loss.item()
        torch.cuda.empty_cache()
        
        loss_val = 0.0
        correct_val = 0.0
        model.eval()
        
        for imgs, labels in tqdm(val_dl):
            imgs = imgs.to(device)
            labels = labels.to(device)
            with torch.no_grad():
                outputs = model(imgs).squeeze(-1)
                #outputs = model(imgs.unsqueeze(1))
                loss = loss_fn(outputs, labels)
                loss_val += loss.item()
        
        torch.cuda.empty_cache()
            
        if loss_val / len(val_dl) < best_val:
            best_val = loss_val / len(val_dl)
            torch.save(model.state_dict(), 'model_tmp.pt')
            best_model.load_state_dict(torch.load('model_tmp.pt'))
            
            
        if scheduler != None:
            scheduler.step()

        print('{} Eopch {}, Training Loss {}, Val Loss {}'.format(time.strftime("%Y-%m-%d %H:%M:%S",time.localtime()),
                                                                  epoch, loss_train / len(trn_dl), loss_val / len(val_dl),))
        
        
        
        wandb.log({'training loss': loss_train / len(trn_dl),
                  'val loss': loss_val / len(val_dl),})
    torch.save(best_model.state_dict(), PATH_MODEL)
    print('Finish training: best_val:{}'.format(best_val))
    wandb.finish()

In [14]:
net = ResNet(block='basic', layers=[2, 2, 2, 2], block_inplanes=[64, 128, 256, 512],
            spatial_dims=3, n_input_channels=1, num_classes=1, conv1_t_stride=4).to(device)
# net = EffNet(ch_out=1).to(device)

optimizer = optim.AdamW(net.parameters(), lr=2e-4)
TrainClassifer(model=net,trn_dl=train_dl,val_dl=val_dl,optimizer=optimizer, 
               scheduler=None, n_eopchs=30, device=device)

[34m[1mwandb[0m: Currently logged in as: [33mnorthm[0m ([33mrsna2023[0m). Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016672427486628293, max=1.0…

100%|██████████| 800/800 [06:19<00:00,  2.11it/s]
100%|██████████| 300/300 [01:20<00:00,  3.74it/s]


2023-10-05 04:41:06 Eopch 1, Training Loss 628.7015599885583, Val Loss 121.93795802434285


100%|██████████| 800/800 [06:24<00:00,  2.08it/s]
100%|██████████| 300/300 [01:19<00:00,  3.76it/s]


2023-10-05 04:48:51 Eopch 2, Training Loss 53.320793522354215, Val Loss 28.623952306111654


100%|██████████| 800/800 [06:16<00:00,  2.12it/s]
100%|██████████| 300/300 [01:15<00:00,  3.96it/s]


2023-10-05 04:56:24 Eopch 3, Training Loss 36.79479534905404, Val Loss 27.282227060049774


100%|██████████| 800/800 [06:11<00:00,  2.15it/s]
100%|██████████| 300/300 [01:19<00:00,  3.80it/s]


2023-10-05 05:03:54 Eopch 4, Training Loss 30.762206713706256, Val Loss 36.48025955279668


100%|██████████| 800/800 [06:13<00:00,  2.14it/s]
100%|██████████| 300/300 [01:17<00:00,  3.87it/s]


2023-10-05 05:11:26 Eopch 5, Training Loss 27.159955179337413, Val Loss 31.836952827771505


100%|██████████| 800/800 [06:18<00:00,  2.12it/s]
100%|██████████| 300/300 [01:21<00:00,  3.69it/s]


2023-10-05 05:19:06 Eopch 6, Training Loss 26.91845420021564, Val Loss 21.30819776189824


100%|██████████| 800/800 [06:10<00:00,  2.16it/s]
100%|██████████| 300/300 [01:21<00:00,  3.67it/s]


2023-10-05 05:26:38 Eopch 7, Training Loss 25.735823681764305, Val Loss 19.644349741141003


100%|██████████| 800/800 [06:16<00:00,  2.12it/s]
100%|██████████| 300/300 [01:21<00:00,  3.67it/s]


2023-10-05 05:34:17 Eopch 8, Training Loss 23.551625190377237, Val Loss 21.803604525725046


100%|██████████| 800/800 [06:20<00:00,  2.10it/s]
100%|██████████| 300/300 [01:20<00:00,  3.70it/s]


2023-10-05 05:41:58 Eopch 9, Training Loss 23.949924967219122, Val Loss 55.27312387307485


100%|██████████| 800/800 [06:15<00:00,  2.13it/s]
100%|██████████| 300/300 [01:21<00:00,  3.67it/s]


2023-10-05 05:49:37 Eopch 10, Training Loss 22.717095074336974, Val Loss 16.251146466682354


100%|██████████| 800/800 [06:17<00:00,  2.12it/s]
100%|██████████| 300/300 [01:19<00:00,  3.77it/s]


2023-10-05 05:57:14 Eopch 11, Training Loss 21.385727366060017, Val Loss 24.116958296696346


100%|██████████| 800/800 [06:22<00:00,  2.09it/s]
100%|██████████| 300/300 [01:21<00:00,  3.68it/s]


2023-10-05 06:04:58 Eopch 12, Training Loss 22.339314028918743, Val Loss 23.253099225362142


100%|██████████| 800/800 [06:23<00:00,  2.09it/s]
100%|██████████| 300/300 [01:20<00:00,  3.71it/s]


2023-10-05 06:12:42 Eopch 13, Training Loss 20.480616237064822, Val Loss 16.823646647160253


100%|██████████| 800/800 [06:21<00:00,  2.10it/s]
100%|██████████| 300/300 [01:20<00:00,  3.75it/s]


2023-10-05 06:20:23 Eopch 14, Training Loss 19.885585325844588, Val Loss 21.913222438494365


100%|██████████| 800/800 [06:15<00:00,  2.13it/s]
100%|██████████| 300/300 [01:18<00:00,  3.82it/s]


2023-10-05 06:27:57 Eopch 15, Training Loss 20.568520724372938, Val Loss 18.730131013691427


100%|██████████| 800/800 [06:15<00:00,  2.13it/s]
100%|██████████| 300/300 [01:19<00:00,  3.76it/s]


2023-10-05 06:35:32 Eopch 16, Training Loss 20.06132139861584, Val Loss 17.04728938281536


100%|██████████| 800/800 [06:17<00:00,  2.12it/s]
100%|██████████| 300/300 [01:20<00:00,  3.71it/s]


2023-10-05 06:43:10 Eopch 17, Training Loss 17.797496595084667, Val Loss 18.550848133563996


100%|██████████| 800/800 [06:22<00:00,  2.09it/s]
100%|██████████| 300/300 [01:20<00:00,  3.72it/s]


2023-10-05 06:50:54 Eopch 18, Training Loss 19.101869901791215, Val Loss 22.069849719802537


100%|██████████| 800/800 [06:17<00:00,  2.12it/s]
100%|██████████| 300/300 [01:20<00:00,  3.73it/s]


2023-10-05 06:58:32 Eopch 19, Training Loss 17.239095314741135, Val Loss 15.877683041468263


100%|██████████| 800/800 [06:18<00:00,  2.11it/s]
100%|██████████| 300/300 [01:20<00:00,  3.73it/s]


2023-10-05 07:06:11 Eopch 20, Training Loss 17.655551286363043, Val Loss 19.256379655525087


100%|██████████| 800/800 [06:19<00:00,  2.11it/s]
100%|██████████| 300/300 [01:19<00:00,  3.79it/s]


2023-10-05 07:13:50 Eopch 21, Training Loss 16.612944446988404, Val Loss 23.298984165688356


100%|██████████| 800/800 [06:15<00:00,  2.13it/s]
100%|██████████| 300/300 [01:19<00:00,  3.75it/s]


2023-10-05 07:21:25 Eopch 22, Training Loss 16.757467159442605, Val Loss 30.133903932174047


100%|██████████| 800/800 [06:20<00:00,  2.10it/s]
100%|██████████| 300/300 [01:20<00:00,  3.75it/s]


2023-10-05 07:29:06 Eopch 23, Training Loss 16.85791667321697, Val Loss 20.362948358953


100%|██████████| 800/800 [06:20<00:00,  2.10it/s]
100%|██████████| 300/300 [01:21<00:00,  3.70it/s]


2023-10-05 07:36:48 Eopch 24, Training Loss 15.74138745892793, Val Loss 19.694037202795347


100%|██████████| 800/800 [06:21<00:00,  2.10it/s]
100%|██████████| 300/300 [01:21<00:00,  3.68it/s]


2023-10-05 07:44:32 Eopch 25, Training Loss 15.810167119428515, Val Loss 16.54695967311971


100%|██████████| 800/800 [06:19<00:00,  2.11it/s]
100%|██████████| 300/300 [01:20<00:00,  3.72it/s]


2023-10-05 07:52:12 Eopch 26, Training Loss 16.087306949691847, Val Loss 17.072936187585196


100%|██████████| 800/800 [06:16<00:00,  2.12it/s]
100%|██████████| 300/300 [01:20<00:00,  3.74it/s]


2023-10-05 07:59:49 Eopch 27, Training Loss 14.73094510352239, Val Loss 16.148107617273926


100%|██████████| 800/800 [06:18<00:00,  2.11it/s]
100%|██████████| 300/300 [01:19<00:00,  3.77it/s]


2023-10-05 08:07:27 Eopch 28, Training Loss 13.363643765053713, Val Loss 16.580306245411435


100%|██████████| 800/800 [06:14<00:00,  2.14it/s]
100%|██████████| 300/300 [01:23<00:00,  3.60it/s]


2023-10-05 08:15:06 Eopch 29, Training Loss 14.423625068883412, Val Loss 15.119095928023258


100%|██████████| 800/800 [06:17<00:00,  2.12it/s]
100%|██████████| 300/300 [01:21<00:00,  3.67it/s]


2023-10-05 08:22:45 Eopch 30, Training Loss 14.084296721015125, Val Loss 18.15269568512837
Finish training: best_val:15.119095928023258


0,1
training loss,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val loss,█▂▂▂▂▁▁▁▄▁▂▂▁▁▁▁▁▁▁▁▂▂▁▁▁▁▁▁▁▁

0,1
training loss,14.0843
val loss,18.1527


In [20]:
net = ResNet(block='basic', layers=[2, 2, 2, 2], block_inplanes=[64, 128, 256, 512],
            spatial_dims=3, n_input_channels=1, num_classes=1, conv1_t_stride=4).to(device)
net.load_state_dict(torch.load('Slice/test_slice1.pt'))

max_diff = 0
mean_diff = 0
loss_fn = nn.MSELoss()
net.eval()
for imgs, labels in tqdm(val_dl):
    with torch.no_grad():
        preds = net(imgs.to(device))
        diff = preds.squeeze(-1) - labels.to(device)
        #print(diff)
        diff = diff.to('cpu').detach().numpy()
        diff = np.mean(np.abs(diff))
        #print(diff)
        max_diff = max(diff, max_diff)
        #diff = loss_fn(preds.squeeze(-1), labels.to(device))
        #print(diff)
        mean_diff += diff
print(max_diff)
print(mean_diff / len(val_dl))

100%|██████████| 300/300 [01:18<00:00,  3.81it/s]

20.988916
1.5932526230812072



