In [1]:
import torch
import torch.nn as nn
import os
from datetime import datetime
import time
import random
import cv2
import pandas as pd
import numpy as np
import albumentations as A
import matplotlib.pyplot as plt
from albumentations.pytorch.transforms import ToTensorV2
from sklearn.model_selection import StratifiedKFold
from torch.utils.data import Dataset,DataLoader
from torch.utils.data.sampler import SequentialSampler, RandomSampler
from glob import glob
from torchvision.transforms.functional import to_tensor


from tqdm import tqdm

#from utils import visualize, plot_data
from scipy.io import loadmat

from torch.nn.parallel import DataParallel


SEED = 42

def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(SEED)

In [2]:
path = '/mnt/home/hheat/USERDIR/counting-bench/data'
train_images = path + '/images'
test_images = path + '/test_images/images'
anno = path + '/annotation'
density_maps = path + '/dmaps'

LOG_PARA = 1000

In [3]:
def get_train_transforms():
    return A.Compose(
        [
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            #A.Resize(360,640,interpolation=2),
            #A.RandomSizedCrop(min_max_height=(409, 512), height=409, width=512, p=1.0),
            #A.Cutout(num_holes=8, max_h_size=64, max_w_size=64, fill_value=0, p=1.0),
        ],
        #additional_targets={'image': 'image','image1': 'image'}
        #keypoint_params = A.KeypointParams(format='xy')
)

def get_train_image_only_transforms():
    return A.Compose(
        [
            #A.Resize(360,640),
            A.OneOf([
                A.HueSaturationValue(hue_shift_limit=0.2, sat_shift_limit= 0.2, 
                                     val_shift_limit=0.2, p=0.9),
                A.RandomBrightnessContrast(brightness_limit=0.2, 
                                           contrast_limit=0.2, p=0.9),
            ],p=0.9),
            A.Blur(blur_limit=3,p=0.2),
            A.Normalize(mean=mean,std=std,p=1.0,max_pixel_value=1.0),
            ToTensorV2(p=1.0),
        ],
        additional_targets={'image': 'image'}
    )

def get_valid_trainsforms():
    return A.Compose(
        [
            #A.Resize(360,640,interpolation=2),
            A.Normalize(mean=mean,std=std,p=1.0,max_pixel_value=1.0),
            ToTensorV2(p=1.0),
        ]
    )

# def get_valid_image_only_transforms():
#     return A.Compose(
#         [
#             A.Resize(360,640),
#         ],
#         additional_targets={'image': 'image'}
#     )

mean = torch.tensor([0.4939, 0.4794, 0.4583])
std = torch.tensor([0.2177, 0.2134, 0.2144])

def denormalize(img):
    img = img * std[...,None,None] + mean[...,None,None]
    img = img.permute(1,2,0).cpu().numpy()
    return img

In [4]:
class Counting_Dataset(Dataset):
    def __init__(self,path,image_fnames,dmap_folder,gt_folder=None,transforms=None,mosaic=False,downsample=4):
        '''
            path: root path 
            image_fnames: path of images
            dmap_folder: density map folder, eg: /dmap
            gt_folder: gt folder, currently set to visdrone xml format, modify _get_gt_data() if needed
            transforms: iteratable, can be tuple / list ... etc
            mosaic: mix up image and density map to form a new image, set to false by default
            downsample: resize dmap
        '''
        super().__init__()
        self.path = path
        self.image_fnames = image_fnames
        self.dmap_folder = path + dmap_folder
        self.transforms = transforms
        self.mosaic = mosaic
        self.downsample = downsample
        self.gt_folder = gt_folder # test purpose
        
    def __len__(self):
        return len(self.image_fnames)
    
    def __getitem__(self,idx):
        image_id = self.image_fnames[idx]
        
        if self.mosaic and random.randint(0,1) < 0.5:
            image, density_map, gt_points = self._load_mosaic_image_and_density_map(idx)
        else:
            image, density_map, gt_points = self._load_image_and_density_map(idx)
        
        h,w = image.shape[0]//self.downsample, image.shape[1]//self.downsample
        image = cv2.resize(image,(w, h))
        density_map = cv2.resize(density_map,(w//(self.downsample*2),h//(self.downsample*2)))#,interpolation=cv2.INTER_NEAREST)
        
        # Warning: doesn't work for cutout, uncommet transform and make fix code to enable cutout
        # Reason: cutout doesn't apply to mask, so mask must be image. check 01a bottom for code
        if self.transforms:
            for tfms in self.transforms:
                aug = tfms(**{
                    'image': image,
                    'mask': density_map,
                    #'keypoints': gt_points
                })
                #image, density_map, gt_points = aug['image'], aug['mask'], aug['keypoints']
                image, density_map = aug['image'], aug['mask'] # issue with previous keypoints (albumentation?)
        
        
        return image, density_map, image_id, gt_points
        
    
    def _get_dmap_name(self,fn):
        mask_name = fn.split('/')[-1].split('.')[0]
        mask_path = self.dmap_folder + '/' + mask_name + '.npy'
        return mask_path
    
    def _load_image_and_density_map(self,idx):
        image_fname = self.image_fnames[idx]
        dmap_fname = self._get_dmap_name(image_fname)
        image = cv2.imread(image_fname)
        image = cv2.cvtColor(image,cv2.COLOR_BGR2RGB).astype(np.float32)
        image = image/255.
        d_map = np.load(dmap_fname,allow_pickle=True)
        
        #sanity check gt
        _, points = self._get_gt_data(idx)
        # end sanity check
        
        return image, d_map, points
    
    def _load_mosaic_image_and_density_map(self,idx):
        image_1, dmap_1, points_1 = self._load_image_and_density_map(idx)
        while True:
            idx_2 = random.randint(0,len(self.image_fnames)-1)
            if idx != idx_2:
                break
        image_2, dmap_2, points_2 = self._load_image_and_density_map(idx_2)
        
        imsize = min(*image_1.shape[:2])
        xc,yc = [int(random.uniform(imsize*0.4,imsize*0.6)) for _ in range(2)]
        h,w = image_1.shape[0], image_1.shape[1]

        pos = random.randint(0,1)
        if pos == 0: #top left
            x1a,y1a,x2a,y2a = 0,0,xc,yc # img_1
            x1b,y1b,x2b,y2b = w-xc,h-yc,w,h # img_2
        elif pos == 1: # top right
            x1a,y1a,x2a,y2a = w-xc,0,w,yc
            x1b,y1b,x2b,y2b = 0,h-yc,xc,h
        elif pos == 2: # bottom left
            x1a,y1a,x2a,y2a = 0,h-yc,xc,h
            x1b,y1b,x2b,y2b = w-xc,0,w,yc
        elif pos == 3: # bottom right
            x1a,y1a,x2a,y2a = w-xc,h-yc,w,h
            x1b,y1b,x2b,y2b = 0,0,xc,yc
        
        new_image = image_1.copy()
        new_dmap = dmap_1.copy()
        new_image[y1a:y2a,x1a:x2a] = image_2[y1b:y2b,x1b:x2b]
        new_dmap[y1a:y2a,x1a:x2a] = dmap_2[y1b:y2b,x1b:x2b]
        
        #TODO: sanity check to see generate gt
        
        new_gt_points = self._get_mixed_gt_points(points_1,points_2,(x1a,y1a,x2a,y2a),(x1b,y1b,x2b,y2b),(h,w))
        
        return new_image, new_dmap, new_gt_points
    
    '''
    The follow section blocks are for sanity check 
    to compare dmap.sum() with gt points
    remove if needed
    '''
    def _get_mixed_gt_points(self,points_1,points_2,img_1_loc, img_2_loc,img_shape):
#         fn_1, points_1 = self._get_gt_data(idx_1)
#         fn_2, points_2 = self._get_gt_data(idx_2)
        x1a,y1a,x2a,y2a = img_1_loc
        x1b,y1b,x2b,y2b = img_2_loc
        h,w = img_shape
        
        result_boxes = []
        result_boxes.append(points_2)
        result_boxes = np.concatenate(result_boxes,0)
        padw = x1a-x1b
        pady = y1a-y1b

        result_boxes[:,0] += padw
        result_boxes[:,1] += pady

        np.clip(result_boxes[:,0],0,w,out=result_boxes[:,0])
        np.clip(result_boxes[:,1],0,h,out=result_boxes[:,1])
        result_boxes = result_boxes.astype(np.int32)

        result_boxes = result_boxes[np.where(result_boxes[:,0] * result_boxes[:,1] > 0)]
        result_boxes = result_boxes[np.where(result_boxes[:,0] < w)]
        result_boxes = result_boxes[np.where(result_boxes[:,1] < h)]
        
        boxes = []
        for (x,y) in points_1:
            if x >= x1a and x <= x2a and y >= y1a and y <= y2a:
                continue
            else:
                boxes.append((x,y))
        if len(boxes) == 0:
            return result_boxes
        return np.concatenate((boxes, result_boxes),axis=0)
    
    def _get_gt_data(self,idx):
        if not self.gt_folder:
            return (None,0)
        fn = self.image_fnames[idx]
        anno_path = self.path + self.gt_folder + '/' + fn.split('/')[-1].split('.')[0] + '.mat'
        test_data = loadmat(anno_path)
        points = test_data['annotation'].astype(int)
        return fn, points

In [5]:
# ADD LOG_PARA to density map

class Crop_Dataset(Counting_Dataset):
    def __init__(self,path,image_fnames,dmap_folder,gt_folder=None,transforms=None,mosaic=False,downsample=4,crop_size=512,method='train'):
        super().__init__(path,image_fnames,dmap_folder,gt_folder,transforms,mosaic,downsample)
        self.crop_size = crop_size
        if method not in ['train','valid']:
            raise Exception('Not Implement')
        self.method = method
    
    def __getitem__(self,idx):
        fn = self.image_fnames[idx]
        
        image,density_map,gt_points = self._load_image_and_density_map(idx)
        h,w = image.shape[0], image.shape[1]
        #image = cv2.resize(image,(w, h))
        
        
        if self.method == 'train':
            #h,w = image.shape[:2]
            i,j = self._random_crop(h,w,self.crop_size,self.crop_size)
            image = image[i:i+self.crop_size,j:j+self.crop_size]
            density_map = density_map[i:i+self.crop_size,j:j+self.crop_size]
            
            gt_points = gt_points - [j,i]
            mask = (gt_points[:,0] >=0 ) * (gt_points[:,0] <= self.crop_size) * (gt_points[:,1]>=0) * (gt_points[:,1]<=self.crop_size)
            gt_points = gt_points[mask]
            density_map = cv2.resize(density_map,(self.crop_size//self.downsample,self.crop_size//self.downsample))
            
        else:
            density_map = cv2.resize(density_map,(w//self.downsample,h//self.downsample))#,interpolation=cv2.INTER_NEAREST)
            #density_map = density_map[1:-1,:]
        
        gray_img = cv2.cvtColor(image,cv2.COLOR_RGB2GRAY)
        masks = [density_map, gray_img]
        if self.transforms:
            for tfms in self.transforms:
                aug = tfms(**{
                    'image': image,
                    'masks': masks,
                    #'keypoints': gt_points
                })
                #image, density_map, gt_points = aug['image'], aug['mask'], aug['keypoints']
                image, masks = aug['image'], aug['masks'] # issue with previous keypoints (albumentation?)
            density_map = masks[0]
            gray_img = masks[1]
            density_map = ToTensorV2(p=1.0)(image=density_map)['image']
            gray_img = ToTensorV2(p=1.0)(image=gray_img)['image']
                
        return image, gray_img, density_map*LOG_PARA, fn, gt_points
    
    def _random_crop(self, im_h, im_w, crop_h, crop_w):
        res_h = im_h - crop_h
        res_w = im_w - crop_w
        i = random.randint(0, res_h)
        j = random.randint(0, res_w)
        return i, j

In [6]:
from random import sample

class SwitchDataset(Dataset):
    def __init__(self, base_dataset, models):
        super().__init__()
        self.base_dataset = base_dataset
        
        losses = torch.tensor([0., 0., 0.])
        self.eq_data = [list(), list(), list()]
        self.indices = []
        for idx, (_, gray_img, density_map, _, _) in enumerate(base_dataset):
            for i, model in enumerate(models):
                model = model.cuda()
                gray_img = gray_img.cuda()
                density_map = density_map.cuda()
                with torch.no_grad():
                    pred = model(gray_img.unsqueeze(0))
                losses[i] = torch.abs(torch.sum(pred)-torch.sum(density_map))
            which = torch.argmin(losses).item()
            self.eq_data[which].append(idx)

        num_per_cls = max(len(ds) for ds in self.eq_data)
        self.targs = torch.zeros(num_per_cls*3)
        self.targs[num_per_cls:num_per_cls*2] = 1
        self.targs[num_per_cls*2:num_per_cls*3] = 2

        for i, ds in enumerate(self.eq_data):
            assert len(ds) != 0
            samples = []
            samples += ds
            while len(samples) < num_per_cls:
                #print(len(samples), min(num_per_cls - len(samples), len(ds)))
                samples += sample(ds, min(num_per_cls - len(samples), len(ds)))
            self.indices += samples

        #self.targs[idx] = which
        #self.targs = self.targs.float()
        
    def __getitem__(self, idx):
        return self.base_dataset[self.indices[idx]][0], self.targs[idx]
    
    def __len__(self):
        return len(self.indices)

In [7]:
train_fp = glob(train_images + '/*.jpg')
test_fp = glob(test_images + '/*.jpg')

In [8]:
split = int(len(train_fp) * 0.8)
train_fp[0:split][:10]

['/mnt/home/hheat/USERDIR/counting-bench/data/images/11_233.jpg',
 '/mnt/home/hheat/USERDIR/counting-bench/data/images/12_240.jpg',
 '/mnt/home/hheat/USERDIR/counting-bench/data/images/08_113.jpg',
 '/mnt/home/hheat/USERDIR/counting-bench/data/images/03_319.jpg',
 '/mnt/home/hheat/USERDIR/counting-bench/data/images/06_176.jpg',
 '/mnt/home/hheat/USERDIR/counting-bench/data/images/05_105.jpg',
 '/mnt/home/hheat/USERDIR/counting-bench/data/images/11_204.jpg',
 '/mnt/home/hheat/USERDIR/counting-bench/data/images/14_253.jpg',
 '/mnt/home/hheat/USERDIR/counting-bench/data/images/14_129.jpg',
 '/mnt/home/hheat/USERDIR/counting-bench/data/images/20_248.jpg']

In [9]:
train_dataset = Crop_Dataset(path=path,
                             image_fnames=train_fp[:split],dmap_folder='/dmaps',
                             gt_folder='/annotation',
                             transforms=[get_train_transforms(),get_train_image_only_transforms()],
                             downsample=4,
                             crop_size=784
                                )

valid_dataset = Crop_Dataset(path=path,
                             image_fnames=test_fp,dmap_folder='/dmaps',
                             gt_folder='/annotation',
                             transforms=[get_valid_trainsforms()],
                             method='valid',
                             downsample=4,
                             crop_size=784
                                )

In [24]:
class TrainDifferentialConfig:
    num_workers = 8
    batch_size = 8
    n_epochs = 120 
    lr = 0.0002

    folder = 'SCNN-7.29-784'
    downsample = 4

    # -------------------
    verbose = True
    verbose_step = 1
    # -------------------

    # --------------------
    step_scheduler = True  # do scheduler.step after optimizer.step
    validation_scheduler = False  # do scheduler.step after validation stage loss

    SchedulerClass = torch.optim.lr_scheduler.OneCycleLR
    scheduler_params = dict(
        max_lr=1e-4,
        #total_steps = len(train_dataset) // 4 * n_epochs, # gradient accumulation
        epochs=n_epochs,
        steps_per_epoch=int(len(train_dataset) / batch_size),
        pct_start=0.2,
        anneal_strategy='cos', 
        final_div_factor=10**5
    )
    
#     SchedulerClass = torch.optim.lr_scheduler.ReduceLROnPlateau
#     scheduler_params = dict(
#         mode='min',
#         factor=0.5,
#         patience=1,
#         verbose=False, 
#         threshold=0.0001,
#         threshold_mode='abs',
#         cooldown=0, 
#         min_lr=1e-8,
#         eps=1e-08
#     )

In [11]:
class TrainCoupledConfig:
    num_workers = 1
    batch_size = 1
    n_epochs = 120 
    lr = 0.0002

    folder = 'SCNN-7.29-784'
    downsample = 4

    # -------------------
    verbose = True
    verbose_step = 1
    # -------------------

    # --------------------
    step_scheduler = True  # do scheduler.step after optimizer.step
    validation_scheduler = False  # do scheduler.step after validation stage loss

    SchedulerClass = torch.optim.lr_scheduler.OneCycleLR
    scheduler_params = dict(
        max_lr=1e-4,
        #total_steps = len(train_dataset) // 4 * n_epochs, # gradient accumulation
        epochs=n_epochs,
        steps_per_epoch=int(len(train_dataset) / batch_size),
        pct_start=0.2,
        anneal_strategy='cos', 
        final_div_factor=10**5
    )
    
#     SchedulerClass = torch.optim.lr_scheduler.ReduceLROnPlateau
#     scheduler_params = dict(
#         mode='min',
#         factor=0.5,
#         patience=1,
#         verbose=False, 
#         threshold=0.0001,
#         threshold_mode='abs',
#         cooldown=0, 
#         min_lr=1e-8,
#         eps=1e-08
#     )

In [12]:
class TrainSwitchConfig:
    num_workers = 1
    batch_size = 1
    n_epochs = 100
    lr = 0.001

    folder = 'Switch'

    # -------------------
    verbose = True
    verbose_step = 1
    # -------------------

    # --------------------
    step_scheduler = False  # do scheduler.step after optimizer.step
    validation_scheduler = True  # do scheduler.step after validation stage loss

    SchedulerClass = torch.optim.lr_scheduler.ReduceLROnPlateau
    scheduler_params = dict(
        min_lr=1e-7,
        patience=5
    )

In [13]:
def std_conv_layer(in_channels, out_channels, kernel_size, padding=0):
    layer = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding)
    nn.init.normal_(layer.weight, std=0.01)
    return layer

class shallow_net_9x9(nn.Module):
    def __init__(self):
        super(shallow_net_9x9, self).__init__()
        self.net = nn.Sequential(
            std_conv_layer(1, 16, 9, 4),
            nn.MaxPool2d(2),
            std_conv_layer(16, 32, 7, 3),
            nn.MaxPool2d(2),
            std_conv_layer(32, 16, 7, 3),
            std_conv_layer(16, 8, 7, 3),
            std_conv_layer(8, 1, 1)
        )

    def forward(self, x):
        return self.net(x)

class shallow_net_7x7(nn.Module):
    def __init__(self):
        super(shallow_net_7x7, self).__init__()
        self.net = nn.Sequential(
            std_conv_layer(1, 20, 7, 3),
            nn.MaxPool2d(2),
            std_conv_layer(20, 40, 5, 2),
            nn.MaxPool2d(2),
            std_conv_layer(40, 20, 5, 2),
            std_conv_layer(20, 10, 5, 2),
            std_conv_layer(10, 1, 1)
        )

    def forward(self, x):
        return self.net(x)

class shallow_net_5x5(nn.Module):
    def __init__(self):
        super(shallow_net_5x5, self).__init__()
        self.net = nn.Sequential(
            std_conv_layer(1, 24, 5, 2),
            nn.MaxPool2d(2),
            std_conv_layer(24, 48, 3, 1),
            nn.MaxPool2d(2),
            std_conv_layer(48, 24, 3, 1),
            std_conv_layer(24, 12, 3, 1),
            std_conv_layer(12, 1, 1)
        )

    def forward(self, x):
        return self.net(x)

In [14]:
class deep_patch_classifier(nn.Module):
    def __init__(self):
        super(deep_patch_classifier, self).__init__()
        self.conv = nn.Sequential(
            std_conv_layer(3, 64, 3, 1),
            std_conv_layer(64, 64, 3, 1),
            nn.MaxPool2d(2),
            std_conv_layer(64, 128, 3, 1),
            std_conv_layer(128, 128, 3, 1),
            nn.MaxPool2d(2),
            std_conv_layer(128, 256, 3, 1),
            std_conv_layer(256, 256, 3, 1),
            std_conv_layer(256, 256, 3, 1),
            nn.MaxPool2d(2),
            std_conv_layer(256, 512, 3, 1),
            std_conv_layer(512, 512, 3, 1),
            std_conv_layer(512, 512, 3, 1),
            nn.MaxPool2d(2),
            std_conv_layer(512, 512, 3, 1),
            std_conv_layer(512, 512, 3, 1),
            std_conv_layer(512, 512, 3, 1),
            nn.MaxPool2d(2),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten()
        )
        self.fc = nn.Sequential(
            nn.Linear(512, 512),
            nn.Linear(512, 3),
            nn.Softmax(dim=1)
        )
        
    def forward(self, x):
        x = self.conv(x)
        x = self.fc(x)
        return x

In [21]:
class SwitchCNN(nn.Module):
    def __init__(self):
        super(SwitchCNN, self).__init__()
        self.switch = deep_patch_classifier().cuda()
        self.models = [shallow_net_9x9().cuda(), shallow_net_7x7().cuda(), shallow_net_5x5().cuda()]
    
    #@torch.cuda.amp.autocast()
    def forward(self, x):
        imgs, gray_imgs = x
        with torch.no_grad():
            y = self.switch(imgs)
        which = torch.argmax(y).item()
        return self.models[which](gray_imgs), which


In [22]:
net = SwitchCNN().cuda()

In [17]:
def CELoss_Switch(preds,targs):
    return nn.CrossEntropyLoss()(preds, targs)

def MSELoss_MCNN(preds,targs):
    return nn.MSELoss()(preds,targs)

def MAELoss_MCNN(preds,targs,upsample):
    return nn.L1Loss()((preds/LOG_PARA).sum(dim=[-1,-2])*upsample*upsample, (targs/LOG_PARA).sum(dim=[-1,-2])*upsample*upsample)

In [18]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

## Differential Training

In [29]:
import warnings
warnings.filterwarnings("ignore")

#opt_level ='O1' # apex

class Fitter:
    
    def __init__(self, model, device, config):
        self.config = config
        self.epoch = 0

        self.base_dir = f'/mnt/home/zpengac/USERDIR/count/drone_benchmark/{config.folder}/{model.__class__.__name__}'
        if not os.path.exists(self.base_dir):
            os.makedirs(self.base_dir)
        
        self.log_path = f'{self.base_dir}/log.txt'
        self.best_summary_loss = 10**5

        self.model = model
        self.device = device

        param_optimizer = list(self.model.named_parameters())
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [
            {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.001},
            {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ] 

        self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=config.lr)
        
        #self.model, self.optimizer = amp.initialize(self.model,self.optimizer,opt_level=opt_level) # apex
        self.scaler = torch.cuda.amp.GradScaler()
        
        self.scheduler = config.SchedulerClass(self.optimizer, **config.scheduler_params)
        self.criterion = MSELoss_MCNN
        self.metric = MAELoss_MCNN
        self.log(f'Fitter prepared. Device is {self.device}')
        
        # self.iters_to_accumulate = 4 # gradient accumulation

    def fit(self, train_loader, validation_loader):
        for e in range(self.config.n_epochs):
            if self.config.verbose:
                lr = self.optimizer.param_groups[0]['lr']
                timestamp = datetime.utcnow().isoformat()
                self.log(f'\n{timestamp}\nLR: {lr}')

            t = time.time()
            summary_loss, mae_loss = self.train_one_epoch(train_loader)

            self.log(f'[RESULT]: Train. Epoch: {self.epoch}, mse_loss: {summary_loss.avg:.8f}, time: {(time.time() - t):.5f}')
            self.log(f'[RESULT]: Train. Epoch: {self.epoch}, mae_loss: {mae_loss.avg:.8f}, time: {(time.time() - t):.5f}')
            self.save(f'{self.base_dir}/last-checkpoint.bin')

            t = time.time()
            summary_loss, mae_loss = self.validation(validation_loader)

            self.log(f'[RESULT]: Val. Epoch: {self.epoch}, mse_loss: {summary_loss.avg:.8f}, time: {(time.time() - t):.5f}')
            self.log(f'[RESULT]: Val. Epoch: {self.epoch}, mae_loss: {mae_loss.avg:.8f}, time: {(time.time() - t):.5f}')
            if summary_loss.avg < self.best_summary_loss:
                self.best_summary_loss = summary_loss.avg
                self.model.eval()
                self.save(f'{self.base_dir}/best-checkpoint-{str(self.epoch).zfill(3)}epoch.bin')
                for path in sorted(glob(f'{self.base_dir}/best-checkpoint-*epoch.bin'))[:-3]:
                    os.remove(path)

            if self.config.validation_scheduler:
                self.scheduler.step(metrics=summary_loss.avg)

            self.epoch += 1

    def validation(self, val_loader):
        self.model.eval()
        summary_loss = AverageMeter()
        mae_loss = AverageMeter()
        t = time.time()
        for step, (_, images, density_maps, fns, gt_pts) in enumerate(val_loader):
            if self.config.verbose:
                if step % self.config.verbose_step == 0:
                    print(
                        f'Val Step {step}/{len(val_loader)}, ' + \
                        f'mse_loss: {summary_loss.avg:.8f}, ' + \
                        f'mae_loss: {mae_loss.avg:.8f}, ' + \
                        f'time: {(time.time() - t):.5f}', end='\r'
                    )
            with torch.no_grad():
                batch_size = images.shape[0]
                images = images.cuda().float()
                density_maps = density_maps.cuda().float()
                

                #preds = self.model(images)
                with torch.cuda.amp.autocast(): #native fp16
                    preds = self.model(images)
                    loss = self.criterion(preds,density_maps)
                    metric_loss = self.metric(preds,density_maps,self.config.downsample)
                mae_loss.update(metric_loss.detach().item(),batch_size)
                summary_loss.update(loss.detach().item(), batch_size)
                
            #if step == 20:
            #    break

        return summary_loss, mae_loss

    def train_one_epoch(self, train_loader):
        self.model.train()
        summary_loss = AverageMeter()
        mae_loss = AverageMeter()
        t = time.time()
        for step, (_, images, density_maps, fns, gt_pts) in enumerate(train_loader):
            if self.config.verbose:
                if step % self.config.verbose_step == 0:
                    print(
                        f'Train Step {step}/{len(train_loader)}, ' + \
                        f'mse_loss: {summary_loss.avg:.8f}, ' + \
                        f'mae_loss: {mae_loss.avg:.8f}, ' + \
                        f'time: {(time.time() - t):.5f}', end='\r'
                    )
            
            images = images.cuda().float()
            batch_size = images.shape[0]
            density_maps = density_maps.cuda().float()
            
            
            self.optimizer.zero_grad()
            
            with torch.cuda.amp.autocast(): #native fp16
                preds = self.model(images)
                loss = self.criterion(preds,density_maps)
                metric_loss = self.metric(preds.detach(),density_maps.detach(),self.config.downsample)
            self.scaler.scale(loss).backward()
            
            # loss = loss / self.iters_to_accumulate # gradient accumulation
            
#             with amp.scale_loss(loss,self.optimizer) as scaled_loss: # apex
#                 scaled_loss.backward()
            #loss.backward()

            
            mae_loss.update(metric_loss.detach().item(),batch_size)
            summary_loss.update(loss.detach().item(), batch_size)
            
            #self.optimizer.step()
            self.scaler.step(self.optimizer) # native fp16
            
            if self.config.step_scheduler:
                self.scheduler.step()
            
            self.scaler.update() #native fp16
                
                
#             if (step+1) % self.iters_to_accumulate == 0: # gradient accumulation

#                 self.optimizer.step()
#                 self.optimizer.zero_grad()

#                 if self.config.step_scheduler:
#                     self.scheduler.step()
                    
            #if step == 20:
            #    break

        return summary_loss, mae_loss
    
    def save(self, path):
        self.model.eval()
        torch.save({
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),
            'best_summary_loss': self.best_summary_loss,
            'epoch': self.epoch,
            #'amp': amp.state_dict() # apex
        }, path)

    def load(self, path):
        checkpoint = torch.load(path)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        self.best_summary_loss = checkpoint['best_summary_loss']
        self.epoch = checkpoint['epoch'] + 1
        
    def log(self, message):
        if self.config.verbose:
            print(message)
        with open(self.log_path, 'a+') as logger:
            logger.write(f'{message}\n')

In [30]:
def collate_fn(batch):
    _, imgs, dmaps, fns, gt_points = zip(*batch)
    imgs = torch.stack(imgs)
    dmaps = torch.stack(dmaps).unsqueeze(1)
    return _, imgs,dmaps,fns,gt_points

def run_training(net):
    device = torch.device('cuda:0')

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=TrainDifferentialConfig.batch_size,
        sampler=RandomSampler(train_dataset),
        #sampler=train_sampler,
        pin_memory=False,
        drop_last=True,
        num_workers=TrainDifferentialConfig.num_workers,
        collate_fn=collate_fn,
    )

    val_loader = torch.utils.data.DataLoader(
        valid_dataset, 
        batch_size=TrainDifferentialConfig.batch_size//4,
        num_workers=TrainDifferentialConfig.num_workers//2,
        shuffle=False,
        sampler=SequentialSampler(valid_dataset),
        #sampler=val_sampler,
        pin_memory=True,
        collate_fn=collate_fn,
    )

    fitter = Fitter(model=net, device=device, config=TrainDifferentialConfig)
#     fitter.load(f'{fitter.base_dir}/last-checkpoint.bin')
    fitter.fit(train_loader, val_loader)

In [None]:
for model in net.models:
    run_training(model)

Fitter prepared. Device is cuda:0

2021-07-29T06:48:27.095887
LR: 4.000000000000002e-06
[RESULT]: Train. Epoch: 0, mse_loss: 0.07044318, time: 191.68533 time: 191.45735
[RESULT]: Train. Epoch: 0, mae_loss: 39.20030907, time: 191.69862
[RESULT]: Val. Epoch: 0, mse_loss: 0.05559005, time: 515.9464047, time: 515.79030
[RESULT]: Val. Epoch: 0, mae_loss: 163.26885633, time: 515.96069

2021-07-29T07:00:14.806594
LR: 4.410700900185825e-06
[RESULT]: Train. Epoch: 1, mse_loss: 0.06776489, time: 113.27615 time: 113.05630
[RESULT]: Train. Epoch: 1, mae_loss: 27.29751306, time: 113.28947
Val Step 99/1350, mse_loss: 0.05991025, mae_loss: 169.13502541, time: 12.49901

## Coupled Training

In [21]:
net.models[0].load_state_dict(torch.load('/mnt/home/zpengac/USERDIR/count/drone_benchmark/SCNN-7.17-784/shallow_net_9x9/best-checkpoint-079epoch.bin')['model_state_dict'])
net.models[1].load_state_dict(torch.load('/mnt/home/zpengac/USERDIR/count/drone_benchmark/SCNN-7.17-784/shallow_net_7x7/best-checkpoint-116epoch.bin')['model_state_dict'])
net.models[2].load_state_dict(torch.load('/mnt/home/zpengac/USERDIR/count/drone_benchmark/SCNN-7.17-784/shallow_net_5x5/best-checkpoint-101epoch.bin')['model_state_dict'])

<All keys matched successfully>

In [22]:
try_dataset = SwitchDataset(train_dataset, net.models)
print(len(try_dataset.eq_data[0]))
print(len(try_dataset.eq_data[1]))
print(len(try_dataset.eq_data[2]))

3247
1374
419


In [110]:
print(len(try_dataset.indices))

9741


In [83]:
import warnings
warnings.filterwarnings("ignore")

#opt_level ='O1' # apex

class CoupledFitter:
    
    def __init__(self, model, device, config, switch_config):
        self.config = config
        self.switch_config = switch_config
        self.epoch = 0

        self.base_dir = f'/mnt/home/zpengac/USERDIR/count/drone_benchmark/{config.folder}'
        if not os.path.exists(self.base_dir):
            os.makedirs(self.base_dir)

        self.switch_dir = os.path.join(self.base_dir, 'Switch')
        if not os.path.exists(self.base_dir):
            os.makedirs(self.switch_dir)
        
        self.log_path = f'{self.base_dir}/log.txt'
        self.best_summary_loss = 10**5
        self.best_ce_loss = 10**5

        self.model = model
        #self.switch = switch
        #self.models = models
        self.device = device

        #param_optimizer = list(self.model.named_parameters())
        #no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        #optimizer_grouped_parameters = [
        #    {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.001},
        #    {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        #] 

        self.optimizers = []
        self.optimizers.append(torch.optim.AdamW(self.model.models[0].parameters(), lr=config.lr))
        self.optimizers.append(torch.optim.AdamW(self.model.models[1].parameters(), lr=config.lr))
        self.optimizers.append(torch.optim.AdamW(self.model.models[2].parameters(), lr=config.lr))
        self.switch_optimizer = torch.optim.AdamW(self.model.switch.parameters(), lr=switch_config.lr)
        
        #self.model, self.optimizer = amp.initialize(self.model,self.optimizer,opt_level=opt_level) # apex
        self.scalers = [torch.cuda.amp.GradScaler()]*3
        self.switch_scaler = torch.cuda.amp.GradScaler()
        
        self.schedulers = []
        self.schedulers.append(config.SchedulerClass(self.optimizers[0], **config.scheduler_params))
        self.schedulers.append(config.SchedulerClass(self.optimizers[1], **config.scheduler_params))
        self.schedulers.append(config.SchedulerClass(self.optimizers[2], **config.scheduler_params))
        self.switch_scheduler = switch_config.SchedulerClass(self.switch_optimizer, **switch_config.scheduler_params)

        self.criterion = MSELoss_MCNN
        self.metric = MAELoss_MCNN
        self.switch_criterion = CELoss_Switch

        self.log(f'Fitter prepared. Device is {self.device}')
        
        # self.iters_to_accumulate = 4 # gradient accumulation

    def fit(self, train_loader, validation_loader):
        for e in range(self.config.n_epochs):
            if self.config.verbose:
                lr = self.switch_optimizer.param_groups[0]['lr']
                lrs = [optimizer.param_groups[0]['lr'] for optimizer in self.optimizers]
                timestamp = datetime.utcnow().isoformat()
                self.log(f'\n{timestamp}\nSwitch LR: {lr}, LRs: {lrs}')

            train_switch_dataset = SwitchDataset(train_dataset, self.model.models)
            valid_switch_dataset = SwitchDataset(valid_dataset, self.model.models)

            train_switch_loader = torch.utils.data.DataLoader(
                train_switch_dataset,
                batch_size=self.switch_config.batch_size,
                sampler=RandomSampler(train_switch_dataset),
                pin_memory=False,
                drop_last=True,
                num_workers=self.switch_config.num_workers
            )

            val_switch_loader = torch.utils.data.DataLoader(
                valid_switch_dataset, 
                batch_size=self.switch_config.batch_size,
                num_workers=self.switch_config.num_workers,
                shuffle=False,
                sampler=SequentialSampler(valid_switch_dataset),
                pin_memory=True
            )

            # Train switch
            t = time.time()
            ce_loss = self.train_one_epoch_switch(train_switch_loader)

            self.log(f'[RESULT]: Train Switch. Epoch: {self.epoch}, ce_loss: {ce_loss.avg:.8f}, time: {(time.time() - t):.5f}')
            self.save_switch(f'{self.switch_dir}/last-checkpoint.bin')

            t = time.time()
            ce_loss = self.validation_switch(val_switch_loader)

            self.log(f'[RESULT]: Val Switch. Epoch: {self.epoch}, ce_loss: {ce_loss.avg:.8f}, time: {(time.time() - t):.5f}')
            if ce_loss.avg < self.best_ce_loss:
                self.best_ce_loss = ce_loss.avg
                self.model.eval()
                self.save_switch(f'{self.switch_dir}/best-checkpoint-{str(self.epoch).zfill(3)}epoch.bin')
                for path in sorted(glob(f'{self.switch_dir}/best-checkpoint-*epoch.bin'))[:-3]:
                    os.remove(path)

            if self.switch_config.validation_scheduler:
                self.switch_scheduler.step(metrics=ce_loss.avg)

            # Train models
            t = time.time()
            summary_loss, mae_loss = self.train_one_epoch(train_loader)

            self.log(f'[RESULT]: Train. Epoch: {self.epoch}, mse_loss: {summary_loss.avg:.8f}, time: {(time.time() - t):.5f}')
            self.log(f'[RESULT]: Train. Epoch: {self.epoch}, mae_loss: {mae_loss.avg:.8f}, time: {(time.time() - t):.5f}')
            self.save(f'{self.base_dir}/last-checkpoint.bin')

            t = time.time()
            summary_loss, mae_loss = self.validation(validation_loader)

            self.log(f'[RESULT]: Val. Epoch: {self.epoch}, mse_loss: {summary_loss.avg:.8f}, time: {(time.time() - t):.5f}')
            self.log(f'[RESULT]: Val. Epoch: {self.epoch}, mae_loss: {mae_loss.avg:.8f}, time: {(time.time() - t):.5f}')
            if summary_loss.avg < self.best_summary_loss:
                self.best_summary_loss = summary_loss.avg
                self.model.eval()
                self.save(f'{self.base_dir}/best-checkpoint-{str(self.epoch).zfill(3)}epoch.bin')
                for path in sorted(glob(f'{self.base_dir}/best-checkpoint-*epoch.bin'))[:-3]:
                    os.remove(path)

            if self.config.validation_scheduler:
                for scheduler in self.schedulers:
                    scheduler.step(metrics=summary_loss.avg)

            self.epoch += 1

    def validation_switch(self, val_loader):
        self.model.switch.eval()
        ce_loss = AverageMeter()
        t = time.time()

        for step, (images, targs) in enumerate(val_loader):
            if self.switch_config.verbose:
                if step % self.switch_config.verbose_step == 0:
                    print(
                        f'Val Step {step}/{len(val_loader)}, ' + \
                        f'ce_loss: {ce_loss.avg:.8f}, ' + \
                        f'time: {(time.time() - t):.5f}', end='\r'
                    )
            with torch.no_grad():
                batch_size = images.shape[0]
                images = images.cuda().float()
                targs = targs.cuda().long()

                #preds = self.model(images)
                with torch.cuda.amp.autocast(): #native fp16
                    preds = self.model.switch(images)
                    loss = self.switch_criterion(preds,targs)
                ce_loss.update(loss.detach().item(),batch_size)
                
            #if step == 20:
            #    break

        return ce_loss

    def train_one_epoch_switch(self, train_loader):
        self.model.switch.train()
        ce_loss = AverageMeter()
        t = time.time()
        for step, (images, targs) in enumerate(train_loader):
            if self.switch_config.verbose:
                if step % self.switch_config.verbose_step == 0:
                    print(
                        f'Train Step {step}/{len(train_loader)}, ' + \
                        f'ce_loss: {ce_loss.avg:.8f}, ' + \
                        f'time: {(time.time() - t):.5f}', end='\r'
                    )
            
            images = images.cuda().float()
            batch_size = images.shape[0]
            targs = targs.cuda().long()
            
            self.switch_optimizer.zero_grad()
            
            with torch.cuda.amp.autocast(): #native fp16
                preds = self.model.switch(images)
                loss = self.switch_criterion(preds,targs)
            self.switch_scaler.scale(loss).backward()
            
            # loss = loss / self.iters_to_accumulate # gradient accumulation
            
#             with amp.scale_loss(loss,self.optimizer) as scaled_loss: # apex
#                 scaled_loss.backward()
            #loss.backward()

            
            ce_loss.update(loss.detach().item(),batch_size)
            
            #self.optimizer.step()
            self.switch_scaler.step(self.switch_optimizer) # native fp16
            
            if self.switch_config.step_scheduler:
                self.switch_scheduler.step()
            
            self.switch_scaler.update() #native fp16
                
                
#             if (step+1) % self.iters_to_accumulate == 0: # gradient accumulation

#                 self.optimizer.step()
#                 self.optimizer.zero_grad()

#                 if self.config.step_scheduler:
#                     self.scheduler.step()
                    
            #if step == 20:
            #    break

        return ce_loss
    
    def validation(self, val_loader):
        self.model.eval()
        #self.switch.eval()
        #for model in self.models:
        #    model.eval()
        summary_loss = AverageMeter()
        mae_loss = AverageMeter()
        t = time.time()
        for step, (images, gray_images, density_maps, fns, gt_pts) in enumerate(val_loader):
            if self.config.verbose:
                if step % self.config.verbose_step == 0:
                    print(
                        f'Val Step {step}/{len(val_loader)}, ' + \
                        f'mse_loss: {summary_loss.avg:.8f}, ' + \
                        f'mae_loss: {mae_loss.avg:.8f}, ' + \
                        f'time: {(time.time() - t):.5f}', end='\r'
                    )

            #print(density_maps.shape)
            with torch.no_grad():
                batch_size = images.shape[0]
                images = images.cuda().float()
                gray_images = gray_images.cuda().float()
                density_maps = density_maps.cuda().float()

                #preds = self.model(images)
                with torch.cuda.amp.autocast(): #native fp16
                    preds, which = self.model((images, gray_images))
                    loss = self.criterion(preds,density_maps)
                    metric_loss = self.metric(preds,density_maps,self.config.downsample)
                mae_loss.update(metric_loss.detach().item(),batch_size)
                summary_loss.update(loss.detach().item(), batch_size)
                
            if step == 20:
                break

        return summary_loss, mae_loss

    def train_one_epoch(self, train_loader):
        self.model.train()
        #self.switch.eval()
        #for model in self.models:
        #    model.train()
        summary_loss = AverageMeter()
        mae_loss = AverageMeter()
        t = time.time()
        for step, (images, gray_images, density_maps, fns, gt_pts) in enumerate(train_loader):
            if self.config.verbose:
                if step % self.config.verbose_step == 0:
                    print(
                        f'Train Step {step}/{len(train_loader)}, ' + \
                        f'mse_loss: {summary_loss.avg:.8f}, ' + \
                        f'mae_loss: {mae_loss.avg:.8f}, ' + \
                        f'time: {(time.time() - t):.5f}', end='\r'
                    )
            
            images = images.cuda().float()
            gray_images = gray_images.cuda().float()
            batch_size = images.shape[0]
            density_maps = density_maps.cuda().float()
            
            #with torch.no_grad():
            #    y = self.switch(images)
            #which = torch.argmax(y).item()

            
            with torch.cuda.amp.autocast(): #native fp16
                preds, which = self.model((images, gray_images))
                print(which)
                loss = self.criterion(preds,density_maps)
                metric_loss = self.metric(preds.detach(),density_maps.detach(),self.config.downsample)
            self.optimizers[which].zero_grad()
            self.scalers[which].scale(loss).backward()
            
            # loss = loss / self.iters_to_accumulate # gradient accumulation
            
#             with amp.scale_loss(loss,self.optimizer) as scaled_loss: # apex
#                 scaled_loss.backward()
            #loss.backward()

            
            mae_loss.update(metric_loss.detach().item(),batch_size)
            summary_loss.update(loss.detach().item(), batch_size)
            
            #self.optimizer.step()
            self.scalers[which].step(self.optimizers[which]) # native fp16
            
            if self.config.step_scheduler:
                self.schedulers[which].step()
            
            self.scalers[which].update() #native fp16
                
                
#             if (step+1) % self.iters_to_accumulate == 0: # gradient accumulation

#                 self.optimizer.step()
#                 self.optimizer.zero_grad()

#                 if self.config.step_scheduler:
#                     self.scheduler.step()
                    
            if step == 20:
                break

        return summary_loss, mae_loss

    def save_switch(self, path):
        self.model.eval()
        torch.save({
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.switch_optimizer.state_dict(),
            'scheduler_state_dict': self.switch_scheduler.state_dict(),
            'best_ce_loss': self.best_ce_loss,
            'epoch': self.epoch,
            #'amp': amp.state_dict() # apex
        }, path)

    def save(self, path):
        self.model.eval()
        torch.save({
            'model_state_dict': self.model.state_dict(),
            'optimizers_state_dict': [optimizer.state_dict() for optimizer in self.optimizers],
            'schedulers_state_dict': [scheduler.state_dict() for scheduler in self.schedulers],
            'best_summary_loss': self.best_summary_loss,
            'epoch': self.epoch,
            #'amp': amp.state_dict() # apex
        }, path)

    def load_switch(self, path):
        checkpoint = torch.load(path)
        self.model.switch.load_state_dict(checkpoint['model_state_dict'])
        self.switch_optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.switch_scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        self.best_ce_loss = checkpoint['best_ce_loss']
        self.epoch = checkpoint['epoch'] + 1

    def load(self, path):
        checkpoint = torch.load(path)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        for i in range(3):
            self.optimizers[i].load_state_dict(checkpoint['optimizers_state_dict'][i])
        for i in range(3):
            self.schedulers[i].load_state_dict(checkpoint['schedulers_state_dict'][i])
        self.best_summary_loss = checkpoint['best_summary_loss']
        self.epoch = checkpoint['epoch'] + 1
        
    def log(self, message):
        if self.config.verbose:
            print(message)
        with open(self.log_path, 'a+') as logger:
            logger.write(f'{message}\n')

In [84]:
def collate_fn(batch):
    imgs, gray_imgs, dmaps, fns, gt_points = zip(*batch)
    imgs = torch.stack(imgs)
    gray_imgs = torch.stack(gray_imgs)
    dmaps = torch.stack(dmaps).unsqueeze(1)
    return imgs,gray_imgs,dmaps,fns,gt_points

def run_training(net):
    device = torch.device('cuda:0')

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=TrainCoupledConfig.batch_size,
        sampler=RandomSampler(train_dataset),
        #sampler=train_sampler,
        pin_memory=False,
        drop_last=True,
        num_workers=TrainCoupledConfig.num_workers,
        collate_fn=collate_fn,
    )

    val_loader = torch.utils.data.DataLoader(
        valid_dataset, 
        batch_size=TrainCoupledConfig.batch_size,
        num_workers=TrainCoupledConfig.num_workers,
        shuffle=False,
        sampler=SequentialSampler(valid_dataset),
        #sampler=val_sampler,
        pin_memory=True,
        collate_fn=collate_fn,
    )

    fitter = CoupledFitter(model=net, device=device, config=TrainCoupledConfig, switch_config=TrainSwitchConfig)
#     fitter.load(f'{fitter.base_dir}/last-checkpoint.bin')
    fitter.fit(train_loader, val_loader)

In [85]:
run_training(net)

Fitter prepared. Device is cuda:0

2021-07-18T01:58:56.963858
Switch LR: 0.0002, LRs: [4.000000000000002e-06, 4.000000000000002e-06, 4.000000000000002e-06]
[RESULT]: Train Switch. Epoch: 0, ce_loss: 1.30224302, time: 3.97359
[RESULT]: Val Switch. Epoch: 0, ce_loss: 0.55144477, time: 5.88725
[RESULT]: Train. Epoch: 0, mse_loss: 0.11719659, time: 2.885205, time: 2.70301
[RESULT]: Train. Epoch: 0, mae_loss: 46.26402069, time: 2.89385
[RESULT]: Val. Epoch: 0, mse_loss: 0.06206556, time: 5.7761011, time: 5.46912
[RESULT]: Val. Epoch: 0, mae_loss: 117.87505268, time: 5.78433

2021-07-18T02:20:17.537322
Switch LR: 0.0002, LRs: [4.0000071395886595e-06, 4.000000000000002e-06, 4.000000000000002e-06]
[RESULT]: Train Switch. Epoch: 1, ce_loss: 1.36096858, time: 2.68786
[RESULT]: Val Switch. Epoch: 1, ce_loss: 0.55144477, time: 4.29811
[RESULT]: Train. Epoch: 1, mse_loss: 0.07548444, time: 2.958964, time: 2.77052
[RESULT]: Train. Epoch: 1, mae_loss: 23.34711411, time: 2.96727
[RESULT]: Val. Epoch: 

KeyboardInterrupt: 