In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import os
import random
import re

from tqdm import tqdm
import time

import pydicom as dicom
import nibabel as nib
import SimpleITK as sitk
import monai

import torch
import torch.nn as nn
import torch.optim as optim

from monai.networks.nets import EfficientNetBN
from monai.networks.nets import ResNet
#from efficientnet_pytorch import EfficientNet
import timm

import wandb


In [2]:
SEED = 344
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True # Fix the network according to random seed
    print('Finish seeding with seed {}'.format(seed))
    
seed_everything(SEED)
print('Training on device {}'.format(device))

Finish seeding with seed 344
Training on device cuda


In [3]:
dicom_tag_columns = [
    'Columns',
    'ImageOrientationPatient',
    'ImagePositionPatient',
    'InstanceNumber',
    'PatientID',
    'PatientPosition',
    'PixelSpacing',
    'RescaleIntercept',
    'RescaleSlope',
    'Rows',
    'SeriesNumber',
    'SliceThickness',
    'path',
    'WindowCenter',
    'WindowWidth'
]

train_dicom_tags = pd.read_parquet('autodl-tmp/train_dicom_tags.parquet', columns=dicom_tag_columns)
test_dicom_tags = pd.read_parquet('autodl-tmp/test_dicom_tags.parquet', columns=dicom_tag_columns)

train_series_meta = pd.read_csv('autodl-tmp/train_series_meta.csv')
test_series_meta = pd.read_csv('autodl-tmp/test_series_meta.csv')

train_csv = pd.read_csv('autodl-tmp/train.csv')

train_csv

Unnamed: 0,patient_id,bowel_healthy,bowel_injury,extravasation_healthy,extravasation_injury,kidney_healthy,kidney_low,kidney_high,liver_healthy,liver_low,liver_high,spleen_healthy,spleen_low,spleen_high,any_injury
0,10004,1,0,0,1,0,1,0,1,0,0,0,0,1,1
1,10005,1,0,1,0,1,0,0,1,0,0,1,0,0,0
2,10007,1,0,1,0,1,0,0,1,0,0,1,0,0,0
3,10026,1,0,1,0,1,0,0,1,0,0,1,0,0,0
4,10051,1,0,1,0,1,0,0,1,0,0,0,1,0,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3142,9951,1,0,1,0,1,0,0,1,0,0,1,0,0,0
3143,9960,1,0,1,0,1,0,0,1,0,0,1,0,0,0
3144,9961,1,0,1,0,1,0,0,1,0,0,1,0,0,0
3145,9980,1,0,1,0,1,0,0,1,0,0,0,0,1,1


In [4]:
injury_series_meta = train_series_meta.loc[train_series_meta.patient_id.isin(train_csv.loc[train_csv.any_injury == 1, "patient_id"].values)]
healthy_series_meta = train_series_meta.loc[train_series_meta.patient_id.isin(train_csv.loc[train_csv.any_injury == 0, "patient_id"].values)]

In [5]:
def raw_path_gen(patient_id, series_id, train=True):
    if(train):
        path = 'autodl-tmp/train_images_resample/'
    else:
        path = 'autodl-tmp/train_images_resample/'
    
    path += str(patient_id) + '/' + str(series_id)
    
    return path

def create_3D_scans(folder, downsample_rate=1): 
    filenames = os.listdir(folder)
    filenames = [int(filename.split('.')[0]) for filename in filenames]
    filenames = sorted(filenames)
    filenames = [str(filename) + '.dcm' for filename in filenames]
        
    volume = []
    #for filename in tqdm(filenames[::downsample_rate], position=0): 
    for filename in filenames[::downsample_rate]: 
        filepath = os.path.join(folder, filename)
        ds = dicom.dcmread(filepath)
        image = ds.pixel_array
        
        if ds.PixelRepresentation == 1:
            bit_shift = ds.BitsAllocated - ds.BitsStored
            dtype = image.dtype 
            image = (image << bit_shift).astype(dtype) >>  bit_shift
        
        # find rescale params
        if ("RescaleIntercept" in ds) and ("RescaleSlope" in ds):
            intercept = float(ds.RescaleIntercept)
            slope = float(ds.RescaleSlope)
    
        # find clipping params
        center = int(ds.WindowCenter)
        width = int(ds.WindowWidth)
        low = center - width / 2
        high = center + width / 2    
        
        
        image = (image * slope) + intercept
        image = np.clip(image, low, high)

        image = (image / np.max(image) * 255).astype(np.int16)
        image = image[::downsample_rate, ::downsample_rate]
        volume.append( image )
    
    volume = np.stack(volume, axis=0)
    return volume

def plot_image_with_seg(volume, volume_seg=[], orientation='Coronal', num_subplots=20):
    # simply copy
    if len(volume_seg) == 0:
        plot_mask = 0
    else:
        plot_mask = 1
        
    if orientation == 'Coronal':
        slices = np.linspace(0, volume.shape[2]-1, num_subplots).astype(np.int16)
        volume = volume.transpose([1, 0, 2])
        if plot_mask:
            volume_seg = volume_seg.transpose([1, 0, 2])
        
    elif orientation == 'Sagittal':
        slices = np.linspace(0, volume.shape[2]-1, num_subplots).astype(np.int16)
        volume = volume.transpose([2, 0, 1])
        if plot_mask:
            volume_seg = volume_seg.transpose([2, 0, 1])

    elif orientation == 'Axial':
        slices = np.linspace(0, volume.shape[0]-1, num_subplots).astype(np.int16)
           
    rows = np.max( [np.floor(np.sqrt(num_subplots)).astype(int) - 2, 1])
    cols = np.ceil(num_subplots/rows).astype(int)
    
    fig, ax = plt.subplots(rows, cols, figsize=(cols * 2, rows * 4))
    fig.tight_layout(h_pad=0.01, w_pad=0)
    
    ax = ax.ravel()
    for this_ax in ax:
        this_ax.axis('off')

    for counter, this_slice in enumerate( slices ):
        plt.sca(ax[counter])
        
        image = volume[this_slice, :, :]
        plt.imshow(image, cmap='gray')
        
        if plot_mask:
            mask = np.where(volume_seg[this_slice, :, :], volume_seg[this_slice, :, :], np.nan)
            plt.imshow(mask, cmap='Set1', alpha=0.5)
            
def load_nii(patient_id, series_id, root='autodl-tmp/train_images_resample/'):
    path = root + str(patient_id) + '/' + str(series_id) + '.nii.gz'
    img = sitk.ReadImage(path)
    img = sitk.GetArrayFromImage(img)
    
    # img = nib.load(path)
    # img = img.get_fdata().transpose(2, 1, 0)
    
    return img

In [6]:
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

ds = 2

class CTDataset(Dataset):
    def __init__(self, root='autodl-tmp/train_images_resample/', augmentation=False, meta=train_series_meta, device='cpu'):
        self.device = device
        self.series_meta = meta
        self.root = root
        self.t = monai.transforms.Compose([monai.transforms.RandZoom(prob=0.5, min_zoom=0.9, max_zoom=1.1),
                                           monai.transforms.RandRotate(range_x=3.14 / 24, prob=0.5),
                                           monai.transforms.SpatialPad(spatial_size=(320//ds, 280, 280), mode="edge"),
                                           monai.transforms.RandSpatialCrop(roi_size=(320//ds, 256, 256), random_size=False),
                                           monai.transforms.NormalizeIntensity(divisor = 400)
                                ])
        self.t_val = monai.transforms.Compose([monai.transforms.NormalizeIntensity(divisor = 400)])
        
        self.aug = augmentation
        
    def __len__(self):
        #return 1100
        return len(self.series_meta)
    
    def __getitem__(self, idx):

        patient_id, series_id = self.series_meta.loc[idx, ["patient_id", "series_id"]].astype('int')
        img_a = load_nii(patient_id, series_id, self.root).astype('float32')
        #img_t = torch.from_numpy(img_a).unsqueeze(0)
        #img_t = torch.from_numpy(img_a[::2, ::2, ::2]).unsqueeze(0)
        if(self.aug):
            img_t = self.t(np.expand_dims(img_a[::ds, :, :], 0))
        else:
            #img_t = torch.from_numpy(img_a[::ds, ::ds, ::ds]).unsqueeze(0)
            img_t = self.t_val(np.expand_dims(img_a[::ds, :, :], 0))
        label_columns = [
            'bowel_injury',
            'extravasation_injury',
            'kidney_low',
            'kidney_high',
            'liver_low',
            'liver_high',
            'spleen_low',
            'spleen_high',
            #'any_injury'
        ]
        label_a = train_csv.loc[train_csv.patient_id == patient_id, label_columns].values[0].astype('float32')
        label_t = torch.from_numpy(label_a)
        return img_t, label_t

In [7]:
train_meta = train_series_meta[-3600:].reset_index()
val_meta = train_series_meta[0:-3600].reset_index()

# train_meta = injury_series_meta[-800:].reset_index()
# val_meta = injury_series_meta[0:-800].reset_index()

train_ds = CTDataset(meta = train_meta, augmentation=True)
train_dl = DataLoader(train_ds, batch_size=4, shuffle=True, num_workers=8)

val_ds = CTDataset(meta = val_meta, augmentation=False)
val_dl = DataLoader(val_ds, batch_size=4, shuffle=False, num_workers=8)

In [8]:
# class EffNet(nn.Module):
#     def __init__(self, ch_out=9):
#         super(EffNet, self).__init__()
#         #self.conv_in = nn.Conv3d(1, 3, kernel_size=5, padding=2, stride=2)
#         self.net = EfficientNetBN("efficientnet-b0", pretrained=False, progress=False, spatial_dims=3, in_channels=1, num_classes=ch_out,)
#     def forward(self, x):
#         #x = self.conv_in(x)
#         #return torch.sigmoid(self.net(x))
#         return self.net(x)

class FullVolNet(nn.Module):
    def __init__(self, backbone = "tf_efficientnetv2_s.in21k_ft_in1k", 
                 ch_in = 3, ch_out = 9, slices = 15, dropout = 0.0, pretrained=True):
        super(FullVolNet, self).__init__()
        self.slices = slices
        
        self.encoder = timm.create_model(
            backbone,
            in_chans=ch_in,
            num_classes=ch_out,
            features_only=False,
            drop_rate=0.0,
            drop_path_rate=0.0,
            pretrained=False,
        )
        
        if 'efficient' in backbone and pretrained:
            self.encoder.load_state_dict(torch.load('pretrained/tf_efficientnetv2_s.in21k_ft_in1k_8class.pt'))
        elif 'convnext' in backbone and pretrained:
            self.encoder.load_state_dict(torch.load('pretrained/convnextv2_nano.fcmae_ft_in22k_in1k_384_8class.pt'))
        
        if 'efficient' in backbone:
            hdim = self.encoder.conv_head.out_channels
            self.encoder.classifier = nn.Identity()
        elif 'convnext' in backbone:
            hdim = self.encoder.head.fc.in_features
            self.encoder.head.fc = nn.Identity()
        
        self.lstm = nn.LSTM(hdim, 256, num_layers=2, dropout=0.0, bidirectional=True, batch_first=True)
        
        self.head = nn.Sequential(
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.Dropout(dropout),
            nn.LeakyReLU(0.1),
            nn.Linear(256, ch_out),
        )
        
        self.head2 = nn.Conv1d(slices, 1, 1)
        
        
    def slicer(self, img, slices):
        #img = img.squeeze(1)
        z_length = img.shape[-3]
        z_slices = (np.linspace(0, z_length, slices + 4)).astype('int')
        z_slices = z_slices[2:-2]
        #print(z_slices)
        slices_list = []
        for z in z_slices:
            slices_list.append(img[:, :, z-1:z+2, :, :])
        img_slice = torch.cat(slices_list, 1)
        return img_slice
        
    def forward(self, x):  # (bs, nslice, ch, sz, sz)
        x = self.slicer(x, self.slices)
        bs, nslice,ch, sz1, sz2 = x.shape
        x = x.view(bs*nslice, ch, sz1, sz2)
        
        feature_2d = self.encoder(x)
        feature_2d = feature_2d.view(bs, nslice, -1)
        
        feature_lstm, _ = self.lstm(feature_2d)
        feature_lstm = feature_lstm.contiguous().view(bs * nslice, -1)
        
        preds = self.head(feature_lstm)
        preds = preds.view(bs, nslice, -1).contiguous()
        preds = self.head2(preds)
        
        return preds.squeeze(1)
        
        
        # bs = x.shape[0]
        # x = x.view(bs * n_slice_per_c, in_chans, image_size, image_size)
        # feat = self.encoder(x)
        # feat = feat.view(bs, n_slice_per_c, -1)
        # feat, _ = self.lstm(feat)
        # feat = feat.contiguous().view(bs * n_slice_per_c, -1)
        # feat = self.head(feat)
        # feat = feat.view(bs, n_slice_per_c).contiguous()
        

In [9]:
# avail_pretrained_models = timm.list_models(pretrained=True)
# avail_pretrained_models

In [10]:
# imgs, labels = next(iter(train_dl))
# net = FullVolNet(ch_out = 8)
# img_s = net(imgs)
# img_s.shape

In [11]:
# img_a = imgs[0, 0].numpy()
# print(img_a.shape)
# plot_image_with_seg(img_a, orientation='Axial', num_subplots=7)

In [12]:
import sklearn.metrics

def transform_9class(label_in):
    label_out = [1 - label_in[0],
                   label_in[0],
                    1- label_in[1],
                   label_in[1],
                   (1 - label_in[2]) * (1 - label_in[3]),
                   label_in[2],
                   label_in[3],
                   (1 - label_in[4]) * (1 - label_in[5]),
                    label_in[4],
                   label_in[5],
                   (1 - label_in[6]) * (1 - label_in[7]),
                   label_in[6],
                   label_in[7]]
    return label_out

def transform_13class(label_in):
    label_out = label_in
    return label_out.tolist()


def loss_metrics(metrics, transform):
    preds = [transform(x) for x in metrics["predict"]]
    targets = [transform(x) for x in metrics["label"]]
    targets_any_injury = metrics["label"][:, -1]
    
    loss_list = []
    
    print("F1 score: ", sklearn.metrics.f1_score(metrics["label"], np.around(metrics["predict"]), average=None, zero_division=0.0))
    print("AUC score: ", sklearn.metrics.roc_auc_score(metrics["label"], metrics["predict"], average=None))
    
    for i in range(0, len(preds)):
        predict = preds[i]
        target = targets[i]
        
        label_pred = np.zeros(14)
        label_pred[0] = predict[0] / (predict[0] + predict[1])
        label_pred[1] = predict[1] / (predict[0] + predict[1])
        label_pred[2] = predict[2] / (predict[2] + predict[3])
        label_pred[3] = predict[3] / (predict[2] + predict[3])
        label_pred[4] = predict[4] / (predict[4] + predict[5] + predict[6])
        label_pred[5] = predict[5] / (predict[4] + predict[5] + predict[6])
        label_pred[6] = predict[6] / (predict[4] + predict[5] + predict[6])
        label_pred[7] = predict[7] / (predict[7] + predict[8] + predict[9])
        label_pred[8] = predict[8] / (predict[7] + predict[8] + predict[9])
        label_pred[9] = predict[9] / (predict[7] + predict[8] + predict[9])
        label_pred[10] = predict[10] / (predict[10] + predict[11] + predict[12])
        label_pred[11] = predict[11] / (predict[10] + predict[11] + predict[12])
        label_pred[12] = predict[12] / (predict[10] + predict[11] + predict[12])
        label_pred[13] = max([1 - label_pred[x] for x in [0, 2, 4, 7, 10]])
        
        targets_any_injury = max([1 - target[x] for x in [0, 2, 4, 7, 10]])
        
        target.append(targets_any_injury)
        label_target = np.array(target)
        
        weight = np.array([1, 2, 1, 6, 1, 2, 4, 1, 2, 4, 1, 2, 4, 6])
        
        loss_list.append(sklearn.metrics.log_loss(
            y_true=label_target,
            y_pred=label_pred,
            sample_weight=weight))
    #print("Weighted Loss: " + np.mean(loss_list))
    
    return np.mean(loss_list)
        
    
    
    #print(np.array(preds).shape)
    

In [13]:
import copy

import torch.cuda.amp as amp

scaler = amp.GradScaler()

def TrainClassifer(model,trn_dl,val_dl,optimizer, project, name, suffix, scheduler=None,
                   n_eopchs=20, device='cpu'):
 
    #loss_fn = nn.BCELoss(weight=torch.Tensor([1, 6, 1, 6, 1, 4, 8, 1, 4, 8, 1, 4, 8]).to(device))
    #loss_fn = nn.BCELoss()
    loss_fn = nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([10, 5, 10, 10, 6, 6, 6, 6]).to(device))
    model.to(device)
    best_model = copy.deepcopy(model)
    bestweight_model = copy.deepcopy(model)
    best_val = 999.0
    best_weightloss = 999.0
    metrics = {'predict': [], 'label' : []}
    PATH_MODEL = project + '/' + name + '/' + suffix + '.pt'
    wandb.init(name=name, 
               project=project)

    for epoch in range(1, n_eopchs + 1):
        loss_train = 0.0
        model.train()
        for imgs, labels in tqdm(trn_dl, position=0):
            imgs = imgs.to(device)
            labels = labels.to(device)
            # outputs = model(imgs)
            # #outputs = model(imgs.unsqueeze(1))
            # loss = loss_fn(outputs, labels)
            
            with amp.autocast():
                outputs = model(imgs)
                #outputs = model(imgs.unsqueeze(1))
                loss = loss_fn(outputs, labels)
            
            optimizer.zero_grad()
            
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            
            # loss.backward()
            # optimizer.step()

            loss_train += loss.item()
        torch.cuda.empty_cache()
        
        loss_val = 0.0
        correct_val = 0.0
        model.eval()
        
        for imgs, labels in tqdm(val_dl):
            imgs = imgs.to(device)
            labels = labels.to(device)
            with torch.no_grad():
                outputs = model(imgs)
                #outputs = model(imgs.unsqueeze(1))
                loss = loss_fn(outputs, labels)
                loss_val += loss.item()
                outputs = torch.sigmoid(outputs)
                metrics['predict'].extend((outputs.to('cpu').detach().numpy()).tolist())
                metrics['label'].extend((labels.to('cpu').detach().numpy()).tolist())
        
        metrics['predict'] = np.array(metrics['predict'])
        #metrics['predict'] = np.array([[0.5, 0.5, 0.33333, 0.33333, 0.33333, 0.33333, 0.33333, 0.33333, 0.33333]]*len(metrics['label']))
        metrics['label'] = np.array(metrics['label'])
        weighted_loss = loss_metrics(metrics, transform_9class)
        metrics = {'predict': [], 'label' : []}
        
        torch.cuda.empty_cache()
        
        if (weighted_loss) < best_weightloss:
            best_weightloss = weighted_loss
            
        if loss_val / len(val_dl) < best_val:
            best_val = loss_val / len(val_dl)
            torch.save(model.state_dict(), 'model_tmp.pt')
            best_model.load_state_dict(torch.load('model_tmp.pt'))
            
            
        if scheduler != None:
            scheduler.step()

        print('{} Eopch {}, Training Loss {}, Val Loss {}, Weighted Loss {}'.format(time.strftime("%Y-%m-%d %H:%M:%S",time.localtime()),
                                                                  epoch, loss_train / len(trn_dl), loss_val / len(val_dl),
                                                                                     weighted_loss))
        
        
        
        wandb.log({'training loss': loss_train / len(trn_dl),
                  'val loss': loss_val / len(val_dl),
                  'weighted loss': weighted_loss})
    torch.save(best_model.state_dict(), PATH_MODEL)
    print('Finish training: best_val:{}, best_weighted loss:{}'.format(best_val, best_weightloss))
    wandb.finish()

In [14]:
import os

project = "FullVol2.5Dx2-nophase-4fold"
name = "FullVol2.5Dx2-nophase_Convnetv2_4fold"
# backbone = "tf_efficientnetv2_s.in21k_ft_in1k"
backbone = 'convnextv2_nano.fcmae_ft_in22k_in1k_384'
n_eopchs = 12

p = project + '/' + name
if not os.path.exists(p):
    os.mkdir(p)

In [15]:
# 4711 // 4 = 1178

train_df_fold1 = train_series_meta[1178:].reset_index(drop=True)
train_df_fold2 = pd.concat([train_series_meta[0:1178], train_series_meta[2356:]], axis=0).reset_index(drop=True)
train_df_fold3 = pd.concat([train_series_meta[0:2356], train_series_meta[3534:]], axis=0).reset_index(drop=True)
train_df_fold4 = train_series_meta[:3534].reset_index()

val_df_fold1 = train_series_meta[0:1178].reset_index(drop=True)
val_df_fold2 = train_series_meta[1178:2356].reset_index(drop=True)
val_df_fold3 = train_series_meta[2356:3534].reset_index(drop=True)
val_df_fold4 = train_series_meta[3534:].reset_index(drop=True)

train_ds_fold1 = CTDataset(meta = train_df_fold1, augmentation=True)
train_dl_fold1 = DataLoader(train_ds_fold1, batch_size=4, shuffle=True, num_workers=8)
val_ds_fold1 = CTDataset(meta = val_df_fold1, augmentation=False)
val_dl_fold1 = DataLoader(val_ds_fold1, batch_size=4, shuffle=False, num_workers=8)

train_ds_fold2 = CTDataset(meta = train_df_fold2, augmentation=True)
train_dl_fold2 = DataLoader(train_ds_fold2, batch_size=4, shuffle=True, num_workers=8)
val_ds_fold2 = CTDataset(meta = val_df_fold2, augmentation=False)
val_dl_fold2 = DataLoader(val_ds_fold2, batch_size=4, shuffle=False, num_workers=8)

train_ds_fold3 = CTDataset(meta = train_df_fold3, augmentation=True)
train_dl_fold3 = DataLoader(train_ds_fold3, batch_size=4, shuffle=True, num_workers=8)
val_ds_fold3 = CTDataset(meta = val_df_fold3, augmentation=False)
val_dl_fold3 = DataLoader(val_ds_fold3, batch_size=4, shuffle=False, num_workers=8)

train_ds_fold4 = CTDataset(meta = train_df_fold4, augmentation=True)
train_dl_fold4 = DataLoader(train_ds_fold4, batch_size=4, shuffle=True, num_workers=8)
val_ds_fold4 = CTDataset(meta = val_df_fold4, augmentation=False)
val_dl_fold4 = DataLoader(val_ds_fold4, batch_size=4, shuffle=False, num_workers=8)

In [16]:
net = FullVolNet(backbone=backbone, ch_out = 8).to(device)
optimizer = optim.AdamW(net.parameters(), lr=1e-5)
TrainClassifer(model=net,trn_dl=train_dl_fold1,val_dl=val_dl_fold1,optimizer=optimizer, 
               project=project, name=name, suffix='fold1', scheduler=None, n_eopchs=n_eopchs, device=device)

net = FullVolNet(backbone=backbone, ch_out = 8).to(device)
optimizer = optim.AdamW(net.parameters(), lr=1e-5)
TrainClassifer(model=net,trn_dl=train_dl_fold2,val_dl=val_dl_fold2,optimizer=optimizer, 
               project=project, name=name, suffix='fold2', scheduler=None, n_eopchs=n_eopchs, device=device)

net = FullVolNet(backbone=backbone, ch_out = 8).to(device)
optimizer = optim.AdamW(net.parameters(), lr=1e-5)
TrainClassifer(model=net,trn_dl=train_dl_fold3,val_dl=val_dl_fold3,optimizer=optimizer, 
               project=project, name=name, suffix='fold3', scheduler=None, n_eopchs=n_eopchs, device=device)

net = FullVolNet(backbone=backbone, ch_out = 8).to(device)
optimizer = optim.AdamW(net.parameters(), lr=1e-5)
TrainClassifer(model=net,trn_dl=train_dl_fold4,val_dl=val_dl_fold4,optimizer=optimizer, 
               project=project, name=name, suffix='fold4', scheduler=None, n_eopchs=n_eopchs, device=device)

[34m[1mwandb[0m: Currently logged in as: [33mnorthm[0m ([33mrsna2023[0m). Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016669607938577732, max=1.0…

100%|██████████| 884/884 [28:55<00:00,  1.96s/it]
100%|██████████| 295/295 [03:18<00:00,  1.49it/s]


F1 score:  [0.09756098 0.         0.03389831 0.         0.15767635 0.
 0.         0.        ]
AUC score:  [0.61894901 0.59243663 0.56446919 0.6528882  0.63544202 0.6637856
 0.54508369 0.75931723]
2023-10-01 12:30:13 Eopch 1, Training Loss 0.6875071620138792, Val Loss 0.6894081056623136, Weighted Loss 0.39113226187961003


100%|██████████| 884/884 [28:47<00:00,  1.95s/it]
100%|██████████| 295/295 [03:16<00:00,  1.50it/s]


F1 score:  [0.         0.02469136 0.10309278 0.         0.22413793 0.
 0.16149068 0.19117647]
AUC score:  [0.70550104 0.67343737 0.62347029 0.75618012 0.66075214 0.77344618
 0.63825413 0.79752534]
2023-10-01 13:02:19 Eopch 2, Training Loss 0.6213568830483369, Val Loss 0.6552936989372059, Weighted Loss 0.3833374996501636


100%|██████████| 884/884 [28:54<00:00,  1.96s/it]
 64%|██████▍   | 190/295 [01:58<01:45,  1.00s/it]IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

100%|██████████| 884/884 [28:52<00:00,  1.96s/it]
100%|██████████| 295/295 [03:00<00:00,  1.63it/s]


F1 score:  [0.12605042 0.23364486 0.12200436 0.15942029 0.2181146  0.1682243
 0.22545455 0.32098765]
AUC score:  [0.76912458 0.74027154 0.70137922 0.76754658 0.69438447 0.84925422
 0.67137398 0.85201252]
2023-10-01 14:06:08 Eopch 4, Training Loss 0.5234435155515758, Val Loss 0.6607550944185863, Weighted Loss 0.4897854752585575


100%|██████████| 884/884 [28:51<00:00,  1.96s/it]
100%|██████████| 295/295 [03:06<00:00,  1.58it/s]


F1 score:  [0.22222222 0.18181818 0.20155039 0.12       0.22916667 0.11111111
 0.13636364 0.33333333]
AUC score:  [0.79802527 0.74279581 0.70680212 0.76568323 0.7003219  0.84175145
 0.67592702 0.88176804]
2023-10-01 14:38:07 Eopch 5, Training Loss 0.48556525600711686, Val Loss 0.6273895484537392, Weighted Loss 0.3019396390058965


 83%|████████▎ | 736/884 [24:05<04:49,  1.95s/it]IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

100%|██████████| 295/295 [03:20<00:00,  1.47it/s]


F1 score:  [0.0952381  0.22222222 0.09756098 0.11363636 0.16949153 0.22222222
 0.17073171 0.37681159]
AUC score:  [0.82953693 0.7354943  0.75857269 0.63468944 0.66207702 0.83781999
 0.67295625 0.86079308]
2023-10-01 18:23:50 Eopch 12, Training Loss 0.19635835481698022, Val Loss 0.82998504038837, Weighted Loss 0.2799061618244913
Finish training: best_val:0.6273895484537392, best_weighted loss:0.2752880620599474


0,1
training loss,█▇▆▆▅▄▄▃▃▂▂▁
val loss,▃▂▁▂▁▂▂▃▄▄▄█
weighted loss,▅▅▅█▂▁▆▂▂▁▅▁

0,1
training loss,0.19636
val loss,0.82999
weighted loss,0.27991


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016668826807290316, max=1.0…

100%|██████████| 884/884 [28:57<00:00,  1.97s/it]
100%|██████████| 295/295 [03:55<00:00,  1.25it/s]


F1 score:  [0. 0. 0. 0. 0. 0. 0. 0.]
AUC score:  [0.83890317 0.57422587 0.52912057 0.65942009 0.72120945 0.67341079
 0.64289265 0.66251514]
2023-10-01 18:56:58 Eopch 1, Training Loss 0.70589496569531, Val Loss 0.6425037007210619, Weighted Loss 0.38949963138183225


100%|██████████| 884/884 [28:59<00:00,  1.56s/it]IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

100%|██████████| 884/884 [29:04<00:00,  1.97s/it]
100%|██████████| 295/295 [03:42<00:00,  1.33it/s]


F1 score:  [0.13333333 0.1443299  0.11891892 0.06315789 0.27439024 0.09090909
 0.18937644 0.18978102]
AUC score:  [0.87847628 0.70007969 0.65155523 0.82259471 0.7518685  0.85950855
 0.68680898 0.80094358]
2023-10-01 20:02:43 Eopch 3, Training Loss 0.6129875259128361, Val Loss 0.605844645833565, Weighted Loss 0.4488674892161398


 44%|████▎     | 385/884 [12:47<16:20,  1.96s/it]IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

100%|██████████| 884/884 [29:06<00:00,  1.98s/it]
100%|██████████| 295/295 [04:02<00:00,  1.21it/s]


F1 score:  [0.10447761 0.21656051 0.13636364 0.21052632 0.3286119  0.22222222
 0.22222222 0.33175355]
AUC score:  [0.90963621 0.70299408 0.74094116 0.87813933 0.78249078 0.86411592
 0.73724089 0.88743877]
2023-10-01 22:47:46 Eopch 8, Training Loss 0.41741179917597665, Val Loss 0.5490297484448401, Weighted Loss 0.36793966694702956


100%|██████████| 884/884 [29:05<00:00,  1.97s/it]
100%|██████████| 295/295 [04:01<00:00,  1.22it/s]


F1 score:  [0.21052632 0.14864865 0.15686275 0.17721519 0.30452675 0.32142857
 0.224      0.37398374]
AUC score:  [0.88509776 0.70217441 0.72418631 0.86143375 0.77860209 0.88304621
 0.73877238 0.89889915]
2023-10-01 23:20:55 Eopch 9, Training Loss 0.3830842376008158, Val Loss 0.5448693524351564, Weighted Loss 0.2878853287837735


100%|██████████| 884/884 [28:58<00:00,  1.97s/it]
100%|██████████| 295/295 [03:57<00:00,  1.24it/s]


F1 score:  [0.07692308 0.1884058  0.20382166 0.16666667 0.31690141 0.38297872
 0.2278481  0.50847458]
AUC score:  [0.87652878 0.6801571  0.7539883  0.86835222 0.76908722 0.87059295
 0.76666174 0.9056597 ]
2023-10-01 23:53:52 Eopch 10, Training Loss 0.3456476076223732, Val Loss 0.5564768998299615, Weighted Loss 0.28158065368187274


100%|██████████| 884/884 [29:01<00:00,  1.97s/it]
100%|██████████| 295/295 [04:04<00:00,  1.21it/s]


F1 score:  [0.0952381  0.23963134 0.21818182 0.26530612 0.25498008 0.36065574
 0.19103314 0.29304029]
AUC score:  [0.90356002 0.71762295 0.72977794 0.86283995 0.74258345 0.84588675
 0.75483966 0.91204064]
2023-10-02 00:27:00 Eopch 11, Training Loss 0.30889497285632933, Val Loss 0.5908101556159682, Weighted Loss 0.3668286871660854


100%|██████████| 884/884 [28:54<00:00,  1.96s/it]
100%|██████████| 295/295 [03:24<00:00,  1.44it/s]


F1 score:  [0.08333333 0.19512195 0.18421053 0.20408163 0.23469388 0.34615385
 0.21857923 0.46296296]
AUC score:  [0.88190387 0.66870446 0.72699214 0.83415361 0.76278073 0.86394899
 0.75560541 0.88704109]
2023-10-02 00:59:20 Eopch 12, Training Loss 0.27261656022112296, Val Loss 0.6233540039572676, Weighted Loss 0.2420850278730046
Finish training: best_val:0.5448693524351564, best_weighted loss:0.2420850278730046


VBox(children=(Label(value='0.011 MB of 0.011 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
training loss,█▇▆▆▅▅▄▃▃▂▂▁
val loss,█▆▅▃▂▂▂▁▁▂▄▇
weighted loss,▆▆█▅▃▂▄▅▃▂▅▁

0,1
training loss,0.27262
val loss,0.62335
weighted loss,0.24209


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016668662211547294, max=1.0…

100%|██████████| 884/884 [28:53<00:00,  1.96s/it]
100%|██████████| 295/295 [03:30<00:00,  1.40it/s]


F1 score:  [0. 0. 0. 0. 0. 0. 0. 0.]
AUC score:  [0.68753611 0.56447265 0.58573917 0.66501035 0.70283584 0.64525037
 0.63417469 0.57923375]
2023-10-02 01:31:56 Eopch 1, Training Loss 0.7028079769527751, Val Loss 0.661520871216968, Weighted Loss 0.43362101878414966


100%|██████████| 884/884 [28:52<00:00,  1.96s/it]
100%|██████████| 295/295 [03:34<00:00,  1.37it/s]


F1 score:  [0. 0. 0. 0. 0. 0. 0. 0.]
AUC score:  [0.74335644 0.63001511 0.67917222 0.72211557 0.68552789 0.65310506
 0.66220333 0.68331843]
2023-10-02 02:04:24 Eopch 2, Training Loss 0.6583787478270574, Val Loss 0.6328361799151211, Weighted Loss 0.3450630442656914


100%|██████████| 884/884 [28:53<00:00,  1.96s/it]
100%|██████████| 295/295 [03:35<00:00,  1.37it/s]


F1 score:  [0.         0.         0.09302326 0.10909091 0.26044226 0.
 0.02197802 0.11764706]
AUC score:  [0.7477253  0.68548806 0.65030222 0.70766046 0.71177384 0.80160776
 0.69846618 0.64807692]
2023-10-02 02:36:54 Eopch 3, Training Loss 0.6186850954066304, Val Loss 0.636518648468842, Weighted Loss 0.4267546746221116


100%|██████████| 884/884 [28:52<00:00,  1.96s/it]
100%|██████████| 295/295 [03:36<00:00,  1.36it/s]


F1 score:  [0.20289855 0.06521739 0.12637363 0.07751938 0.25891182 0.08
 0.20444444 0.22325581]
AUC score:  [0.80932265 0.73407072 0.72172933 0.75772633 0.71756091 0.72250859
 0.72057478 0.77389684]
2023-10-02 03:09:25 Eopch 4, Training Loss 0.5758065237785897, Val Loss 0.6189531717765129, Weighted Loss 0.44045138492188446


100%|██████████| 884/884 [28:52<00:00,  1.96s/it]
100%|██████████| 295/295 [03:36<00:00,  1.36it/s]


F1 score:  [0.13333333 0.125      0.09195402 0.17391304 0.3141994  0.0952381
 0.20111732 0.16666667]
AUC score:  [0.7918833  0.74051375 0.70533757 0.81261058 0.75453779 0.79356897
 0.73875464 0.76560823]
2023-10-02 03:41:56 Eopch 5, Training Loss 0.5369341819972744, Val Loss 0.5794388326547913, Weighted Loss 0.3225070371346227


100%|██████████| 884/884 [28:53<00:00,  1.96s/it]
100%|██████████| 295/295 [03:24<00:00,  1.44it/s]


F1 score:  [0.2173913  0.18032787 0.16       0.11111111 0.28436019 0.
 0.11464968 0.28346457]
AUC score:  [0.83679954 0.74441825 0.75172626 0.83628835 0.74357609 0.75613648
 0.7322749  0.81110614]
2023-10-02 04:14:15 Eopch 6, Training Loss 0.4991888412067928, Val Loss 0.5679041800104966, Weighted Loss 0.3206102932744319


100%|██████████| 884/884 [28:57<00:00,  1.97s/it]
100%|██████████| 295/295 [03:23<00:00,  1.45it/s]


F1 score:  [0.15189873 0.18461538 0.19753086 0.13636364 0.31724138 0.08333333
 0.18604651 0.2244898 ]
AUC score:  [0.83120306 0.73739498 0.78153878 0.83726708 0.74613234 0.80749877
 0.72539691 0.77944246]
2023-10-02 04:46:38 Eopch 7, Training Loss 0.45458875550167865, Val Loss 0.5832916634315152, Weighted Loss 0.28254308781709175


100%|██████████| 884/884 [28:53<00:00,  1.96s/it]
100%|██████████| 295/295 [03:34<00:00,  1.38it/s]


F1 score:  [0.20588235 0.21960784 0.12738854 0.1443299  0.30227743 0.04255319
 0.22082019 0.23170732]
AUC score:  [0.8136193  0.72875189 0.73580576 0.82864672 0.77339901 0.81817624
 0.73839944 0.80864639]
2023-10-02 05:19:07 Eopch 8, Training Loss 0.41614568268888674, Val Loss 0.5753563775602034, Weighted Loss 0.36126826266745493


100%|██████████| 884/884 [28:52<00:00,  1.96s/it]
100%|██████████| 295/295 [03:33<00:00,  1.38it/s]


F1 score:  [0.19354839 0.26190476 0.11881188 0.14814815 0.28571429 0.05263158
 0.23966942 0.24193548]
AUC score:  [0.82145436 0.72731339 0.72888024 0.84302654 0.75958816 0.80737604
 0.7489048  0.79197973]
2023-10-02 05:51:36 Eopch 9, Training Loss 0.3847345663487439, Val Loss 0.5840994170787981, Weighted Loss 0.293577582931967


100%|██████████| 884/884 [28:57<00:00,  1.97s/it]
100%|██████████| 295/295 [03:34<00:00,  1.37it/s]


F1 score:  [0.13461538 0.1971831  0.1        0.17021277 0.29333333 0.06666667
 0.2202381  0.34210526]
AUC score:  [0.82770075 0.72886068 0.76088516 0.83519669 0.75550526 0.76785714
 0.71521447 0.84306798]
2023-10-02 06:24:09 Eopch 10, Training Loss 0.34339403149111375, Val Loss 0.6045430273449017, Weighted Loss 0.28728700134108054


100%|██████████| 884/884 [28:54<00:00,  1.96s/it]
100%|██████████| 295/295 [03:33<00:00,  1.38it/s]


F1 score:  [0.14634146 0.22535211 0.171875   0.21238938 0.29916898 0.04255319
 0.21621622 0.2739726 ]
AUC score:  [0.8382799  0.70484134 0.7418912  0.85785808 0.75897572 0.78583702
 0.71959529 0.81192606]
2023-10-02 06:56:39 Eopch 11, Training Loss 0.30888712296169674, Val Loss 0.5869799765868712, Weighted Loss 0.34763251600522066


100%|██████████| 884/884 [28:53<00:00,  1.96s/it]
100%|██████████| 295/295 [03:26<00:00,  1.43it/s]


F1 score:  [0.20338983 0.23684211 0.12738854 0.19753086 0.3        0.0754717
 0.23571429 0.33519553]
AUC score:  [0.81463027 0.72961015 0.71437353 0.84434406 0.77614166 0.79860088
 0.71167321 0.82938283]
2023-10-02 07:29:00 Eopch 12, Training Loss 0.27639799103208257, Val Loss 0.5941863929821273, Weighted Loss 0.32313103889542555
Finish training: best_val:0.5679041800104966, best_weighted loss:0.28254308781709175


0,1
training loss,█▇▇▆▅▅▄▃▃▂▂▁
val loss,█▆▆▅▂▁▂▂▂▄▂▃
weighted loss,█▄▇█▃▃▁▄▁▁▄▃

0,1
training loss,0.2764
val loss,0.59419
weighted loss,0.32313


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016669257637113334, max=1.0…

100%|██████████| 884/884 [28:55<00:00,  1.96s/it]
100%|██████████| 295/295 [03:19<00:00,  1.48it/s]


F1 score:  [0.         0.         0.         0.         0.10909091 0.
 0.         0.        ]
AUC score:  [0.66215919 0.54497958 0.63361917 0.52425121 0.63504771 0.7292011
 0.62744691 0.58291555]
2023-10-02 08:01:28 Eopch 1, Training Loss 0.6871051449514083, Val Loss 0.6770173975471723, Weighted Loss 0.39052474032701756


100%|██████████| 884/884 [28:55<00:00,  1.96s/it]
100%|██████████| 295/295 [03:37<00:00,  1.36it/s]


F1 score:  [0.24       0.05263158 0.03571429 0.04580153 0.19178082 0.
 0.20547945 0.14876033]
AUC score:  [0.73012313 0.62753938 0.71644679 0.57455717 0.65109136 0.7566706
 0.66751977 0.72590272]
2023-10-02 08:34:03 Eopch 2, Training Loss 0.6343855618655142, Val Loss 0.6617109271429353, Weighted Loss 0.4469973246718072


100%|██████████| 884/884 [28:52<00:00,  1.96s/it]
100%|██████████| 295/295 [03:24<00:00,  1.44it/s]


F1 score:  [0.125      0.09876543 0.14285714 0.11111111 0.21052632 0.
 0.15       0.25454545]
AUC score:  [0.79487687 0.63578741 0.72966139 0.67848631 0.673681   0.84978355
 0.70476816 0.77363474]
2023-10-02 09:06:21 Eopch 3, Training Loss 0.5985044510824378, Val Loss 0.6218104404160532, Weighted Loss 0.3549195691875008


100%|██████████| 884/884 [28:54<00:00,  1.96s/it]
100%|██████████| 295/295 [03:25<00:00,  1.43it/s]


F1 score:  [0.24390244 0.18045113 0.15151515 0.07446809 0.20216606 0.13793103
 0.25454545 0.29959514]
AUC score:  [0.7657212  0.66315706 0.72867194 0.71285024 0.68000915 0.79996065
 0.76407046 0.83552671]
2023-10-02 09:38:42 Eopch 4, Training Loss 0.5614227701265079, Val Loss 0.6325238530161017, Weighted Loss 0.449982470771754


100%|██████████| 884/884 [29:00<00:00,  1.97s/it]
100%|██████████| 295/295 [03:21<00:00,  1.47it/s]


F1 score:  [0.1682243  0.19       0.13821138 0.14516129 0.25613079 0.12121212
 0.28104575 0.37234043]
AUC score:  [0.80072559 0.67694797 0.75848725 0.80193237 0.7053762  0.83325462
 0.77861935 0.84888093]
2023-10-02 10:11:06 Eopch 5, Training Loss 0.5175992300506361, Val Loss 0.5916816701323299, Weighted Loss 0.37951242733862


100%|██████████| 884/884 [28:59<00:00,  1.97s/it]
100%|██████████| 295/295 [03:31<00:00,  1.39it/s]


F1 score:  [0.16666667 0.20740741 0.15189873 0.13333333 0.2484472  0.
 0.19875776 0.39285714]
AUC score:  [0.81332454 0.690898   0.75780563 0.80869565 0.71808696 0.81180638
 0.75994328 0.85013429]
2023-10-02 10:43:39 Eopch 6, Training Loss 0.4824025288510781, Val Loss 0.6148480744053751, Weighted Loss 0.2665382193413572


100%|██████████| 884/884 [28:57<00:00,  1.97s/it]
100%|██████████| 295/295 [03:37<00:00,  1.36it/s]


F1 score:  [0.18421053 0.15873016 0.19148936 0.16071429 0.24725275 0.06896552
 0.29508197 0.425     ]
AUC score:  [0.82282322 0.65134196 0.7748241  0.82657005 0.71982965 0.85548996
 0.75283023 0.85502835]
2023-10-02 11:16:15 Eopch 7, Training Loss 0.4401513392956953, Val Loss 0.5965412277538897, Weighted Loss 0.3030272513884402


100%|██████████| 884/884 [28:57<00:00,  1.97s/it]
100%|██████████| 295/295 [03:46<00:00,  1.31it/s]


F1 score:  [0.27536232 0.21052632 0.15544041 0.13496933 0.19964349 0.07017544
 0.26344086 0.40191388]
AUC score:  [0.817854   0.6685806  0.7994723  0.7989694  0.70668322 0.83081464
 0.76810542 0.87912563]
2023-10-02 11:49:00 Eopch 8, Training Loss 0.40936375668955066, Val Loss 0.5982483788440793, Weighted Loss 0.4048062782540054


100%|██████████| 884/884 [28:53<00:00,  1.96s/it]
100%|██████████| 295/295 [03:36<00:00,  1.36it/s]


F1 score:  [0.25641026 0.17313433 0.16736402 0.13333333 0.19650655 0.16393443
 0.22540984 0.31782946]
AUC score:  [0.81578716 0.65937782 0.79371152 0.7821256  0.69908073 0.82542306
 0.76207604 0.89055506]
2023-10-02 12:21:31 Eopch 9, Training Loss 0.375929304942219, Val Loss 0.6109682053074998, Weighted Loss 0.42371163637510334


100%|██████████| 884/884 [29:00<00:00,  1.97s/it]
100%|██████████| 295/295 [04:35<00:00,  1.07it/s]


F1 score:  [0.24705882 0.18627451 0.13513514 0.15311005 0.24404762 0.1025641
 0.24355972 0.37168142]
AUC score:  [0.84373351 0.63698085 0.78049692 0.79806763 0.71120333 0.87866982
 0.75578151 0.87693226]
2023-10-02 12:55:08 Eopch 10, Training Loss 0.334140787769227, Val Loss 0.6059734647051763, Weighted Loss 0.37262667610875716


100%|██████████| 884/884 [28:57<00:00,  1.97s/it]
100%|██████████| 295/295 [03:44<00:00,  1.32it/s]


F1 score:  [0.18867925 0.21276596 0.11111111 0.         0.24561404 0.10526316
 0.20430108 0.464     ]
AUC score:  [0.82324099 0.65066568 0.76270888 0.79275362 0.70171655 0.80716253
 0.72091951 0.8515667 ]
2023-10-02 13:27:52 Eopch 11, Training Loss 0.30283181161969497, Val Loss 0.7034354956473334, Weighted Loss 0.25811512151588256


100%|██████████| 884/884 [28:56<00:00,  1.96s/it]
100%|██████████| 295/295 [04:11<00:00,  1.17it/s]


F1 score:  [0.16949153 0.2556391  0.12121212 0.13636364 0.26627219 0.14035088
 0.21290323 0.39506173]
AUC score:  [0.82689094 0.64597146 0.74023747 0.80637681 0.70319784 0.82089728
 0.72798644 0.88340794]
2023-10-02 14:01:01 Eopch 12, Training Loss 0.2698209661762841, Val Loss 0.6542712587547505, Weighted Loss 0.3025703747603121
Finish training: best_val:0.5916816701323299, best_weighted loss:0.25811512151588256


0,1
training loss,█▇▇▆▅▅▄▃▃▂▂▁
val loss,▆▅▃▄▁▂▁▁▂▂█▅
weighted loss,▆█▅█▅▁▃▆▇▅▁▃

0,1
training loss,0.26982
val loss,0.65427
weighted loss,0.30257


In [17]:
# #net = FullVolNet(backbone = 'convnextv2_nano.fcmae_ft_in22k_in1k_384', ch_out = 8, slices=15).to(device)
# net = FullVolNet(ch_out = 8, slices=15).to(device)

# optimizer = optim.AdamW(net.parameters(), lr=1e-5)
# TrainClassifer(model=net,trn_dl=train_dl,val_dl=val_dl,optimizer=optimizer, 
#                scheduler=None, n_eopchs=15, device=device)