**Heng's Starter code for training the classification model on Kaggle.**

I have not run the kernel with GPU enabled because I do not have much of Kaggle GPU left as of now. So the kernel as expected is giving CUDA error.
This is just a simple kernel for training model on Kaggle easily. Made some minor changes to his code and seems like it will run fine here on kaggle.
I have not tested the training time. It can exceed the 9 hour limit.

The kernel is based on Heng's starter kit version 20190910 you can find it [here](https://www.kaggle.com/c/severstal-steel-defect-detection/discussion/106462#latest-645576) .
I have imported 2 utility scripts one for the utility functions with plotting code and another one is for model.
You can fork and edit the utility scripts and add the model classes as you feel like.
The model architecture can be changed from this kernel below by changing the Net() class.

If you face any problems or errors then feel free to comment them.
At last thank you very much [Heng](https://www.kaggle.com/hengck23) and other leaderboard rankers for helping newbies like me.

In [1]:
import numpy as np
import pandas as pd
import os
import glob
import random

from timeit import default_timer as timer
import cv2
import matplotlib.pyplot as plt

from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
import torch.utils.data as data
from torch.utils.data.sampler import Sampler
import torchvision.models as models
import torch.nn as nn
from torch.nn import functional as F
import torch

from lib.utility_functions import *
from lib.models_all import *
from lib.rate import *

PI = np.pi
IMAGE_RGB_MEAN = [0.485, 0.456, 0.406]
IMAGE_RGB_STD  = [0.229, 0.224, 0.225]
DEFECT_COLOR = [(0,0,0),(0,0,255),(0,255,0),(255,0,0),(0,255,255)]

In [2]:
SPLIT_DIR = 'data/split'
DATA_DIR = 'data'

In [3]:
class Net(nn.Module):
    def load_pretrain(self, skip=['logit.'], is_print=True):
        load_pretrain(self, skip, pretrain_file=PRETRAIN_FILE, conversion=CONVERSION, is_print=is_print)

    def __init__(self, num_class=4):
        super(Net, self).__init__()

        e = ResNext50()
        self.block0 = e.block0
        self.block1 = e.block1
        self.block2 = e.block2
        self.block3 = e.block3
        self.block4 = e.block4
        e = None  #dropped

        self.feature = nn.Conv2d(2048, 64, kernel_size=1) #dummy conv for dim reduction
        self.logit   = nn.Conv2d(64, num_class, kernel_size=1)



    def forward(self, x):
        batch_size,C,H,W = x.shape
        x = x.clone()
        x = x-torch.FloatTensor(IMAGE_RGB_MEAN).to(x.device).view(1,-1,1,1)
        x = x/torch.FloatTensor(IMAGE_RGB_STD).to(x.device).view(1,-1,1,1)

        x = self.block0(x)
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)

        x = F.dropout(x,0.5,training=self.training)
        x = F.avg_pool2d(x, kernel_size=(8, 13),stride=(8, 8))
        #x = F.adaptive_avg_pool2d(x, 1)
        x = self.feature(x)

        logit = self.logit(x) #.view(batch_size,-1)
        return logit

In [4]:
# Class which is used by the infor object in __get_item__
class Struct(object):
    def __init__(self, is_copy=False, **kwargs):
        self.add(is_copy, **kwargs)

    def add(self, is_copy=False, **kwargs):
        #self.__dict__.update(kwargs)

        if is_copy == False:
            for key, value in kwargs.items():
                setattr(self, key, value)
        else:
            for key, value in kwargs.items():
                try:
                    setattr(self, key, copy.deepcopy(value))
                    #setattr(self, key, value.copy())
                except Exception:
                    setattr(self, key, value)

    def __str__(self):
        text =''
        for k,v in self.__dict__.items():
            text += '\t%s : %s\n'%(k, str(v))
        return text

# Creating masks
def run_length_decode(rle, height=256, width=1600, fill_value=1):
    mask = np.zeros((height,width), np.float32)
    if rle != '':
        mask=mask.reshape(-1)
        r = [int(r) for r in rle.split(' ')]
        r = np.array(r).reshape(-1, 2)
        for start,length in r:
            start = start-1  #???? 0 or 1 index ???
            mask[start:(start + length)] = fill_value
        mask=mask.reshape(width, height).T
    return mask

# Collations
def null_collate(batch):
    batch_size = len(batch)

    input = []
    truth_mask  = []
    truth_label = []
    infor = []
    for b in range(batch_size):
        input.append(batch[b][0])
        truth_mask.append(batch[b][1])
        infor.append(batch[b][2])

        label = (batch[b][1].reshape(4,-1).sum(1)>8).astype(np.int32)
        truth_label.append(label)


    input = np.stack(input)
    input = image_to_input(input, IMAGE_RGB_MEAN,IMAGE_RGB_STD)
    input = torch.from_numpy(input).float()

    truth_mask = np.stack(truth_mask)
    truth_mask = (truth_mask>0.5).astype(np.float32)
    truth_mask = torch.from_numpy(truth_mask).float()

    truth_label = np.array(truth_label)
    truth_label = torch.from_numpy(truth_label).float()

    return input, truth_mask, truth_label, infor

# Metric
def metric_hit(logit, truth, threshold=0.5):
    batch_size,num_class, H,W = logit.shape

    with torch.no_grad():
        logit = logit.view(batch_size,num_class,-1)
        truth = truth.view(batch_size,num_class,-1)

        probability = torch.sigmoid(logit)
        p = (probability>threshold).float()
        t = (truth>0.5).float()

        tp = ((p + t) == 2).float()  # True positives
        tn = ((p + t) == 0).float()  # True negatives

        tp = tp.sum(dim=[0,2])
        tn = tn.sum(dim=[0,2])
        num_pos = t.sum(dim=[0,2])
        num_neg = batch_size*H*W - num_pos

        tp = tp.data.cpu().numpy()
        tn = tn.data.cpu().numpy().sum()
        num_pos = num_pos.data.cpu().numpy()
        num_neg = num_neg.data.cpu().numpy().sum()

        tp = np.nan_to_num(tp/(num_pos+1e-12),0)
        tn = np.nan_to_num(tn/(num_neg+1e-12),0)

        tp = list(tp)
        num_pos = list(num_pos)

    return tn,tp, num_neg,num_pos

# Loss
#def criterion(logit, truth, weight=None):
#    batch_size,num_class, H,W = logit.shape
#    logit = logit.view(batch_size,num_class)
#    truth = truth.view(batch_size,num_class)
#    assert(logit.shape==truth.shape)
#
#    loss = F.binary_cross_entropy_with_logits(logit, truth, reduction='none')
#
#    if weight is None:
#        loss = loss.mean()
#
#    else:
#        pos = (truth>0.5).float()
#        neg = (truth<0.5).float()
#        pos_sum = pos.sum().item() + 1e-12
#        neg_sum = neg.sum().item() + 1e-12
#        loss = (weight[1]*pos*loss/pos_sum + weight[0]*neg*loss/neg_sum).sum()
#        #raise NotImplementedError
#
#    return loss


def criterion(logit, truth, weight=None):
    batch_size,num_class = logit.shape[:2]
    logit = logit.view(batch_size,num_class)
    truth = truth.view(batch_size,num_class)

    if weight is None: weight=[1,1,1,1]
    weight = torch.FloatTensor(weight).to(truth.device).view(1,-1)

    loss = F.binary_cross_entropy_with_logits(logit, truth, reduction='none')

    loss = loss*weight
    loss = loss.mean()
    return loss

# Learning Rate Adjustments
def adjust_learning_rate(optimizer, lr):
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

def get_learning_rate(optimizer):
    lr=[]
    for param_group in optimizer.param_groups:
        lr += [param_group['lr']]

    assert(len(lr)==1) #we support only one param_group
    lr = lr[0]
    return lr

# Learning Rate Schedule
class NullScheduler():
    def __init__(self, lr=0.01 ):
        super(NullScheduler, self).__init__()
        self.lr    = lr
        self.cycle = 0

    def __call__(self, time):
        return self.lr

    def __str__(self):
        string = 'NullScheduler\n' \
                + 'lr=%0.5f '%(self.lr)
        return string

In [5]:
class SteelDataset(Dataset):
    def __init__(self, split, csv, mode, augment=None):
        self.split   = split
        self.csv     = csv
        self.mode    = mode
        self.augment = augment

        self.uid = list(np.concatenate([np.load(SPLIT_DIR + '/%s'%f , allow_pickle=True) for f in split]))
        df = pd.concat([pd.read_csv(DATA_DIR + '/%s'%f) for f in csv])
        df.fillna('', inplace=True)
        df['Class'] = df['ImageId_ClassId'].str[-1].astype(np.int32)
        df['Label'] = (df['EncodedPixels']!='').astype(np.int32)
        df = df_loc_by_list(df, 'ImageId_ClassId', [ u.split('/')[-1] + '_%d'%c  for u in self.uid for c in [1,2,3,4] ])
        self.df = df

    def __str__(self):
        num1 = (self.df['Class']==1).sum()
        num2 = (self.df['Class']==2).sum()
        num3 = (self.df['Class']==3).sum()
        num4 = (self.df['Class']==4).sum()
        pos1 = ((self.df['Class']==1) & (self.df['Label']==1)).sum()
        pos2 = ((self.df['Class']==2) & (self.df['Label']==1)).sum()
        pos3 = ((self.df['Class']==3) & (self.df['Label']==1)).sum()
        pos4 = ((self.df['Class']==4) & (self.df['Label']==1)).sum()

        length = len(self)
        num = len(self)*4
        pos = (self.df['Label']==1).sum()
        neg = num-pos

        string  = ''
        string += '\tmode    = %s\n'%self.mode
        string += '\tsplit   = %s\n'%self.split
        string += '\tcsv     = %s\n'%str(self.csv)
        string += '\t\tlen   = %5d\n'%len(self)
        if self.mode == 'train':
            string += '\t\tnum   = %5d\n'%num
            string += '\t\tneg   = %5d  %0.3f\n'%(neg,neg/num)
            string += '\t\tpos   = %5d  %0.3f\n'%(pos,pos/num)
            string += '\t\tpos1  = %5d  %0.3f  %0.3f\n'%(pos1,pos1/length,pos1/pos)
            string += '\t\tpos2  = %5d  %0.3f  %0.3f\n'%(pos2,pos2/length,pos2/pos)
            string += '\t\tpos3  = %5d  %0.3f  %0.3f\n'%(pos3,pos3/length,pos3/pos)
            string += '\t\tpos4  = %5d  %0.3f  %0.3f\n'%(pos4,pos4/length,pos4/pos)
        return string


    def __len__(self):
        return len(self.uid)


    def __getitem__(self, index):
        folder, image_id = self.uid[index].split('/')
        rle = [
            self.df.loc[self.df['ImageId_ClassId']==image_id + '_1','EncodedPixels'].values[0],
            self.df.loc[self.df['ImageId_ClassId']==image_id + '_2','EncodedPixels'].values[0],
            self.df.loc[self.df['ImageId_ClassId']==image_id + '_3','EncodedPixels'].values[0],
            self.df.loc[self.df['ImageId_ClassId']==image_id + '_4','EncodedPixels'].values[0],
        ]
        
        image = cv2.imread(DATA_DIR + '/%s/%s'%(folder,image_id), cv2.IMREAD_COLOR)
        mask  = np.array([run_length_decode(r, height=256, width=1600, fill_value=1) for r in rle])

        infor = Struct(
            index    = index,
            folder   = folder,
            image_id = image_id,
        )

        if self.augment is None:
            return image, mask, infor
        else:
            return self.augment(image, mask, infor)

In [6]:
class FiveBalanceClassSampler(Sampler):

    def __init__(self, dataset):
        self.dataset = dataset

        label = (self.dataset.df['Label'].values)
        
        #cannot reshape array of size 49155 into shape (4)
        label = label.reshape(-1,4)
        label = np.hstack([label.sum(1,keepdims=True)==0,label]).T

        self.neg_index  = np.where(label[0])[0]
        self.pos1_index = np.where(label[1])[0]
        self.pos2_index = np.where(label[2])[0]
        self.pos3_index = np.where(label[3])[0]
        self.pos4_index = np.where(label[4])[0]

        #5x
        self.num_image = len(self.dataset.df)//4
        self.length = self.num_image*5


    def __iter__(self):
        neg  = np.random.choice(self.neg_index,  self.num_image, replace=True)
        pos1 = np.random.choice(self.pos1_index, self.num_image, replace=True)
        pos2 = np.random.choice(self.pos2_index, self.num_image, replace=True)
        pos3 = np.random.choice(self.pos3_index, self.num_image, replace=True)
        pos4 = np.random.choice(self.pos4_index, self.num_image, replace=True)

        l = np.stack([neg,pos1,pos2,pos3,pos4]).T
        l = l.reshape(-1)
        return iter(l)

    def __len__(self):
        return self.length

In [7]:
def do_valid(net, valid_loader, displays=None):
    valid_num  = np.zeros(6, np.float32)
    valid_loss = np.zeros(6, np.float32)
    
    for t, (input, truth_mask, truth_label, infor) in enumerate(valid_loader):

        net.eval()
        input = input.cuda()
        truth_mask  = truth_mask.cuda()
        truth_label = truth_label.cuda()

        with torch.no_grad():
            logit = net(input) #data_parallel(net, input)  
            logit = logit.max(-1,True)[0]
            loss  = criterion(logit, truth_label)
            tn,tp, num_neg,num_pos = metric_hit(logit, truth_label)

        batch_size = len(infor)
        l = np.array([ loss.item(), tn,*tp])
        n = np.array([ batch_size, num_neg,*num_pos])
        valid_loss += l*n
        valid_num  += n

        if displays is not None:
            probability = torch.sigmoid(logit)
            image = input_to_image(input, IMAGE_RGB_MEAN,IMAGE_RGB_STD)

            probability_label = probability.data.cpu().numpy()
            truth_label = truth_label.data.cpu().numpy()
            truth_mask  = truth_mask.data.cpu().numpy()

            for b in range(0, batch_size, 4):
                image_id = infor[b].image_id[:-4]
                result = draw_predict_result_label(image[b], truth_mask[b], truth_label[b], probability_label[b], stack='vertical')
                draw_shadow_text(result,'%05d    %s.jpg'%(valid_num[0]-batch_size+b, image_id),(5,24),0.75,[255,255,255],1)
                image_show('result',result,resize=1)

        print('\r %8d /%8d'%(valid_num[0], len(valid_loader.dataset)),end='',flush=True)

    assert(valid_num[0] == len(valid_loader.dataset))
    valid_loss = valid_loss/valid_num

    return valid_loss

In [8]:
CONVERSION=[
 'block0.0.weight',	(64, 3, 7, 7),	 'layer0.conv1.weight',	(64, 3, 7, 7),
 'block0.1.weight',	(64,),	 'layer0.bn1.weight',	(64,),
 'block0.1.bias',	(64,),	 'layer0.bn1.bias',	(64,),
 'block0.1.running_mean',	(64,),	 'layer0.bn1.running_mean',	(64,),
 'block0.1.running_var',	(64,),	 'layer0.bn1.running_var',	(64,),
 'block1.1.conv_bn1.conv.weight',	(128, 64, 1, 1),	 'layer1.0.conv1.weight',	(128, 64, 1, 1),
 'block1.1.conv_bn1.bn.weight',	(128,),	 'layer1.0.bn1.weight',	(128,),
 'block1.1.conv_bn1.bn.bias',	(128,),	 'layer1.0.bn1.bias',	(128,),
 'block1.1.conv_bn1.bn.running_mean',	(128,),	 'layer1.0.bn1.running_mean',	(128,),
 'block1.1.conv_bn1.bn.running_var',	(128,),	 'layer1.0.bn1.running_var',	(128,),
 'block1.1.conv_bn2.conv.weight',	(128, 4, 3, 3),	 'layer1.0.conv2.weight',	(128, 4, 3, 3),
 'block1.1.conv_bn2.bn.weight',	(128,),	 'layer1.0.bn2.weight',	(128,),
 'block1.1.conv_bn2.bn.bias',	(128,),	 'layer1.0.bn2.bias',	(128,),
 'block1.1.conv_bn2.bn.running_mean',	(128,),	 'layer1.0.bn2.running_mean',	(128,),
 'block1.1.conv_bn2.bn.running_var',	(128,),	 'layer1.0.bn2.running_var',	(128,),
 'block1.1.conv_bn3.conv.weight',	(256, 128, 1, 1),	 'layer1.0.conv3.weight',	(256, 128, 1, 1),
 'block1.1.conv_bn3.bn.weight',	(256,),	 'layer1.0.bn3.weight',	(256,),
 'block1.1.conv_bn3.bn.bias',	(256,),	 'layer1.0.bn3.bias',	(256,),
 'block1.1.conv_bn3.bn.running_mean',	(256,),	 'layer1.0.bn3.running_mean',	(256,),
 'block1.1.conv_bn3.bn.running_var',	(256,),	 'layer1.0.bn3.running_var',	(256,),
 'block1.1.scale.fc1.weight',	(16, 256, 1, 1),	 'layer1.0.se_module.fc1.weight',	(16, 256, 1, 1),
 'block1.1.scale.fc1.bias',	(16,),	 'layer1.0.se_module.fc1.bias',	(16,),
 'block1.1.scale.fc2.weight',	(256, 16, 1, 1),	 'layer1.0.se_module.fc2.weight',	(256, 16, 1, 1),
 'block1.1.scale.fc2.bias',	(256,),	 'layer1.0.se_module.fc2.bias',	(256,),
 'block1.1.shortcut.conv.weight',	(256, 64, 1, 1),	 'layer1.0.downsample.0.weight',	(256, 64, 1, 1),
 'block1.1.shortcut.bn.weight',	(256,),	 'layer1.0.downsample.1.weight',	(256,),
 'block1.1.shortcut.bn.bias',	(256,),	 'layer1.0.downsample.1.bias',	(256,),
 'block1.1.shortcut.bn.running_mean',	(256,),	 'layer1.0.downsample.1.running_mean',	(256,),
 'block1.1.shortcut.bn.running_var',	(256,),	 'layer1.0.downsample.1.running_var',	(256,),
 'block1.2.conv_bn1.conv.weight',	(128, 256, 1, 1),	 'layer1.1.conv1.weight',	(128, 256, 1, 1),
 'block1.2.conv_bn1.bn.weight',	(128,),	 'layer1.1.bn1.weight',	(128,),
 'block1.2.conv_bn1.bn.bias',	(128,),	 'layer1.1.bn1.bias',	(128,),
 'block1.2.conv_bn1.bn.running_mean',	(128,),	 'layer1.1.bn1.running_mean',	(128,),
 'block1.2.conv_bn1.bn.running_var',	(128,),	 'layer1.1.bn1.running_var',	(128,),
 'block1.2.conv_bn2.conv.weight',	(128, 4, 3, 3),	 'layer1.1.conv2.weight',	(128, 4, 3, 3),
 'block1.2.conv_bn2.bn.weight',	(128,),	 'layer1.1.bn2.weight',	(128,),
 'block1.2.conv_bn2.bn.bias',	(128,),	 'layer1.1.bn2.bias',	(128,),
 'block1.2.conv_bn2.bn.running_mean',	(128,),	 'layer1.1.bn2.running_mean',	(128,),
 'block1.2.conv_bn2.bn.running_var',	(128,),	 'layer1.1.bn2.running_var',	(128,),
 'block1.2.conv_bn3.conv.weight',	(256, 128, 1, 1),	 'layer1.1.conv3.weight',	(256, 128, 1, 1),
 'block1.2.conv_bn3.bn.weight',	(256,),	 'layer1.1.bn3.weight',	(256,),
 'block1.2.conv_bn3.bn.bias',	(256,),	 'layer1.1.bn3.bias',	(256,),
 'block1.2.conv_bn3.bn.running_mean',	(256,),	 'layer1.1.bn3.running_mean',	(256,),
 'block1.2.conv_bn3.bn.running_var',	(256,),	 'layer1.1.bn3.running_var',	(256,),
 'block1.2.scale.fc1.weight',	(16, 256, 1, 1),	 'layer1.1.se_module.fc1.weight',	(16, 256, 1, 1),
 'block1.2.scale.fc1.bias',	(16,),	 'layer1.1.se_module.fc1.bias',	(16,),
 'block1.2.scale.fc2.weight',	(256, 16, 1, 1),	 'layer1.1.se_module.fc2.weight',	(256, 16, 1, 1),
 'block1.2.scale.fc2.bias',	(256,),	 'layer1.1.se_module.fc2.bias',	(256,),
 'block1.3.conv_bn1.conv.weight',	(128, 256, 1, 1),	 'layer1.2.conv1.weight',	(128, 256, 1, 1),
 'block1.3.conv_bn1.bn.weight',	(128,),	 'layer1.2.bn1.weight',	(128,),
 'block1.3.conv_bn1.bn.bias',	(128,),	 'layer1.2.bn1.bias',	(128,),
 'block1.3.conv_bn1.bn.running_mean',	(128,),	 'layer1.2.bn1.running_mean',	(128,),
 'block1.3.conv_bn1.bn.running_var',	(128,),	 'layer1.2.bn1.running_var',	(128,),
 'block1.3.conv_bn2.conv.weight',	(128, 4, 3, 3),	 'layer1.2.conv2.weight',	(128, 4, 3, 3),
 'block1.3.conv_bn2.bn.weight',	(128,),	 'layer1.2.bn2.weight',	(128,),
 'block1.3.conv_bn2.bn.bias',	(128,),	 'layer1.2.bn2.bias',	(128,),
 'block1.3.conv_bn2.bn.running_mean',	(128,),	 'layer1.2.bn2.running_mean',	(128,),
 'block1.3.conv_bn2.bn.running_var',	(128,),	 'layer1.2.bn2.running_var',	(128,),
 'block1.3.conv_bn3.conv.weight',	(256, 128, 1, 1),	 'layer1.2.conv3.weight',	(256, 128, 1, 1),
 'block1.3.conv_bn3.bn.weight',	(256,),	 'layer1.2.bn3.weight',	(256,),
 'block1.3.conv_bn3.bn.bias',	(256,),	 'layer1.2.bn3.bias',	(256,),
 'block1.3.conv_bn3.bn.running_mean',	(256,),	 'layer1.2.bn3.running_mean',	(256,),
 'block1.3.conv_bn3.bn.running_var',	(256,),	 'layer1.2.bn3.running_var',	(256,),
 'block1.3.scale.fc1.weight',	(16, 256, 1, 1),	 'layer1.2.se_module.fc1.weight',	(16, 256, 1, 1),
 'block1.3.scale.fc1.bias',	(16,),	 'layer1.2.se_module.fc1.bias',	(16,),
 'block1.3.scale.fc2.weight',	(256, 16, 1, 1),	 'layer1.2.se_module.fc2.weight',	(256, 16, 1, 1),
 'block1.3.scale.fc2.bias',	(256,),	 'layer1.2.se_module.fc2.bias',	(256,),
 'block2.0.conv_bn1.conv.weight',	(256, 256, 1, 1),	 'layer2.0.conv1.weight',	(256, 256, 1, 1),
 'block2.0.conv_bn1.bn.weight',	(256,),	 'layer2.0.bn1.weight',	(256,),
 'block2.0.conv_bn1.bn.bias',	(256,),	 'layer2.0.bn1.bias',	(256,),
 'block2.0.conv_bn1.bn.running_mean',	(256,),	 'layer2.0.bn1.running_mean',	(256,),
 'block2.0.conv_bn1.bn.running_var',	(256,),	 'layer2.0.bn1.running_var',	(256,),
 'block2.0.conv_bn2.conv.weight',	(256, 8, 3, 3),	 'layer2.0.conv2.weight',	(256, 8, 3, 3),
 'block2.0.conv_bn2.bn.weight',	(256,),	 'layer2.0.bn2.weight',	(256,),
 'block2.0.conv_bn2.bn.bias',	(256,),	 'layer2.0.bn2.bias',	(256,),
 'block2.0.conv_bn2.bn.running_mean',	(256,),	 'layer2.0.bn2.running_mean',	(256,),
 'block2.0.conv_bn2.bn.running_var',	(256,),	 'layer2.0.bn2.running_var',	(256,),
 'block2.0.conv_bn3.conv.weight',	(512, 256, 1, 1),	 'layer2.0.conv3.weight',	(512, 256, 1, 1),
 'block2.0.conv_bn3.bn.weight',	(512,),	 'layer2.0.bn3.weight',	(512,),
 'block2.0.conv_bn3.bn.bias',	(512,),	 'layer2.0.bn3.bias',	(512,),
 'block2.0.conv_bn3.bn.running_mean',	(512,),	 'layer2.0.bn3.running_mean',	(512,),
 'block2.0.conv_bn3.bn.running_var',	(512,),	 'layer2.0.bn3.running_var',	(512,),
 'block2.0.scale.fc1.weight',	(32, 512, 1, 1),	 'layer2.0.se_module.fc1.weight',	(32, 512, 1, 1),
 'block2.0.scale.fc1.bias',	(32,),	 'layer2.0.se_module.fc1.bias',	(32,),
 'block2.0.scale.fc2.weight',	(512, 32, 1, 1),	 'layer2.0.se_module.fc2.weight',	(512, 32, 1, 1),
 'block2.0.scale.fc2.bias',	(512,),	 'layer2.0.se_module.fc2.bias',	(512,),
 'block2.0.shortcut.conv.weight',	(512, 256, 1, 1),	 'layer2.0.downsample.0.weight',	(512, 256, 1, 1),
 'block2.0.shortcut.bn.weight',	(512,),	 'layer2.0.downsample.1.weight',	(512,),
 'block2.0.shortcut.bn.bias',	(512,),	 'layer2.0.downsample.1.bias',	(512,),
 'block2.0.shortcut.bn.running_mean',	(512,),	 'layer2.0.downsample.1.running_mean',	(512,),
 'block2.0.shortcut.bn.running_var',	(512,),	 'layer2.0.downsample.1.running_var',	(512,),
 'block2.1.conv_bn1.conv.weight',	(256, 512, 1, 1),	 'layer2.1.conv1.weight',	(256, 512, 1, 1),
 'block2.1.conv_bn1.bn.weight',	(256,),	 'layer2.1.bn1.weight',	(256,),
 'block2.1.conv_bn1.bn.bias',	(256,),	 'layer2.1.bn1.bias',	(256,),
 'block2.1.conv_bn1.bn.running_mean',	(256,),	 'layer2.1.bn1.running_mean',	(256,),
 'block2.1.conv_bn1.bn.running_var',	(256,),	 'layer2.1.bn1.running_var',	(256,),
 'block2.1.conv_bn2.conv.weight',	(256, 8, 3, 3),	 'layer2.1.conv2.weight',	(256, 8, 3, 3),
 'block2.1.conv_bn2.bn.weight',	(256,),	 'layer2.1.bn2.weight',	(256,),
 'block2.1.conv_bn2.bn.bias',	(256,),	 'layer2.1.bn2.bias',	(256,),
 'block2.1.conv_bn2.bn.running_mean',	(256,),	 'layer2.1.bn2.running_mean',	(256,),
 'block2.1.conv_bn2.bn.running_var',	(256,),	 'layer2.1.bn2.running_var',	(256,),
 'block2.1.conv_bn3.conv.weight',	(512, 256, 1, 1),	 'layer2.1.conv3.weight',	(512, 256, 1, 1),
 'block2.1.conv_bn3.bn.weight',	(512,),	 'layer2.1.bn3.weight',	(512,),
 'block2.1.conv_bn3.bn.bias',	(512,),	 'layer2.1.bn3.bias',	(512,),
 'block2.1.conv_bn3.bn.running_mean',	(512,),	 'layer2.1.bn3.running_mean',	(512,),
 'block2.1.conv_bn3.bn.running_var',	(512,),	 'layer2.1.bn3.running_var',	(512,),
 'block2.1.scale.fc1.weight',	(32, 512, 1, 1),	 'layer2.1.se_module.fc1.weight',	(32, 512, 1, 1),
 'block2.1.scale.fc1.bias',	(32,),	 'layer2.1.se_module.fc1.bias',	(32,),
 'block2.1.scale.fc2.weight',	(512, 32, 1, 1),	 'layer2.1.se_module.fc2.weight',	(512, 32, 1, 1),
 'block2.1.scale.fc2.bias',	(512,),	 'layer2.1.se_module.fc2.bias',	(512,),
 'block2.2.conv_bn1.conv.weight',	(256, 512, 1, 1),	 'layer2.2.conv1.weight',	(256, 512, 1, 1),
 'block2.2.conv_bn1.bn.weight',	(256,),	 'layer2.2.bn1.weight',	(256,),
 'block2.2.conv_bn1.bn.bias',	(256,),	 'layer2.2.bn1.bias',	(256,),
 'block2.2.conv_bn1.bn.running_mean',	(256,),	 'layer2.2.bn1.running_mean',	(256,),
 'block2.2.conv_bn1.bn.running_var',	(256,),	 'layer2.2.bn1.running_var',	(256,),
 'block2.2.conv_bn2.conv.weight',	(256, 8, 3, 3),	 'layer2.2.conv2.weight',	(256, 8, 3, 3),
 'block2.2.conv_bn2.bn.weight',	(256,),	 'layer2.2.bn2.weight',	(256,),
 'block2.2.conv_bn2.bn.bias',	(256,),	 'layer2.2.bn2.bias',	(256,),
 'block2.2.conv_bn2.bn.running_mean',	(256,),	 'layer2.2.bn2.running_mean',	(256,),
 'block2.2.conv_bn2.bn.running_var',	(256,),	 'layer2.2.bn2.running_var',	(256,),
 'block2.2.conv_bn3.conv.weight',	(512, 256, 1, 1),	 'layer2.2.conv3.weight',	(512, 256, 1, 1),
 'block2.2.conv_bn3.bn.weight',	(512,),	 'layer2.2.bn3.weight',	(512,),
 'block2.2.conv_bn3.bn.bias',	(512,),	 'layer2.2.bn3.bias',	(512,),
 'block2.2.conv_bn3.bn.running_mean',	(512,),	 'layer2.2.bn3.running_mean',	(512,),
 'block2.2.conv_bn3.bn.running_var',	(512,),	 'layer2.2.bn3.running_var',	(512,),
 'block2.2.scale.fc1.weight',	(32, 512, 1, 1),	 'layer2.2.se_module.fc1.weight',	(32, 512, 1, 1),
 'block2.2.scale.fc1.bias',	(32,),	 'layer2.2.se_module.fc1.bias',	(32,),
 'block2.2.scale.fc2.weight',	(512, 32, 1, 1),	 'layer2.2.se_module.fc2.weight',	(512, 32, 1, 1),
 'block2.2.scale.fc2.bias',	(512,),	 'layer2.2.se_module.fc2.bias',	(512,),
 'block2.3.conv_bn1.conv.weight',	(256, 512, 1, 1),	 'layer2.3.conv1.weight',	(256, 512, 1, 1),
 'block2.3.conv_bn1.bn.weight',	(256,),	 'layer2.3.bn1.weight',	(256,),
 'block2.3.conv_bn1.bn.bias',	(256,),	 'layer2.3.bn1.bias',	(256,),
 'block2.3.conv_bn1.bn.running_mean',	(256,),	 'layer2.3.bn1.running_mean',	(256,),
 'block2.3.conv_bn1.bn.running_var',	(256,),	 'layer2.3.bn1.running_var',	(256,),
 'block2.3.conv_bn2.conv.weight',	(256, 8, 3, 3),	 'layer2.3.conv2.weight',	(256, 8, 3, 3),
 'block2.3.conv_bn2.bn.weight',	(256,),	 'layer2.3.bn2.weight',	(256,),
 'block2.3.conv_bn2.bn.bias',	(256,),	 'layer2.3.bn2.bias',	(256,),
 'block2.3.conv_bn2.bn.running_mean',	(256,),	 'layer2.3.bn2.running_mean',	(256,),
 'block2.3.conv_bn2.bn.running_var',	(256,),	 'layer2.3.bn2.running_var',	(256,),
 'block2.3.conv_bn3.conv.weight',	(512, 256, 1, 1),	 'layer2.3.conv3.weight',	(512, 256, 1, 1),
 'block2.3.conv_bn3.bn.weight',	(512,),	 'layer2.3.bn3.weight',	(512,),
 'block2.3.conv_bn3.bn.bias',	(512,),	 'layer2.3.bn3.bias',	(512,),
 'block2.3.conv_bn3.bn.running_mean',	(512,),	 'layer2.3.bn3.running_mean',	(512,),
 'block2.3.conv_bn3.bn.running_var',	(512,),	 'layer2.3.bn3.running_var',	(512,),
 'block2.3.scale.fc1.weight',	(32, 512, 1, 1),	 'layer2.3.se_module.fc1.weight',	(32, 512, 1, 1),
 'block2.3.scale.fc1.bias',	(32,),	 'layer2.3.se_module.fc1.bias',	(32,),
 'block2.3.scale.fc2.weight',	(512, 32, 1, 1),	 'layer2.3.se_module.fc2.weight',	(512, 32, 1, 1),
 'block2.3.scale.fc2.bias',	(512,),	 'layer2.3.se_module.fc2.bias',	(512,),
 'block3.0.conv_bn1.conv.weight',	(512, 512, 1, 1),	 'layer3.0.conv1.weight',	(512, 512, 1, 1),
 'block3.0.conv_bn1.bn.weight',	(512,),	 'layer3.0.bn1.weight',	(512,),
 'block3.0.conv_bn1.bn.bias',	(512,),	 'layer3.0.bn1.bias',	(512,),
 'block3.0.conv_bn1.bn.running_mean',	(512,),	 'layer3.0.bn1.running_mean',	(512,),
 'block3.0.conv_bn1.bn.running_var',	(512,),	 'layer3.0.bn1.running_var',	(512,),
 'block3.0.conv_bn2.conv.weight',	(512, 16, 3, 3),	 'layer3.0.conv2.weight',	(512, 16, 3, 3),
 'block3.0.conv_bn2.bn.weight',	(512,),	 'layer3.0.bn2.weight',	(512,),
 'block3.0.conv_bn2.bn.bias',	(512,),	 'layer3.0.bn2.bias',	(512,),
 'block3.0.conv_bn2.bn.running_mean',	(512,),	 'layer3.0.bn2.running_mean',	(512,),
 'block3.0.conv_bn2.bn.running_var',	(512,),	 'layer3.0.bn2.running_var',	(512,),
 'block3.0.conv_bn3.conv.weight',	(1024, 512, 1, 1),	 'layer3.0.conv3.weight',	(1024, 512, 1, 1),
 'block3.0.conv_bn3.bn.weight',	(1024,),	 'layer3.0.bn3.weight',	(1024,),
 'block3.0.conv_bn3.bn.bias',	(1024,),	 'layer3.0.bn3.bias',	(1024,),
 'block3.0.conv_bn3.bn.running_mean',	(1024,),	 'layer3.0.bn3.running_mean',	(1024,),
 'block3.0.conv_bn3.bn.running_var',	(1024,),	 'layer3.0.bn3.running_var',	(1024,),
 'block3.0.scale.fc1.weight',	(64, 1024, 1, 1),	 'layer3.0.se_module.fc1.weight',	(64, 1024, 1, 1),
 'block3.0.scale.fc1.bias',	(64,),	 'layer3.0.se_module.fc1.bias',	(64,),
 'block3.0.scale.fc2.weight',	(1024, 64, 1, 1),	 'layer3.0.se_module.fc2.weight',	(1024, 64, 1, 1),
 'block3.0.scale.fc2.bias',	(1024,),	 'layer3.0.se_module.fc2.bias',	(1024,),
 'block3.0.shortcut.conv.weight',	(1024, 512, 1, 1),	 'layer3.0.downsample.0.weight',	(1024, 512, 1, 1),
 'block3.0.shortcut.bn.weight',	(1024,),	 'layer3.0.downsample.1.weight',	(1024,),
 'block3.0.shortcut.bn.bias',	(1024,),	 'layer3.0.downsample.1.bias',	(1024,),
 'block3.0.shortcut.bn.running_mean',	(1024,),	 'layer3.0.downsample.1.running_mean',	(1024,),
 'block3.0.shortcut.bn.running_var',	(1024,),	 'layer3.0.downsample.1.running_var',	(1024,),
 'block3.1.conv_bn1.conv.weight',	(512, 1024, 1, 1),	 'layer3.1.conv1.weight',	(512, 1024, 1, 1),
 'block3.1.conv_bn1.bn.weight',	(512,),	 'layer3.1.bn1.weight',	(512,),
 'block3.1.conv_bn1.bn.bias',	(512,),	 'layer3.1.bn1.bias',	(512,),
 'block3.1.conv_bn1.bn.running_mean',	(512,),	 'layer3.1.bn1.running_mean',	(512,),
 'block3.1.conv_bn1.bn.running_var',	(512,),	 'layer3.1.bn1.running_var',	(512,),
 'block3.1.conv_bn2.conv.weight',	(512, 16, 3, 3),	 'layer3.1.conv2.weight',	(512, 16, 3, 3),
 'block3.1.conv_bn2.bn.weight',	(512,),	 'layer3.1.bn2.weight',	(512,),
 'block3.1.conv_bn2.bn.bias',	(512,),	 'layer3.1.bn2.bias',	(512,),
 'block3.1.conv_bn2.bn.running_mean',	(512,),	 'layer3.1.bn2.running_mean',	(512,),
 'block3.1.conv_bn2.bn.running_var',	(512,),	 'layer3.1.bn2.running_var',	(512,),
 'block3.1.conv_bn3.conv.weight',	(1024, 512, 1, 1),	 'layer3.1.conv3.weight',	(1024, 512, 1, 1),
 'block3.1.conv_bn3.bn.weight',	(1024,),	 'layer3.1.bn3.weight',	(1024,),
 'block3.1.conv_bn3.bn.bias',	(1024,),	 'layer3.1.bn3.bias',	(1024,),
 'block3.1.conv_bn3.bn.running_mean',	(1024,),	 'layer3.1.bn3.running_mean',	(1024,),
 'block3.1.conv_bn3.bn.running_var',	(1024,),	 'layer3.1.bn3.running_var',	(1024,),
 'block3.1.scale.fc1.weight',	(64, 1024, 1, 1),	 'layer3.1.se_module.fc1.weight',	(64, 1024, 1, 1),
 'block3.1.scale.fc1.bias',	(64,),	 'layer3.1.se_module.fc1.bias',	(64,),
 'block3.1.scale.fc2.weight',	(1024, 64, 1, 1),	 'layer3.1.se_module.fc2.weight',	(1024, 64, 1, 1),
 'block3.1.scale.fc2.bias',	(1024,),	 'layer3.1.se_module.fc2.bias',	(1024,),
 'block3.2.conv_bn1.conv.weight',	(512, 1024, 1, 1),	 'layer3.2.conv1.weight',	(512, 1024, 1, 1),
 'block3.2.conv_bn1.bn.weight',	(512,),	 'layer3.2.bn1.weight',	(512,),
 'block3.2.conv_bn1.bn.bias',	(512,),	 'layer3.2.bn1.bias',	(512,),
 'block3.2.conv_bn1.bn.running_mean',	(512,),	 'layer3.2.bn1.running_mean',	(512,),
 'block3.2.conv_bn1.bn.running_var',	(512,),	 'layer3.2.bn1.running_var',	(512,),
 'block3.2.conv_bn2.conv.weight',	(512, 16, 3, 3),	 'layer3.2.conv2.weight',	(512, 16, 3, 3),
 'block3.2.conv_bn2.bn.weight',	(512,),	 'layer3.2.bn2.weight',	(512,),
 'block3.2.conv_bn2.bn.bias',	(512,),	 'layer3.2.bn2.bias',	(512,),
 'block3.2.conv_bn2.bn.running_mean',	(512,),	 'layer3.2.bn2.running_mean',	(512,),
 'block3.2.conv_bn2.bn.running_var',	(512,),	 'layer3.2.bn2.running_var',	(512,),
 'block3.2.conv_bn3.conv.weight',	(1024, 512, 1, 1),	 'layer3.2.conv3.weight',	(1024, 512, 1, 1),
 'block3.2.conv_bn3.bn.weight',	(1024,),	 'layer3.2.bn3.weight',	(1024,),
 'block3.2.conv_bn3.bn.bias',	(1024,),	 'layer3.2.bn3.bias',	(1024,),
 'block3.2.conv_bn3.bn.running_mean',	(1024,),	 'layer3.2.bn3.running_mean',	(1024,),
 'block3.2.conv_bn3.bn.running_var',	(1024,),	 'layer3.2.bn3.running_var',	(1024,),
 'block3.2.scale.fc1.weight',	(64, 1024, 1, 1),	 'layer3.2.se_module.fc1.weight',	(64, 1024, 1, 1),
 'block3.2.scale.fc1.bias',	(64,),	 'layer3.2.se_module.fc1.bias',	(64,),
 'block3.2.scale.fc2.weight',	(1024, 64, 1, 1),	 'layer3.2.se_module.fc2.weight',	(1024, 64, 1, 1),
 'block3.2.scale.fc2.bias',	(1024,),	 'layer3.2.se_module.fc2.bias',	(1024,),
 'block3.3.conv_bn1.conv.weight',	(512, 1024, 1, 1),	 'layer3.3.conv1.weight',	(512, 1024, 1, 1),
 'block3.3.conv_bn1.bn.weight',	(512,),	 'layer3.3.bn1.weight',	(512,),
 'block3.3.conv_bn1.bn.bias',	(512,),	 'layer3.3.bn1.bias',	(512,),
 'block3.3.conv_bn1.bn.running_mean',	(512,),	 'layer3.3.bn1.running_mean',	(512,),
 'block3.3.conv_bn1.bn.running_var',	(512,),	 'layer3.3.bn1.running_var',	(512,),
 'block3.3.conv_bn2.conv.weight',	(512, 16, 3, 3),	 'layer3.3.conv2.weight',	(512, 16, 3, 3),
 'block3.3.conv_bn2.bn.weight',	(512,),	 'layer3.3.bn2.weight',	(512,),
 'block3.3.conv_bn2.bn.bias',	(512,),	 'layer3.3.bn2.bias',	(512,),
 'block3.3.conv_bn2.bn.running_mean',	(512,),	 'layer3.3.bn2.running_mean',	(512,),
 'block3.3.conv_bn2.bn.running_var',	(512,),	 'layer3.3.bn2.running_var',	(512,),
 'block3.3.conv_bn3.conv.weight',	(1024, 512, 1, 1),	 'layer3.3.conv3.weight',	(1024, 512, 1, 1),
 'block3.3.conv_bn3.bn.weight',	(1024,),	 'layer3.3.bn3.weight',	(1024,),
 'block3.3.conv_bn3.bn.bias',	(1024,),	 'layer3.3.bn3.bias',	(1024,),
 'block3.3.conv_bn3.bn.running_mean',	(1024,),	 'layer3.3.bn3.running_mean',	(1024,),
 'block3.3.conv_bn3.bn.running_var',	(1024,),	 'layer3.3.bn3.running_var',	(1024,),
 'block3.3.scale.fc1.weight',	(64, 1024, 1, 1),	 'layer3.3.se_module.fc1.weight',	(64, 1024, 1, 1),
 'block3.3.scale.fc1.bias',	(64,),	 'layer3.3.se_module.fc1.bias',	(64,),
 'block3.3.scale.fc2.weight',	(1024, 64, 1, 1),	 'layer3.3.se_module.fc2.weight',	(1024, 64, 1, 1),
 'block3.3.scale.fc2.bias',	(1024,),	 'layer3.3.se_module.fc2.bias',	(1024,),
 'block3.4.conv_bn1.conv.weight',	(512, 1024, 1, 1),	 'layer3.4.conv1.weight',	(512, 1024, 1, 1),
 'block3.4.conv_bn1.bn.weight',	(512,),	 'layer3.4.bn1.weight',	(512,),
 'block3.4.conv_bn1.bn.bias',	(512,),	 'layer3.4.bn1.bias',	(512,),
 'block3.4.conv_bn1.bn.running_mean',	(512,),	 'layer3.4.bn1.running_mean',	(512,),
 'block3.4.conv_bn1.bn.running_var',	(512,),	 'layer3.4.bn1.running_var',	(512,),
 'block3.4.conv_bn2.conv.weight',	(512, 16, 3, 3),	 'layer3.4.conv2.weight',	(512, 16, 3, 3),
 'block3.4.conv_bn2.bn.weight',	(512,),	 'layer3.4.bn2.weight',	(512,),
 'block3.4.conv_bn2.bn.bias',	(512,),	 'layer3.4.bn2.bias',	(512,),
 'block3.4.conv_bn2.bn.running_mean',	(512,),	 'layer3.4.bn2.running_mean',	(512,),
 'block3.4.conv_bn2.bn.running_var',	(512,),	 'layer3.4.bn2.running_var',	(512,),
 'block3.4.conv_bn3.conv.weight',	(1024, 512, 1, 1),	 'layer3.4.conv3.weight',	(1024, 512, 1, 1),
 'block3.4.conv_bn3.bn.weight',	(1024,),	 'layer3.4.bn3.weight',	(1024,),
 'block3.4.conv_bn3.bn.bias',	(1024,),	 'layer3.4.bn3.bias',	(1024,),
 'block3.4.conv_bn3.bn.running_mean',	(1024,),	 'layer3.4.bn3.running_mean',	(1024,),
 'block3.4.conv_bn3.bn.running_var',	(1024,),	 'layer3.4.bn3.running_var',	(1024,),
 'block3.4.scale.fc1.weight',	(64, 1024, 1, 1),	 'layer3.4.se_module.fc1.weight',	(64, 1024, 1, 1),
 'block3.4.scale.fc1.bias',	(64,),	 'layer3.4.se_module.fc1.bias',	(64,),
 'block3.4.scale.fc2.weight',	(1024, 64, 1, 1),	 'layer3.4.se_module.fc2.weight',	(1024, 64, 1, 1),
 'block3.4.scale.fc2.bias',	(1024,),	 'layer3.4.se_module.fc2.bias',	(1024,),
 'block3.5.conv_bn1.conv.weight',	(512, 1024, 1, 1),	 'layer3.5.conv1.weight',	(512, 1024, 1, 1),
 'block3.5.conv_bn1.bn.weight',	(512,),	 'layer3.5.bn1.weight',	(512,),
 'block3.5.conv_bn1.bn.bias',	(512,),	 'layer3.5.bn1.bias',	(512,),
 'block3.5.conv_bn1.bn.running_mean',	(512,),	 'layer3.5.bn1.running_mean',	(512,),
 'block3.5.conv_bn1.bn.running_var',	(512,),	 'layer3.5.bn1.running_var',	(512,),
 'block3.5.conv_bn2.conv.weight',	(512, 16, 3, 3),	 'layer3.5.conv2.weight',	(512, 16, 3, 3),
 'block3.5.conv_bn2.bn.weight',	(512,),	 'layer3.5.bn2.weight',	(512,),
 'block3.5.conv_bn2.bn.bias',	(512,),	 'layer3.5.bn2.bias',	(512,),
 'block3.5.conv_bn2.bn.running_mean',	(512,),	 'layer3.5.bn2.running_mean',	(512,),
 'block3.5.conv_bn2.bn.running_var',	(512,),	 'layer3.5.bn2.running_var',	(512,),
 'block3.5.conv_bn3.conv.weight',	(1024, 512, 1, 1),	 'layer3.5.conv3.weight',	(1024, 512, 1, 1),
 'block3.5.conv_bn3.bn.weight',	(1024,),	 'layer3.5.bn3.weight',	(1024,),
 'block3.5.conv_bn3.bn.bias',	(1024,),	 'layer3.5.bn3.bias',	(1024,),
 'block3.5.conv_bn3.bn.running_mean',	(1024,),	 'layer3.5.bn3.running_mean',	(1024,),
 'block3.5.conv_bn3.bn.running_var',	(1024,),	 'layer3.5.bn3.running_var',	(1024,),
 'block3.5.scale.fc1.weight',	(64, 1024, 1, 1),	 'layer3.5.se_module.fc1.weight',	(64, 1024, 1, 1),
 'block3.5.scale.fc1.bias',	(64,),	 'layer3.5.se_module.fc1.bias',	(64,),
 'block3.5.scale.fc2.weight',	(1024, 64, 1, 1),	 'layer3.5.se_module.fc2.weight',	(1024, 64, 1, 1),
 'block3.5.scale.fc2.bias',	(1024,),	 'layer3.5.se_module.fc2.bias',	(1024,),
 'block4.0.conv_bn1.conv.weight',	(1024, 1024, 1, 1),	 'layer4.0.conv1.weight',	(1024, 1024, 1, 1),
 'block4.0.conv_bn1.bn.weight',	(1024,),	 'layer4.0.bn1.weight',	(1024,),
 'block4.0.conv_bn1.bn.bias',	(1024,),	 'layer4.0.bn1.bias',	(1024,),
 'block4.0.conv_bn1.bn.running_mean',	(1024,),	 'layer4.0.bn1.running_mean',	(1024,),
 'block4.0.conv_bn1.bn.running_var',	(1024,),	 'layer4.0.bn1.running_var',	(1024,),
 'block4.0.conv_bn2.conv.weight',	(1024, 32, 3, 3),	 'layer4.0.conv2.weight',	(1024, 32, 3, 3),
 'block4.0.conv_bn2.bn.weight',	(1024,),	 'layer4.0.bn2.weight',	(1024,),
 'block4.0.conv_bn2.bn.bias',	(1024,),	 'layer4.0.bn2.bias',	(1024,),
 'block4.0.conv_bn2.bn.running_mean',	(1024,),	 'layer4.0.bn2.running_mean',	(1024,),
 'block4.0.conv_bn2.bn.running_var',	(1024,),	 'layer4.0.bn2.running_var',	(1024,),
 'block4.0.conv_bn3.conv.weight',	(2048, 1024, 1, 1),	 'layer4.0.conv3.weight',	(2048, 1024, 1, 1),
 'block4.0.conv_bn3.bn.weight',	(2048,),	 'layer4.0.bn3.weight',	(2048,),
 'block4.0.conv_bn3.bn.bias',	(2048,),	 'layer4.0.bn3.bias',	(2048,),
 'block4.0.conv_bn3.bn.running_mean',	(2048,),	 'layer4.0.bn3.running_mean',	(2048,),
 'block4.0.conv_bn3.bn.running_var',	(2048,),	 'layer4.0.bn3.running_var',	(2048,),
 'block4.0.scale.fc1.weight',	(128, 2048, 1, 1),	 'layer4.0.se_module.fc1.weight',	(128, 2048, 1, 1),
 'block4.0.scale.fc1.bias',	(128,),	 'layer4.0.se_module.fc1.bias',	(128,),
 'block4.0.scale.fc2.weight',	(2048, 128, 1, 1),	 'layer4.0.se_module.fc2.weight',	(2048, 128, 1, 1),
 'block4.0.scale.fc2.bias',	(2048,),	 'layer4.0.se_module.fc2.bias',	(2048,),
 'block4.0.shortcut.conv.weight',	(2048, 1024, 1, 1),	 'layer4.0.downsample.0.weight',	(2048, 1024, 1, 1),
 'block4.0.shortcut.bn.weight',	(2048,),	 'layer4.0.downsample.1.weight',	(2048,),
 'block4.0.shortcut.bn.bias',	(2048,),	 'layer4.0.downsample.1.bias',	(2048,),
 'block4.0.shortcut.bn.running_mean',	(2048,),	 'layer4.0.downsample.1.running_mean',	(2048,),
 'block4.0.shortcut.bn.running_var',	(2048,),	 'layer4.0.downsample.1.running_var',	(2048,),
 'block4.1.conv_bn1.conv.weight',	(1024, 2048, 1, 1),	 'layer4.1.conv1.weight',	(1024, 2048, 1, 1),
 'block4.1.conv_bn1.bn.weight',	(1024,),	 'layer4.1.bn1.weight',	(1024,),
 'block4.1.conv_bn1.bn.bias',	(1024,),	 'layer4.1.bn1.bias',	(1024,),
 'block4.1.conv_bn1.bn.running_mean',	(1024,),	 'layer4.1.bn1.running_mean',	(1024,),
 'block4.1.conv_bn1.bn.running_var',	(1024,),	 'layer4.1.bn1.running_var',	(1024,),
 'block4.1.conv_bn2.conv.weight',	(1024, 32, 3, 3),	 'layer4.1.conv2.weight',	(1024, 32, 3, 3),
 'block4.1.conv_bn2.bn.weight',	(1024,),	 'layer4.1.bn2.weight',	(1024,),
 'block4.1.conv_bn2.bn.bias',	(1024,),	 'layer4.1.bn2.bias',	(1024,),
 'block4.1.conv_bn2.bn.running_mean',	(1024,),	 'layer4.1.bn2.running_mean',	(1024,),
 'block4.1.conv_bn2.bn.running_var',	(1024,),	 'layer4.1.bn2.running_var',	(1024,),
 'block4.1.conv_bn3.conv.weight',	(2048, 1024, 1, 1),	 'layer4.1.conv3.weight',	(2048, 1024, 1, 1),
 'block4.1.conv_bn3.bn.weight',	(2048,),	 'layer4.1.bn3.weight',	(2048,),
 'block4.1.conv_bn3.bn.bias',	(2048,),	 'layer4.1.bn3.bias',	(2048,),
 'block4.1.conv_bn3.bn.running_mean',	(2048,),	 'layer4.1.bn3.running_mean',	(2048,),
 'block4.1.conv_bn3.bn.running_var',	(2048,),	 'layer4.1.bn3.running_var',	(2048,),
 'block4.1.scale.fc1.weight',	(128, 2048, 1, 1),	 'layer4.1.se_module.fc1.weight',	(128, 2048, 1, 1),
 'block4.1.scale.fc1.bias',	(128,),	 'layer4.1.se_module.fc1.bias',	(128,),
 'block4.1.scale.fc2.weight',	(2048, 128, 1, 1),	 'layer4.1.se_module.fc2.weight',	(2048, 128, 1, 1),
 'block4.1.scale.fc2.bias',	(2048,),	 'layer4.1.se_module.fc2.bias',	(2048,),
 'block4.2.conv_bn1.conv.weight',	(1024, 2048, 1, 1),	 'layer4.2.conv1.weight',	(1024, 2048, 1, 1),
 'block4.2.conv_bn1.bn.weight',	(1024,),	 'layer4.2.bn1.weight',	(1024,),
 'block4.2.conv_bn1.bn.bias',	(1024,),	 'layer4.2.bn1.bias',	(1024,),
 'block4.2.conv_bn1.bn.running_mean',	(1024,),	 'layer4.2.bn1.running_mean',	(1024,),
 'block4.2.conv_bn1.bn.running_var',	(1024,),	 'layer4.2.bn1.running_var',	(1024,),
 'block4.2.conv_bn2.conv.weight',	(1024, 32, 3, 3),	 'layer4.2.conv2.weight',	(1024, 32, 3, 3),
 'block4.2.conv_bn2.bn.weight',	(1024,),	 'layer4.2.bn2.weight',	(1024,),
 'block4.2.conv_bn2.bn.bias',	(1024,),	 'layer4.2.bn2.bias',	(1024,),
 'block4.2.conv_bn2.bn.running_mean',	(1024,),	 'layer4.2.bn2.running_mean',	(1024,),
 'block4.2.conv_bn2.bn.running_var',	(1024,),	 'layer4.2.bn2.running_var',	(1024,),
 'block4.2.conv_bn3.conv.weight',	(2048, 1024, 1, 1),	 'layer4.2.conv3.weight',	(2048, 1024, 1, 1),
 'block4.2.conv_bn3.bn.weight',	(2048,),	 'layer4.2.bn3.weight',	(2048,),
 'block4.2.conv_bn3.bn.bias',	(2048,),	 'layer4.2.bn3.bias',	(2048,),
 'block4.2.conv_bn3.bn.running_mean',	(2048,),	 'layer4.2.bn3.running_mean',	(2048,),
 'block4.2.conv_bn3.bn.running_var',	(2048,),	 'layer4.2.bn3.running_var',	(2048,),
 'block4.2.scale.fc1.weight',	(128, 2048, 1, 1),	 'layer4.2.se_module.fc1.weight',	(128, 2048, 1, 1),
 'block4.2.scale.fc1.bias',	(128,),	 'layer4.2.se_module.fc1.bias',	(128,),
 'block4.2.scale.fc2.weight',	(2048, 128, 1, 1),	 'layer4.2.se_module.fc2.weight',	(2048, 128, 1, 1),
 'block4.2.scale.fc2.bias',	(2048,),	 'layer4.2.se_module.fc2.bias',	(2048,),
 'logit.weight',	(1000, 1280),	 'last_linear.weight',	(1000, 2048),
 'logit.bias',	(1000,),	 'last_linear.bias',	(1000,),
]

PRETRAIN_FILE = 'data/pretrained_models/se_resnext50_32x4d-a260b3a4.pth'

def load_pretrain(net, skip=[], pretrain_file=PRETRAIN_FILE, conversion=CONVERSION, is_print=True):

    print('\tload pretrain_file: %s'%pretrain_file)

    #pretrain_state_dict = torch.load(pretrain_file)
    pretrain_state_dict = torch.load(pretrain_file, map_location=lambda storage, loc: storage)
    state_dict = net.state_dict()

    i = 0
    conversion = np.array(CONVERSION).reshape(-1,4)
    for key,_,pretrain_key,_ in conversion:
        if any(s in key for s in
            ['.num_batches_tracked',]+skip):
            continue

        #print('\t\t',key)
        if is_print:
            print('\t\t','%-48s  %-24s  <---  %-32s  %-24s'%(
                key, str(state_dict[key].shape),
                pretrain_key, str(pretrain_state_dict[pretrain_key].shape),
            ))
        i = i+1

        state_dict[key] = pretrain_state_dict[pretrain_key]

    net.load_state_dict(state_dict)
    print('')
    print('len(pretrain_state_dict.keys()) = %d'%len(pretrain_state_dict.keys()))
    print('len(state_dict.keys())          = %d'%len(state_dict.keys()))
    print('loaded    = %d'%i)
    print('')

### create split

In [9]:
#image_file =  glob.glob('data/train_pseudo_images_confident/*.jpg') #train_images
#image_file = ['train_pseudo_images_confident/'+i.split('/')[-1] for i in image_file] #train_images
#random.shuffle(image_file)
#
##12568
#num_valid = 1000
#num_all   = len(image_file)
#num_train = num_all-num_valid
#
#train=np.array(image_file[num_valid:])
#valid=np.array(image_file[:num_valid])
#
#print(len(image_file))
#print(len(train))
#print(len(valid))
#print(len(valid)/len(train))
#
#np.save('data/split/train0_%d.npy'%len(train),train)
#np.save('data/split/valid0_%d.npy'%len(valid),valid)
#
#print('train0_%d.npy'%len(train))
#print('valid0_%d.npy'%len(valid))

In [10]:
# stats for train_pseudo_images_confident:
#14045
#13045
#1000
#0.07665772326561901
#train0_13045.npy
#valid0_1000.npy

### define training loop

In [11]:
def run_train():
    batch_size = 6
    
    initial_checkpoint = 'data/classification_models/se-resnext50/00049999_model.pth'
    #'resnet34-cls-full-foldb0-0/checkpoint/00007500_model.pth'
    
    train_dataset = SteelDataset(
        mode    = 'train',
        csv     = ['train_pseudolabel_segmentation.csv',], 
        split   = ['train0_13045.npy'], #train0_11968.npy

        augment = train_augment,
    )
    train_loader  = DataLoader(
        train_dataset,
        #sampler     = BalanceClassSampler(train_dataset, 3*len(train_dataset)),
        #sampler    = SequentialSampler(train_dataset),
        #sampler    = RandomSampler(train_dataset),
        sampler    = FiveBalanceClassSampler(train_dataset),
        batch_size  = batch_size,
        drop_last   = True,
        num_workers = 4,
        pin_memory  = True,
        collate_fn  = null_collate
    )

    valid_dataset = SteelDataset(
        mode    = 'train',
        csv     = ['train_pseudolabel_segmentation.csv'], 
        split   = ['valid0_1000.npy'], #valid_b1_1000.npy
        augment = valid_augment,
    )
    valid_loader = DataLoader(
        valid_dataset,
        sampler    = SequentialSampler(valid_dataset),
        #sampler     = RandomSampler(valid_dataset),
        batch_size  = 4,
        drop_last   = False,
        num_workers = 4,
        pin_memory  = True,
        collate_fn  = null_collate
    )
    
    assert(len(train_dataset)>=batch_size)
    
    net = Net().cuda()
    
    if initial_checkpoint is not None:
        state_dict = torch.load(initial_checkpoint, map_location=lambda storage, loc: storage)
        #for k in ['logit.weight','logit.bias']: state_dict.pop(k, None)
        net.load_state_dict(state_dict,strict=False)
    else:
        net.load_pretrain(skip=['logit'], is_print=False)

    num_iters   = 60*1000 #50*1000
    iter_smooth = 50
    iter_log    = 500
    iter_valid  = 500
    iter_save   = [num_iters-1] + list(range(0, num_iters, 1000))#1*1000
    
    #optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, net.parameters()), lr=schduler(0), momentum=0.9, weight_decay=0.0001)
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, net.parameters()))
    
    start_iter = 0
    start_epoch= 0
    rate       = 0
    if initial_checkpoint is not None:
        initial_optimizer = initial_checkpoint.replace('_model.pth','_optimizer.pth')
        if os.path.exists(initial_optimizer):
            checkpoint  = torch.load(initial_optimizer)
            start_iter  = checkpoint['iter' ]
            start_epoch = checkpoint['epoch']
            optimizer.load_state_dict(checkpoint['optimizer'])
    
    #scheduler = NullScheduler(lr=0.001)
    max_lr = 0.0001 #0.0002
    #scheduler = OneCycleLR(optimizer, max_lr=max_lr, div_factor=25, pct_start=0.3, total_steps=num_iters)
    scheduler = OneCycleLR(optimizer, max_lr=max_lr, div_factor=25, pct_start=0.9, total_steps=num_iters)
    lr = scheduler.get_lr()[0]
    
    train_loss = np.zeros(20,np.float32)
    valid_loss = np.zeros(20,np.float32)
    batch_loss = np.zeros(20,np.float32)
    iter_accum = 8
    iter = 0
    i    = 0
    
    start = timer()
    while  iter<num_iters:
        sum_train_loss = np.zeros(20,np.float32)
        sum = np.zeros(20,np.float32)

        optimizer.zero_grad()
        for t, (input, truth_mask, truth_label, infor) in enumerate(train_loader):
            batch_size = len(infor)
            iter  = i + start_iter
            epoch = (iter-start_iter)*batch_size/len(train_dataset) + start_epoch
            
            # Weather to display images or not while in validation loss
            displays = None
            
            if (iter % iter_valid==0):
                valid_loss = do_valid(net, valid_loader, displays) # omitted outdir variable
                #pass

            if (iter % iter_log==0):
                print('\r',end='',flush=True)
                asterisk = '*' if iter in iter_save else ' '
                print('%0.8f  %5.1f%s %5.1f |  %5.3f   %4.4f [%4.4f,%4.4f,%4.4f,%4.4f]  |  %5.3f   %4.4f [%4.4f,%4.4f,%4.4f,%4.4f]  | %s' % (\
                         lr, iter/1000, asterisk, epoch,
                         *valid_loss[:6],
                         *train_loss[:6],
                         time_to_str((timer() - start)))
                )
                print('\n')
                
            if iter in iter_save:
                torch.save(net.state_dict(),'data/classification_models/se-resnext50/%08d_model.pth'%(iter))
                torch.save({
                    'optimizer': optimizer.state_dict(),
                    'iter'     : iter,
                    'epoch'    : epoch,
                }, 'data/classification_models/se-resnext50/%08d_optimizer.pth'%(iter))
                pass

            # learning rate schduler -------------
            lr = scheduler.get_lr()[0]
            if lr<0 : break
            #adjust_learning_rate(optimizer, lr)
            #rate = get_learning_rate(optimizer)
            
            net.train()
            input = input.cuda()
            truth_label = truth_label.cuda()
            truth_mask  = truth_mask.cuda()

            optimizer.zero_grad()
            
            logit =  net(input) #data_parallel(net,input)  
            logit = logit.max(-1,True)[0]
            loss = criterion(logit, truth_label)
            tn,tp, num_neg,num_pos = metric_hit(logit, truth_label)
            
            (loss/iter_accum).backward()
            if (iter % iter_accum)==0:
                optimizer.step()
                if iter < num_iters:
                    scheduler.step(iter)

            # print statistics  ------------
            l = np.array([ loss.item(), tn,*tp ])
            n = np.array([ batch_size, num_neg,*num_pos ])

            batch_loss[:6] = l
            sum_train_loss[:6] += l*n
            sum[:6] += n
            if iter%iter_smooth == 0:
                train_loss = sum_train_loss/(sum+1e-12)
                sum_train_loss[...] = 0
                sum[...]            = 0


            print('\r',end='',flush=True)
            asterisk = ' '
            print('%0.8f  %5.1f%s %5.1f |  %5.3f   %4.4f [%4.4f,%4.4f,%4.4f,%4.4f]  |  %5.3f   %4.2f [%4.4f,%4.4f,%4.4f,%4.4f]  | %s' % (\
                         lr, iter/1000, asterisk, epoch,
                         *valid_loss[:6],
                         *batch_loss[:6],
                         time_to_str((timer() - start)))
            , end='',flush=True)
            i=i+1
           
            # debug-----------------------------
            if 1:
                for di in range(3):
                    if (iter+di)%1000==0:

                        probability = torch.sigmoid(logit)
                        image = input_to_image(input, IMAGE_RGB_MEAN,IMAGE_RGB_STD)

                        probability_label = probability.data.cpu().numpy()
                        truth_label = truth_label.data.cpu().numpy()
                        truth_mask  = truth_mask.data.cpu().numpy()


                        for b in range(batch_size):
                            result = draw_predict_result_label(image[b], truth_mask[b], truth_label[b], probability_label[b], stack='vertical')

In [12]:
print('lr          iter   epoch |  loss    tn, [tp1,tp2,tp3,tp4]       |  loss    tn, [tp1,tp2,tp3,tp4]       | time           ')
print('--------------------------------------------------------------------------------------------------------------------\n')
run_train()

lr          iter   epoch |  loss    tn, [tp1,tp2,tp3,tp4]       |  loss    tn, [tp1,tp2,tp3,tp4]       | time           
--------------------------------------------------------------------------------------------------------------------

0.00004000   50.0*  23.0 |  0.116   0.9692 [0.9324,0.9048,0.9091,0.9610]  |  0.000   0.0000 [0.0000,0.0000,0.0000,0.0000]  |  0 hr 00 min


0.00009999   50.5   23.2 |  0.120   0.9692 [0.8919,0.9524,0.8923,0.9481]  |  0.125   0.9717 [0.8806,0.9839,0.7303,0.9839]  |  0 hr 07 min


0.00010000   51.0*  23.5 |  0.120   0.9716 [0.9054,0.9048,0.8828,0.9610]  |  0.132   0.9727 [0.8649,0.9048,0.8452,0.9524]  |  0 hr 14 min


0.00009925   51.5   23.7 |  0.117   0.9721 [0.9054,0.9524,0.8732,0.9610]  |  0.151   0.9672 [0.8438,0.8387,0.8132,0.9254]  |  0 hr 21 min


0.00009703   52.0*  23.9 |  0.132   0.9619 [0.8919,0.9524,0.9019,0.9481]  |  0.130   0.9747 [0.9296,0.9531,0.6915,0.9508]  |  0 hr 28 min


0.00009333   52.5   24.1 |  0.128   0.9633 [0.9189,0.9524,0.8

KeyboardInterrupt: 

## continued training

### 1st try

lr          iter   epoch |  loss    tn, [tp1,tp2,tp3,tp4]       |  loss    tn, [tp1,tp2,tp3,tp4]       | time           
--------------------------------------------------------------------------------------------------------------------

0.00004000   50.0*  23.0 |  0.115   0.9698 [0.9324,0.9048,0.9115,0.9481]  |  0.000   0.0000 [0.0000,0.0000,0.0000,0.0000]  |  0 hr 00 min


0.00003425   50.5   23.2 |  0.124   0.9686 [0.9324,0.9048,0.8780,0.9610]  |  0.136   0.9747 [0.8824,0.9365,0.7097,0.8939]  |  0 hr 08 min


0.00003275   51.0*  23.5 |  0.113   0.9716 [0.9189,0.9524,0.8995,0.9740]  |  0.108   0.9770 [0.9155,0.9048,0.7865,1.0000]  |  0 hr 15 min


0.00003125   51.5   23.7 |  0.116   0.9692 [0.9189,0.9524,0.9187,0.9740]  |  0.104   0.9847 [0.8611,0.9355,0.7529,0.9531]  |  0 hr 22 min


0.00002978   52.0*  23.9 |  0.115   0.9713 [0.9189,0.9524,0.8852,0.9740]  |  0.117   0.9792 [0.9000,0.9365,0.7473,0.9531]  |  0 hr 29 min


0.00002831   52.5   24.1 |  0.118   0.9701 [0.9189,0.9524,0.9091,0.9740]  |  0.176   0.9637 [0.8406,0.8906,0.6842,0.9355]  |  0 hr 36 min


0.00002689   53.0*  24.4 |  0.123   0.9648 [0.8919,0.9524,0.9187,0.9740]  |  0.117   0.9804 [0.8923,0.9180,0.7419,0.9531]  |  0 hr 42 min

### 2nd try

lr          iter   epoch |  loss    tn, [tp1,tp2,tp3,tp4]       |  loss    tn, [tp1,tp2,tp3,tp4]       | time           
--------------------------------------------------------------------------------------------------------------------

0.00004000   50.0*  23.0 |  0.116   0.9692 [0.9324,0.9048,0.9091,0.9610]  |  0.000   0.0000 [0.0000,0.0000,0.0000,0.0000]  |  0 hr 00 min


0.00009999   50.5   23.2 |  0.120   0.9692 [0.8919,0.9524,0.8923,0.9481]  |  0.125   0.9717 [0.8806,0.9839,0.7303,0.9839]  |  0 hr 07 min


0.00010000   51.0*  23.5 |  0.120   0.9716 [0.9054,0.9048,0.8828,0.9610]  |  0.132   0.9727 [0.8649,0.9048,0.8452,0.9524]  |  0 hr 14 min


0.00009925   51.5   23.7 |  0.117   0.9721 [0.9054,0.9524,0.8732,0.9610]  |  0.151   0.9672 [0.8438,0.8387,0.8132,0.9254]  |  0 hr 21 min


0.00009703   52.0*  23.9 |  0.132   0.9619 [0.8919,0.9524,0.9019,0.9481]  |  0.130   0.9747 [0.9296,0.9531,0.6915,0.9508]  |  0 hr 28 min


0.00009333   52.5   24.1 |  0.128   0.9633 [0.9189,0.9524,0.8900,0.9481]  |  0.118   0.9672 [0.8939,0.9048,0.7849,1.0000]  |  0 hr 35 min


0.00008838   53.0*  24.4 |  0.120   0.9718 [0.8784,0.9524,0.8589,0.9351]  |  0.114   0.9738 [0.9000,0.9839,0.7865,0.9375]  |  0 hr 42 min


0.00008218   53.5   24.6 |  0.117   0.9710 [0.8784,0.9524,0.8828,0.9481]  |  0.140   0.9706 [0.8529,0.9355,0.7111,0.9839]  |  0 hr 49 min


0.00007511   54.0*  24.8 |  0.123   0.9692 [0.9189,0.9524,0.8947,0.9351]  |  0.110   0.9794 [0.9692,0.9206,0.7556,0.9667]  |  0 hr 55 min


0.00006715   54.5   25.1 |  0.109   0.9730 [0.8919,0.9524,0.9091,0.9481]  |  0.126   0.9693 [0.8824,0.8833,0.7895,0.9385]  |  1 hr 02 min


0.00005880   55.0*  25.3 |  0.108   0.9721 [0.9189,0.9524,0.8612,0.9481]  |  0.115   0.9762 [0.8906,0.8594,0.7619,0.9839]  |  1 hr 09 min


0.00005005   55.5   25.5 |  0.111   0.9716 [0.9595,0.9524,0.8660,0.9610]  |  0.126   0.9792 [0.9254,0.8939,0.7340,0.9500]  |  1 hr 16 min


0.00004144   56.0*  25.8 |  0.112   0.9707 [0.9459,0.9524,0.8684,0.9740]  |  0.126   0.9762 [0.8906,0.9167,0.7444,0.9672]  |  1 hr 23 min


0.00003295   56.5   26.0 |  0.120   0.9654 [0.9865,0.9524,0.8708,0.9740]  |  0.140   0.9749 [0.8507,0.8906,0.7253,0.9672]  |  1 hr 30 min


0.00002511   57.0*  26.2 |  0.117   0.9674 [0.9595,0.9524,0.8612,0.9740]  |  0.109   0.9825 [0.8841,0.9683,0.7416,1.0000]  |  1 hr 37 min


0.00001790   57.5   26.4 |  0.112   0.9669 [0.9459,0.9524,0.8804,0.9610]  |  0.137   0.9705 [0.8611,0.9516,0.7303,0.9683]  |  1 hr 44 min


0.00001178   58.0*  26.7 |  0.106   0.9721 [0.9459,0.9524,0.8565,0.9740]  |  0.102   0.9815 [0.9265,0.9508,0.8140,0.9531]  |  1 hr 50 min


0.00000673   58.5   26.9 |  0.110   0.9713 [0.9459,0.9524,0.8780,0.9740]  |  0.143   0.9746 [0.8472,0.9365,0.6947,0.9692]  |  1 hr 57 min


0.00000306   59.0*  27.1 |  0.113   0.9677 [0.9459,0.9524,0.8971,0.9610]  |  0.108   0.9696 [0.9231,0.9683,0.7159,1.0000]  |  2 hr 04 min

## best so far

0.00000810   48.0*  22.1 |  0.108   0.9730 [0.9054,0.9048,0.8876,0.9610]  |  0.144   0.9738 [0.8767,0.9153,0.7326,0.9701]  | 10 hr 45 min


0.00000455   48.5   22.3 |  0.122   0.9657 [0.9054,0.9524,0.9378,0.9481]  |  0.122   0.9814 [0.9420,0.8906,0.7391,0.9365]  | 10 hr 52 min


0.00000205   49.0*  22.5 |  0.116   0.9704 [0.9459,0.9048,0.8852,0.9610]  |  0.119   0.9716 [0.8732,0.9394,0.7674,0.9683]  | 10 hr 58 min


0.00000051   49.5   22.8 |  0.104   0.9771 [0.8919,0.9048,0.8636,0.9481]  |  0.117   0.9693 [0.8841,0.9077,0.7753,0.9545]  | 11 hr 05 min


0.00000000   50.0   23.0 |  0.113   0.9718 [0.9189,0.9524,0.9091,0.9481]  |  0.127   0.9661 [0.8382,0.9077,0.7444,0.9524]  | 11 hr 12 min

lr          iter   epoch |  loss    tn, [tp1,tp2,tp3,tp4]       |  loss    tn, [tp1,tp2,tp3,tp4]       | time           
--------------------------------------------------------------------------------------------------------------------

	load pretrain_file: data/pretrained_models/se_resnext50_32x4d-a260b3a4.pth

len(pretrain_state_dict.keys()) = 331
len(state_dict.keys())          = 386
loaded    = 329

0.00004000    0.0*   0.0 |  2.353   0.1147 [0.9459,1.0000,0.9163,0.9740]  |  0.000   0.0000 [0.0000,0.0000,0.0000,0.0000]  |  0 hr 00 min


0.00004259    0.5    0.2 |  0.399   0.9962 [0.0000,0.0000,0.0670,0.0000]  |  0.512   0.9967 [0.0000,0.0164,0.0000,0.0000]  |  0 hr 07 min


0.00005032    1.0*   0.5 |  0.315   0.9648 [0.0000,0.7143,0.4450,0.8312]  |  0.413   0.9499 [0.0000,0.4098,0.3864,0.6562]  |  0 hr 13 min


0.00006337    1.5    0.7 |  0.303   0.9214 [0.2568,0.9524,0.4498,0.8831]  |  0.348   0.9561 [0.0800,0.7937,0.3933,0.8033]  |  0 hr 20 min


0.00008118    2.0*   0.9 |  0.340   0.9302 [0.3649,0.8095,0.6268,0.7013]  |  0.351   0.9835 [0.2222,0.6935,0.3711,0.6613]  |  0 hr 27 min


0.00010412    2.5    1.1 |  0.247   0.9616 [0.2297,0.9048,0.5239,0.9351]  |  0.319   0.9613 [0.2703,0.8387,0.3711,0.8413]  |  0 hr 34 min


0.00013121    3.0*   1.4 |  0.303   0.9044 [0.7568,1.0000,0.6053,0.8701]  |  0.279   0.9541 [0.3514,0.7705,0.6000,0.8750]  |  0 hr 40 min


0.00016304    3.5    1.6 |  0.282   0.9044 [0.8784,0.5714,0.5718,0.8831]  |  0.315   0.9394 [0.6000,0.7419,0.4762,0.7077]  |  0 hr 47 min


0.00019824    4.0*   1.8 |  0.308   0.9381 [0.7703,0.5238,0.4689,0.9740]  |  0.281   0.9619 [0.2500,0.8889,0.5909,0.9194]  |  0 hr 54 min


0.00023756    4.5    2.1 |  0.374   0.8792 [0.8514,1.0000,0.5215,0.8961]  |  0.276   0.9543 [0.5672,0.8226,0.5172,0.9062]  |  1 hr 01 min


0.00027933    5.0*   2.3 |  0.306   0.8985 [0.8649,0.4286,0.7081,0.8701]  |  0.301   0.9429 [0.6081,0.8548,0.3407,0.8095]  |  1 hr 07 min


0.00032443    5.5    2.5 |  0.373   0.9182 [0.5541,0.4762,0.5718,0.8961]  |  0.322   0.9591 [0.2500,0.6032,0.4699,0.7742]  |  1 hr 14 min


0.00037095    6.0*   2.8 |  0.245   0.9185 [0.7703,0.9048,0.7871,0.8701]  |  0.288   0.9441 [0.6389,0.7213,0.4783,0.8730]  |  1 hr 21 min


0.00041985    6.5    3.0 |  0.227   0.9733 [0.3243,0.9048,0.6914,0.6234]  |  0.312   0.9467 [0.4638,0.7302,0.5402,0.7869]  |  1 hr 28 min


0.00046907    7.0*   3.2 |  0.244   0.9566 [0.5000,0.6190,0.5502,0.9221]  |  0.311   0.9464 [0.4000,0.7258,0.6444,0.7143]  |  1 hr 34 min


0.00051965    7.5    3.4 |  0.324   0.9000 [0.6892,0.9524,0.5646,0.9221]  |  0.302   0.9749 [0.3016,0.8788,0.3478,0.8281]  |  1 hr 41 min


0.00056943    8.0*   3.7 |  0.323   0.9352 [0.4459,0.9048,0.4115,0.8052]  |  0.312   0.9384 [0.4848,0.8871,0.4643,0.7302]  |  1 hr 48 min


0.00061946    8.5    3.9 |  0.320   0.9179 [0.7432,0.8095,0.4019,0.9351]  |  0.293   0.9456 [0.5507,0.7500,0.5393,0.8254]  |  1 hr 55 min


0.00066762    9.0*   4.1 |  0.332   0.9188 [0.8243,0.8095,0.4761,0.6494]  |  0.310   0.9518 [0.5972,0.8197,0.5165,0.8889]  |  2 hr 01 min


0.00071492    9.5    4.4 |  0.355   0.8938 [0.9054,0.3333,0.4067,0.7273]  |  0.339   0.9347 [0.5000,0.7460,0.4200,0.7576]  |  2 hr 08 min


0.00075936   10.0*   4.6 |  0.304   0.9065 [0.7432,0.7143,0.6579,0.6234]  |  0.343   0.9069 [0.5147,0.8065,0.4118,0.9016]  |  2 hr 15 min


0.00080187   10.5    4.8 |  0.378   0.9067 [0.7162,0.2381,0.7153,0.8961]  |  0.334   0.9454 [0.6176,0.7742,0.4176,0.7031]  |  2 hr 22 min


0.00084064   11.0*   5.1 |  0.288   0.9370 [0.5811,0.9524,0.5670,0.8442]  |  0.331   0.9471 [0.4932,0.8095,0.3483,0.8209]  |  2 hr 28 min


0.00087649   11.5    5.3 |  0.437   0.8724 [0.8243,0.5714,0.5718,0.3896]  |  0.362   0.9276 [0.5205,0.6885,0.3913,0.8889]  |  2 hr 35 min


0.00090790   12.0*   5.5 |  0.258   0.9367 [0.5135,0.6190,0.4785,0.9351]  |  0.339   0.9575 [0.5147,0.7500,0.3605,0.6719]  |  2 hr 42 min


0.00093553   12.5    5.7 |  0.239   0.9208 [0.8243,0.1905,0.7033,0.9221]  |  0.312   0.9573 [0.4154,0.7188,0.3587,0.9231]  |  2 hr 49 min


0.00095821   13.0*   6.0 |  0.301   0.9038 [0.7973,0.5714,0.6316,0.8831]  |  0.323   0.9406 [0.5493,0.5672,0.3407,0.9355]  |  2 hr 56 min


0.00097641   13.5    6.2 |  0.315   0.9170 [0.7568,0.9048,0.6435,0.8831]  |  0.282   0.9454 [0.5694,0.6667,0.5275,0.8065]  |  3 hr 02 min


0.00098936   14.0*   6.4 |  0.341   0.8933 [0.8108,0.6667,0.6220,0.8831]  |  0.275   0.9650 [0.5797,0.7000,0.4842,0.9048]  |  3 hr 09 min


0.00099734   14.5    6.7 |  0.302   0.9106 [0.8514,0.9048,0.3900,0.5584]  |  0.306   0.9491 [0.1571,0.7619,0.5309,0.7903]  |  3 hr 16 min


0.00100000   15.0*   6.9 |  0.377   0.9073 [0.7432,0.9524,0.3469,0.8182]  |  0.324   0.9500 [0.6812,0.7258,0.2727,0.7869]  |  3 hr 23 min


0.00099950   15.5    7.1 |  0.277   0.9352 [0.8243,0.8571,0.4641,0.8312]  |  0.329   0.9532 [0.5735,0.6129,0.2299,0.8000]  |  3 hr 30 min


0.00099802   16.0*   7.4 |  0.226   0.9422 [0.7703,0.8095,0.5861,0.9351]  |  0.265   0.9516 [0.6197,0.7538,0.5333,0.9231]  |  3 hr 36 min


0.00099549   16.5    7.6 |  0.276   0.9120 [0.7297,0.8095,0.6435,0.7662]  |  0.300   0.9530 [0.6667,0.6508,0.4574,0.7143]  |  3 hr 43 min


0.00099202   17.0*   7.8 |  0.266   0.9032 [0.7838,0.8571,0.8014,0.6753]  |  0.310   0.9570 [0.6769,0.7143,0.3400,0.8182]  |  3 hr 50 min


0.00098749   17.5    8.0 |  0.270   0.9073 [0.7838,0.9524,0.6220,0.8182]  |  0.326   0.9595 [0.4242,0.7302,0.2717,0.8462]  |  3 hr 57 min


0.00098206   18.0*   8.3 |  0.227   0.9282 [0.8514,0.9048,0.7273,0.7662]  |  0.260   0.9568 [0.6447,0.7619,0.6737,0.7656]  |  4 hr 04 min


0.00097557   18.5    8.5 |  0.273   0.9416 [0.7432,0.9524,0.5072,0.7273]  |  0.228   0.9397 [0.6970,0.8438,0.7317,0.8833]  |  4 hr 10 min


0.00096823   19.0*   8.7 |  0.241   0.9334 [0.8784,0.7619,0.7416,0.7143]  |  0.270   0.9574 [0.6269,0.6923,0.5889,0.8095]  |  4 hr 17 min


0.00095982   19.5    9.0 |  0.209   0.9405 [0.8243,0.9524,0.6220,0.8442]  |  0.260   0.9596 [0.5694,0.8033,0.4659,0.8906]  |  4 hr 24 min


0.00095062   20.0*   9.2 |  0.209   0.9405 [0.9054,0.9524,0.6938,0.8961]  |  0.273   0.9451 [0.6438,0.8197,0.5889,0.8636]  |  4 hr 31 min


0.00094036   20.5    9.4 |  0.318   0.8962 [0.9054,1.0000,0.5191,0.8571]  |  0.273   0.9597 [0.4225,0.8730,0.4535,0.8730]  |  4 hr 38 min


0.00092939   21.0*   9.7 |  0.242   0.9243 [0.9054,0.9048,0.8110,0.9351]  |  0.259   0.9470 [0.6389,0.9032,0.5102,0.8387]  |  4 hr 44 min


0.00091736   21.5    9.9 |  0.196   0.9613 [0.7703,0.8095,0.6986,0.7532]  |  0.236   0.9728 [0.6479,0.8571,0.4884,0.9167]  |  4 hr 51 min


0.00090469   22.0*  10.1 |  0.232   0.9484 [0.8378,0.6667,0.6268,0.6883]  |  0.274   0.9364 [0.5970,0.7333,0.6531,0.8889]  |  4 hr 58 min


0.00089100   22.5   10.3 |  0.205   0.9434 [0.7838,0.9048,0.7273,0.9481]  |  0.223   0.9508 [0.8209,0.9167,0.6170,0.8769]  |  5 hr 05 min


0.00087674   23.0*  10.6 |  0.188   0.9630 [0.7162,0.9524,0.6077,0.8701]  |  0.221   0.9587 [0.6761,0.8833,0.6023,0.9016]  |  5 hr 11 min


0.00086149   23.5   10.8 |  0.200   0.9569 [0.7973,0.9048,0.6316,0.8312]  |  0.238   0.9462 [0.6522,0.8689,0.6907,0.8730]  |  5 hr 18 min


0.00084576   24.0*  11.0 |  0.206   0.9490 [0.6757,0.9524,0.6794,0.8961]  |  0.259   0.9466 [0.5303,0.8871,0.6477,0.8485]  |  5 hr 25 min


0.00082907   24.5   11.3 |  0.186   0.9446 [0.8649,0.9524,0.8397,0.8701]  |  0.269   0.9515 [0.6957,0.6190,0.6022,0.8088]  |  5 hr 32 min


0.00081199   25.0*  11.5 |  0.234   0.9419 [0.8919,0.9524,0.6077,0.8442]  |  0.236   0.9362 [0.7183,0.9180,0.6044,0.8676]  |  5 hr 39 min


0.00079400   25.5   11.7 |  0.253   0.9126 [0.8919,0.9524,0.7153,0.9091]  |  0.227   0.9661 [0.7101,0.8254,0.5955,0.8438]  |  5 hr 45 min


0.00077571   26.0*  12.0 |  0.172   0.9569 [0.8243,0.9524,0.7871,0.9351]  |  0.224   0.9498 [0.6714,0.8254,0.6279,0.8594]  |  5 hr 52 min


0.00075657   26.5   12.2 |  0.201   0.9437 [0.8108,0.8571,0.8062,0.7662]  |  0.256   0.9780 [0.7042,0.8438,0.5000,0.7031]  |  5 hr 59 min


0.00073721   27.0*  12.4 |  0.193   0.9469 [0.8243,0.9048,0.7656,0.8961]  |  0.213   0.9675 [0.6957,0.8689,0.5904,0.9375]  |  6 hr 06 min


0.00071706   27.5   12.6 |  0.194   0.9595 [0.7838,0.9524,0.7416,0.8312]  |  0.266   0.9397 [0.6438,0.9167,0.6915,0.8361]  |  6 hr 12 min


0.00069680   28.0*  12.9 |  0.203   0.9563 [0.7432,0.9524,0.6842,0.8312]  |  0.237   0.9518 [0.6761,0.8413,0.7000,0.8254]  |  6 hr 19 min


0.00067581   28.5   13.1 |  0.217   0.9434 [0.8784,0.9524,0.7105,0.9221]  |  0.246   0.9410 [0.7463,0.7742,0.6087,0.9206]  |  6 hr 26 min


0.00065481   29.0*  13.3 |  0.180   0.9610 [0.8784,0.9048,0.6507,0.9221]  |  0.218   0.9663 [0.6765,0.9355,0.4318,0.9355]  |  6 hr 32 min


0.00063315   29.5   13.6 |  0.231   0.9545 [0.7297,0.9524,0.6148,0.8961]  |  0.180   0.9729 [0.8923,0.9194,0.6703,0.8852]  |  6 hr 39 min


0.00061157   30.0*  13.8 |  0.167   0.9604 [0.8378,0.9524,0.7990,0.9221]  |  0.218   0.9634 [0.6866,0.9683,0.5513,0.8594]  |  6 hr 46 min


0.00058941   30.5   14.0 |  0.135   0.9730 [0.8378,0.9048,0.7799,0.9091]  |  0.194   0.9694 [0.7059,0.9524,0.5955,0.8594]  |  6 hr 52 min


0.00056743   31.0*  14.3 |  0.165   0.9630 [0.8243,0.9524,0.7823,0.8312]  |  0.205   0.9532 [0.7273,0.9206,0.5934,0.9355]  |  6 hr 59 min


0.00054496   31.5   14.5 |  0.158   0.9525 [0.7838,0.9524,0.8708,0.9221]  |  0.200   0.9571 [0.7879,0.8154,0.7320,0.9365]  |  7 hr 06 min


0.00052275   32.0*  14.7 |  0.187   0.9481 [0.9054,0.9524,0.7871,0.9610]  |  0.215   0.9540 [0.7945,0.8226,0.5843,0.9683]  |  7 hr 12 min


0.00050014   32.5   14.9 |  0.173   0.9666 [0.7838,0.9524,0.7560,0.9091]  |  0.216   0.9607 [0.6761,0.9048,0.5882,0.9219]  |  7 hr 19 min


0.00047788   33.0*  15.2 |  0.134   0.9674 [0.9189,0.9524,0.7727,0.9351]  |  0.187   0.9686 [0.9365,0.8710,0.5506,0.9365]  |  7 hr 26 min


0.00045532   33.5   15.4 |  0.153   0.9598 [0.8784,0.9524,0.8517,0.9351]  |  0.190   0.9575 [0.8000,0.9048,0.6237,0.9677]  |  7 hr 32 min


0.00043320   34.0*  15.6 |  0.160   0.9569 [0.8784,0.9524,0.8900,0.9351]  |  0.187   0.9651 [0.7727,0.8852,0.6452,0.9375]  |  7 hr 39 min


0.00041086   34.5   15.9 |  0.159   0.9630 [0.8919,0.8571,0.7895,0.9481]  |  0.175   0.9602 [0.8143,0.8333,0.7010,0.9841]  |  7 hr 46 min


0.00038905   35.0*  16.1 |  0.149   0.9587 [0.8784,0.9048,0.8014,0.9351]  |  0.171   0.9642 [0.8676,0.7937,0.6867,0.9524]  |  7 hr 52 min


0.00036711   35.5   16.3 |  0.167   0.9598 [0.8919,0.9524,0.8158,0.9481]  |  0.190   0.9681 [0.7714,0.9077,0.6250,0.9254]  |  7 hr 59 min


0.00034579   36.0*  16.6 |  0.162   0.9622 [0.8919,0.9524,0.7608,0.9740]  |  0.168   0.9683 [0.7432,0.9219,0.6512,0.9677]  |  8 hr 05 min


0.00032444   36.5   16.8 |  0.146   0.9657 [0.7973,0.8571,0.8565,0.9481]  |  0.163   0.9682 [0.8000,0.9365,0.6264,0.9531]  |  8 hr 12 min


0.00030378   37.0*  17.0 |  0.135   0.9698 [0.9189,0.8571,0.8134,0.9351]  |  0.146   0.9716 [0.7703,0.9180,0.7857,0.9531]  |  8 hr 19 min


0.00028318   37.5   17.2 |  0.150   0.9592 [0.8649,0.9524,0.8612,0.9481]  |  0.164   0.9672 [0.7973,0.8361,0.6279,1.0000]  |  8 hr 25 min


0.00026335   38.0*  17.5 |  0.136   0.9654 [0.8919,0.9524,0.8541,0.9351]  |  0.157   0.9716 [0.8676,0.9839,0.6559,0.9508]  |  8 hr 32 min


0.00024367   38.5   17.7 |  0.132   0.9625 [0.8514,0.9524,0.8876,0.9481]  |  0.130   0.9770 [0.8769,0.9219,0.7889,0.9403]  |  8 hr 39 min


0.00022482   39.0*  17.9 |  0.133   0.9607 [0.8784,0.9524,0.9067,0.8961]  |  0.155   0.9576 [0.8056,0.9032,0.8193,0.9206]  |  8 hr 45 min


0.00020622   39.5   18.2 |  0.137   0.9598 [0.9189,0.9524,0.8947,0.9221]  |  0.146   0.9726 [0.8209,0.9365,0.7174,0.9385]  |  8 hr 52 min


0.00018850   40.0*  18.4 |  0.114   0.9674 [0.8919,0.9048,0.9019,0.9610]  |  0.164   0.9603 [0.7945,0.8525,0.7234,0.9692]  |  8 hr 59 min


0.00017114   40.5   18.6 |  0.137   0.9642 [0.9189,0.9048,0.8636,0.9610]  |  0.168   0.9685 [0.7606,0.8710,0.7059,0.9836]  |  9 hr 05 min


0.00015470   41.0*  18.9 |  0.116   0.9695 [0.8784,0.9048,0.8804,0.9481]  |  0.127   0.9812 [0.9412,0.9219,0.6768,0.9683]  |  9 hr 12 min


0.00013870   41.5   19.1 |  0.123   0.9672 [0.8784,0.9048,0.8636,0.9481]  |  0.150   0.9671 [0.8462,0.8636,0.7065,0.9844]  |  9 hr 19 min


0.00012367   42.0*  19.3 |  0.143   0.9610 [0.8784,0.9524,0.8947,0.9870]  |  0.135   0.9726 [0.8551,0.9531,0.7474,0.9667]  |  9 hr 25 min


0.00010917   42.5   19.5 |  0.153   0.9584 [0.9324,0.9048,0.8828,0.9610]  |  0.141   0.9692 [0.8971,0.9231,0.6484,0.9848]  |  9 hr 32 min


0.00009568   43.0*  19.8 |  0.126   0.9674 [0.9189,0.9048,0.8995,0.9610]  |  0.159   0.9722 [0.8082,0.8923,0.7083,0.8955]  |  9 hr 39 min


0.00008279   43.5   20.0 |  0.127   0.9695 [0.9189,0.9048,0.8756,0.9481]  |  0.147   0.9650 [0.9130,0.9194,0.6304,0.9365]  |  9 hr 45 min


0.00007094   44.0*  20.2 |  0.117   0.9748 [0.9054,0.9048,0.8852,0.9351]  |  0.120   0.9761 [0.9014,0.9048,0.7407,0.9385]  |  9 hr 52 min


0.00005977   44.5   20.5 |  0.124   0.9704 [0.8919,0.9048,0.8804,0.9610]  |  0.138   0.9836 [0.8611,0.9508,0.7222,0.9219]  |  9 hr 59 min


0.00004966   45.0*  20.7 |  0.118   0.9721 [0.8784,0.9524,0.8852,0.9610]  |  0.128   0.9684 [0.8611,0.9516,0.7024,0.9375]  | 10 hr 05 min


0.00004029   45.5   20.9 |  0.124   0.9663 [0.9054,0.9524,0.9043,0.9740]  |  0.112   0.9848 [0.8824,0.9500,0.7333,0.9841]  | 10 hr 12 min


0.00003200   46.0*  21.2 |  0.121   0.9672 [0.9189,0.9048,0.8947,0.9740]  |  0.118   0.9781 [0.8714,0.8923,0.8068,0.9375]  | 10 hr 18 min


0.00002452   46.5   21.4 |  0.115   0.9710 [0.9324,0.9048,0.8971,0.9610]  |  0.138   0.9654 [0.8382,0.9180,0.7590,0.9531]  | 10 hr 25 min


0.00001811   47.0*  21.6 |  0.120   0.9718 [0.9459,0.9524,0.8612,0.9610]  |  0.149   0.9716 [0.8485,0.9194,0.6526,0.9836]  | 10 hr 32 min


0.00001257   47.5   21.8 |  0.116   0.9698 [0.9189,0.9524,0.9211,0.9481]  |  0.137   0.9738 [0.8261,0.8889,0.7753,0.9839]  | 10 hr 38 min


0.00000810   48.0*  22.1 |  0.108   0.9730 [0.9054,0.9048,0.8876,0.9610]  |  0.144   0.9738 [0.8767,0.9153,0.7326,0.9701]  | 10 hr 45 min


0.00000455   48.5   22.3 |  0.122   0.9657 [0.9054,0.9524,0.9378,0.9481]  |  0.122   0.9814 [0.9420,0.8906,0.7391,0.9365]  | 10 hr 52 min


0.00000205   49.0*  22.5 |  0.116   0.9704 [0.9459,0.9048,0.8852,0.9610]  |  0.119   0.9716 [0.8732,0.9394,0.7674,0.9683]  | 10 hr 58 min


0.00000051   49.5   22.8 |  0.104   0.9771 [0.8919,0.9048,0.8636,0.9481]  |  0.117   0.9693 [0.8841,0.9077,0.7753,0.9545]  | 11 hr 05 min


0.00000000   50.0   23.0 |  0.113   0.9718 [0.9189,0.9524,0.9091,0.9481]  |  0.127   0.9661 [0.8382,0.9077,0.7444,0.9524]  | 11 hr 12 min
