In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import os
import random
import re

from tqdm import tqdm
import time

import pydicom as dicom
import nibabel as nib
import SimpleITK as sitk
import monai

import torch
import torch.nn as nn
import torch.optim as optim

from monai.networks.nets import EfficientNetBN
from monai.networks.nets import ResNet
#from efficientnet_pytorch import EfficientNet
import timm

import wandb


In [2]:
SEED = 344
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True # Fix the network according to random seed
    print('Finish seeding with seed {}'.format(seed))
    
seed_everything(SEED)
print('Training on device {}'.format(device))

Finish seeding with seed 344
Training on device cuda


In [3]:
dicom_tag_columns = [
    'Columns',
    'ImageOrientationPatient',
    'ImagePositionPatient',
    'InstanceNumber',
    'PatientID',
    'PatientPosition',
    'PixelSpacing',
    'RescaleIntercept',
    'RescaleSlope',
    'Rows',
    'SeriesNumber',
    'SliceThickness',
    'path',
    'WindowCenter',
    'WindowWidth'
]

train_dicom_tags = pd.read_parquet('autodl-tmp/train_dicom_tags.parquet', columns=dicom_tag_columns)
test_dicom_tags = pd.read_parquet('autodl-tmp/test_dicom_tags.parquet', columns=dicom_tag_columns)

train_series_meta = pd.read_csv('autodl-tmp/train_series_meta.csv')
test_series_meta = pd.read_csv('autodl-tmp/test_series_meta.csv')

train_csv = pd.read_csv('autodl-tmp/train.csv')

mid_z_csv = pd.read_csv('middle_z.csv')
complete_series_meta = mid_z_csv[mid_z_csv.middle_z != -1].reset_index()
complete_series_meta

Unnamed: 0,index,patient_id,series_id,middle_z
0,0,10004,21057,128
1,1,10004,51033,131
2,2,10005,18667,120
3,3,10007,47578,127
4,4,10026,29700,110
...,...,...,...,...
4392,4393,9961,2003,133
4393,4394,9961,63032,129
4394,4395,9980,40214,94
4395,4396,9980,40466,151


In [4]:
injury_series_meta = train_series_meta.loc[train_series_meta.patient_id.isin(train_csv.loc[train_csv.any_injury == 1, "patient_id"].values)]
healthy_series_meta = train_series_meta.loc[train_series_meta.patient_id.isin(train_csv.loc[train_csv.any_injury == 0, "patient_id"].values)]

In [5]:
def raw_path_gen(patient_id, series_id, train=True):
    if(train):
        path = 'autodl-tmp/train_images_resample/'
    else:
        path = 'autodl-tmp/train_images_resample/'
    
    path += str(patient_id) + '/' + str(series_id)
    
    return path

def create_3D_scans(folder, downsample_rate=1): 
    filenames = os.listdir(folder)
    filenames = [int(filename.split('.')[0]) for filename in filenames]
    filenames = sorted(filenames)
    filenames = [str(filename) + '.dcm' for filename in filenames]
        
    volume = []
    #for filename in tqdm(filenames[::downsample_rate], position=0): 
    for filename in filenames[::downsample_rate]: 
        filepath = os.path.join(folder, filename)
        ds = dicom.dcmread(filepath)
        image = ds.pixel_array
        
        if ds.PixelRepresentation == 1:
            bit_shift = ds.BitsAllocated - ds.BitsStored
            dtype = image.dtype 
            image = (image << bit_shift).astype(dtype) >>  bit_shift
        
        # find rescale params
        if ("RescaleIntercept" in ds) and ("RescaleSlope" in ds):
            intercept = float(ds.RescaleIntercept)
            slope = float(ds.RescaleSlope)
    
        # find clipping params
        center = int(ds.WindowCenter)
        width = int(ds.WindowWidth)
        low = center - width / 2
        high = center + width / 2    
        
        
        image = (image * slope) + intercept
        image = np.clip(image, low, high)

        image = (image / np.max(image) * 255).astype(np.int16)
        image = image[::downsample_rate, ::downsample_rate]
        volume.append( image )
    
    volume = np.stack(volume, axis=0)
    return volume

def plot_image_with_seg(volume, volume_seg=[], orientation='Coronal', num_subplots=20):
    # simply copy
    if len(volume_seg) == 0:
        plot_mask = 0
    else:
        plot_mask = 1
        
    if orientation == 'Coronal':
        slices = np.linspace(0, volume.shape[2]-1, num_subplots).astype(np.int16)
        volume = volume.transpose([1, 0, 2])
        if plot_mask:
            volume_seg = volume_seg.transpose([1, 0, 2])
        
    elif orientation == 'Sagittal':
        slices = np.linspace(0, volume.shape[2]-1, num_subplots).astype(np.int16)
        volume = volume.transpose([2, 0, 1])
        if plot_mask:
            volume_seg = volume_seg.transpose([2, 0, 1])

    elif orientation == 'Axial':
        slices = np.linspace(0, volume.shape[0]-1, num_subplots).astype(np.int16)
           
    rows = np.max( [np.floor(np.sqrt(num_subplots)).astype(int) - 2, 1])
    cols = np.ceil(num_subplots/rows).astype(int)
    
    fig, ax = plt.subplots(rows, cols, figsize=(cols * 2, rows * 4))
    fig.tight_layout(h_pad=0.01, w_pad=0)
    
    ax = ax.ravel()
    for this_ax in ax:
        this_ax.axis('off')

    for counter, this_slice in enumerate( slices ):
        plt.sca(ax[counter])
        
        image = volume[this_slice, :, :]
        plt.imshow(image, cmap='gray')
        
        if plot_mask:
            mask = np.where(volume_seg[this_slice, :, :], volume_seg[this_slice, :, :], np.nan)
            plt.imshow(mask, cmap='Set1', alpha=0.5)
            
def load_nii(patient_id, series_id, root='autodl-tmp/train_images_resample/'):
    path = root + str(patient_id) + '/' + str(series_id) + '.nii.gz'
    img = sitk.ReadImage(path)
    img = sitk.GetArrayFromImage(img)
    
    # img = nib.load(path)
    # img = img.get_fdata().transpose(2, 1, 0)
    
    return img

In [6]:
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

ds = 2

midz_pd = pd.read_csv('middle_z.csv')

class CTDataset(Dataset):
    def __init__(self, root='autodl-tmp/train_images_resample/', augmentation=False, meta=train_series_meta, device='cpu'):
        self.device = device
        self.series_meta = meta
        self.root = root
        self.t = monai.transforms.Compose([monai.transforms.RandZoom(prob=0.5, min_zoom=0.9, max_zoom=1.1),
                                           monai.transforms.RandRotate(range_x=3.14 / 24, prob=0.5),
                                           monai.transforms.SpatialPad(spatial_size=(320//ds, 280, 280), mode="edge"),
                                           monai.transforms.RandSpatialCrop(roi_size=(320//ds, 256, 256), random_size=False),
                                           monai.transforms.NormalizeIntensity(divisor = 400)
                                ])
        self.t_val = monai.transforms.Compose([monai.transforms.NormalizeIntensity(divisor = 400)])
        
        self.aug = augmentation
        
    def __len__(self):
        #return 1100
        return len(self.series_meta)
    
    def __getitem__(self, idx):

        patient_id, series_id = self.series_meta.loc[idx, ["patient_id", "series_id"]].astype('int')
        middle_z = midz_pd.loc[midz_pd.series_id == series_id, "middle_z"].values[0] // ds
        
        img_a = load_nii(patient_id, series_id, self.root).astype('float32')
        #img_t = torch.from_numpy(img_a).unsqueeze(0)
        #img_t = torch.from_numpy(img_a[::2, ::2, ::2]).unsqueeze(0)
        if(self.aug):
            img_t = self.t(np.expand_dims(img_a[::ds, :, :], 0))
        else:
            #img_t = torch.from_numpy(img_a[::ds, ::ds, ::ds]).unsqueeze(0)
            img_t = self.t_val(np.expand_dims(img_a[::ds, :, :], 0))
        label_columns = [
            'kidney_low',
            'kidney_high',
            'liver_low',
            'liver_high',
            'spleen_low',
            'spleen_high',
            #'any_injury'
        ]
        label_a = train_csv.loc[train_csv.patient_id == patient_id, label_columns].values[0].astype('float32')
        label_t = torch.from_numpy(label_a)
        label_midz = torch.tensor(middle_z.astype('int'))
        
        return img_t, label_t, label_midz

In [7]:
# train_meta = complete_series_meta[0:3200].reset_index()
# val_meta = complete_series_meta[3200:].reset_index()

# # train_meta = injury_series_meta[-800:].reset_index()
# # val_meta = injury_series_meta[0:-800].reset_index()

# train_ds = CTDataset(meta = train_meta, augmentation=True)
# train_dl = DataLoader(train_ds, batch_size=4, shuffle=True, num_workers=8)

# val_ds = CTDataset(meta = val_meta, augmentation=False)
# val_dl = DataLoader(val_ds, batch_size=4, shuffle=False, num_workers=8)

In [8]:
# class EffNet(nn.Module):
#     def __init__(self, ch_out=9):
#         super(EffNet, self).__init__()
#         #self.conv_in = nn.Conv3d(1, 3, kernel_size=5, padding=2, stride=2)
#         self.net = EfficientNetBN("efficientnet-b0", pretrained=False, progress=False, spatial_dims=3, in_channels=1, num_classes=ch_out,)
#     def forward(self, x):
#         #x = self.conv_in(x)
#         #return torch.sigmoid(self.net(x))
#         return self.net(x)

class KLSVolNet(nn.Module):
    def __init__(self, backbone = "tf_efficientnetv2_s.in21k_ft_in1k", 
                 ch_in = 3, ch_out = 9, slices = 15, dropout = 0.0, pretrained=True):
        super(KLSVolNet, self).__init__()
        self.slices = slices
        
        self.encoder = timm.create_model(
            backbone,
            in_chans=ch_in,
            num_classes=ch_out,
            features_only=False,
            drop_rate=0.0,
            drop_path_rate=0.0,
            pretrained=False,
        )
        
        if 'efficient' in backbone and pretrained:
            self.encoder.load_state_dict(torch.load('pretrained/tf_efficientnetv2_s.in21k_ft_in1k_6class.pt'))
        elif 'convnext' in backbone and pretrained:
            self.encoder.load_state_dict(torch.load('pretrained/convnextv2_nano.fcmae_ft_in22k_in1k_384_6class.pt'))
        
        if 'efficient' in backbone:
            hdim = self.encoder.conv_head.out_channels
            self.encoder.classifier = nn.Identity()
        elif 'convnext' in backbone:
            hdim = self.encoder.head.fc.in_features
            self.encoder.head.fc = nn.Identity()
        
        self.lstm = nn.LSTM(hdim, 256, num_layers=2, dropout=0.0, bidirectional=True, batch_first=True)
        
        self.head = nn.Sequential(
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.Dropout(dropout),
            nn.LeakyReLU(0.1),
            nn.Linear(256, ch_out),
        )
        
        self.head2 = nn.Conv1d(slices, 1, 1)
        
        
    def slicer(self, img, mid_z, slices):
        img_kls = []
        for i in range(0, len(mid_z)):
            z = mid_z[i]
            if z < 30:
                z = 30
            elif img.shape[-3] - z < 30:
                z = img.shape[-3] - 31
                
            slice_kls = img[i, :, z-30:z+30, :, :]
            img_kls.append(slice_kls)
            # plot_image_with_seg(slice_kls[0].numpy(), orientation='Axial', num_subplots=7)
        img = torch.cat(img_kls, 0)
        img = img.unsqueeze(1)
        #print(img.shape)
        z_length = img.shape[-3]
        z_slices = (np.linspace(0, z_length, slices + 2)).astype('int')
        z_slices = z_slices[1:-1]
        #print(z_slices)
        slices_list = []
        for z in z_slices:
            slices_list.append(img[:, :, z-1:z+2, :, :])
        img_slice = torch.cat(slices_list, 1)
        return img_slice
        
    def forward(self, x, mid_z):  # (bs, nslice, ch, sz, sz)
        x = self.slicer(x, mid_z, self.slices)
        bs, nslice,ch, sz1, sz2 = x.shape
        x = x.view(bs*nslice, ch, sz1, sz2)
        
        feature_2d = self.encoder(x)
        feature_2d = feature_2d.view(bs, nslice, -1)
        
        feature_lstm, _ = self.lstm(feature_2d)
        feature_lstm = feature_lstm.contiguous().view(bs * nslice, -1)
        
        preds = self.head(feature_lstm)
        preds = preds.view(bs, nslice, -1).contiguous()
        preds = self.head2(preds)
        
        return preds.squeeze(1)
        
        
        # bs = x.shape[0]
        # x = x.view(bs * n_slice_per_c, in_chans, image_size, image_size)
        # feat = self.encoder(x)
        # feat = feat.view(bs, n_slice_per_c, -1)
        # feat, _ = self.lstm(feat)
        # feat = feat.contiguous().view(bs * n_slice_per_c, -1)
        # feat = self.head(feat)
        # feat = feat.view(bs, n_slice_per_c).contiguous()
        

In [9]:
# imgs, labels, midz = next(iter(train_dl))
# net = KLSVolNet(ch_out = 6)
# img_s = net(imgs, midz)
# img_s.shape

In [10]:
# img_a = imgs[0, 0].numpy()
# plot_image_with_seg(img_a, orientation='Coronal', num_subplots=7)

In [11]:
import sklearn.metrics

def transform_9class(label_in):
    label_out = [1 - label_in[0],
                   label_in[0],
                    1- label_in[1],
                   label_in[1],
                   (1 - label_in[2]) * (1 - label_in[3]),
                   label_in[2],
                   label_in[3],
                   (1 - label_in[4]) * (1 - label_in[5]),
                    label_in[4],
                   label_in[5],
                   (1 - label_in[6]) * (1 - label_in[7]),
                   label_in[6],
                   label_in[7]]
    return label_out

def transform_13class(label_in):
    label_out = label_in
    return label_out.tolist()

def transform_kls9class(label_in):
    label_out = [1, 1, 1, 1]
    label_out.extend(label_in.tolist())
    return label_out

def transform_kls6class(label_in):
    label_kls = label_in
    label_in = [0.5, 0.5]
    label_in.extend(label_kls.tolist())
    label_out = [1, 1, 1, 1,
                   (1 - label_in[2]) * (1 - label_in[3]),
                   label_in[2],
                   label_in[3],
                   (1 - label_in[4]) * (1 - label_in[5]),
                    label_in[4],
                   label_in[5],
                   (1 - label_in[6]) * (1 - label_in[7]),
                   label_in[6],
                   label_in[7]]
    return label_out


def loss_metrics(metrics, transform):
    preds = [transform(x) for x in metrics["predict"]]
    targets = [transform(x) for x in metrics["label"]]
    targets_any_injury = metrics["label"][:, -1]
    
    loss_list = []
    
    print("F1 score: ", sklearn.metrics.f1_score(metrics["label"], np.around(metrics["predict"]), average=None, zero_division=0.0))
    print("AUC score: ", sklearn.metrics.roc_auc_score(metrics["label"], metrics["predict"], average=None))
    
    for i in range(0, len(preds)):
        predict = preds[i]
        target = targets[i]
        
        label_pred = np.zeros(14)
        label_pred[0] = predict[0] / (predict[0] + predict[1])
        label_pred[1] = predict[1] / (predict[0] + predict[1])
        label_pred[2] = predict[2] / (predict[2] + predict[3])
        label_pred[3] = predict[3] / (predict[2] + predict[3])
        label_pred[4] = predict[4] / (predict[4] + predict[5] + predict[6])
        label_pred[5] = predict[5] / (predict[4] + predict[5] + predict[6])
        label_pred[6] = predict[6] / (predict[4] + predict[5] + predict[6])
        label_pred[7] = predict[7] / (predict[7] + predict[8] + predict[9])
        label_pred[8] = predict[8] / (predict[7] + predict[8] + predict[9])
        label_pred[9] = predict[9] / (predict[7] + predict[8] + predict[9])
        label_pred[10] = predict[10] / (predict[10] + predict[11] + predict[12])
        label_pred[11] = predict[11] / (predict[10] + predict[11] + predict[12])
        label_pred[12] = predict[12] / (predict[10] + predict[11] + predict[12])
        label_pred[13] = max([1 - label_pred[x] for x in [0, 2, 4, 7, 10]])
        
        targets_any_injury = max([1 - target[x] for x in [0, 2, 4, 7, 10]])
        
        target.append(targets_any_injury)
        label_target = np.array(target)
        
        weight = np.array([1, 2, 1, 6, 1, 2, 4, 1, 2, 4, 1, 2, 4, 6])
        
        loss_list.append(sklearn.metrics.log_loss(
            y_true=label_target,
            y_pred=label_pred,
            sample_weight=weight))
    #print("Weighted Loss: " + np.mean(loss_list))
    
    return np.mean(loss_list)
        
    
    
    #print(np.array(preds).shape)
    

In [12]:
import copy

import torch.cuda.amp as amp

scaler = amp.GradScaler()

def TrainClassifer(model,trn_dl,val_dl,optimizer, project, name, suffix, scheduler=None,
                   n_eopchs=20, device='cpu'):
 
    #loss_fn = nn.BCELoss(weight=torch.Tensor([1, 6, 1, 6, 1, 4, 8, 1, 4, 8, 1, 4, 8]).to(device))
    #loss_fn = nn.BCELoss()
    loss_fn = nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([10, 10, 6, 6, 6, 6]).to(device))
    model.to(device)
    best_model = copy.deepcopy(model)
    bestweight_model = copy.deepcopy(model)
    best_val = 999.0
    best_weightloss = 999.0
    metrics = {'predict': [], 'label' : []}
    PATH_MODEL = project + '/' + name + '/' + suffix + '.pt'
    wandb.init(name=name, 
               project=project)

    for epoch in range(1, n_eopchs + 1):
        loss_train = 0.0
        model.train()
        for imgs, labels, midz in tqdm(trn_dl, position=0):
            imgs = imgs.to(device)
            labels = labels.to(device)
            midz = midz.to(device)
            
            with amp.autocast():
                outputs = model(imgs, midz)
                #outputs = model(imgs.unsqueeze(1))
                loss = loss_fn(outputs, labels)
                
            optimizer.zero_grad()
            
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            # outputs = model(imgs, midz)
            # #outputs = model(imgs.unsqueeze(1))
            # loss = loss_fn(outputs, labels)
            # optimizer.zero_grad()
            # loss.backward()
            # optimizer.step()
            loss_train += loss.item()
        torch.cuda.empty_cache()
        
        loss_val = 0.0
        correct_val = 0.0
        model.eval()
        
        for imgs, labels, midz in tqdm(val_dl):
            imgs = imgs.to(device)
            labels = labels.to(device)
            midz = midz.to(device)
            with torch.no_grad():
                outputs = model(imgs, midz)
                #outputs = model(imgs.unsqueeze(1))
                loss = loss_fn(outputs, labels)
                loss_val += loss.item()
                outputs = torch.sigmoid(outputs)
                metrics['predict'].extend((outputs.to('cpu').detach().numpy()).tolist())
                metrics['label'].extend((labels.to('cpu').detach().numpy()).tolist())
        
        metrics['predict'] = np.array(metrics['predict'])
        #metrics['predict'] = np.array([[0.5, 0.5, 0.33333, 0.33333, 0.33333, 0.33333, 0.33333, 0.33333, 0.33333]]*len(metrics['label']))
        metrics['label'] = np.array(metrics['label'])
        weighted_loss = loss_metrics(metrics, transform_kls6class)
        metrics = {'predict': [], 'label' : []}
        
        torch.cuda.empty_cache()
        
        if (weighted_loss) < best_weightloss:
            best_weightloss = weighted_loss
            
        if loss_val / len(val_dl) < best_val:
            best_val = loss_val / len(val_dl)
            torch.save(model.state_dict(), 'model_tmp.pt')
            best_model.load_state_dict(torch.load('model_tmp.pt'))
            
            
        if scheduler != None:
            scheduler.step()

        print('{} Eopch {}, Training Loss {}, Val Loss {}, Weighted Loss {}'.format(time.strftime("%Y-%m-%d %H:%M:%S",time.localtime()),
                                                                  epoch, loss_train / len(trn_dl), loss_val / len(val_dl),
                                                                                     weighted_loss))
        
        
        
        wandb.log({'training loss': loss_train / len(trn_dl),
                  'val loss': loss_val / len(val_dl),
                  'weighted loss': weighted_loss})
    torch.save(best_model.state_dict(), PATH_MODEL)
    print('Finish training: best_val:{}, best_weighted loss:{}'.format(best_val, best_weightloss))
    wandb.finish()

In [13]:
import os

project = "KLS2.5Dx2-nophase-4fold"
name = "KLS2.5Dx2-nophase_Effnetv2_4fold"
backbone = "tf_efficientnetv2_s.in21k_ft_in1k"
# backbone = 'convnextv2_nano.fcmae_ft_in22k_in1k_384'
n_eopchs = 12

p = project + '/' + name
if not os.path.exists(p):
    os.mkdir(p)

In [14]:
# 4397 // 4 = 1099

train_df_fold1 = complete_series_meta[1100:].reset_index(drop=True)
train_df_fold2 = pd.concat([complete_series_meta[0:1100], complete_series_meta[2200:]], axis=0).reset_index(drop=True)
train_df_fold3 = pd.concat([complete_series_meta[0:2200], complete_series_meta[3300:]], axis=0).reset_index(drop=True)
train_df_fold4 = complete_series_meta[:3300].reset_index()

val_df_fold1 = complete_series_meta[0:1100].reset_index(drop=True)
val_df_fold2 = complete_series_meta[1100:2200].reset_index(drop=True)
val_df_fold3 = complete_series_meta[2200:3300].reset_index(drop=True)
val_df_fold4 = complete_series_meta[3300:].reset_index(drop=True)

train_ds_fold1 = CTDataset(meta = train_df_fold1, augmentation=True)
train_dl_fold1 = DataLoader(train_ds_fold1, batch_size=4, shuffle=True, num_workers=8)
val_ds_fold1 = CTDataset(meta = val_df_fold1, augmentation=False)
val_dl_fold1 = DataLoader(val_ds_fold1, batch_size=4, shuffle=False, num_workers=8)

train_ds_fold2 = CTDataset(meta = train_df_fold2, augmentation=True)
train_dl_fold2 = DataLoader(train_ds_fold2, batch_size=4, shuffle=True, num_workers=8)
val_ds_fold2 = CTDataset(meta = val_df_fold2, augmentation=False)
val_dl_fold2 = DataLoader(val_ds_fold2, batch_size=4, shuffle=False, num_workers=8)

train_ds_fold3 = CTDataset(meta = train_df_fold3, augmentation=True)
train_dl_fold3 = DataLoader(train_ds_fold3, batch_size=4, shuffle=True, num_workers=8)
val_ds_fold3 = CTDataset(meta = val_df_fold3, augmentation=False)
val_dl_fold3 = DataLoader(val_ds_fold3, batch_size=4, shuffle=False, num_workers=8)

train_ds_fold4 = CTDataset(meta = train_df_fold4, augmentation=True)
train_dl_fold4 = DataLoader(train_ds_fold4, batch_size=4, shuffle=True, num_workers=8)
val_ds_fold4 = CTDataset(meta = val_df_fold4, augmentation=False)
val_dl_fold4 = DataLoader(val_ds_fold4, batch_size=4, shuffle=False, num_workers=8)

In [15]:
net = KLSVolNet(backbone=backbone, ch_out = 6).to(device)
optimizer = optim.AdamW(net.parameters(), lr=1e-5)
TrainClassifer(model=net,trn_dl=train_dl_fold1,val_dl=val_dl_fold1,optimizer=optimizer, 
               project=project, name=name, suffix='fold1', scheduler=None, n_eopchs=n_eopchs, device=device)

net = KLSVolNet(backbone=backbone, ch_out = 6).to(device)
optimizer = optim.AdamW(net.parameters(), lr=1e-5)
TrainClassifer(model=net,trn_dl=train_dl_fold2,val_dl=val_dl_fold2,optimizer=optimizer, 
               project=project, name=name, suffix='fold2', scheduler=None, n_eopchs=n_eopchs, device=device)

net = KLSVolNet(backbone=backbone, ch_out = 6).to(device)
optimizer = optim.AdamW(net.parameters(), lr=1e-5)
TrainClassifer(model=net,trn_dl=train_dl_fold3,val_dl=val_dl_fold3,optimizer=optimizer, 
               project=project, name=name, suffix='fold3', scheduler=None, n_eopchs=n_eopchs, device=device)

net = KLSVolNet(backbone=backbone, ch_out = 6).to(device)
optimizer = optim.AdamW(net.parameters(), lr=1e-5)
TrainClassifer(model=net,trn_dl=train_dl_fold4,val_dl=val_dl_fold4,optimizer=optimizer, 
               project=project, name=name, suffix='fold4', scheduler=None, n_eopchs=n_eopchs, device=device)

[34m[1mwandb[0m: Currently logged in as: [33mnorthm[0m ([33mrsna2023[0m). Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016670822290082772, max=1.0…

100%|██████████| 825/825 [23:26<00:00,  1.70s/it]
100%|██████████| 275/275 [03:21<00:00,  1.37it/s]


F1 score:  [0.         0.         0.11111111 0.         0.         0.        ]
AUC score:  [0.54665136 0.76725746 0.63020833 0.66753598 0.62616422 0.68340796]
2023-10-04 03:07:41 Eopch 1, Training Loss 0.6720599909081604, Val Loss 0.695117086334662, Weighted Loss 0.5005865122420716


100%|██████████| 825/825 [23:25<00:00,  1.70s/it]
100%|██████████| 275/275 [03:30<00:00,  1.31it/s]


F1 score:  [0.         0.16666667 0.26415094 0.         0.24752475 0.21176471]
AUC score:  [0.66855703 0.80766924 0.70186335 0.75936122 0.70551471 0.75704639]
2023-10-04 03:34:39 Eopch 2, Training Loss 0.5949021696141272, Val Loss 0.6551582938974554, Weighted Loss 0.5096649422195338


100%|██████████| 825/825 [23:23<00:00,  1.70s/it]
100%|██████████| 275/275 [03:27<00:00,  1.33it/s]


F1 score:  [0.12195122 0.09756098 0.22826087 0.         0.2259887  0.24691358]
AUC score:  [0.73033126 0.83665378 0.69953416 0.80685148 0.72055147 0.80477466]
2023-10-04 04:01:32 Eopch 3, Training Loss 0.534466760447531, Val Loss 0.6253911812468008, Weighted Loss 0.4648979999768082


100%|██████████| 825/825 [23:22<00:00,  1.70s/it]
100%|██████████| 275/275 [03:29<00:00,  1.31it/s]


F1 score:  [0.09411765 0.22222222 0.32128514 0.05405405 0.17567568 0.27184466]
AUC score:  [0.77277433 0.87276786 0.72154287 0.79352201 0.74448529 0.83097842]
2023-10-04 04:28:27 Eopch 4, Training Loss 0.4817624850887241, Val Loss 0.6041884283721447, Weighted Loss 0.4524494226446627


100%|██████████| 825/825 [23:24<00:00,  1.70s/it]
100%|██████████| 275/275 [03:25<00:00,  1.34it/s]


F1 score:  [0.24561404 0.2195122  0.25806452 0.09756098 0.2125     0.32214765]
AUC score:  [0.81911063 0.85097948 0.71837258 0.8256866  0.74838235 0.84969539]
2023-10-04 04:55:18 Eopch 5, Training Loss 0.42391622788978345, Val Loss 0.5938822773031213, Weighted Loss 0.4498474964993964


100%|██████████| 825/825 [23:22<00:00,  1.70s/it]
100%|██████████| 275/275 [03:26<00:00,  1.33it/s]


F1 score:  [0.27272727 0.27586207 0.28115016 0.18181818 0.20359281 0.33333333]
AUC score:  [0.85403727 0.89359009 0.70884015 0.83966    0.74708333 0.8726512 ]
2023-10-04 05:22:09 Eopch 6, Training Loss 0.37613292629068545, Val Loss 0.5718219070949337, Weighted Loss 0.4627629232469721


100%|██████████| 825/825 [23:23<00:00,  1.70s/it]
100%|██████████| 275/275 [03:20<00:00,  1.37it/s]


F1 score:  [0.21818182 0.25       0.31683168 0.125      0.21582734 0.37762238]
AUC score:  [0.86202628 0.83891924 0.69215839 0.85482469 0.7553799  0.86945831]
2023-10-04 05:48:55 Eopch 7, Training Loss 0.32305688708117514, Val Loss 0.6059444340860302, Weighted Loss 0.42844373209881886


100%|██████████| 825/825 [23:23<00:00,  1.70s/it]
100%|██████████| 275/275 [03:30<00:00,  1.31it/s]


F1 score:  [0.2875817  0.22641509 0.2832618  0.24657534 0.25870647 0.38216561]
AUC score:  [0.87197317 0.86777052 0.66971834 0.86886249 0.75666667 0.87762405]
2023-10-04 06:15:49 Eopch 8, Training Loss 0.2769767564625451, Val Loss 0.5904450258341702, Weighted Loss 0.44524665848377687


100%|██████████| 825/825 [23:23<00:00,  1.70s/it]
100%|██████████| 275/275 [03:19<00:00,  1.38it/s]


F1 score:  [0.14634146 0.24489796 0.28888889 0.21276596 0.09917355 0.41176471]
AUC score:  [0.86036097 0.85797575 0.69021739 0.86161821 0.72477941 0.86389827]
2023-10-04 06:42:34 Eopch 9, Training Loss 0.2344209292880965, Val Loss 0.668400447984988, Weighted Loss 0.40730745007206925


100%|██████████| 825/825 [23:25<00:00,  1.70s/it]
100%|██████████| 275/275 [03:25<00:00,  1.34it/s]


F1 score:  [0.24161074 0.23333333 0.29230769 0.30769231 0.25210084 0.44094488]
AUC score:  [0.87145558 0.84948028 0.68163389 0.88309347 0.71563725 0.87597255]
2023-10-04 07:09:26 Eopch 10, Training Loss 0.19626132901870844, Val Loss 0.622183856530623, Weighted Loss 0.4350695621170031


100%|██████████| 825/825 [23:25<00:00,  1.70s/it]
100%|██████████| 275/275 [03:29<00:00,  1.31it/s]


F1 score:  [0.23853211 0.2962963  0.36082474 0.35087719 0.19875776 0.40559441]
AUC score:  [0.85786299 0.87513326 0.69830487 0.91058952 0.72517157 0.86764166]
2023-10-04 07:36:23 Eopch 11, Training Loss 0.16275244907899336, Val Loss 0.6465592230179094, Weighted Loss 0.41231958510394034


100%|██████████| 825/825 [23:20<00:00,  1.70s/it]
100%|██████████| 275/275 [03:21<00:00,  1.37it/s]


F1 score:  [0.21238938 0.23529412 0.23170732 0.25       0.25       0.45112782]
AUC score:  [0.85709785 0.84538246 0.65774888 0.88409157 0.72708333 0.86118247]
2023-10-04 08:03:06 Eopch 12, Training Loss 0.13050101359459487, Val Loss 0.718915954540399, Weighted Loss 0.42110679979349075
Finish training: best_val:0.5718219070949337, best_weighted loss:0.40730745007206925


0,1
training loss,█▇▆▆▅▄▃▃▂▂▁▁
val loss,▇▅▄▃▂▁▃▂▆▃▅█
weighted loss,▇█▅▄▄▅▂▄▁▃▁▂

0,1
training loss,0.1305
val loss,0.71892
weighted loss,0.42111


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016669102727125087, max=1.0…

100%|██████████| 825/825 [23:30<00:00,  1.71s/it]
100%|██████████| 275/275 [03:22<00:00,  1.36it/s]


F1 score:  [0.06153846 0.09375    0.2506812  0.         0.13402062 0.19402985]
AUC score:  [0.58962276 0.67422595 0.69128844 0.75665116 0.60977564 0.64281213]
2023-10-04 08:30:14 Eopch 1, Training Loss 0.6990965938929355, Val Loss 0.710586852051995, Weighted Loss 0.6245343411550824


100%|██████████| 825/825 [23:26<00:00,  1.70s/it]
100%|██████████| 275/275 [03:22<00:00,  1.36it/s]


F1 score:  [0.12403101 0.11764706 0.26578073 0.         0.18539326 0.28865979]
AUC score:  [0.62722637 0.74056815 0.73959435 0.79981395 0.69014423 0.76496534]
2023-10-04 08:57:05 Eopch 2, Training Loss 0.5887359294024381, Val Loss 0.6326062835346569, Weighted Loss 0.5335022198361941


100%|██████████| 825/825 [23:26<00:00,  1.71s/it]
100%|██████████| 275/275 [03:31<00:00,  1.30it/s]


F1 score:  [0.10926366 0.12       0.27151052 0.1509434  0.16632017 0.26262626]
AUC score:  [0.66144663 0.77477477 0.75485386 0.8427907  0.68796474 0.79541363]
2023-10-04 09:24:05 Eopch 3, Training Loss 0.5351786560000795, Val Loss 0.6605800152908672, Weighted Loss 0.5944054473117178


100%|██████████| 825/825 [23:28<00:00,  1.71s/it]
100%|██████████| 275/275 [03:34<00:00,  1.28it/s]


F1 score:  [0.12396694 0.19354839 0.26993865 0.28125    0.20532319 0.3826087 ]
AUC score:  [0.65944078 0.77539609 0.75473753 0.85555349 0.70482372 0.83711405]
2023-10-04 09:51:10 Eopch 4, Training Loss 0.48761062506473424, Val Loss 0.5952232907576994, Weighted Loss 0.510958840450633


100%|██████████| 825/825 [23:28<00:00,  1.71s/it]
100%|██████████| 275/275 [03:24<00:00,  1.35it/s]


F1 score:  [0.12280702 0.10434783 0.26835443 0.18421053 0.19653179 0.30681818]
AUC score:  [0.66545832 0.73594284 0.73641132 0.84554419 0.69804487 0.84485552]
2023-10-04 10:18:04 Eopch 5, Training Loss 0.43194544439966026, Val Loss 0.6176102743907408, Weighted Loss 0.5459884753732874


100%|██████████| 825/825 [23:30<00:00,  1.71s/it]
100%|██████████| 275/275 [03:31<00:00,  1.30it/s]


F1 score:  [0.13293051 0.19834711 0.28045977 0.25316456 0.1728972  0.30687831]
AUC score:  [0.70900698 0.76490283 0.76851656 0.85748837 0.685      0.86634711]
2023-10-04 10:45:07 Eopch 6, Training Loss 0.3900458668759375, Val Loss 0.6256428903070363, Weighted Loss 0.5659652594665959


100%|██████████| 825/825 [23:30<00:00,  1.71s/it]
100%|██████████| 275/275 [03:26<00:00,  1.33it/s]


F1 score:  [0.17518248 0.22222222 0.23188406 0.24489796 0.21602787 0.41481481]
AUC score:  [0.68186761 0.8131925  0.76983842 0.82909767 0.68294872 0.89562517]
2023-10-04 11:12:06 Eopch 7, Training Loss 0.3439701747984597, Val Loss 0.5885968101701953, Weighted Loss 0.4620685963150945


100%|██████████| 825/825 [23:30<00:00,  1.71s/it]
100%|██████████| 275/275 [03:31<00:00,  1.30it/s]


F1 score:  [0.10880829 0.20512821 0.29223744 0.23529412 0.19101124 0.39705882]
AUC score:  [0.66308997 0.80252666 0.77192166 0.85469767 0.67022436 0.89065172]
2023-10-04 11:39:09 Eopch 8, Training Loss 0.2995577045462348, Val Loss 0.6322890168428421, Weighted Loss 0.5676339958422156


100%|██████████| 825/825 [23:29<00:00,  1.71s/it]
100%|██████████| 275/275 [03:30<00:00,  1.31it/s]


F1 score:  [0.16460905 0.18644068 0.29295775 0.20618557 0.19318182 0.37333333]
AUC score:  [0.70240943 0.76407442 0.75787826 0.8404093  0.6781891  0.88122243]
2023-10-04 12:06:11 Eopch 9, Training Loss 0.2649247878022266, Val Loss 0.6296614603427323, Weighted Loss 0.5313290667224118


100%|██████████| 825/825 [23:28<00:00,  1.71s/it]
100%|██████████| 275/275 [03:32<00:00,  1.30it/s]


F1 score:  [0.19428571 0.20634921 0.25974026 0.25531915 0.20183486 0.34319527]
AUC score:  [0.68749849 0.76479928 0.73745823 0.86634419 0.69830128 0.88196507]
2023-10-04 12:33:13 Eopch 10, Training Loss 0.2215394971406821, Val Loss 0.655096462206407, Weighted Loss 0.5378875117850449


100%|██████████| 825/825 [23:30<00:00,  1.71s/it]
100%|██████████| 275/275 [03:30<00:00,  1.31it/s]


F1 score:  [0.16923077 0.24324324 0.26480836 0.20689655 0.17431193 0.38888889]
AUC score:  [0.6487832  0.77736357 0.72651326 0.84316279 0.67099359 0.87649653]
2023-10-04 13:00:16 Eopch 11, Training Loss 0.19577766688258358, Val Loss 0.6729699832201004, Weighted Loss 0.4792422666227153


100%|██████████| 825/825 [23:32<00:00,  1.71s/it]
100%|██████████| 275/275 [03:31<00:00,  1.30it/s]


F1 score:  [0.17910448 0.21538462 0.24271845 0.27692308 0.21428571 0.39252336]
AUC score:  [0.65902994 0.74622899 0.71735544 0.86530233 0.67123397 0.85921325]
2023-10-04 13:27:21 Eopch 12, Training Loss 0.1645331087911671, Val Loss 0.7278805334120989, Weighted Loss 0.46648807831781924
Finish training: best_val:0.5885968101701953, best_weighted loss:0.4620685963150945


0,1
training loss,█▇▆▅▅▄▃▃▂▂▁▁
val loss,▇▃▅▁▂▃▁▃▃▄▅█
weighted loss,█▄▇▃▅▅▁▆▄▄▂▁

0,1
training loss,0.16453
val loss,0.72788
weighted loss,0.46649


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016668784680465856, max=1.0…

100%|██████████| 825/825 [23:30<00:00,  1.71s/it]
 79%|███████▉  | 217/275 [02:53<01:02,  1.07s/it]IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

100%|██████████| 825/825 [23:26<00:00,  1.71s/it]
100%|██████████| 275/275 [03:18<00:00,  1.38it/s]


F1 score:  [0.23834197 0.23333333 0.31549296 0.25       0.22594142 0.42384106]
AUC score:  [0.8238428  0.84633038 0.76290376 0.87858458 0.69013977 0.8670415 ]
2023-10-04 16:36:21 Eopch 7, Training Loss 0.3678479951078242, Val Loss 0.5402472373030403, Weighted Loss 0.4780455630618948


100%|██████████| 825/825 [23:27<00:00,  1.71s/it]
100%|██████████| 275/275 [03:30<00:00,  1.31it/s]


F1 score:  [0.20731707 0.22222222 0.3042394  0.23529412 0.24573379 0.44594595]
AUC score:  [0.80999108 0.85250894 0.74994702 0.8520784  0.70236508 0.8793498 ]
2023-10-04 17:03:20 Eopch 8, Training Loss 0.31134742896665224, Val Loss 0.5585354118320075, Weighted Loss 0.4978696536498193


100%|██████████| 825/825 [23:27<00:00,  1.71s/it]
100%|██████████| 275/275 [03:25<00:00,  1.34it/s]


F1 score:  [0.21505376 0.26315789 0.28882834 0.23376623 0.21404682 0.41428571]
AUC score:  [0.8145505  0.87055916 0.74668766 0.92166535 0.70696068 0.86320304]
2023-10-04 17:30:14 Eopch 9, Training Loss 0.28226625377481634, Val Loss 0.5535002233223482, Weighted Loss 0.48610762346639386


 18%|█▊        | 152/825 [04:28<18:43,  1.67s/it]IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

100%|██████████| 825/825 [23:30<00:00,  1.71s/it]
100%|██████████| 275/275 [03:31<00:00,  1.30it/s]


F1 score:  [0.12060302 0.12903226 0.24608501 0.15384615 0.22429907 0.33333333]
AUC score:  [0.74426642 0.76066715 0.75726027 0.89424327 0.74428246 0.85846829]
2023-10-04 21:06:00 Eopch 5, Training Loss 0.49310236437754196, Val Loss 0.5829471138390627, Weighted Loss 0.5214706214240067


100%|██████████| 825/825 [23:27<00:00,  1.71s/it]
100%|██████████| 275/275 [03:30<00:00,  1.30it/s]


F1 score:  [0.14814815 0.16666667 0.24561404 0.13953488 0.19123506 0.41791045]
AUC score:  [0.76882396 0.75860254 0.74020874 0.88012999 0.74759081 0.87516869]
2023-10-04 21:33:01 Eopch 6, Training Loss 0.44810426432978023, Val Loss 0.5516063197092577, Weighted Loss 0.47012716699883794


100%|██████████| 275/275 [03:24<00:00,  1.35it/s]


F1 score:  [0.13483146 0.15517241 0.25507246 0.20833333 0.21774194 0.43209877]
AUC score:  [0.74397843 0.77467412 0.7569863  0.8641597  0.74551794 0.90182186]
2023-10-04 21:59:56 Eopch 7, Training Loss 0.38824111674771167, Val Loss 0.5482682138545947, Weighted Loss 0.4795232641377206


100%|██████████| 825/825 [23:30<00:00,  1.71s/it]
100%|██████████| 275/275 [03:29<00:00,  1.31it/s]


F1 score:  [0.06       0.20408163 0.25       0.17857143 0.23931624 0.47863248]
AUC score:  [0.75442455 0.75831916 0.74131768 0.84545032 0.73010186 0.89964575]
2023-10-04 22:26:57 Eopch 8, Training Loss 0.3517447210983797, Val Loss 0.5594747824966908, Weighted Loss 0.4538200726892968


100%|██████████| 825/825 [23:29<00:00,  1.71s/it]
100%|██████████| 275/275 [03:31<00:00,  1.30it/s]


F1 score:  [0.04651163 0.25       0.24561404 0.20408163 0.1981982  0.51094891]
AUC score:  [0.74672741 0.78139422 0.75191129 0.82237697 0.72791917 0.91737517]
2023-10-04 22:54:01 Eopch 9, Training Loss 0.3025226848730535, Val Loss 0.556865749176253, Weighted Loss 0.43898134647158477


100%|██████████| 825/825 [23:31<00:00,  1.71s/it]
100%|██████████| 275/275 [03:33<00:00,  1.29it/s]


F1 score:  [0.11111111 0.24489796 0.2392638  0.17142857 0.19762846 0.53658537]
AUC score:  [0.7445544  0.78657599 0.7456621  0.84730734 0.71014194 0.90873819]
2023-10-04 23:21:07 Eopch 10, Training Loss 0.26427097617225215, Val Loss 0.5686047594452446, Weighted Loss 0.44220225701152294


100%|██████████| 825/825 [23:29<00:00,  1.71s/it]
100%|██████████| 275/275 [03:33<00:00,  1.29it/s]


F1 score:  [0.12       0.26923077 0.22764228 0.24390244 0.21296296 0.53125   ]
AUC score:  [0.75586449 0.77406688 0.74734508 0.83282266 0.70503528 0.91192645]
2023-10-04 23:48:12 Eopch 11, Training Loss 0.2219030103245468, Val Loss 0.5730201077461242, Weighted Loss 0.4381416539002072


100%|██████████| 825/825 [23:29<00:00,  1.71s/it]
100%|██████████| 275/275 [03:22<00:00,  1.36it/s]


F1 score:  [0.14611872 0.18556701 0.26086957 0.25925926 0.21587302 0.368     ]
AUC score:  [0.74300974 0.76864222 0.75328115 0.81318477 0.69402575 0.92950405]
2023-10-05 00:15:05 Eopch 12, Training Loss 0.18955247396095232, Val Loss 0.5887762835621834, Weighted Loss 0.5270731996462156
Finish training: best_val:0.5482682138545947, best_weighted loss:0.4381416539002072


0,1
training loss,█▇▆▆▅▄▄▃▂▂▁▁
val loss,█▆▄▃▃▁▁▂▂▂▃▄
weighted loss,██▅▆█▄▄▂▁▁▁█

0,1
training loss,0.18955
val loss,0.58878
weighted loss,0.52707


In [16]:
# net = KLSVolNet(ch_out = 6, slices = 10).to(device)

# optimizer = optim.AdamW(net.parameters(), lr=1e-5)
# TrainClassifer(model=net,trn_dl=train_dl,val_dl=val_dl,optimizer=optimizer, 
#                scheduler=None, n_eopchs=15, device=device)

In [17]:
# out -> rescale by threshold -> transform -> average -> post processing
# 5/6 = 0.833*4=3.53
# 0.2' 0.1 0.2 -> 0.25 0.5
# 0.8; 0.8 0.9 -> 0.5 0.75
# 0.5; 0.45 0.55 -> 0.375 0.625