In [3]:
import numpy as np
import matplotlib.pyplot as plt
import os
import torch
import torch.nn as nn
import torchvision
from torchvision import models,transforms,datasets
from torch.utils import data
%matplotlib inline
import torchvision.transforms as standard_transforms
from torch.utils.data import DataLoader
import random
from torch.autograd import Variable
from torch.optim.lr_scheduler import StepLR
from PIL import Image, ImageOps, ImageFilter
import numbers
import pdb
import pandas as pd
import torch.nn.functional as F
import torchvision.transforms.functional as F2
import torch.utils.model_zoo as model_zoo
from glob import glob
from torchvision import transforms, models
from torch.utils.data.dataloader import default_collate
from torch import optim
from torch.nn import Module
import time
import torchvision
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter

Code can be run both on laptop and on GCP.

In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('Using gpu: %s ' % torch.cuda.is_available())

Using gpu: True 


unzip data

In [25]:
!unzip data.zip

Archive:  data.zip
replace data/bayes/train/IMG_290.jpg? [y]es, [n]o, [A]ll, [N]one, [r]ename: ^C


## Models

We use 2 nn, a CSRNet and a VGG19 extented.

#### VGG19 extended

In [5]:
__all__ = ['vgg19']
model_urls = {
    'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
}

class VGGExtended(nn.Module):
    def __init__(self, features):
        super(VGGExtended, self).__init__()
        self.features = features
        self.reg_layer = nn.Sequential(
            nn.Conv2d(512, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 1, 1)
        )

    def forward(self, x):
        x = self.features(x)
        x = F.upsample_bilinear(x, scale_factor=2)
        x = self.reg_layer(x)
        return torch.abs(x)


def make_layers(cfg, in_channels = 3,batch_norm=False,dilation = False):
    if dilation:
        d_rate = 2
    else:
        d_rate = 1
    layers = []
    for v in cfg:
        if v == 'M':
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        else:
            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=d_rate,dilation = d_rate)
            if batch_norm:
                layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
            else:
                layers += [conv2d, nn.ReLU(inplace=True)]
            in_channels = v
    return nn.Sequential(*layers)


cfg = {
    'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512]
}

def vgg19():
    """VGG 19-layer model (configuration "E")
        model pre-trained on ImageNet
    """
    model = VGGExtended(make_layers(cfg['E']))
    model.load_state_dict(model_zoo.load_url(model_urls['vgg19']), strict=False)
    return model

#### CSRNet

In [6]:
class CSRNet(nn.Module):
    def __init__(self, load_weights=False):
        super(CSRNet, self).__init__()
        self.seen = 0
        self.frontend_feat = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512]
        self.backend_feat  = [512, 512, 512,256,128,64]
        self.frontend = make_layers(self.frontend_feat)
        self.backend = make_layers(self.backend_feat,in_channels = 512,dilation = True)
        self.output_layer = nn.Conv2d(64, 1, kernel_size=1)
        if not load_weights:
            mod = models.vgg16(pretrained = True)
            self._initialize_weights()
            self.frontend.load_state_dict(mod.features[0:23].state_dict())
            
    def forward(self,x):
        size = x.size()
        x = self.frontend(x)
        x = self.backend(x)
        x = self.output_layer(x)
        x = F.upsample(x, size = size[2:])
        return x
    
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight, std=0.01)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0) 

## Processing helpers

In [7]:
def random_cropBayes(im_h, im_w, crop_h, crop_w):
    res_h = im_h - crop_h
    res_w = im_w - crop_w
    i = random.randint(0, res_h)
    j = random.randint(0, res_w)
    return i, j, crop_h, crop_w


def cal_innner_area(c_left, c_up, c_right, c_down, bbox):
    inner_left = np.maximum(c_left, bbox[:, 0])
    inner_up = np.maximum(c_up, bbox[:, 1])
    inner_right = np.minimum(c_right, bbox[:, 2])
    inner_down = np.minimum(c_down, bbox[:, 3])
    inner_area = np.maximum(inner_right-inner_left, 0.0) * np.maximum(inner_down-inner_up, 0.0)
    return inner_area

## Datasets

#### Ground Truth dataset

In [8]:
class GTDataset(data.Dataset):
    def __init__(self, data_path, mode, main_transform=None, img_transform=None, gt_transform=None):
        self.img_path = data_path + '/img'
        self.gt_path = data_path + '/den'
        self.data_files = [filename for filename in os.listdir(self.img_path) \
                           if os.path.isfile(os.path.join(self.img_path,filename))]
        self.num_samples = len(self.data_files) 
        self.main_transform=main_transform  
        self.img_transform = img_transform
        self.gt_transform = gt_transform     
    
    def __getitem__(self, index):
        fname = self.data_files[index]
        img, den = self.read_image_and_gt(fname)      
        if self.main_transform is not None:
            img, den = self.main_transform(img,den) 
        if self.img_transform is not None:
            img = self.img_transform(img)         
        if self.gt_transform is not None:
            den = self.gt_transform(den)               
        return img, den

    def __len__(self):
        return self.num_samples

    def read_image_and_gt(self,fname):
        img = Image.open(os.path.join(self.img_path,fname))
        if img.mode == 'L':
            img = img.convert('RGB')

        den = pd.read_csv(os.path.join(self.gt_path,os.path.splitext(fname)[0] + '.csv'), sep=',',header=None).values
        
        den = den.astype(np.float32, copy=False)    
        den = Image.fromarray(den)  
        return img, den    

    def get_num_samples(self):
        return self.num_samples

#### Bayes method Dataset

In [9]:
class BayesDataset(data.Dataset):
    def __init__(self, root_path, crop_size,
                 downsample_ratio, is_gray=False,
                 method='train'):

        self.root_path = root_path
        self.im_list = sorted(glob(os.path.join(self.root_path, '*.jpg')))
        if method not in ['train', 'val', 'test']:
            raise Exception("not implement")
        self.method = method

        self.c_size = crop_size
        self.d_ratio = downsample_ratio
        assert self.c_size % self.d_ratio == 0
        self.dc_size = self.c_size // self.d_ratio

        if is_gray:
            self.trans = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
            ])
        else:
            self.trans = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # Pour CSRNet à vérifier si besoin de modif
            ])

    def __len__(self):
        return len(self.im_list)

    def __getitem__(self, item):
        img_path = self.im_list[item]
        gd_path = img_path.replace('jpg', 'npy')
        img = Image.open(img_path).convert('RGB')
        if self.method == 'train':
            keypoints = np.load(gd_path)
            return self.train_transform(img, keypoints)
        else:
            keypoints = np.load(gd_path)
            img = self.trans(img)
            name = os.path.basename(img_path).split('.')[0]
            return img, len(keypoints), name

    def train_transform(self, img, keypoints):
        """random crop image patch and find people in it"""
        
        """
        Les keypoints correspondent aux coordonnées des têtes
        MAIS une troisième coordonnée a été calculée lors du preprocessing des données,
        elle correspont à "dis" et semble important pour calculer pas mal de choses
        """
        
        wd, ht = img.size
        st_size = min(wd, ht)
        assert st_size >= self.c_size
        assert len(keypoints) > 0
        i, j, h, w = random_cropBayes(ht, wd, self.c_size, self.c_size)
        img = F2.crop(img, i, j, h, w)
        
        nearest_dis = np.clip(keypoints[:, 2], 4.0, 128.0)
       
        points_left_up = keypoints[:, :2] - nearest_dis[:, None] / 2.0
        points_right_down = keypoints[:, :2] + nearest_dis[:, None] / 2.0
        bbox = np.concatenate((points_left_up, points_right_down), axis=1)
        inner_area = cal_innner_area(j, i, j+w, i+h, bbox)
        origin_area = nearest_dis * nearest_dis
        ratio = np.clip(1.0 * inner_area / origin_area, 0.0, 1.0)
        mask = (ratio >= 0.3)

        target = ratio[mask]
        keypoints = keypoints[mask]
        keypoints = keypoints[:, :2] - [j, i]  # change coodinate
        if len(keypoints) > 0:
            if random.random() > 0.5:
                img = F2.hflip(img)
                keypoints[:, 0] = w - keypoints[:, 0]
        else:
            if random.random() > 0.5:
                img = F2.hflip(img)
        return self.trans(img), torch.from_numpy(keypoints.copy()).float(), \
               torch.from_numpy(target.copy()).float(), st_size

## DataLoader

#### Loading Data GT

In [10]:
#CSRNet
LABEL_FACTOR = 1


def random_crop_GT(img,den,dst_size):
    # dst_size: ht, wd

    _,ts_hd,ts_wd = img.shape

    x1 = random.randint(0, ts_wd - dst_size[1])//LABEL_FACTOR*LABEL_FACTOR
    y1 = random.randint(0, ts_hd - dst_size[0])//LABEL_FACTOR*LABEL_FACTOR
    x2 = x1 + dst_size[1]
    y2 = y1 + dst_size[0]

    label_x1 = x1//LABEL_FACTOR
    label_y1 = y1//LABEL_FACTOR
    label_x2 = x2//LABEL_FACTOR
    label_y2 = y2//LABEL_FACTOR

    return img[:,y1:y2,x1:x2], den[label_y1:label_y2,label_x1:label_x2]



def share_memory(batch):
    out = None
    if False:
        # If we're in a background process, concatenate directly into a
        # shared memory tensor to avoid an extra copy
        numel = sum([x.numel() for x in batch])
        storage = batch[0].storage()._new_shared(numel)
        out = batch[0].new(storage)
    return out

crop_size = 256

def GT_collate(batch):
    # @GJY 
    r"""Puts each data field into a tensor with outer dimension batch size"""

    transposed = list(zip(*batch)) # imgs and dens
    imgs, dens = [transposed[0],transposed[1]]


    error_msg = "batch must contain tensors; found {}"
    if isinstance(imgs[0], torch.Tensor) and isinstance(dens[0], torch.Tensor):
        
        cropped_imgs = []
        cropped_dens = []
        for i_sample in range(len(batch)):
            _img, _den = random_crop_GT(imgs[i_sample],dens[i_sample],[crop_size,crop_size])
            cropped_imgs.append(_img)
            cropped_dens.append(_den)


        cropped_imgs = torch.stack(cropped_imgs, 0, out=share_memory(cropped_imgs))
        cropped_dens = torch.stack(cropped_dens, 0, out=share_memory(cropped_dens))

        return [cropped_imgs,cropped_dens]

    raise TypeError((error_msg.format(type(batch[0]))))


def loading_data_GT(batch_size=5, num_workers=8):
    mean_std = ([0.410824894905, 0.370634973049, 0.359682112932], [0.278580576181, 0.26925137639, 0.27156367898])
    log_para = 100.
    factor = 1
    # DATA_PATH = "/home/simon/Bureau/framework-crowd-counting/ProcessedData/shanghaitech_part_A"
    DATA_PATH = "data/gt"
    
    
    train_main_transform = Compose([
        RandomHorizontallyFlip()
    ])
    img_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])
    gt_transform = standard_transforms.Compose([
        GTScaleDown(factor),
        LabelNormalize(log_para)
    ])

    train_set = GTDataset(DATA_PATH+'/train', 'train',main_transform=train_main_transform, img_transform=img_transform, gt_transform=gt_transform)
    train_loader =None
    if batch_size == 1:
        train_loader = DataLoader(train_set, batch_size=1, shuffle=True, drop_last=True)
    elif batch_size > 1:
        train_loader = DataLoader(train_set, batch_size=batch_size, num_workers=num_workers, collate_fn=GT_collate, shuffle=True, drop_last=True)
    
    val_set = GTDataset(DATA_PATH+'/val', 'val', main_transform=None, img_transform=img_transform, gt_transform=gt_transform)
    val_loader = DataLoader(val_set, batch_size=1, num_workers=num_workers, shuffle=True, drop_last=False)
    
    test_set = GTDataset(DATA_PATH+'/test', 'test', main_transform=None, img_transform=img_transform, gt_transform=gt_transform)
    test_loader = DataLoader(test_set, batch_size=1, num_workers=num_workers, shuffle=True, drop_last=False)
    
    return train_loader, val_loader, test_loader




#### Loading Data Bayes

In [28]:
#bayes
downsample_ratio = 1 # Mettre à 8 pour le réseau du répo (à 1 pour CSRNet puisque on ne modifie pas la dim avec le réseau)
data_dir = "data/bayes"
#data_dir = "/home/simon/Bureau/framework-crowd-counting/processed_data_bcc/SHHA"
#data_dir = "/Users/VictoRambaud/dev/crowd_counting2/ProcessedData/SHHA"
crop_size = 256
is_gray = False

def train_collate(batch):
    transposed_batch = list(zip(*batch))
    images = torch.stack(transposed_batch[0], 0)
    points = transposed_batch[1]  # the number of points is not fixed, keep it as a list of tensor
    targets = transposed_batch[2]
    st_sizes = torch.FloatTensor(transposed_batch[3])
    return images, points, targets, st_sizes


def loading_data_Bayes(batch_size = 5, num_workers = 8):
    datasets_bayes = {x: BayesDataset(os.path.join(data_dir, x),
                              crop_size,
                              downsample_ratio,
                              is_gray, x) for x in ['train', 'val', 'test']}

    dataloaders_bayes = {x: DataLoader(datasets_bayes[x],
                                collate_fn=(train_collate if x == 'train' else default_collate),
                                batch_size=(batch_size if x == 'train' else 1),
                                shuffle=(True if x == 'train' else False),
                                num_workers=num_workers,
                                pin_memory=(True if x == 'train' else False))
                                for x in ['train', 'val', 'test']}
    
    dataloaders_bayes_test = "To do"
    
    return dataloaders_bayes["train"], dataloaders_bayes["val"], dataloaders_bayes["test"]
    

In [29]:
loading_data_Bayes()

(<torch.utils.data.dataloader.DataLoader at 0x7f7c527a3f50>,
 <torch.utils.data.dataloader.DataLoader at 0x7f7c527a3e90>,
 <torch.utils.data.dataloader.DataLoader at 0x7f7c527a32d0>)

## Bayes : computing losses

In [30]:
class Post_Prob(Module):
    def __init__(self, sigma, c_size, stride, background_ratio, use_background, device):
        super(Post_Prob, self).__init__()
        assert c_size % stride == 0

        self.sigma = sigma
        self.bg_ratio = background_ratio
        self.device = device
        # coordinate is same to image space, set to constant since crop size is same
        self.cood = torch.arange(0, c_size, step=stride,
                                 dtype=torch.float32, device=device) + stride / 2
        self.cood.unsqueeze_(0)
        self.softmax = torch.nn.Softmax(dim=0)
        self.use_bg = use_background

    def forward(self, points, st_sizes):
        num_points_per_image = [len(points_per_image) for points_per_image in points]
        all_points = torch.cat(points, dim=0)

        if len(all_points) > 0:
            x = all_points[:, 0].unsqueeze_(1)
            y = all_points[:, 1].unsqueeze_(1)
            x_dis = -2 * torch.matmul(x, self.cood) + x * x + self.cood * self.cood
            y_dis = -2 * torch.matmul(y, self.cood) + y * y + self.cood * self.cood
            y_dis.unsqueeze_(2)
            x_dis.unsqueeze_(1)
            dis = y_dis + x_dis
            dis = dis.view((dis.size(0), -1))

            dis_list = torch.split(dis, num_points_per_image)
            prob_list = []
            for dis, st_size in zip(dis_list, st_sizes):
                if len(dis) > 0:
                    if self.use_bg:
                        min_dis = torch.clamp(torch.min(dis, dim=0, keepdim=True)[0], min=0.0)
                        d = st_size * self.bg_ratio
                        bg_dis = (d - torch.sqrt(min_dis))**2
                        dis = torch.cat([dis, bg_dis], 0)  # concatenate background distance to the last
                    dis = -dis / (2.0 * self.sigma ** 2)
                    prob = self.softmax(dis)
                else:
                    prob = None
                prob_list.append(prob)
        else:
            prob_list = []
            for _ in range(len(points)):
                prob_list.append(None)
        return prob_list
    
    
class Bay_Loss(Module):
    def __init__(self, use_background, device):
        super(Bay_Loss, self).__init__()
        self.device = device
        self.use_bg = use_background

    def forward(self, prob_list, target_list, pre_density):
        loss = 0
        
        """
            - prob list semble être la listes des p(yn|xm) ie la contribution du pixel xm sur la n-ieme tête
            (les lignes de cette matrice sont de taille 4096 = 64*64)
            - pre density est la prédiction de la densité (sortie du réseau) - de taille 64x64 ici
            - target list a pour longueur le nombre de têtes - correspond aux E[cn] "réel" (le calcul reste un mystère)
            - On obtient les E[cn] estimées grâce à un produit terme à terme de prob_list et pre_density
        """
        
        
        for idx, prob in enumerate(prob_list):  # iterative through each sample
            if prob is None:  # image contains no annotation points
                pre_count = torch.sum(pre_density[idx])
                target = torch.zeros((1,), dtype=torch.float32, device=self.device)
            else:
                N = len(prob)
                if self.use_bg:
                    target = torch.zeros((N,), dtype=torch.float32, device=self.device)
                    target[:-1] = target_list[idx]
                else:
                    target = target_list[idx]
                pre_count = torch.sum(pre_density[idx].view((1, -1)) * prob, dim=1)  # flatten into vector
            
            loss += torch.sum(torch.abs(target - pre_count))
        loss = loss / len(prob_list)
        return loss

## Utils

In [31]:
# ===============================img tranforms============================

class Compose(object):
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, img, mask, bbx=None):
        if bbx is None:
            for t in self.transforms:
                img, mask = t(img, mask)
            return img, mask
        for t in self.transforms:
            img, mask, bbx = t(img, mask, bbx)
        return img, mask, bbx

class RandomHorizontallyFlip(object):
    def __call__(self, img, mask, bbx=None):
        if random.random() < 0.5:
            if bbx is None:
                return img.transpose(Image.FLIP_LEFT_RIGHT), mask.transpose(Image.FLIP_LEFT_RIGHT)
            w, h = img.size
            xmin = w - bbx[:,3]
            xmax = w - bbx[:,1]
            bbx[:,1] = xmin
            bbx[:,3] = xmax
            return img.transpose(Image.FLIP_LEFT_RIGHT), mask.transpose(Image.FLIP_LEFT_RIGHT), bbx
        if bbx is None:
            return img, mask
        return img, mask, bbx



# ===============================label tranforms============================

class LabelNormalize(object):
    def __init__(self, para):
        self.para = para

    def __call__(self, tensor):
        # tensor = 1./(tensor+self.para).log()
        tensor = torch.from_numpy(np.array(tensor))
        tensor = tensor*self.para
        return tensor

    
class GTScaleDown(object):
    def __init__(self, factor=8):
        self.factor = factor

    def __call__(self, img):
        w, h = img.size
        if self.factor==1:
            return img
        tmp = np.array(img.resize((w//self.factor, h//self.factor), Image.BICUBIC))*self.factor*self.factor
        img = Image.fromarray(tmp)
        return img

## Trainers

Dossier à créer sur Google Cloud !

In [32]:
save_dir = "best_model_weights"

Tensorboard settings

In [57]:
from torch.utils.tensorboard import SummaryWriter
# default `log_dir` is "runs" - we'll be more specific here
writer = SummaryWriter()

# image = Image.open('../ProcessedData/SHHA/train/IMG_1.jpg')
# trans1 = transforms.ToTensor()
# img = trans1(image).to(device)
# img = img.unsqueeze(0)

# write to tensorboard

# writer.add_graph(model, img)
# writer.close()



In [43]:
# tensorboard --logdir=runs

#### Trainer GT

In [44]:
LOG_PARA = 100. # C'est quoi ce LOG_PARA ??
seed = 1


class Trainer_GT():
    def __init__(self, dataloader, net, loss, optimizer, validation_frequency=1, max_epoch=100):
        self.train_loader, self.val_loader, self.test_loader = dataloader()
        self.net = net
        self.loss = loss 
        self.optimizer = optimizer
        self.best_mae = 1e20
        self.best_mse = 1e20
        self.epoch = 0
        self.validation_frequency = validation_frequency
        self.max_epoch = max_epoch



    def train(self):
        for epoch in range(0, self.max_epoch):
            self.epoch = epoch
            # si on veut un lr sheduler il faut le mettre là
                
            # training    
            self.train_epoch()

            # validation
            if epoch%self.validation_frequency==0:
                self.validate()
                
        print(f'Train finished | best_mse: {self.best_mse} | best_mae: {self.best_mae}')


    def train_epoch(self): # training for all datasets
        self.net.train()
        epoch_loss = 0
        
        for step, data in enumerate(self.train_loader, 0):
            img, gt_map = data
            img = Variable(img).to(device)
            gt_map = Variable(gt_map).to(device)

            self.optimizer.zero_grad()
            # pred_map = self.net(img, gt_map)
            
            pred_density_map = self.net(img)
            loss = self.loss(pred_density_map, gt_map)
            loss.backward()
            self.optimizer.step()
            
            gt_count = [int(gt_map[i].sum().data / LOG_PARA) for i in range(gt_map.size()[0])]
            pre_count = [int(pred_density_map[i].sum().data/LOG_PARA) for i in range(pred_density_map.size()[0])]
            epoch_loss += float(loss)
            
            print(f'epoch: {self.epoch} | step: {step} | count: {gt_count} | prediction: {pre_count} | loss: {loss}')
        
        writer.add_scalar('train loss GT',
            epoch_loss,
            self.epoch)


    def validate(self):
        epoch_start = time.time()
        self.net.eval()
        epoch_res = []

        for vi, data in enumerate(self.val_loader, 0):
            img, gt_map = data

            with torch.no_grad():
                img = Variable(img).to(device)
                assert img.size(0) == 1
                gt_map = Variable(gt_map).to(device)
                pred_density_map = self.net(img)
                
                pred_cnt = int(gt_map[0].sum().data / LOG_PARA)
                gt_count = int(pred_density_map[0].sum().data/LOG_PARA)
                res = gt_count - pred_cnt
                    
                epoch_res.append(res)


        epoch_res = np.array(epoch_res)
        mse = np.sqrt(np.mean(np.square(epoch_res)))
        mae = np.mean(np.abs(epoch_res))
        
        print('Epoch {} Val, MSE: {:.2f} MAE: {:.2f}, Cost {:.1f} sec'
                     .format(self.epoch, mse, mae, time.time()-epoch_start))
        
        writer.add_scalar('val MAE GT',
                    mae,
                    self.epoch)
        writer.add_scalar('val MSE GT',
                        mse,
                        self.epoch)

        if (2.0 * mse + mae) < (2.0 * self.best_mse + self.best_mae):
            self.best_mse = mse
            self.best_mae = mae
            print("save best mse {:.2f} mae {:.2f} model epoch {}".format(self.best_mse,
                                                                            self.best_mae,
                                                                                 self.epoch))
            torch.save(self.net.state_dict(), os.path.join(save_dir, 'best_model_gt.pth'))


In [None]:
# Launch GT Train !
lr = 1e-6

gt_net = CSRNet().to(device)
loss = nn.MSELoss().to(device)
optimizer = optim.Adam(gt_net.parameters(), lr=lr)
# optimizer = optim.SGD(gt_net.parameters(), lr=lr, momentum=0.95,weight_decay=5e-4)  

gt_trainer = Trainer_GT(loading_data_GT, gt_net, loss, optimizer, max_epoch=250)
gt_trainer.train()

epoch: 0 | step: 0 | count: [31, 21, 30, 2, 1] | prediction: [0, 0, 0, 0, 0] | loss: 0.020717572420835495
epoch: 0 | step: 1 | count: [49, 11, 0, 42, 55] | prediction: [0, 0, 0, 0, 0] | loss: 0.0355154424905777
epoch: 0 | step: 2 | count: [0, 78, 42, 344, 40] | prediction: [0, 0, 0, 0, 0] | loss: 0.15167298913002014
epoch: 0 | step: 3 | count: [405, 8, 4, 0, 104] | prediction: [0, 0, 0, 0, 0] | loss: 0.3549301028251648
epoch: 0 | step: 4 | count: [20, 40, 16, 24, 20] | prediction: [0, 0, 0, 0, 0] | loss: 0.02362821437418461
epoch: 0 | step: 5 | count: [59, 24, 36, 0, 29] | prediction: [0, 0, 0, 0, 0] | loss: 0.029076790437102318
epoch: 0 | step: 6 | count: [213, 286, 581, 93, 41] | prediction: [0, 0, 0, 0, 0] | loss: 0.7049778699874878
epoch: 0 | step: 7 | count: [25, 69, 56, 47, 87] | prediction: [0, 0, 0, 0, 0] | loss: 0.0628519058227539
epoch: 0 | step: 8 | count: [5, 8, 44, 59, 77] | prediction: [0, 0, 0, 0, 0] | loss: 0.039926838129758835
epoch: 0 | step: 9 | count: [52, 28, 14, 7

epoch: 1 | step: 28 | count: [37, 186, 6, 86, 361] | prediction: [4, 4, 4, 4, 2] | loss: 0.32947012782096863
epoch: 1 | step: 29 | count: [11, 22, 83, 0, 40] | prediction: [4, 3, 6, 2, 4] | loss: 0.03238391503691673
epoch: 1 | step: 30 | count: [1, 30, 80, 24, 74] | prediction: [4, 3, 4, 5, 6] | loss: 0.04066537693142891
epoch: 1 | step: 31 | count: [11, 235, 5, 183, 5] | prediction: [3, 4, 4, 4, 4] | loss: 0.11160626262426376
epoch: 1 | step: 32 | count: [132, 22, 45, 120, 51] | prediction: [7, 6, 5, 3, 5] | loss: 0.0733201876282692
epoch: 1 | step: 33 | count: [53, 9, 5, 375, 44] | prediction: [4, 5, 3, 5, 6] | loss: 0.33575117588043213
epoch: 1 | step: 34 | count: [29, 130, 25, 38, 16] | prediction: [5, 6, 5, 5, 4] | loss: 0.04682972654700279
epoch: 1 | step: 35 | count: [14, 30, 0, 86, 0] | prediction: [5, 6, 3, 8, 2] | loss: 0.025759048759937286
epoch: 1 | step: 36 | count: [34, 249, 26, 8, 1] | prediction: [5, 3, 4, 4, 5] | loss: 0.07798546552658081
epoch: 1 | step: 37 | count: [

epoch: 3 | step: 4 | count: [147, 52, 38, 56, 8] | prediction: [41, 65, 55, 46, 48] | loss: 0.06603101640939713
epoch: 3 | step: 5 | count: [61, 66, 70, 16, 64] | prediction: [45, 56, 29, 49, 61] | loss: 0.05406080186367035
epoch: 3 | step: 6 | count: [17, 20, 14, 1, 60] | prediction: [48, 37, 46, 43, 65] | loss: 0.0231565423309803
epoch: 3 | step: 7 | count: [107, 61, 16, 3, 124] | prediction: [65, 64, 36, 26, 68] | loss: 0.07100804895162582
epoch: 3 | step: 8 | count: [1, 24, 36, 124, 134] | prediction: [45, 43, 46, 58, 74] | loss: 0.05842021852731705
epoch: 3 | step: 9 | count: [27, 130, 34, 123, 5] | prediction: [69, 69, 42, 82, 42] | loss: 0.057912588119506836
epoch: 3 | step: 10 | count: [8, 22, 10, 70, 10] | prediction: [51, 49, 51, 45, 24] | loss: 0.024718990549445152
epoch: 3 | step: 11 | count: [70, 28, 30, 21, 36] | prediction: [62, 45, 45, 52, 55] | loss: 0.03782180696725845
epoch: 3 | step: 12 | count: [72, 111, 5, 444, 29] | prediction: [49, 21, 43, 67, 51] | loss: 0.2041

epoch: 4 | step: 29 | count: [232, 54, 50, 144, 130] | prediction: [77, 53, 71, 93, 62] | loss: 0.1417577862739563
epoch: 4 | step: 30 | count: [64, 9, 14, 12, 35] | prediction: [40, 56, 41, 46, 49] | loss: 0.02591846138238907
epoch: 4 | step: 31 | count: [302, 38, 50, 47, 84] | prediction: [68, 42, 51, 76, 81] | loss: 0.12151595950126648
epoch: 4 | step: 32 | count: [6, 66, 0, 0, 20] | prediction: [43, 85, 20, 4, 73] | loss: 0.022065648809075356
epoch: 4 | step: 33 | count: [16, 9, 382, 0, 71] | prediction: [67, 59, 77, 44, 89] | loss: 0.138270303606987
epoch: 4 | step: 34 | count: [89, 3, 131, 146, 29] | prediction: [76, 50, 62, 22, 65] | loss: 0.0756261795759201
epoch: 4 | step: 35 | count: [48, 7, 44, 6, 496] | prediction: [71, 32, 54, 45, 65] | loss: 0.15555952489376068
epoch: 4 | step: 36 | count: [4, 49, 0, 15, 0] | prediction: [40, 69, 19, 73, 41] | loss: 0.018044380471110344
epoch: 4 | step: 37 | count: [47, 63, 2, 231, 123] | prediction: [60, 81, 42, 85, 101] | loss: 0.087516

epoch: 6 | step: 5 | count: [5, 55, 46, 45, 228] | prediction: [36, 65, 51, 49, 75] | loss: 0.09035713225603104
epoch: 6 | step: 6 | count: [4, 78, 11, 4, 3] | prediction: [41, 75, 50, 35, 40] | loss: 0.022089514881372452
epoch: 6 | step: 7 | count: [1, 58, 147, 126, 69] | prediction: [50, 47, 44, 77, 47] | loss: 0.07642290741205215
epoch: 6 | step: 8 | count: [64, 282, 30, 7, 45] | prediction: [67, 73, 67, 22, 47] | loss: 0.08827582746744156
epoch: 6 | step: 9 | count: [0, 58, 0, 81, 8] | prediction: [29, 53, 26, 51, 43] | loss: 0.027950044721364975
epoch: 6 | step: 10 | count: [18, 53, 75, 30, 23] | prediction: [55, 48, 47, 65, 50] | loss: 0.04655570909380913
epoch: 6 | step: 11 | count: [367, 4, 61, 24, 13] | prediction: [42, 42, 65, 43, 57] | loss: 0.21139541268348694
epoch: 6 | step: 12 | count: [106, 17, 9, 82, 16] | prediction: [76, 43, 46, 71, 46] | loss: 0.04249562695622444
epoch: 6 | step: 13 | count: [52, 7, 9, 198, 28] | prediction: [34, 44, 53, 64, 36] | loss: 0.0887459367

epoch: 7 | step: 30 | count: [330, 25, 48, 176, 0] | prediction: [56, 27, 56, 65, 33] | loss: 0.30582955479621887
epoch: 7 | step: 31 | count: [47, 58, 0, 49, 12] | prediction: [61, 55, 32, 62, 37] | loss: 0.03294162079691887
epoch: 7 | step: 32 | count: [122, 7, 203, 92, 12] | prediction: [72, 53, 69, 48, 49] | loss: 0.08202808350324631
epoch: 7 | step: 33 | count: [73, 56, 26, 5, 19] | prediction: [62, 65, 67, 32, 48] | loss: 0.0371735654771328
epoch: 7 | step: 34 | count: [2, 34, 47, 65, 70] | prediction: [40, 70, 61, 56, 66] | loss: 0.042447779327631
epoch: 7 | step: 35 | count: [37, 139, 140, 94, 57] | prediction: [73, 74, 73, 78, 61] | loss: 0.13815902173519135
epoch: 7 | step: 36 | count: [88, 9, 172, 0, 24] | prediction: [69, 51, 72, 14, 56] | loss: 0.07904574275016785
epoch: 7 | step: 37 | count: [73, 81, 27, 0, 81] | prediction: [65, 62, 54, 41, 54] | loss: 0.052910592406988144
epoch: 7 | step: 38 | count: [1, 72, 29, 28, 24] | prediction: [51, 80, 58, 60, 57] | loss: 0.03144

epoch: 9 | step: 6 | count: [368, 1, 13, 50, 54] | prediction: [50, 34, 47, 73, 81] | loss: 0.23484225571155548
epoch: 9 | step: 7 | count: [23, 14, 26, 31, 73] | prediction: [58, 49, 55, 70, 44] | loss: 0.09439650923013687
epoch: 9 | step: 8 | count: [7, 45, 0, 434, 51] | prediction: [50, 74, 41, 103, 50] | loss: 0.15899673104286194
epoch: 9 | step: 9 | count: [43, 15, 210, 0, 65] | prediction: [71, 40, 71, 32, 79] | loss: 0.061031971126794815
epoch: 9 | step: 10 | count: [12, 32, 5, 7, 4] | prediction: [51, 62, 48, 49, 46] | loss: 0.01578138768672943
epoch: 9 | step: 11 | count: [39, 73, 40, 45, 4] | prediction: [52, 27, 50, 60, 38] | loss: 0.04238596186041832
epoch: 9 | step: 12 | count: [17, 302, 265, 1, 100] | prediction: [56, 76, 89, 33, 43] | loss: 0.25215351581573486
epoch: 9 | step: 13 | count: [26, 61, 4, 22, 3] | prediction: [39, 80, 37, 63, 35] | loss: 0.024706415832042694
epoch: 9 | step: 14 | count: [0, 15, 28, 40, 152] | prediction: [23, 56, 51, 51, 67] | loss: 0.0698821

epoch: 10 | step: 30 | count: [8, 27, 235, 17, 55] | prediction: [44, 54, 67, 47, 54] | loss: 0.13241754472255707
epoch: 10 | step: 31 | count: [64, 9, 0, 0, 49] | prediction: [73, 52, 20, 14, 65] | loss: 0.025608310475945473
epoch: 10 | step: 32 | count: [5, 25, 4, 30, 181] | prediction: [44, 56, 33, 60, 105] | loss: 0.049135494977235794
epoch: 10 | step: 33 | count: [41, 220, 14, 4, 20] | prediction: [35, 70, 37, 46, 51] | loss: 0.05922741815447807
epoch: 10 | step: 34 | count: [154, 11, 35, 332, 0] | prediction: [80, 49, 65, 71, 12] | loss: 0.1253020465373993
epoch: 10 | step: 35 | count: [11, 48, 41, 32, 21] | prediction: [48, 74, 56, 59, 42] | loss: 0.030515097081661224
epoch: 10 | step: 36 | count: [103, 41, 107, 16, 46] | prediction: [70, 78, 79, 56, 78] | loss: 0.05905577540397644
epoch: 10 | step: 37 | count: [50, 7, 6, 0, 75] | prediction: [76, 48, 42, 48, 64] | loss: 0.028637150302529335
epoch: 10 | step: 38 | count: [26, 0, 181, 24, 22] | prediction: [65, 36, 85, 53, 55] | 

epoch: 12 | step: 5 | count: [133, 22, 8, 13, 271] | prediction: [75, 49, 50, 53, 91] | loss: 0.09282117336988449
epoch: 12 | step: 6 | count: [24, 65, 75, 46, 28] | prediction: [52, 47, 65, 53, 53] | loss: 0.04931217432022095
epoch: 12 | step: 7 | count: [12, 2, 68, 4, 72] | prediction: [49, 49, 53, 43, 74] | loss: 0.033834073692560196
epoch: 12 | step: 8 | count: [36, 45, 163, 64, 59] | prediction: [49, 71, 69, 57, 74] | loss: 0.06198442354798317
epoch: 12 | step: 9 | count: [30, 38, 51, 5, 125] | prediction: [65, 70, 61, 51, 93] | loss: 0.046023640781641006
epoch: 12 | step: 10 | count: [96, 125, 7, 5, 6] | prediction: [81, 80, 54, 35, 43] | loss: 0.04674266651272774
epoch: 12 | step: 11 | count: [3, 0, 42, 51, 127] | prediction: [39, 48, 70, 66, 53] | loss: 0.05511917173862457
epoch: 12 | step: 12 | count: [12, 63, 85, 80, 252] | prediction: [47, 65, 56, 63, 72] | loss: 0.09754688292741776
epoch: 12 | step: 13 | count: [13, 52, 221, 28, 105] | prediction: [50, 69, 75, 53, 69] | los

epoch: 13 | step: 29 | count: [88, 24, 19, 174, 97] | prediction: [63, 46, 60, 71, 41] | loss: 0.15548564493656158
epoch: 13 | step: 30 | count: [7, 73, 10, 217, 0] | prediction: [49, 66, 55, 65, 11] | loss: 0.12029559910297394
epoch: 13 | step: 31 | count: [70, 16, 11, 1, 33] | prediction: [63, 56, 55, 33, 71] | loss: 0.03181975334882736
epoch: 13 | step: 32 | count: [106, 0, 46, 5, 18] | prediction: [77, 41, 44, 44, 53] | loss: 0.03651515021920204
epoch: 13 | step: 33 | count: [128, 41, 49, 95, 126] | prediction: [64, 50, 54, 72, 60] | loss: 0.07814543694257736
epoch: 13 | step: 34 | count: [0, 43, 35, 0, 54] | prediction: [10, 61, 54, 41, 72] | loss: 0.02748166024684906
epoch: 13 | step: 35 | count: [51, 67, 116, 27, 37] | prediction: [71, 43, 81, 54, 50] | loss: 0.057320430874824524
epoch: 13 | step: 36 | count: [9, 60, 66, 59, 125] | prediction: [43, 60, 67, 59, 74] | loss: 0.07648783922195435
epoch: 13 | step: 37 | count: [24, 24, 14, 4, 15] | prediction: [57, 65, 31, 47, 41] | l

epoch: 15 | step: 5 | count: [50, 228, 7, 38, 135] | prediction: [55, 55, 39, 54, 74] | loss: 0.10434511303901672
epoch: 15 | step: 6 | count: [35, 16, 72, 76, 402] | prediction: [53, 49, 65, 62, 91] | loss: 0.15281589329242706
epoch: 15 | step: 7 | count: [6, 184, 13, 0, 8] | prediction: [27, 76, 49, 32, 51] | loss: 0.045114144682884216
epoch: 15 | step: 8 | count: [58, 12, 3, 177, 71] | prediction: [58, 46, 37, 73, 59] | loss: 0.06992355734109879
epoch: 15 | step: 9 | count: [10, 0, 85, 62, 1] | prediction: [51, 40, 61, 52, 31] | loss: 0.043604668229818344
epoch: 15 | step: 10 | count: [61, 72, 32, 2, 55] | prediction: [59, 65, 47, 40, 69] | loss: 0.04363233596086502
epoch: 15 | step: 11 | count: [102, 49, 21, 25, 36] | prediction: [59, 52, 52, 53, 54] | loss: 0.040838830173015594
epoch: 15 | step: 12 | count: [87, 340, 1, 8, 29] | prediction: [56, 95, 32, 35, 61] | loss: 0.1441202163696289
epoch: 15 | step: 13 | count: [14, 151, 42, 30, 169] | prediction: [47, 59, 62, 51, 65] | loss

epoch: 16 | step: 29 | count: [26, 92, 39, 98, 14] | prediction: [49, 69, 62, 54, 52] | loss: 0.052337583154439926
epoch: 16 | step: 30 | count: [86, 43, 110, 68, 35] | prediction: [64, 56, 58, 71, 55] | loss: 0.08257324248552322
epoch: 16 | step: 31 | count: [31, 1, 120, 16, 8] | prediction: [53, 36, 56, 65, 51] | loss: 0.04792940616607666
epoch: 16 | step: 32 | count: [31, 6, 154, 132, 27] | prediction: [56, 40, 88, 51, 59] | loss: 0.11758802831172943
epoch: 16 | step: 33 | count: [108, 120, 13, 65, 90] | prediction: [72, 80, 57, 61, 89] | loss: 0.06737085431814194
epoch: 16 | step: 34 | count: [31, 35, 20, 1, 0] | prediction: [50, 67, 66, 54, 36] | loss: 0.020728055387735367
epoch: 16 | step: 35 | count: [20, 0, 129, 133, 32] | prediction: [59, 39, 63, 68, 46] | loss: 0.07565368711948395
epoch: 16 | step: 36 | count: [15, 0, 0, 33, 37] | prediction: [54, 47, 8, 80, 54] | loss: 0.019911479204893112
epoch: 16 | step: 37 | count: [16, 170, 37, 10, 9] | prediction: [43, 85, 61, 63, 49] 

epoch: 18 | step: 4 | count: [31, 19, 15, 20, 6] | prediction: [58, 46, 29, 57, 43] | loss: 0.0200248621404171
epoch: 18 | step: 5 | count: [66, 62, 6, 62, 410] | prediction: [74, 62, 45, 60, 88] | loss: 0.19686992466449738
epoch: 18 | step: 6 | count: [101, 44, 35, 0, 7] | prediction: [74, 68, 60, 42, 45] | loss: 0.03655681386590004
epoch: 18 | step: 7 | count: [26, 3, 92, 57, 9] | prediction: [84, 43, 74, 65, 48] | loss: 0.05685364454984665
epoch: 18 | step: 8 | count: [0, 312, 1, 19, 12] | prediction: [35, 108, 38, 48, 45] | loss: 0.07979366928339005
epoch: 18 | step: 9 | count: [35, 22, 16, 14, 26] | prediction: [52, 55, 60, 46, 62] | loss: 0.024373287335038185
epoch: 18 | step: 10 | count: [8, 56, 95, 20, 93] | prediction: [54, 63, 71, 49, 69] | loss: 0.047543007880449295
epoch: 18 | step: 11 | count: [514, 22, 6, 11, 88] | prediction: [95, 65, 34, 44, 68] | loss: 0.22632773220539093
epoch: 18 | step: 12 | count: [28, 9, 7, 5, 16] | prediction: [59, 54, 64, 49, 58] | loss: 0.01758

epoch: 19 | step: 28 | count: [40, 0, 93, 367, 22] | prediction: [48, 48, 53, 92, 53] | loss: 0.12016250193119049
epoch: 19 | step: 29 | count: [266, 130, 148, 91, 83] | prediction: [59, 63, 77, 71, 69] | loss: 0.18513107299804688
epoch: 19 | step: 30 | count: [15, 31, 22, 136, 38] | prediction: [40, 56, 49, 73, 43] | loss: 0.0464010164141655
epoch: 19 | step: 31 | count: [418, 58, 7, 54, 15] | prediction: [97, 55, 49, 52, 47] | loss: 0.1481553465127945
epoch: 19 | step: 32 | count: [36, 5, 10, 32, 264] | prediction: [60, 41, 47, 53, 101] | loss: 0.08613281697034836
epoch: 19 | step: 33 | count: [149, 19, 73, 14, 53] | prediction: [72, 40, 71, 52, 65] | loss: 0.059517525136470795
epoch: 19 | step: 34 | count: [145, 17, 4, 16, 3] | prediction: [70, 57, 46, 56, 46] | loss: 0.03522429242730141
epoch: 19 | step: 35 | count: [50, 12, 2, 123, 0] | prediction: [55, 45, 44, 83, 31] | loss: 0.04652257263660431
epoch: 19 | step: 36 | count: [3, 39, 6, 61, 163] | prediction: [42, 43, 44, 58, 69] 

epoch: 21 | step: 3 | count: [27, 36, 12, 101, 29] | prediction: [70, 58, 54, 87, 54] | loss: 0.04075601324439049
epoch: 21 | step: 4 | count: [22, 29, 0, 300, 0] | prediction: [53, 60, 26, 97, 33] | loss: 0.1558971405029297
epoch: 21 | step: 5 | count: [38, 21, 53, 6, 153] | prediction: [62, 60, 56, 50, 62] | loss: 0.0640801265835762
epoch: 21 | step: 6 | count: [13, 9, 91, 1, 51] | prediction: [50, 49, 67, 43, 58] | loss: 0.030820636078715324
epoch: 21 | step: 7 | count: [65, 79, 170, 69, 0] | prediction: [65, 65, 95, 89, 9] | loss: 0.0682143047451973
epoch: 21 | step: 8 | count: [11, 4, 117, 22, 107] | prediction: [52, 43, 68, 52, 71] | loss: 0.07579867541790009
epoch: 21 | step: 9 | count: [74, 2, 192, 202, 64] | prediction: [73, 40, 88, 97, 61] | loss: 0.11071637272834778
epoch: 21 | step: 10 | count: [170, 50, 1, 134, 6] | prediction: [51, 59, 49, 66, 56] | loss: 0.0713047981262207
epoch: 21 | step: 11 | count: [39, 29, 122, 326, 45] | prediction: [65, 58, 77, 102, 65] | loss: 0.

epoch: 22 | step: 27 | count: [22, 23, 232, 0, 0] | prediction: [55, 66, 100, 33, 62] | loss: 0.059598349034786224
epoch: 22 | step: 28 | count: [6, 51, 28, 29, 193] | prediction: [45, 65, 51, 57, 102] | loss: 0.05697768181562424
epoch: 22 | step: 29 | count: [15, 36, 4, 17, 88] | prediction: [55, 49, 54, 66, 63] | loss: 0.031984105706214905
epoch: 22 | step: 30 | count: [51, 57, 54, 35, 13] | prediction: [60, 64, 72, 50, 55] | loss: 0.038719501346349716
epoch: 22 | step: 31 | count: [39, 157, 68, 0, 50] | prediction: [62, 85, 69, 31, 63] | loss: 0.06991454213857651
epoch: 22 | step: 32 | count: [33, 27, 0, 31, 1] | prediction: [64, 68, 47, 59, 43] | loss: 0.021159693598747253
epoch: 22 | step: 33 | count: [21, 131, 156, 0, 6] | prediction: [67, 76, 71, 52, 41] | loss: 0.06073020026087761
epoch: 22 | step: 34 | count: [65, 39, 208, 8, 39] | prediction: [76, 55, 100, 44, 57] | loss: 0.06829317659139633
epoch: 22 | step: 35 | count: [15, 21, 13, 53, 42] | prediction: [48, 54, 56, 55, 55]

epoch: 24 | step: 2 | count: [0, 28, 88, 42, 41] | prediction: [41, 55, 65, 65, 69] | loss: 0.039778925478458405
epoch: 24 | step: 3 | count: [0, 22, 8, 55, 38] | prediction: [54, 61, 46, 57, 61] | loss: 0.025824060663580894
epoch: 24 | step: 4 | count: [26, 112, 28, 14, 6] | prediction: [58, 79, 52, 47, 45] | loss: 0.037577927112579346
epoch: 24 | step: 5 | count: [71, 118, 169, 19, 61] | prediction: [70, 81, 75, 50, 59] | loss: 0.08291309326887131
epoch: 24 | step: 6 | count: [7, 50, 58, 18, 63] | prediction: [49, 59, 58, 60, 68] | loss: 0.036987900733947754
epoch: 24 | step: 7 | count: [39, 4, 0, 94, 27] | prediction: [64, 54, 37, 74, 83] | loss: 0.03302765265107155
epoch: 24 | step: 8 | count: [40, 32, 15, 329, 15] | prediction: [63, 58, 53, 90, 44] | loss: 0.16688929498195648
epoch: 24 | step: 9 | count: [101, 8, 90, 4, 88] | prediction: [72, 49, 77, 49, 70] | loss: 0.05037681385874748
epoch: 24 | step: 10 | count: [36, 94, 37, 74, 48] | prediction: [59, 65, 59, 76, 59] | loss: 0.

epoch: 25 | step: 26 | count: [7, 105, 129, 10, 12] | prediction: [51, 73, 70, 50, 52] | loss: 0.04980204626917839
epoch: 25 | step: 27 | count: [6, 57, 5, 10, 36] | prediction: [42, 71, 51, 52, 61] | loss: 0.028319640085101128
epoch: 25 | step: 28 | count: [35, 29, 8, 13, 12] | prediction: [62, 53, 50, 44, 60] | loss: 0.021741697564721107
epoch: 25 | step: 29 | count: [39, 65, 41, 7, 20] | prediction: [68, 61, 60, 39, 54] | loss: 0.03363553062081337
epoch: 25 | step: 30 | count: [64, 16, 26, 24, 81] | prediction: [59, 53, 53, 52, 68] | loss: 0.038284782320261
epoch: 25 | step: 31 | count: [98, 18, 31, 90, 0] | prediction: [84, 54, 57, 61, 43] | loss: 0.04852430149912834
epoch: 25 | step: 32 | count: [6, 33, 5, 39, 223] | prediction: [48, 56, 54, 65, 104] | loss: 0.058334968984127045
epoch: 25 | step: 33 | count: [19, 10, 114, 8, 7] | prediction: [50, 46, 66, 54, 49] | loss: 0.04098635911941528
epoch: 25 | step: 34 | count: [8, 57, 124, 66, 12] | prediction: [48, 56, 76, 72, 43] | loss

#### Trainer Bayes

In [39]:
class Trainer_Bayes():
    def __init__(self, dataloader, net, loss, optimizer,  validation_frequency=1, max_epoch=100):
        self.train_loader, self.val_loader, _ = dataloader()
        self.net = net
        self.loss = loss 
        self.optimizer = optimizer
        self.best_mae = 1e20
        self.best_mse = 1e20
        self.epoch = 0
        self.validation_frequency = validation_frequency
        self.max_epoch = max_epoch
        

    def train(self):
        for epoch in range(0, self.max_epoch):
            self.epoch = epoch

            # training    
            self.train_epoch()

            # validation
            if epoch % self.validation_frequency == 0:
                self.validate()
                
        print(f'Train finished | best_mse: {self.best_mse} | best_mae: {self.best_mae}')
                

    def train_epoch(self):
        self.net.train()
        epoch_loss = 0
        
        for step, (inputs, points, targets, st_sizes) in enumerate(self.train_loader):
            inputs = inputs.to(device)
            st_sizes = st_sizes.to(device)
            gd_count = np.array([len(p) for p in points], dtype=np.float32)
            points = [p.to(device) for p in points]
            targets = [t.to(device) for t in targets]
            
            with torch.set_grad_enabled(True):
                outputs = self.net(inputs)
                prob_list = post_prob(points, st_sizes)
                loss = self.loss(prob_list, targets, outputs)

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                N = inputs.size(0) # batch size
                pre_count = torch.sum(outputs.view(N, -1), dim=1).detach().cpu().numpy()
                res = pre_count - gd_count
                epoch_loss += float(loss)
                
                print(f'epoch: {self.epoch} | step: {step} | gd_count: {gd_count} | prediction: {pre_count} | loss: {loss}')
       
        writer.add_scalar('train loss Bayes',
            epoch_loss,
            self.epoch)

    def validate(self):
        epoch_start = time.time()
        self.net.eval()  # Set model to evaluate mode
        epoch_res = []

        # Iterate over data.
        for inputs, count, name in self.val_loader:
            inputs = inputs.to(device)
            # inputs are images with different sizes
            assert inputs.size(0) == 1 # 'the batch size should equal to 1 in validation mode'
            with torch.set_grad_enabled(False):
                outputs = self.net(inputs)
                res = count[0].item() - torch.sum(outputs).item()
                epoch_res.append(res)


        epoch_res = np.array(epoch_res)
        mse = np.sqrt(np.mean(np.square(epoch_res)))
        mae = np.mean(np.abs(epoch_res))

        # ...log the running loss
        writer.add_scalar('val MAE Bayes',
                            mae,
                            self.epoch)
        writer.add_scalar('val MSE Bayes',
                        mse,
                        self.epoch)

        print('Epoch {} Val, MSE: {:.2f} MAE: {:.2f}, Cost {:.1f} sec'
                     .format(self.epoch, mse, mae, time.time()-epoch_start))

        if (2.0 * mse + mae) < (2.0 * self.best_mse + self.best_mae):
            self.best_mse = mse
            self.best_mae = mae
            print("save best mse {:.2f} mae {:.2f} model epoch {}".format(self.best_mse,
                                                                            self.best_mae,
                                                                                 self.epoch))
            torch.save(self.net.state_dict(), os.path.join(save_dir, 'best_model_bayes.pth'))

In [40]:
lr = 0.00001
sigma = 0.1
use_background = False
background_ratio = 1

bayes_net = CSRNet().to(device)
optimizer = optim.Adam(bayes_net.parameters(), lr=lr)

post_prob = Post_Prob(sigma,
                           crop_size,
                           downsample_ratio,
                           background_ratio,
                           use_background,
                           device)
loss = Bay_Loss(use_background, device)

bayes_trainer = Trainer_Bayes(loading_data_Bayes, bayes_net, loss, optimizer, max_epoch=50)
bayes_trainer.train()

epoch: 0 | step: 0 | gd_count: [764.  35. 121.  55.  41.] | prediction: [ 1.1734737  -0.5358329  -2.1619778   0.8747963   0.83166647] | loss: 195.96067810058594
epoch: 0 | step: 1 | gd_count: [ 46.  60.  24.  23. 107.] | prediction: [30.892334 25.346998 38.396698 21.787317 26.02096 ] | loss: 34.24599075317383
epoch: 0 | step: 2 | gd_count: [ 72. 121.  12.  15. 194.] | prediction: [33.293617 49.386833 42.65613  42.45251  45.40686 ] | loss: 71.55133056640625
epoch: 0 | step: 3 | gd_count: [75. 27. 22. 34. 33.] | prediction: [54.414    60.753166 59.412373 49.57749  55.834785] | loss: 31.879133224487305
epoch: 0 | step: 4 | gd_count: [ 49.  22.  64. 217.  79.] | prediction: [45.937218 28.814945 40.03018  42.442444 47.0007  ] | loss: 50.07511520385742
epoch: 0 | step: 5 | gd_count: [247.  95.  82.  69.   1.] | prediction: [27.493107 32.38881  43.11589  36.847103 26.54713 ] | loss: 76.5000991821289
epoch: 0 | step: 6 | gd_count: [143.  52.  33. 113. 170.] | prediction: [32.694454 32.587234 3

epoch: 1 | step: 6 | gd_count: [224.  46. 182.  60.  17.] | prediction: [126.595604  58.205795 155.24347   77.96738   23.088873] | loss: 39.5947265625
epoch: 1 | step: 7 | gd_count: [496. 190.  39.  47. 295.] | prediction: [ 74.255264  78.59016   33.871853  56.757874 186.10313 ] | loss: 133.2845001220703
epoch: 1 | step: 8 | gd_count: [142.   7.  47.  34.  54.] | prediction: [125.91382     6.7408504  46.88152    32.856987   66.66794  ] | loss: 21.045455932617188
epoch: 1 | step: 9 | gd_count: [254.   9.  83.   4. 300.] | prediction: [142.62837     5.8829107 102.90018     2.9487004 240.65588  ] | loss: 49.03963851928711
epoch: 1 | step: 10 | gd_count: [166.  61. 245.  69.  16.] | prediction: [69.87828  25.91167  81.17125  20.37563   7.554511] | loss: 70.96468353271484
epoch: 1 | step: 11 | gd_count: [210.  76. 239.  59. 120.] | prediction: [ 51.785645  60.7522   134.7822    32.487625 101.76242 ] | loss: 66.97354888916016
epoch: 1 | step: 12 | gd_count: [ 77.  18.  18. 128.  52.] | predi

epoch: 2 | step: 10 | gd_count: [129.  53.  35.  32.   5.] | prediction: [81.909874 31.894379 45.75521  43.799004 11.534307] | loss: 22.69855308532715
epoch: 2 | step: 11 | gd_count: [  3.  18.  33.  97. 118.] | prediction: [  6.7417803  19.756748   33.16746    88.46847   106.46133  ] | loss: 16.880205154418945
epoch: 2 | step: 12 | gd_count: [15. 63.  3. 19. 86.] | prediction: [ 9.57955   50.385887   2.7935371 18.87825   89.87189  ] | loss: 14.299142837524414
epoch: 2 | step: 13 | gd_count: [ 21. 202.  66.  53.  57.] | prediction: [ 26.40318  148.8189    34.932262  38.110413  27.678074] | loss: 26.815332412719727
epoch: 2 | step: 14 | gd_count: [ 19.   4.  53. 192.   4.] | prediction: [ 18.715279    7.1553926  53.43641   104.874985    1.6479177] | loss: 23.211610794067383
epoch: 2 | step: 15 | gd_count: [131.  49.  38.  25. 283.] | prediction: [ 81.42499   19.083776  21.6892     9.817995 269.53445 ] | loss: 36.66590118408203
epoch: 2 | step: 16 | gd_count: [ 34.  32.  24. 270. 156.] |

epoch: 3 | step: 15 | gd_count: [ 90.  39. 201.  25.  34.] | prediction: [42.46851  22.81834  97.86815   9.874573 22.816462] | loss: 33.63671875
epoch: 3 | step: 16 | gd_count: [172.  11.   0.  21.  62.] | prediction: [47.961506 10.592445 13.482571 14.606262 55.67321 ] | loss: 33.142024993896484
epoch: 3 | step: 17 | gd_count: [58. 32. 38.  1. 80.] | prediction: [55.991917 24.378296 32.564877  7.662466 40.954285] | loss: 15.885504722595215
epoch: 3 | step: 18 | gd_count: [36.  0. 40. 54. 39.] | prediction: [33.962883  5.096125 38.595676 26.98198  39.15174 ] | loss: 13.570784568786621
epoch: 3 | step: 19 | gd_count: [15. 50. 47. 64. 69.] | prediction: [16.835644 57.327473 56.256527 70.040955 75.5211  ] | loss: 20.264509201049805
epoch: 3 | step: 20 | gd_count: [238.  19. 131.  31.  59.] | prediction: [101.20888   15.154784  81.00964   31.806545  51.720516] | loss: 43.66622543334961
epoch: 3 | step: 21 | gd_count: [ 68.  98. 142.  23.  21.] | prediction: [ 41.37961  112.50323   82.65572 

epoch: 4 | step: 20 | gd_count: [ 70.  86.  34. 208.  18.] | prediction: [ 34.14246   71.51673   36.140625 155.66164   16.549746] | loss: 31.723691940307617
epoch: 4 | step: 21 | gd_count: [ 68. 139.  64.  13.  42.] | prediction: [ 54.373363 115.21193   50.45674   19.552317  25.45932 ] | loss: 20.823532104492188
epoch: 4 | step: 22 | gd_count: [ 83.  44. 152.   0. 114.] | prediction: [58.087433  36.921623  86.85936   -1.4357965 78.23598  ] | loss: 29.713117599487305
epoch: 4 | step: 23 | gd_count: [ 26.  20.  16.  17. 134.] | prediction: [ 14.610017   9.311903 153.45158    8.082491 136.5534  ] | loss: 41.78010177612305
epoch: 4 | step: 24 | gd_count: [ 35.   6.  79.  12. 414.] | prediction: [ 24.112793    7.5841756  47.01381     2.7801871 285.01257  ] | loss: 37.77671432495117
epoch: 4 | step: 25 | gd_count: [ 23.  53. 129.   3.  32.] | prediction: [ 14.369108   49.384735  120.3902      3.4530272  25.286526 ] | loss: 11.095305442810059
epoch: 4 | step: 26 | gd_count: [ 16.  21. 109.  2

epoch: 5 | step: 25 | gd_count: [  5.   0. 115.   0.   4.] | prediction: [  8.8522835   2.6001945 128.84302     3.126608    6.4346066] | loss: 11.361165046691895
epoch: 5 | step: 26 | gd_count: [ 0. 61. 96. 39. 28.] | prediction: [ 1.5402966 56.49762   72.50682   41.67343   20.55759  ] | loss: 12.478275299072266
epoch: 5 | step: 27 | gd_count: [167. 102.   8.  49.   3.] | prediction: [142.1933     72.50534     8.405947   22.497746    0.9549729] | loss: 26.802616119384766
epoch: 5 | step: 28 | gd_count: [79. 85. 36.  7. 13.] | prediction: [94.341064 61.574665 29.083176  8.753103 10.477627] | loss: 14.48108196258545
epoch: 5 | step: 29 | gd_count: [117.  33.  12. 106. 101.] | prediction: [165.61913    19.985674    9.7449665  75.44882    63.34909  ] | loss: 31.157245635986328
epoch: 5 | step: 30 | gd_count: [24. 38. 75. 57. 27.] | prediction: [17.52157  41.794895 50.671196 44.29342  14.486773] | loss: 14.641341209411621
epoch: 5 | step: 31 | gd_count: [ 77.  71.  51.  76. 141.] | predicti

epoch: 6 | step: 30 | gd_count: [143.  60.  72.  12.   2.] | prediction: [104.15965    44.401474   60.4842     10.738521    6.1708593] | loss: 17.751285552978516
epoch: 6 | step: 31 | gd_count: [432.  11. 192.  40. 337.] | prediction: [115.5075    12.692657 137.63025   33.844086 154.8566  ] | loss: 111.13063049316406
epoch: 6 | step: 32 | gd_count: [ 12.  41.  80.  16. 142.] | prediction: [  8.021019  30.853287  51.95163   25.142698 116.579636] | loss: 21.398923873901367
epoch: 6 | step: 33 | gd_count: [ 29. 132.  41.  31.   9.] | prediction: [ 31.943308 101.4279    41.98281   36.140026   9.849774] | loss: 14.035514831542969
epoch: 6 | step: 34 | gd_count: [24. 62. 38.  5. 66.] | prediction: [38.418068 63.335686 36.18016   8.955727 75.73056 ] | loss: 13.833935737609863
epoch: 6 | step: 35 | gd_count: [11.  7. 34. 21. 13.] | prediction: [13.040897  6.516342 36.258354 17.378567 14.697576] | loss: 5.023139953613281
epoch: 6 | step: 36 | gd_count: [ 28. 309.  83.   1.  91.] | prediction: [

epoch: 7 | step: 35 | gd_count: [ 34. 450. 160.   0. 110.] | prediction: [ 46.537163  404.11115    49.525276    2.3113713  98.02929  ] | loss: 62.68502426147461
epoch: 7 | step: 36 | gd_count: [25.  6.  0. 69. 56.] | prediction: [14.381216   7.151761   2.5818849 75.17552   49.621254 ] | loss: 13.936271667480469
epoch: 7 | step: 37 | gd_count: [232.  19. 146. 102.  15.] | prediction: [168.88316    8.453027 160.64645   60.611     14.091144] | loss: 34.65073776245117
epoch: 7 | step: 38 | gd_count: [23. 71. 49. 64. 11.] | prediction: [ 8.543174 53.63323  29.28818  44.14187   7.24712 ] | loss: 14.475946426391602
epoch: 7 | step: 39 | gd_count: [232.  52.  15. 263.  39.] | prediction: [213.7023    29.625107  12.298586 124.77117   25.625645] | loss: 46.555824279785156
epoch: 7 | step: 40 | gd_count: [72.  4. 39.  3. 81.] | prediction: [33.78591    1.2209344 27.846348   1.0934583 59.35771  ] | loss: 13.540634155273438
epoch: 7 | step: 41 | gd_count: [16. 45. 43.  8.  9.] | prediction: [10.322

epoch: 8 | step: 40 | gd_count: [ 79. 115.  55.  58. 116.] | prediction: [ 51.95101   89.909935  50.29191   59.11248  117.090225] | loss: 21.101909637451172
epoch: 8 | step: 41 | gd_count: [13. 60. 41. 65. 84.] | prediction: [10.713011 40.367535 19.719452 57.850704 90.71675 ] | loss: 14.869235038757324
epoch: 8 | step: 42 | gd_count: [17. 25. 98. 42.  4.] | prediction: [13.403273 22.212866 61.902    17.130062  2.73224 ] | loss: 13.392512321472168
epoch: 8 | step: 43 | gd_count: [19. 20. 96. 29. 10.] | prediction: [11.81189    9.448845  86.174934  22.55803    4.2325797] | loss: 9.8380126953125
epoch: 8 | step: 44 | gd_count: [ 53.  25.   0. 241.   4.] | prediction: [ 59.057228    25.376625    -0.62382674 171.07529      6.7968225 ] | loss: 19.59474754333496
epoch: 8 | step: 45 | gd_count: [  6. 109.  74.  38.   0.] | prediction: [  1.8425343  56.934     103.981445   30.171877    7.4803524] | loss: 22.27608299255371
epoch: 8 | step: 46 | gd_count: [ 39. 101.  53.   5.  47.] | prediction: 

epoch: 9 | step: 45 | gd_count: [ 46.  57. 266. 118. 274.] | prediction: [ 53.97571   65.0277   225.04785  125.418304 328.8323  ] | loss: 58.57196044921875
epoch: 9 | step: 46 | gd_count: [ 42.  58.  63. 605.  93.] | prediction: [ 44.99035   62.497425  66.96901  483.32547   92.985374] | loss: 66.33150482177734
epoch: 9 | step: 47 | gd_count: [111. 125. 112.  29.  51.] | prediction: [137.47461  109.26146   95.89746   15.537491  57.947754] | loss: 28.740766525268555
Epoch 9 Val, MSE: 281.28 MAE: 139.29, Cost 8.5 sec
epoch: 10 | step: 0 | gd_count: [31. 24. 78.  4. 17.] | prediction: [37.720642 20.970337 60.83126   5.4887   19.928604] | loss: 9.188592910766602
epoch: 10 | step: 1 | gd_count: [ 31.   0.  85.  28. 163.] | prediction: [27.07143    1.5272303 60.781494  26.405579  80.54994  ] | loss: 24.799936294555664
epoch: 10 | step: 2 | gd_count: [102.  17.   7.  40.  39.] | prediction: [63.996086  14.603846   6.6386547 30.804806  24.360699 ] | loss: 11.938624382019043
epoch: 10 | step: 3 

epoch: 11 | step: 1 | gd_count: [39. 24. 37. 46. 62.] | prediction: [27.7998   22.229454 37.21581  47.375237 63.76584 ] | loss: 12.048484802246094
epoch: 11 | step: 2 | gd_count: [295.  13.  11. 115.  37.] | prediction: [321.42163     7.927743    7.4883065 231.018      24.403955 ] | loss: 47.22829055786133
epoch: 11 | step: 3 | gd_count: [15.  0. 77.  0. 38.] | prediction: [ 7.188755   -0.5902641  75.24182     0.21343674 28.523056  ] | loss: 8.153709411621094
epoch: 11 | step: 4 | gd_count: [  8. 223. 100.  73.   5.] | prediction: [  2.4773138 230.13213    68.6175     48.94032     2.8429577] | loss: 24.03530502319336
epoch: 11 | step: 5 | gd_count: [ 8. 11. 10. 36. 77.] | prediction: [ 6.9583735  9.1481     3.3476825 20.693737  57.527096 ] | loss: 8.58241081237793
epoch: 11 | step: 6 | gd_count: [ 3. 73.  3. 52. 15.] | prediction: [ 0.5170214 91.84959    1.9374294 30.177177  15.289881 ] | loss: 11.331661224365234
epoch: 11 | step: 7 | gd_count: [ 45.  95. 149.  97. 228.] | prediction: 

epoch: 12 | step: 5 | gd_count: [10. 21. 89. 50. 58.] | prediction: [13.391606 23.425331 85.23958  49.821033 58.309624] | loss: 13.322183609008789
epoch: 12 | step: 6 | gd_count: [126.  70.  51.  34.  26.] | prediction: [139.34396   64.806946  64.29504   40.6699    34.886982] | loss: 26.334171295166016
epoch: 12 | step: 7 | gd_count: [83. 48. 51. 65.  9.] | prediction: [83.35133  33.465668 37.019257 42.650127  9.19079 ] | loss: 13.516563415527344
epoch: 12 | step: 8 | gd_count: [ 11.  17. 259. 118. 132.] | prediction: [ 11.632462  18.994448 203.38663   89.71688   86.969284] | loss: 29.450571060180664
epoch: 12 | step: 9 | gd_count: [382.  36. 117.  43. 126.] | prediction: [233.34767   29.67635   75.462906  31.42393  101.55441 ] | loss: 50.90731430053711
epoch: 12 | step: 10 | gd_count: [ 47.  56.  11.  83. 103.] | prediction: [ 35.626984   52.08575     7.1089306  70.82733   133.79333  ] | loss: 20.234512329101562
epoch: 12 | step: 11 | gd_count: [ 72.  22. 276.   9. 260.] | prediction:

epoch: 13 | step: 10 | gd_count: [46. 78. 21. 61.  2.] | prediction: [45.683914 57.883392 15.132906 60.005028  5.077318] | loss: 12.516983985900879
epoch: 13 | step: 11 | gd_count: [839.  25.  76.  55.  15.] | prediction: [520.0475    22.966188  69.83238   57.07938   11.096958] | loss: 88.4639663696289
epoch: 13 | step: 12 | gd_count: [ 26. 110. 473.  53. 186.] | prediction: [ 25.745157 102.103806 217.93816   48.55542  214.76413 ] | loss: 73.99052429199219
epoch: 13 | step: 13 | gd_count: [ 78.   9.   0. 115. 292.] | prediction: [5.3497429e+01 1.0053075e+01 9.0545408e-02 8.3615219e+01 3.5570508e+02] | loss: 38.4477653503418
epoch: 13 | step: 14 | gd_count: [ 62.  31. 134. 203. 178.] | prediction: [ 61.2229    29.957518  98.339874 130.10889  169.70865 ] | loss: 37.197357177734375
epoch: 13 | step: 15 | gd_count: [ 40.  11.  10.  30. 114.] | prediction: [41.95079    6.8663225  4.678816  23.073416  85.741196 ] | loss: 11.779826164245605
epoch: 13 | step: 16 | gd_count: [ 46. 158.   6. 144

epoch: 14 | step: 14 | gd_count: [127. 224.   9.  75.  21.] | prediction: [ 99.605194  192.86809     6.4824557  93.99048    25.834457 ] | loss: 31.1446533203125
epoch: 14 | step: 15 | gd_count: [ 75. 338.  12.  40.  17.] | prediction: [ 72.59004  262.28525    8.345288  36.341232  24.884544] | loss: 28.507719039916992
epoch: 14 | step: 16 | gd_count: [  8.  65.  42. 101. 580.] | prediction: [  2.6410222  87.01695    48.587753  109.68813   392.2726   ] | loss: 66.47216796875
epoch: 14 | step: 17 | gd_count: [ 57. 114.  34. 153.  45.] | prediction: [140.64886  108.674255  40.53659  132.15222   37.078724] | loss: 37.287052154541016
epoch: 14 | step: 18 | gd_count: [117. 127.  71.  52. 116.] | prediction: [101.545105 167.20558   65.05738   59.696613 143.61134 ] | loss: 35.1878662109375
epoch: 14 | step: 19 | gd_count: [28. 66. 16. 19. 44.] | prediction: [28.546358 67.829    13.095875 17.319786 45.096916] | loss: 10.007534980773926
epoch: 14 | step: 20 | gd_count: [199.   7.  34.  34.  48.] 

epoch: 15 | step: 19 | gd_count: [67. 61.  8. 61.  9.] | prediction: [51.12573  56.741837  6.411185 40.65114  12.506857] | loss: 11.740550994873047
epoch: 15 | step: 20 | gd_count: [174.  29. 114. 442.  39.] | prediction: [133.68503   22.914078 178.46884  168.38329   35.03045 ] | loss: 84.98029327392578
epoch: 15 | step: 21 | gd_count: [139.  73. 209. 124.  29.] | prediction: [118.6142   66.84443 145.97168 116.43622  31.77066] | loss: 35.921234130859375
epoch: 15 | step: 22 | gd_count: [ 65.  34.   4.  56. 170.] | prediction: [ 56.10828    47.054607    7.7870646  38.376404  207.10246  ] | loss: 26.112884521484375
epoch: 15 | step: 23 | gd_count: [ 29.  56.  60. 420.  15.] | prediction: [ 34.688164  53.612953  81.47001  231.80888   17.661882] | loss: 51.3578987121582
epoch: 15 | step: 24 | gd_count: [ 19.  82. 188.  47.  86.] | prediction: [ 29.3222    94.291336 202.19891   49.606987  55.03015 ] | loss: 30.752826690673828
epoch: 15 | step: 25 | gd_count: [ 67.   6.  42. 199. 137.] | pre

epoch: 16 | step: 24 | gd_count: [ 60.  22. 198. 119. 122.] | prediction: [ 49.490047  15.145399 187.75766   97.69223  119.438324] | loss: 28.64165687561035
epoch: 16 | step: 25 | gd_count: [ 94.   7.  29.   7. 128.] | prediction: [ 91.73491     3.5561404  23.6637      3.5885775 102.120026 ] | loss: 15.37863826751709
epoch: 16 | step: 26 | gd_count: [ 34. 138.   6.  21.   0.] | prediction: [ 13.506582  104.55664     0.5243222   7.55394     0.6558585] | loss: 13.142003059387207
epoch: 16 | step: 27 | gd_count: [ 50.  82.   4.   0. 170.] | prediction: [ 24.042442    74.45981      2.2751741   -0.50042415 123.034424  ] | loss: 21.243711471557617
epoch: 16 | step: 28 | gd_count: [592.  89. 201. 209. 136.] | prediction: [368.69397   71.82326  130.50075   88.483536  90.42105 ] | loss: 91.57911682128906
epoch: 16 | step: 29 | gd_count: [10. 45. 22. 33. 17.] | prediction: [ 6.0456767 30.95521   12.343114  27.924967  12.85045  ] | loss: 6.123208522796631
epoch: 16 | step: 30 | gd_count: [107.  5

epoch: 17 | step: 28 | gd_count: [259.  54.  23. 150.  46.] | prediction: [223.29565   51.400284  21.0017   117.38532   74.6996  ] | loss: 32.7799072265625
epoch: 17 | step: 29 | gd_count: [118. 188.  56. 212. 159.] | prediction: [ 98.226776 194.45111   39.74209  139.89166  144.47913 ] | loss: 38.99576950073242
epoch: 17 | step: 30 | gd_count: [48. 19. 31. 53. 78.] | prediction: [50.003708 22.846977 28.705318 56.315994 89.36804 ] | loss: 15.4281587600708
epoch: 17 | step: 31 | gd_count: [ 33.  37.  40. 142.  16.] | prediction: [35.17833  39.024254 47.01889  80.81822  14.401876] | loss: 18.774028778076172
epoch: 17 | step: 32 | gd_count: [18. 33. 25. 21. 29.] | prediction: [19.160347 40.470554 19.282988 25.45287  29.954632] | loss: 8.047608375549316
epoch: 17 | step: 33 | gd_count: [309.  97.  38.  20.   7.] | prediction: [256.22522   61.252747  35.18735   23.542988   9.448749] | loss: 26.65614128112793
epoch: 17 | step: 34 | gd_count: [37. 78. 88. 20. 54.] | prediction: [29.743649 75.4

epoch: 18 | step: 33 | gd_count: [ 64.  74.  28.   8. 200.] | prediction: [ 44.906418   41.17031    29.557213    2.6882887 149.43018  ] | loss: 24.610668182373047
epoch: 18 | step: 34 | gd_count: [ 27.  16.  32. 109.  43.] | prediction: [23.86883  16.962248 28.94026  65.44157  38.165535] | loss: 13.670987129211426
epoch: 18 | step: 35 | gd_count: [54. 38. 49. 14. 76.] | prediction: [58.852203 38.546616 44.981827 14.945168 70.848175] | loss: 16.149681091308594
epoch: 18 | step: 36 | gd_count: [12.  5. 39. 60.  0.] | prediction: [ 9.129662   3.9469562 40.213303  46.235275  -0.5879103] | loss: 5.402559280395508
epoch: 18 | step: 37 | gd_count: [ 9. 63. 65. 35. 43.] | prediction: [ 4.5163364 49.96476   60.997658  26.338318  35.86268  ] | loss: 9.425804138183594
epoch: 18 | step: 38 | gd_count: [151.  11.   5.  44.  58.] | prediction: [113.8079      7.303581    1.1134715  23.108616   42.007957 ] | loss: 15.966293334960938
epoch: 18 | step: 39 | gd_count: [12. 64. 44. 27. 12.] | prediction: 

epoch: 19 | step: 38 | gd_count: [  9.  32. 110.  39.  15.] | prediction: [ 4.372707 12.59095  77.60701  27.31486  12.97106 ] | loss: 13.880915641784668
epoch: 19 | step: 39 | gd_count: [246.  42.  67.   5.  61.] | prediction: [2.4682607e+02 3.6078999e+01 4.6535511e+01 1.6605595e-01 2.9692894e+01] | loss: 23.7128849029541
epoch: 19 | step: 40 | gd_count: [  6. 177. 125.  92.  63.] | prediction: [  1.4366765 103.520424  152.9111     77.78824    49.34861  ] | loss: 35.6036262512207
epoch: 19 | step: 41 | gd_count: [ 67. 122.  10.   3. 194.] | prediction: [ 59.841232    88.5901       5.8448124    0.33510774 155.40086   ] | loss: 24.07880401611328
epoch: 19 | step: 42 | gd_count: [32. 53.  3.  8. 70.] | prediction: [24.978119  39.30371   -1.934061   5.3978896 47.196648 ] | loss: 9.257918357849121
epoch: 19 | step: 43 | gd_count: [ 26.  23. 197. 302.  38.] | prediction: [ 21.142242  12.049675 106.14493  182.28831   30.343395] | loss: 50.65007781982422
epoch: 19 | step: 44 | gd_count: [145. 

epoch: 20 | step: 42 | gd_count: [ 15. 118.  98.  15.  56.] | prediction: [ 14.971062  118.7424     39.18715     6.7801867  59.663628 ] | loss: 22.193580627441406
epoch: 20 | step: 43 | gd_count: [109.  30. 127.  12. 179.] | prediction: [109.47516   26.111582  97.04195   11.076023 132.20695 ] | loss: 28.663360595703125
epoch: 20 | step: 44 | gd_count: [ 19. 140. 127.  57. 137.] | prediction: [ 13.898226 126.79443  131.3329    57.24135   97.41777 ] | loss: 25.675521850585938
epoch: 20 | step: 45 | gd_count: [124.  95.  64.  26.  49.] | prediction: [121.49026   65.42034   78.35327   20.512447  44.143005] | loss: 21.506277084350586
epoch: 20 | step: 46 | gd_count: [ 36.  64. 105.  83.  17.] | prediction: [16.96409  70.74466  87.493256 59.851463 14.746305] | loss: 16.07781982421875
epoch: 20 | step: 47 | gd_count: [ 41.   4.  19. 506.  41.] | prediction: [ 34.42952    6.966336  16.74018  296.0108    40.5566  ] | loss: 50.41986846923828
Epoch 20 Val, MSE: 268.94 MAE: 133.82, Cost 8.5 sec
ep

epoch: 21 | step: 47 | gd_count: [18. 38.  3. 78. 10.] | prediction: [15.909943 22.298225  8.876205 38.73013   8.217007] | loss: 11.197053909301758
Epoch 21 Val, MSE: 220.42 MAE: 117.19, Cost 8.4 sec
epoch: 22 | step: 0 | gd_count: [ 23.  58.  56.  11. 131.] | prediction: [ 23.153786  50.068645  62.00824    8.638318 164.35559 ] | loss: 19.64859390258789
epoch: 22 | step: 1 | gd_count: [102.  51.  40.   7.  40.] | prediction: [77.04428   40.60285   33.268074   9.0960045 27.944328 ] | loss: 13.873321533203125
epoch: 22 | step: 2 | gd_count: [ 22.  34.   8. 211.  11.] | prediction: [ 15.652381  27.443138   9.82733  148.97382   13.472164] | loss: 16.389217376708984
epoch: 22 | step: 3 | gd_count: [ 16.  34. 103.  91. 377.] | prediction: [ 16.252201  27.419994 103.55408   70.47656  305.99246 ] | loss: 31.45124626159668
epoch: 22 | step: 4 | gd_count: [67.  2.  4. 85. 44.] | prediction: [75.2277     9.1735325 14.940752  72.48741   36.044037 ] | loss: 15.831130981445312
epoch: 22 | step: 5 | 

epoch: 23 | step: 3 | gd_count: [ 27. 160.  71.   0. 127.] | prediction: [ 11.969589  127.29126    59.12942     4.1302786 136.05157  ] | loss: 21.573637008666992
epoch: 23 | step: 4 | gd_count: [65. 13. 59. 23.  0.] | prediction: [54.316883  12.139849  54.63555   13.31029    1.4185176] | loss: 8.635583877563477
epoch: 23 | step: 5 | gd_count: [ 78.   6.  71.  32. 433.] | prediction: [ 64.73157     2.7284777  38.268433   20.22221   392.05157  ] | loss: 36.93195343017578
epoch: 23 | step: 6 | gd_count: [53. 19. 29.  4. 35.] | prediction: [40.10605  10.476978 21.655525  2.49341  31.492065] | loss: 6.18536901473999
epoch: 23 | step: 7 | gd_count: [220.  30.  84.  19.  72.] | prediction: [133.07333   25.353436  63.574684  17.51389   55.60781 ] | loss: 26.617231369018555
epoch: 23 | step: 8 | gd_count: [ 73. 159.  53.  98.  38.] | prediction: [55.00485  93.980286 34.77858  83.10739  25.821243] | loss: 28.72047996520996
epoch: 23 | step: 9 | gd_count: [135. 123.  73.  15.  38.] | prediction: 

epoch: 24 | step: 7 | gd_count: [15. 61. 18. 86. 37.] | prediction: [16.236221 40.931828 17.652267 46.393448 39.35756 ] | loss: 15.481358528137207
epoch: 24 | step: 8 | gd_count: [ 4. 53. 35. 47. 47.] | prediction: [ 7.4325523 32.923325  42.404015  23.534918  50.19882  ] | loss: 15.370747566223145
epoch: 24 | step: 9 | gd_count: [ 81.  18.  24. 266.  20.] | prediction: [ 86.30215   15.604904  26.373169 283.1963    23.1953  ] | loss: 24.93033218383789
epoch: 24 | step: 10 | gd_count: [ 60. 136. 202. 129.  12.] | prediction: [ 53.914543 138.77591  152.62735  102.06403    8.594066] | loss: 31.337018966674805
epoch: 24 | step: 11 | gd_count: [72. 52. 12.  1. 50.] | prediction: [63.888794  66.119286   9.140617   1.5809476 51.045364 ] | loss: 12.1945161819458
epoch: 24 | step: 12 | gd_count: [ 15. 134. 133.  17.  95.] | prediction: [ 26.135994 125.37422  123.16231   12.843283  79.006065] | loss: 22.48685073852539
epoch: 24 | step: 13 | gd_count: [ 4. 30. 31.  7.  3.] | prediction: [ 2.008330

epoch: 25 | step: 12 | gd_count: [ 19.  32.   0.  97. 227.] | prediction: [  6.7872      19.673672    -0.17847744  61.621758   150.97302   ] | loss: 27.328214645385742
epoch: 25 | step: 13 | gd_count: [ 75. 175.  29. 150.  10.] | prediction: [ 62.26751  151.5881    18.610924  87.60306    6.584139] | loss: 25.510343551635742
epoch: 25 | step: 14 | gd_count: [ 92.   9.   2.  16. 117.] | prediction: [ 74.810745    2.5346212   0.5340011   8.543373  106.73608  ] | loss: 14.562960624694824
epoch: 25 | step: 15 | gd_count: [  0. 141. 235.  76.  80.] | prediction: [  1.4417207 113.7953    238.73651    43.137352   65.318344 ] | loss: 31.23407554626465
epoch: 25 | step: 16 | gd_count: [  7. 202.  60. 162.  71.] | prediction: [  3.5995016 113.93883    42.10028   120.86751    77.6602   ] | loss: 38.79234313964844
epoch: 25 | step: 17 | gd_count: [ 14. 815.   1.  25.  31.] | prediction: [ 11.834541  472.44452     2.8281417  23.567451   35.601646 ] | loss: 76.99822235107422
epoch: 25 | step: 18 | gd

epoch: 26 | step: 16 | gd_count: [75.  0. 53. 14. 39.] | prediction: [58.31701    0.2488674 60.66291   13.080491  56.91922  ] | loss: 12.512436866760254
epoch: 26 | step: 17 | gd_count: [ 83.  15. 120.  73.  59.] | prediction: [ 69.60817  12.02169 131.67535  83.10365  68.97107] | loss: 21.2098445892334
epoch: 26 | step: 18 | gd_count: [ 18.  55.  85.  22. 216.] | prediction: [ 16.530193  47.41913   79.742256  19.870104 169.08438 ] | loss: 20.24386978149414
epoch: 26 | step: 19 | gd_count: [244.  91.  40.  48.   8.] | prediction: [181.77983    69.26013    56.65941    44.071358    5.7765665] | loss: 26.183795928955078
epoch: 26 | step: 20 | gd_count: [ 8. 56. 60. 24. 52.] | prediction: [ 1.0734625 41.469368  51.052113  15.362431  52.35468  ] | loss: 11.119131088256836
epoch: 26 | step: 21 | gd_count: [38. 82. 23. 67. 49.] | prediction: [ 19.552393  113.9975      5.5124903  54.856262   40.281574 ] | loss: 23.38399314880371
epoch: 26 | step: 22 | gd_count: [139.   8.  59.  30.  23.] | pred

epoch: 27 | step: 21 | gd_count: [  0. 107.  90. 278.  21.] | prediction: [  1.5670087  50.81207    62.717728  214.91907    22.536552 ] | loss: 34.76363754272461
epoch: 27 | step: 22 | gd_count: [  1.  24.  67. 108. 173.] | prediction: [  6.5843854  21.97416    71.013885   99.57764   152.03821  ] | loss: 20.933134078979492
epoch: 27 | step: 23 | gd_count: [ 18.  37. 107.  19.  60.] | prediction: [18.389427 24.940315 91.53895  28.487785 56.379246] | loss: 14.018144607543945
epoch: 27 | step: 24 | gd_count: [ 70.  38.  12.   3. 107.] | prediction: [ 72.646706   38.22084    18.378277    2.6761947 133.1838   ] | loss: 15.905435562133789
epoch: 27 | step: 25 | gd_count: [131.  83.  32.  19. 247.] | prediction: [ 82.9414    93.0379    34.620277  22.705343 316.80682 ] | loss: 42.71382522583008
epoch: 27 | step: 26 | gd_count: [ 92.   9. 119. 149. 269.] | prediction: [ 80.53283   12.597582  74.09837  136.46962  253.36667 ] | loss: 32.610939025878906
epoch: 27 | step: 27 | gd_count: [  6.  35. 

epoch: 28 | step: 25 | gd_count: [159.  18. 139.  35.  39.] | prediction: [95.69234  10.401699 92.67594  22.138859 37.330692] | loss: 27.36228370666504
epoch: 28 | step: 26 | gd_count: [87.  6. 61. 70. 29.] | prediction: [56.830917  4.636548 72.925934 58.034615 30.763056] | loss: 14.330798149108887
epoch: 28 | step: 27 | gd_count: [  6.  40. 145.  35. 210.] | prediction: [  3.2862372   8.297009   98.36218    21.488672  155.65921  ] | loss: 33.329925537109375
epoch: 28 | step: 28 | gd_count: [ 23. 297.   7.  28.  34.] | prediction: [ 16.327568  176.80507     1.1634486  21.283384   27.973488 ] | loss: 28.474523544311523
epoch: 28 | step: 29 | gd_count: [ 1. 38. 12.  7. 53.] | prediction: [ 1.5714293 33.32019   10.363661   4.925205  56.753334 ] | loss: 6.088871002197266
epoch: 28 | step: 30 | gd_count: [ 17. 252.  68.  25. 469.] | prediction: [ 12.612631 290.2813    48.36554   19.237352 238.42543 ] | loss: 71.12267303466797
epoch: 28 | step: 31 | gd_count: [ 16. 214.  91.  60.  11.] | pre

epoch: 29 | step: 29 | gd_count: [ 0. 83. 42. 19. 77.] | prediction: [ 2.0561223 82.223145  32.71946   16.485634  62.738457 ] | loss: 11.230334281921387
epoch: 29 | step: 30 | gd_count: [  9. 155.  61.  60. 100.] | prediction: [  6.3925476 114.54741    48.37018    52.978386   76.365005 ] | loss: 20.175527572631836
epoch: 29 | step: 31 | gd_count: [111.  57.  28. 125.  18.] | prediction: [ 93.05729   41.614365  23.548904 125.73505   17.846495] | loss: 18.061077117919922
epoch: 29 | step: 32 | gd_count: [271.  12.  62.  13.  12.] | prediction: [271.73907   10.890682  60.3521    20.22252   12.447409] | loss: 22.03272819519043
epoch: 29 | step: 33 | gd_count: [ 55.  11. 151.  17.   8.] | prediction: [ 53.529537  12.319565 127.019455  14.427942   8.750786] | loss: 12.643518447875977
epoch: 29 | step: 34 | gd_count: [ 30. 147. 119.  48. 198.] | prediction: [ 27.522726 115.45419  128.93224   38.58963  202.21507 ] | loss: 32.13191223144531
epoch: 29 | step: 35 | gd_count: [188.  27.  41.  61. 

epoch: 30 | step: 34 | gd_count: [317. 164. 208.  94.  66.] | prediction: [211.90437  128.61232  219.03258  126.969154  51.711857] | loss: 55.620357513427734
epoch: 30 | step: 35 | gd_count: [ 26. 269.  62.  26.  11.] | prediction: [ 16.945065 172.43307   66.21688   19.525496   9.327723] | loss: 24.113300323486328
epoch: 30 | step: 36 | gd_count: [53. 85.  9. 18. 15.] | prediction: [43.4858   68.38399   5.319749 13.817495 13.209251] | loss: 7.197824954986572
epoch: 30 | step: 37 | gd_count: [56. 17. 67. 14.  7.] | prediction: [54.383274  79.57127   59.871956  10.343613   6.1882405] | loss: 17.50821304321289
epoch: 30 | step: 38 | gd_count: [ 42. 113.   3.  66.  38.] | prediction: [23.845201 95.0952    3.671833 46.918755 37.661064] | loss: 15.379325866699219
epoch: 30 | step: 39 | gd_count: [190.  17. 222.  38. 157.] | prediction: [122.79718   11.750329 109.07852   34.481445  92.803024] | loss: 52.50236892700195
epoch: 30 | step: 40 | gd_count: [ 30. 219.  75.  16.  77.] | prediction: [

epoch: 31 | step: 38 | gd_count: [207.  88.  63. 210.  17.] | prediction: [197.43008   96.53529   71.049774 278.02435   14.683864] | loss: 42.21645736694336
epoch: 31 | step: 39 | gd_count: [14.  5.  0. 59. 77.] | prediction: [15.055458   5.741584   2.9296575 71.7481    68.72778  ] | loss: 8.819254875183105
epoch: 31 | step: 40 | gd_count: [44. 25. 12.  5. 36.] | prediction: [39.359673 23.429089 11.753797 10.321104 34.016975] | loss: 6.02640438079834
epoch: 31 | step: 41 | gd_count: [ 62.   9. 104.  88.  28.] | prediction: [34.283875  6.48291  65.518036 74.511856 18.471127] | loss: 18.385334014892578
epoch: 31 | step: 42 | gd_count: [ 59.   2.  20. 149.  70.] | prediction: [ 51.984455    3.1211112  17.906527  111.70961    54.835182 ] | loss: 15.5278959274292
epoch: 31 | step: 43 | gd_count: [  6. 517.  54.   0.  27.] | prediction: [ 10.34543   250.0423     43.2828      2.5120296  26.254875 ] | loss: 61.7443962097168
epoch: 31 | step: 44 | gd_count: [155.  61.  39. 192.  33.] | predicti

epoch: 32 | step: 43 | gd_count: [47.  9. 68. 36. 16.] | prediction: [49.318962   7.0723305 60.952023  20.74744   14.20174  ] | loss: 8.735835075378418
epoch: 32 | step: 44 | gd_count: [ 34. 187.  72.  43.  51.] | prediction: [ 29.387585 115.82422   52.240063  46.398403  49.53543 ] | loss: 27.187726974487305
epoch: 32 | step: 45 | gd_count: [36. 67. 48. 79. 62.] | prediction: [ 23.80077   59.700615  43.601162 103.40335   51.793354] | loss: 21.61993408203125
epoch: 32 | step: 46 | gd_count: [130.  98.  43.   5.  69.] | prediction: [112.204056   96.042366   27.493805    2.9090564  77.81401  ] | loss: 16.215600967407227
epoch: 32 | step: 47 | gd_count: [ 35. 114.  38.  40.  83.] | prediction: [25.049658 78.445    26.596508 29.711044 49.92556 ] | loss: 19.365371704101562
Epoch 32 Val, MSE: 207.78 MAE: 110.33, Cost 8.5 sec
epoch: 33 | step: 0 | gd_count: [ 24.  79.  13. 198. 160.] | prediction: [ 19.456465  55.461197   6.160983 189.30862  134.77858 ] | loss: 27.66424560546875
epoch: 33 | st

epoch: 33 | step: 47 | gd_count: [34. 98. 59. 84. 16.] | prediction: [26.963451 70.82299  59.54339  65.84434  13.392997] | loss: 17.683629989624023
Epoch 33 Val, MSE: 253.83 MAE: 138.49, Cost 8.5 sec
epoch: 34 | step: 0 | gd_count: [  7.  22. 116. 267.  27.] | prediction: [  6.1733193  20.41203   101.780716  113.82512    24.797316 ] | loss: 41.846900939941406
epoch: 34 | step: 1 | gd_count: [176.  95.  49.  16. 469.] | prediction: [154.69885   89.350204  37.93546   11.66651  432.25317 ] | loss: 47.22480010986328
epoch: 34 | step: 2 | gd_count: [ 12.  44. 284.  39. 405.] | prediction: [  5.669673  32.39248  259.25876   30.369434 310.5689  ] | loss: 53.855384826660156
epoch: 34 | step: 3 | gd_count: [ 40. 120.  11.  23.  51.] | prediction: [ 32.97963  122.98223    7.14962   20.578623  27.40244 ] | loss: 13.6797513961792
epoch: 34 | step: 4 | gd_count: [ 44. 244.  48.   5.  14.] | prediction: [ 44.197113 214.56918   36.675537   2.850001   9.790057] | loss: 19.709644317626953
epoch: 34 | s

epoch: 35 | step: 3 | gd_count: [46. 29. 13.  8. 49.] | prediction: [57.812695  30.299923  11.692175   5.3639183 54.12972  ] | loss: 12.231708526611328
epoch: 35 | step: 4 | gd_count: [ 0. 54. 49. 43. 53.] | prediction: [ 1.5293161 47.66744   50.69568   35.772827  41.18955  ] | loss: 9.499671936035156
epoch: 35 | step: 5 | gd_count: [ 48. 100.  86.  20.  17.] | prediction: [37.012817 89.10089  73.712715 18.089935 12.045616] | loss: 13.057021141052246
epoch: 35 | step: 6 | gd_count: [109.   3. 168.  49. 156.] | prediction: [ 89.55727    2.463254 151.6404    39.12113  166.96228 ] | loss: 24.38555335998535
epoch: 35 | step: 7 | gd_count: [122.  55.  49.  73.   5.] | prediction: [77.51235  48.93766  42.985798 48.440605  4.019988] | loss: 16.52790641784668
epoch: 35 | step: 8 | gd_count: [112. 123.  32.  58.  41.] | prediction: [85.21423  85.06212  26.255966 53.963562 36.23127 ] | loss: 21.763219833374023
epoch: 35 | step: 9 | gd_count: [67. 19. 88. 33. 57.] | prediction: [75.4623   19.8642

epoch: 36 | step: 8 | gd_count: [  0. 131.   1. 165. 166.] | prediction: [ -0.26030332  68.790634     0.7565813  130.58041    121.6398    ] | loss: 33.3397216796875
epoch: 36 | step: 9 | gd_count: [115.  26.  84.  24.  19.] | prediction: [98.861725 18.678371 68.073456 27.159678 11.727253] | loss: 10.79654598236084
epoch: 36 | step: 10 | gd_count: [ 77.  13.  71.   1. 273.] | prediction: [ 65.26558      7.0679774   63.809868     0.50961506 295.59894   ] | loss: 27.15992546081543
epoch: 36 | step: 11 | gd_count: [ 98.   5. 154.  93. 457.] | prediction: [ 92.499954    2.6188815 166.60156    90.349655  322.05225  ] | loss: 58.02480697631836
epoch: 36 | step: 12 | gd_count: [63.  4. 22. 22. 64.] | prediction: [51.372543   1.6486952 15.480858  16.753513  41.161575 ] | loss: 9.349817276000977
epoch: 36 | step: 13 | gd_count: [260. 198.   2.   1.  49.] | prediction: [214.1499    215.3489      1.7276536   1.2992278  36.9139   ] | loss: 32.35512161254883
epoch: 36 | step: 14 | gd_count: [  2.  3

epoch: 37 | step: 12 | gd_count: [120.  27.  49.  41.  62.] | prediction: [138.30165   25.399574  40.638493  47.557606  61.503925] | loss: 16.714746475219727
epoch: 37 | step: 13 | gd_count: [141. 105.  11.   2.   0.] | prediction: [107.90818    97.60043     2.8695054  -1.1143556  -0.6371467] | loss: 16.226675033569336
epoch: 37 | step: 14 | gd_count: [105. 123. 186.  38.  49.] | prediction: [ 82.35242   91.088715 128.77052   26.427517  41.984123] | loss: 30.435302734375
epoch: 37 | step: 15 | gd_count: [ 30.  45.   3.  10. 159.] | prediction: [ 15.166824    31.006025     0.74344563   6.63806    129.58348   ] | loss: 13.159812927246094
epoch: 37 | step: 16 | gd_count: [ 51.  98. 181.   6.  51.] | prediction: [ 52.757633   53.73941   168.39784     1.3185182  35.602547 ] | loss: 30.96295738220215
epoch: 37 | step: 17 | gd_count: [102.  78.  92.  19.  52.] | prediction: [77.14545   73.72356   90.12045   11.5362015 36.19281  ] | loss: 22.13569450378418
epoch: 37 | step: 18 | gd_count: [264

epoch: 38 | step: 17 | gd_count: [ 30. 124.   3. 157.  18.] | prediction: [ 24.318188  108.895905    2.2025552 181.59436    21.053288 ] | loss: 25.745447158813477
epoch: 38 | step: 18 | gd_count: [  0.  14.  51.  22. 179.] | prediction: [  2.6810484   9.346016   42.508785   19.074623  166.45117  ] | loss: 14.273719787597656
epoch: 38 | step: 19 | gd_count: [391.  14.  50.  28. 162.] | prediction: [409.286       5.0828676  43.086807   21.100967  143.4119   ] | loss: 44.536373138427734
epoch: 38 | step: 20 | gd_count: [21. 82. 20. 35.  6.] | prediction: [13.620977 76.884     9.662094 21.779072  2.94866 ] | loss: 10.205486297607422
epoch: 38 | step: 21 | gd_count: [132.  30.  25. 117.  35.] | prediction: [94.04736  20.043072 16.948503 87.93756  21.070923] | loss: 19.750341415405273
epoch: 38 | step: 22 | gd_count: [ 6. 24. 85. 21. 14.] | prediction: [ 2.4010563 17.242733  49.269684  17.750574  11.118317 ] | loss: 10.181220054626465
epoch: 38 | step: 23 | gd_count: [ 25. 147.  32. 164.   6

epoch: 39 | step: 22 | gd_count: [ 44.  59.  80. 210. 147.] | prediction: [ 28.253998  67.09372   69.68695  185.6177   123.36314 ] | loss: 28.069196701049805
epoch: 39 | step: 23 | gd_count: [ 25.  10.  29. 242.  13.] | prediction: [ 16.402897   7.324732  20.954243 116.965775   7.087556] | loss: 27.929548263549805
epoch: 39 | step: 24 | gd_count: [ 65. 120. 221.  19.  61.] | prediction: [ 55.31895   98.32148  186.20749   17.845345  52.908836] | loss: 22.854564666748047
epoch: 39 | step: 25 | gd_count: [245.  56.  55.  30.  55.] | prediction: [191.6224    52.139     51.166695  16.886944  54.803715] | loss: 27.676773071289062
epoch: 39 | step: 26 | gd_count: [ 20.  71. 394.  14.  60.] | prediction: [ 14.613216  56.619217 385.4643    11.796516  61.186142] | loss: 31.869674682617188
epoch: 39 | step: 27 | gd_count: [31. 80. 45. 78. 88.] | prediction: [25.56486  85.885544 47.428574 68.84801  82.44597 ] | loss: 16.161029815673828
epoch: 39 | step: 28 | gd_count: [ 63.   6.  11.  33. 105.] | 

epoch: 40 | step: 27 | gd_count: [68. 36. 33. 30. 40.] | prediction: [53.463844 26.844156 27.616137 22.055748 24.952007] | loss: 10.436076164245605
epoch: 40 | step: 28 | gd_count: [ 60. 139.   4.  59.  29.] | prediction: [ 39.546745  116.27669     2.3117185  51.027225   21.908527 ] | loss: 14.846304893493652
epoch: 40 | step: 29 | gd_count: [ 15.   6. 106.  17.  14.] | prediction: [11.373831  7.608902 81.85702  12.427715  9.46877 ] | loss: 7.448346138000488
epoch: 40 | step: 30 | gd_count: [319.   4.  63.  45.  30.] | prediction: [257.1875      4.6507263  49.6643     33.884327   24.378468 ] | loss: 22.878881454467773
epoch: 40 | step: 31 | gd_count: [  9.  28. 181. 305.  63.] | prediction: [  7.5255737  24.923157  135.87616   303.51532    47.29217  ] | loss: 33.218421936035156
epoch: 40 | step: 32 | gd_count: [  8. 172.  45. 103.  14.] | prediction: [  6.087406 129.6761    36.8255    81.2573    15.706171] | loss: 19.808866500854492
epoch: 40 | step: 33 | gd_count: [ 25. 238.  24.  15.

epoch: 41 | step: 31 | gd_count: [ 22.  70.  37.  37. 266.] | prediction: [ 22.157516  48.419106  23.770933  33.534695 251.93527 ] | loss: 24.686758041381836
epoch: 41 | step: 32 | gd_count: [106.   2.   4.  14.  40.] | prediction: [78.9483      0.38449222  2.5666134   5.1251764  21.836657  ] | loss: 9.138470649719238
epoch: 41 | step: 33 | gd_count: [ 53. 248.  19.   4.  73.] | prediction: [ 22.928177   163.71826     11.807068     0.56980133  44.708702  ] | loss: 30.397342681884766
epoch: 41 | step: 34 | gd_count: [84. 16.  2. 39. 31.] | prediction: [48.490005   9.781319   0.6269689 15.957876  22.463223 ] | loss: 13.225563049316406
epoch: 41 | step: 35 | gd_count: [102.  13.  26.  46.  28.] | prediction: [97.377235  6.750347 20.92186  47.56566  28.832367] | loss: 13.568473815917969
epoch: 41 | step: 36 | gd_count: [103.  55.  37.  23. 375.] | prediction: [ 87.03429   62.23702   28.480043  13.802697 321.84732 ] | loss: 34.01212692260742
epoch: 41 | step: 37 | gd_count: [50. 36. 97. 47.

epoch: 42 | step: 36 | gd_count: [ 15.  53.  38.  25. 310.] | prediction: [ 11.685359  46.902992  31.277912  17.456741 381.43884 ] | loss: 29.006031036376953
epoch: 42 | step: 37 | gd_count: [ 48.  78.  71.  66. 154.] | prediction: [ 39.965332  68.1162    68.296936  42.6073   132.29005 ] | loss: 21.32602882385254
epoch: 42 | step: 38 | gd_count: [107. 131.  45. 548.  25.] | prediction: [101.76688   98.790726  40.73666  368.947     11.39092 ] | loss: 58.8978157043457
epoch: 42 | step: 39 | gd_count: [ 16.  13.  31.   4. 149.] | prediction: [ 15.344751   12.761702   21.46648     4.4625053 143.52505  ] | loss: 9.271035194396973
epoch: 42 | step: 40 | gd_count: [205.  33. 266.  37.  88.] | prediction: [174.77142   30.166348 168.02051   25.706404 100.0296  ] | loss: 43.9036865234375
epoch: 42 | step: 41 | gd_count: [ 29.  33.  13. 212.   6.] | prediction: [ 27.441639  22.687336  14.472511 224.12       4.973027] | loss: 15.9669189453125
epoch: 42 | step: 42 | gd_count: [ 24.  70. 317.  22.  

epoch: 43 | step: 40 | gd_count: [ 44. 132.  40.  56. 178.] | prediction: [ 33.936085  95.84685   23.883608  42.851418 128.85254 ] | loss: 23.095420837402344
epoch: 43 | step: 41 | gd_count: [ 84. 300.  38.  61.  37.] | prediction: [ 67.55139  221.78801   36.01353   43.250675  30.360771] | loss: 30.60991859436035
epoch: 43 | step: 42 | gd_count: [ 67. 118.  59.  44.   8.] | prediction: [52.79267   75.39339   42.546753  37.089592   3.5345762] | loss: 16.36072540283203
epoch: 43 | step: 43 | gd_count: [58. 61. 40. 34.  3.] | prediction: [55.528526  42.854767  23.663673  28.239204   2.3506813] | loss: 9.965537071228027
epoch: 43 | step: 44 | gd_count: [118.  16.  26. 251. 120.] | prediction: [ 76.5331    12.405196  26.954065 198.56146   95.39329 ] | loss: 28.255918502807617
epoch: 43 | step: 45 | gd_count: [ 2. 34. 76. 12. 85.] | prediction: [ 2.8570185 32.254734  66.03962    9.480377  88.01718  ] | loss: 10.915719985961914
epoch: 43 | step: 46 | gd_count: [28. 88. 55. 54. 12.] | predicti

epoch: 44 | step: 44 | gd_count: [ 63. 508.   7.  43.  46.] | prediction: [ 52.29851   371.0645      4.6852627   8.2727785  44.113106 ] | loss: 51.01762008666992
epoch: 44 | step: 45 | gd_count: [ 36.   0.   1.  36. 747.] | prediction: [ 23.079218    2.1540504   2.5178523  40.550972  704.09985  ] | loss: 62.30221176147461
epoch: 44 | step: 46 | gd_count: [ 69.  38.  60. 110.  51.] | prediction: [80.25795  28.088118 69.46735  95.00923  59.452168] | loss: 19.534963607788086
epoch: 44 | step: 47 | gd_count: [183. 109.  18.  36.  23.] | prediction: [185.20291   66.22389   14.729334  17.771889  24.36396 ] | loss: 21.58099937438965
Epoch 44 Val, MSE: 181.67 MAE: 108.52, Cost 8.5 sec
epoch: 45 | step: 0 | gd_count: [ 53.  33.  56.  47. 161.] | prediction: [ 44.670776  46.705864  38.18737   45.802174 155.37244 ] | loss: 20.471593856811523
epoch: 45 | step: 1 | gd_count: [ 79. 175. 293.  96.  12.] | prediction: [ 65.951584 148.44305  280.41833   65.14988    5.465749] | loss: 36.35939025878906
e

Epoch 45 Val, MSE: 264.78 MAE: 133.39, Cost 8.5 sec
epoch: 46 | step: 0 | gd_count: [32.  6. 68. 32. 32.] | prediction: [19.734245   3.9846282 45.74904   23.044815  27.742914 ] | loss: 8.96134090423584
epoch: 46 | step: 1 | gd_count: [ 12.  29. 121. 133.  88.] | prediction: [ 11.575041  25.188353 152.91068   74.001785  79.09802 ] | loss: 32.83039474487305
epoch: 46 | step: 2 | gd_count: [ 34.  20.  27. 110.  26.] | prediction: [ 26.19965   19.022255  20.306757 126.481476  23.743889] | loss: 10.730093955993652
epoch: 46 | step: 3 | gd_count: [15. 34. 45. 75. 33.] | prediction: [10.856439 26.695833 42.865067 93.18663  33.280018] | loss: 14.823206901550293
epoch: 46 | step: 4 | gd_count: [30. 28. 39. 20. 69.] | prediction: [25.781158 22.13929  49.55521  16.06704  74.68388 ] | loss: 9.478469848632812
epoch: 46 | step: 5 | gd_count: [ 24.  56. 171. 160.  70.] | prediction: [ 17.255625  39.454735 135.71597  136.52231   58.2566  ] | loss: 26.316516876220703
epoch: 46 | step: 6 | gd_count: [24

epoch: 47 | step: 4 | gd_count: [  5. 141. 244.  13. 143.] | prediction: [  3.6434398  97.99263   220.79813     9.878296  109.683525 ] | loss: 34.03936767578125
epoch: 47 | step: 5 | gd_count: [ 35.   9. 223.  52.  13.] | prediction: [ 26.25079    4.527002 230.09622   35.825367   8.294422] | loss: 21.445423126220703
epoch: 47 | step: 6 | gd_count: [ 14.  91.  37.   0. 159.] | prediction: [ 10.598456   62.88378    27.731201    2.4118419 106.6203   ] | loss: 21.46632194519043
epoch: 47 | step: 7 | gd_count: [ 78.  28.  37. 107.   3.] | prediction: [57.87015   12.253308  21.37495   60.135487  -3.0687742] | loss: 17.7388916015625
epoch: 47 | step: 8 | gd_count: [37. 14. 72. 39.  5.] | prediction: [23.085995    7.3158092  51.279255   32.536648    0.79461455] | loss: 9.597407341003418
epoch: 47 | step: 9 | gd_count: [ 51. 100.  32.  33.  16.] | prediction: [41.00747  85.63037  23.587662 26.341484  5.492161] | loss: 13.393445014953613
epoch: 47 | step: 10 | gd_count: [29. 31. 45. 55. 29.] | p

epoch: 48 | step: 9 | gd_count: [ 42. 110. 123.  22. 210.] | prediction: [ 38.855804  89.20523  125.151535  18.13282  150.8963  ] | loss: 29.790775299072266
epoch: 48 | step: 10 | gd_count: [300.   7.  30.  22.  16.] | prediction: [218.45581     7.8913636  30.678436   16.480572   10.76992  ] | loss: 24.60469627380371
epoch: 48 | step: 11 | gd_count: [167.  47.  41. 128.  89.] | prediction: [165.1922    39.638123  40.859947 150.37123   78.98662 ] | loss: 26.00154685974121
epoch: 48 | step: 12 | gd_count: [129. 206.  41.  70.  28.] | prediction: [112.6374   208.96814   36.97307   48.55235   29.167767] | loss: 24.51656150817871
epoch: 48 | step: 13 | gd_count: [158.  77.  19.   0.   8.] | prediction: [128.90483    68.443436   17.169554   -2.0871403   3.1561136] | loss: 14.840819358825684
epoch: 48 | step: 14 | gd_count: [173.  26.  35. 146. 109.] | prediction: [116.43529   21.039232  31.507774 110.90381   99.24677 ] | loss: 28.99724769592285
epoch: 48 | step: 15 | gd_count: [273.  34.  27

epoch: 49 | step: 13 | gd_count: [ 99. 180. 120.  92. 191.] | prediction: [ 91.92086  192.33786   97.424995  91.18108  119.72247 ] | loss: 39.369049072265625
epoch: 49 | step: 14 | gd_count: [27.  0. 12.  2. 72.] | prediction: [22.725733   -0.86657465  9.640396    1.6718063  63.424854  ] | loss: 5.855239391326904
epoch: 49 | step: 15 | gd_count: [121.  86.  24.  19.  17.] | prediction: [101.21237    87.19855    24.34212    13.5581875  16.700148 ] | loss: 12.446191787719727
epoch: 49 | step: 16 | gd_count: [ 73. 130. 140.  29.  44.] | prediction: [ 72.60195  116.94959  126.419586  29.55299   36.302616] | loss: 21.887250900268555
epoch: 49 | step: 17 | gd_count: [ 50.   0.  20.  57. 210.] | prediction: [ 50.90311      0.47870982  21.503788    61.325115   169.13383   ] | loss: 22.606725692749023
epoch: 49 | step: 18 | gd_count: [ 27.  27.   9. 184.   1.] | prediction: [ 26.430435   24.800194    9.728435  161.88425     2.3574572] | loss: 14.671191215515137
epoch: 49 | step: 19 | gd_count: 

## Results on test set

#### Test GT

In [None]:
_, _, test_dataloader = loading_data_GT()

gt_net.load_state_dict(torch.load(os.path.join(save_dir, 'best_model_gt.pth'), device))
gt_net.eval()
errors = []


for vi, data in enumerate(test_dataloader, 0):
    img, gt_map = data

    with torch.no_grad():
        img = Variable(img).to(device)
        assert img.size(0) == 1
        gt_map = Variable(gt_map).to(device)
        pred_density_map = gt_net(img)
        pred_cnt = int(gt_map[0].sum().data / LOG_PARA)
        gt_count = int(pred_density_map[0].sum().data/LOG_PARA)
        error = gt_count - pred_cnt
        print(vi, error, gt_count, pred_cnt)

        errors.append(error)


errors = np.array(errors)
mse = np.sqrt(np.mean(np.square(errors)))
mae = np.mean(np.abs(errors))

log_str = 'Final Test: mae {}, mse {}'.format(mae, mse)
print(log_str)

#### Test Bayes

In [41]:
_, _, test_dataloader = loading_data_Bayes()

bayes_net.load_state_dict(torch.load(os.path.join(save_dir, 'best_model_bayes.pth'), device))
errors = []

for inputs, count, name in test_dataloader:
    inputs = inputs.to(device)
    assert inputs.size(0) == 1
    with torch.set_grad_enabled(False):
        outputs = bayes_net(inputs)
        error = count[0].item() - torch.sum(outputs).item()
        print(name, error, count[0].item(), torch.sum(outputs).item())
        errors.append(error)

errors = np.array(errors)
mse = np.sqrt(np.mean(np.square(errors)))
mae = np.mean(np.abs(errors))
log_str = 'Final Test: mae {}, mse {}'.format(mae, mse)
print(log_str)


('IMG_1',) -202.23348999023438 172 374.2334899902344
('IMG_10',) 88.87063598632812 502 413.1293640136719
('IMG_100',) -24.1907958984375 389 413.1907958984375
('IMG_101',) -192.61929321289062 211 403.6192932128906
('IMG_102',) -32.866729736328125 223 255.86672973632812
('IMG_103',) -151.6229248046875 431 582.6229248046875
('IMG_104',) -328.8505859375 1175 1503.8505859375
('IMG_105',) 34.80546569824219 265 230.1945343017578
('IMG_106',) 158.375244140625 1232 1073.624755859375
('IMG_107',) 106.19613647460938 289 182.80386352539062
('IMG_108',) 1.051666259765625 182 180.94833374023438
('IMG_109',) 7.990203857421875 379 371.0097961425781
('IMG_11',) 63.0145263671875 1068 1004.9854736328125
('IMG_110',) 266.922607421875 1021 754.077392578125
('IMG_111',) -93.29608154296875 452 545.2960815429688
('IMG_112',) -96.02096557617188 256 352.0209655761719
('IMG_113',) -166.2392578125 66 232.2392578125
('IMG_114',) -18.207733154296875 141 159.20773315429688
('IMG_115',) -81.3922119140625 1191 1272.39

('IMG_77',) 19.209259033203125 207 187.79074096679688
('IMG_78',) 9.372802734375 271 261.627197265625
('IMG_79',) 41.407928466796875 186 144.59207153320312
('IMG_8',) 633.2510986328125 1326 692.7489013671875
('IMG_80',) 11.48370361328125 157 145.51629638671875
('IMG_81',) -77.814453125 356 433.814453125
('IMG_82',) 28.390884399414062 217 188.60911560058594
('IMG_83',) -13.132156372070312 218 231.1321563720703
('IMG_84',) -47.686431884765625 122 169.68643188476562
('IMG_85',) 5.472686767578125 97 91.52731323242188
('IMG_86',) 92.7808837890625 579 486.2191162109375
('IMG_87',) 31.757720947265625 417 385.2422790527344
('IMG_88',) 0.8637466430664062 86 85.1362533569336
('IMG_89',) -26.011383056640625 383 409.0113830566406
('IMG_9',) 111.75747680664062 371 259.2425231933594
('IMG_90',) 545.5450439453125 2256 1710.4549560546875
('IMG_91',) -7.98394775390625 101 108.98394775390625
('IMG_92',) 29.25 1366 1336.75
('IMG_93',) 28.058013916015625 255 226.94198608398438
('IMG_94',) 11.5452117919921

In [None]:
## Fonction Victor - à garder pour le github final

def test_bayes(net, test_data, has_loader=False):
    """
    net : the trained network
    has_loader : if false, we are just giving a single input and want to get its results, 
        else we are giving a dataloader
    test_data : just input (np.array) and count if has_loader == False, data_loader if not
    """
    
    net.eval()  # Set model to evaluate mode
    
    if not has_loader:
        img, count = test_data[0], len(test_data[1])
        # img must be a np array
        img = img.to(device)
        img = np.asarray(img)
        #the chanels must be in first position in order to work
        if img.shape[0] != 3:
            img = np.moveaxis(img, (0,1,2), (1,2,0))
        img = torch.Tensor(img).unsqueeze(0)
        with torch.set_grad_enabled(False):
            outputs = net(img)
            res = np.abs(count - torch.sum(outputs).item())
        return outputs, res
    
    else:
        full_res = []

        # Iterate over data.
        for inputs, count, name in test_data:
            print(name)
            inputs = inputs.to(device)
            # inputs are images with different sizes
            assert inputs.size(0) == 1, 'the batch size should equal to 1 in test mode'
            with torch.set_grad_enabled(False):
                outputs = net(inputs)
                res = count[0].item() - torch.sum(outputs).item()
                full_res.append(res)


        res = np.array(full_resres)
        mse = np.sqrt(np.mean(np.square(res)))
        mae = np.mean(np.abs(res))

        print('MSE: {:.2f} MAE: {:.2f}'
                     .format(mse, mae))
        return