In [1]:
import numpy as np
import time
import torch
import torch.nn as nn
import os
import sys
from tqdm import tqdm

from config import Config
from model import CSRNet
from dataset import create_train_dataloader,create_test_dataloader
from utils import denormalize

In [None]:
cfg = Config()                                                          # configuration
continue_training = False

model = CSRNet().to(cfg.device) 

if continue_training:
    model.load_state_dict(torch.load('checkpoints/shaghai_tech_a_best.pth') )                                        # GPU
    # torch.load('checkpoints/shaghai_tech_a_best.pth', map_location=lambda storage, loc: storage)  # CPU
                                                                        # model
criterion = nn.MSELoss(size_average=False)                              # objective
optimizer = torch.optim.Adam(model.parameters(),lr=cfg.lr)              # optimizer
train_dataloader = create_train_dataloader(cfg.dataset_root, use_flip=True, batch_size=cfg.batch_size)
test_dataloader  = create_test_dataloader(cfg.dataset_root)             # dataloader

min_mae = sys.maxsize
min_mae_epoch = -1
for epoch in range(1, cfg.epochs):                                      # start training
    model.train()
    epoch_loss = 0.0
    for i, data in enumerate(tqdm(train_dataloader)):
        image = data['image'].to(cfg.device)
        gt_densitymap = data['densitymap'].to(cfg.device)
        et_densitymap = model(image)                        # forward propagation
        loss = criterion(et_densitymap,gt_densitymap)       # calculate loss
        epoch_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()                                     # back propagation
        optimizer.step()                                    # update network parameters
    cfg.writer.add_scalar('Train_Loss', epoch_loss/len(train_dataloader), epoch)

    model.eval()
    with torch.no_grad():
        epoch_mae = 0.0
        for i, data in enumerate(tqdm(test_dataloader)):
            image = data['image'].to(cfg.device)
            gt_densitymap = data['densitymap'].to(cfg.device)
            et_densitymap = model(image).detach()           # forward propagation
            mae = abs(et_densitymap.data.sum()-gt_densitymap.data.sum())
            epoch_mae += mae.item()
        epoch_mae /= len(test_dataloader)
        if epoch_mae < min_mae:
            min_mae, min_mae_epoch = epoch_mae, epoch
            torch.save(model.state_dict(), os.path.join(cfg.checkpoints,str(epoch)+".pth"))     # save checkpoints
        print('Epoch ', epoch, ' MAE: ', epoch_mae, ' Min MAE: ', min_mae, ' Min Epoch: ', min_mae_epoch)   # print information
        cfg.writer.add_scalar('Val_MAE', epoch_mae, epoch)
        cfg.writer.add_image(str(epoch)+'/Image', denormalize(image[0].cpu()))
        cfg.writer.add_image(str(epoch)+'/Estimate density count:'+ str('%.2f'%(et_densitymap[0].cpu().sum())), et_densitymap[0]/torch.max(et_densitymap[0]))
        cfg.writer.add_image(str(epoch)+'/Ground Truth count:'+ str('%.2f'%(gt_densitymap[0].cpu().sum())), gt_densitymap[0]/torch.max(gt_densitymap[0]))


  0%|          | 0/300 [00:00<?, ?it/s]

Before backend torch.Size([1, 512, 96, 128])
After backend torch.Size([1, 64, 96, 128])
After output layer:  torch.Size([1, 1, 96, 128])
Final out torch.Size([1, 1, 768, 1024])


  0%|          | 1/300 [00:22<1:49:54, 22.05s/it]

Before backend torch.Size([1, 512, 50, 74])
After backend torch.Size([1, 64, 50, 74])
After output layer:  torch.Size([1, 1, 50, 74])
Final out torch.Size([1, 1, 400, 592])


  1%|          | 2/300 [00:28<1:26:21, 17.39s/it]

Before backend torch.Size([1, 512, 80, 128])
After backend torch.Size([1, 64, 80, 128])
After output layer:  torch.Size([1, 1, 80, 128])
Final out torch.Size([1, 1, 640, 1024])


  1%|          | 3/300 [00:52<1:35:17, 19.25s/it]

Before backend torch.Size([1, 512, 54, 128])
After backend torch.Size([1, 64, 54, 128])
After output layer:  torch.Size([1, 1, 54, 128])
Final out torch.Size([1, 1, 432, 1024])


  1%|▏         | 4/300 [01:05<1:26:09, 17.46s/it]

There is a grayscale image.
Before backend torch.Size([1, 512, 92, 128])
After backend torch.Size([1, 64, 92, 128])
After output layer:  torch.Size([1, 1, 92, 128])
Final out torch.Size([1, 1, 736, 1024])


  2%|▏         | 5/300 [01:28<1:33:42, 19.06s/it]

There is a grayscale image.
Before backend torch.Size([1, 512, 84, 128])
After backend torch.Size([1, 64, 84, 128])
After output layer:  torch.Size([1, 1, 84, 128])
Final out torch.Size([1, 1, 672, 1024])


  2%|▏         | 6/300 [01:51<1:39:26, 20.29s/it]

Before backend torch.Size([1, 512, 42, 62])
After backend torch.Size([1, 64, 42, 62])
After output layer:  torch.Size([1, 1, 42, 62])
Final out torch.Size([1, 1, 336, 496])


  2%|▏         | 7/300 [01:56<1:16:17, 15.62s/it]

Before backend torch.Size([1, 512, 80, 100])
After backend torch.Size([1, 64, 80, 100])
After output layer:  torch.Size([1, 1, 80, 100])
Final out torch.Size([1, 1, 640, 800])


  3%|▎         | 8/300 [02:10<1:14:34, 15.32s/it]

Before backend torch.Size([1, 512, 84, 128])
After backend torch.Size([1, 64, 84, 128])
After output layer:  torch.Size([1, 1, 84, 128])
Final out torch.Size([1, 1, 672, 1024])


  3%|▎         | 9/300 [02:31<1:22:22, 16.98s/it]

Before backend torch.Size([1, 512, 50, 80])
After backend torch.Size([1, 64, 50, 80])
After output layer:  torch.Size([1, 1, 50, 80])
Final out torch.Size([1, 1, 400, 640])


  3%|▎         | 10/300 [02:38<1:08:05, 14.09s/it]

Before backend torch.Size([1, 512, 84, 128])
After backend torch.Size([1, 64, 84, 128])
After output layer:  torch.Size([1, 1, 84, 128])
Final out torch.Size([1, 1, 672, 1024])


  4%|▎         | 11/300 [02:58<1:16:06, 15.80s/it]

Before backend torch.Size([1, 512, 48, 62])
After backend torch.Size([1, 64, 48, 62])
After output layer:  torch.Size([1, 1, 48, 62])
Final out torch.Size([1, 1, 384, 496])


  4%|▍         | 12/300 [03:04<1:00:49, 12.67s/it]

Before backend torch.Size([1, 512, 84, 128])
After backend torch.Size([1, 64, 84, 128])
After output layer:  torch.Size([1, 1, 84, 128])
Final out torch.Size([1, 1, 672, 1024])


  4%|▍         | 13/300 [03:24<1:12:18, 15.12s/it]

Before backend torch.Size([1, 512, 82, 124])
After backend torch.Size([1, 64, 82, 124])
After output layer:  torch.Size([1, 1, 82, 124])
Final out torch.Size([1, 1, 656, 992])


  5%|▍         | 14/300 [03:43<1:16:38, 16.08s/it]

Before backend torch.Size([1, 512, 72, 128])
After backend torch.Size([1, 64, 72, 128])
After output layer:  torch.Size([1, 1, 72, 128])
Final out torch.Size([1, 1, 576, 1024])


  5%|▌         | 15/300 [04:00<1:18:25, 16.51s/it]

Before backend torch.Size([1, 512, 42, 126])
After backend torch.Size([1, 64, 42, 126])
After output layer:  torch.Size([1, 1, 42, 126])
Final out torch.Size([1, 1, 336, 1008])


  5%|▌         | 16/300 [04:10<1:08:25, 14.46s/it]

Before backend torch.Size([1, 512, 96, 128])
After backend torch.Size([1, 64, 96, 128])
After output layer:  torch.Size([1, 1, 96, 128])
Final out torch.Size([1, 1, 768, 1024])


  6%|▌         | 17/300 [04:35<1:23:23, 17.68s/it]

Before backend torch.Size([1, 512, 84, 128])
After backend torch.Size([1, 64, 84, 128])
After output layer:  torch.Size([1, 1, 84, 128])
Final out torch.Size([1, 1, 672, 1024])


  6%|▌         | 18/300 [04:59<1:31:34, 19.48s/it]

Before backend torch.Size([1, 512, 48, 72])
After backend torch.Size([1, 64, 48, 72])
After output layer:  torch.Size([1, 1, 48, 72])
Final out torch.Size([1, 1, 384, 576])


  6%|▋         | 19/300 [05:05<1:12:51, 15.56s/it]

Before backend torch.Size([1, 512, 74, 112])
After backend torch.Size([1, 64, 74, 112])
After output layer:  torch.Size([1, 1, 74, 112])
Final out torch.Size([1, 1, 592, 896])


  7%|▋         | 20/300 [05:23<1:15:43, 16.23s/it]

Before backend torch.Size([1, 512, 96, 128])
After backend torch.Size([1, 64, 96, 128])
After output layer:  torch.Size([1, 1, 96, 128])
Final out torch.Size([1, 1, 768, 1024])


  7%|▋         | 21/300 [05:47<1:26:43, 18.65s/it]

Before backend torch.Size([1, 512, 84, 128])
After backend torch.Size([1, 64, 84, 128])
After output layer:  torch.Size([1, 1, 84, 128])
Final out torch.Size([1, 1, 672, 1024])


  7%|▋         | 22/300 [06:09<1:30:08, 19.46s/it]

Before backend torch.Size([1, 512, 74, 128])
After backend torch.Size([1, 64, 74, 128])
After output layer:  torch.Size([1, 1, 74, 128])
Final out torch.Size([1, 1, 592, 1024])


  8%|▊         | 23/300 [06:27<1:27:48, 19.02s/it]

Before backend torch.Size([1, 512, 84, 128])
After backend torch.Size([1, 64, 84, 128])
After output layer:  torch.Size([1, 1, 84, 128])
Final out torch.Size([1, 1, 672, 1024])


  8%|▊         | 24/300 [06:47<1:29:46, 19.52s/it]

Before backend torch.Size([1, 512, 74, 100])
After backend torch.Size([1, 64, 74, 100])
After output layer:  torch.Size([1, 1, 74, 100])
Final out torch.Size([1, 1, 592, 800])


  8%|▊         | 25/300 [07:01<1:21:22, 17.76s/it]

Before backend torch.Size([1, 512, 84, 128])
After backend torch.Size([1, 64, 84, 128])
After output layer:  torch.Size([1, 1, 84, 128])
Final out torch.Size([1, 1, 672, 1024])


  9%|▊         | 26/300 [07:23<1:27:16, 19.11s/it]

Before backend torch.Size([1, 512, 40, 128])
After backend torch.Size([1, 64, 40, 128])
After output layer:  torch.Size([1, 1, 40, 128])
Final out torch.Size([1, 1, 320, 1024])


  9%|▉         | 27/300 [07:33<1:13:41, 16.19s/it]

Before backend torch.Size([1, 512, 84, 128])
After backend torch.Size([1, 64, 84, 128])
After output layer:  torch.Size([1, 1, 84, 128])
Final out torch.Size([1, 1, 672, 1024])


  9%|▉         | 28/300 [07:54<1:20:00, 17.65s/it]

Before backend torch.Size([1, 512, 84, 128])
After backend torch.Size([1, 64, 84, 128])
After output layer:  torch.Size([1, 1, 84, 128])
Final out torch.Size([1, 1, 672, 1024])


 10%|▉         | 29/300 [08:13<1:22:35, 18.28s/it]

Before backend torch.Size([1, 512, 64, 128])
After backend torch.Size([1, 64, 64, 128])
After output layer:  torch.Size([1, 1, 64, 128])
Final out torch.Size([1, 1, 512, 1024])


 10%|█         | 30/300 [08:30<1:19:34, 17.68s/it]

Before backend torch.Size([1, 512, 82, 128])
After backend torch.Size([1, 64, 82, 128])
After output layer:  torch.Size([1, 1, 82, 128])
Final out torch.Size([1, 1, 656, 1024])


 10%|█         | 31/300 [08:51<1:24:33, 18.86s/it]

Before backend torch.Size([1, 512, 80, 128])
After backend torch.Size([1, 64, 80, 128])
After output layer:  torch.Size([1, 1, 80, 128])
Final out torch.Size([1, 1, 640, 1024])


 11%|█         | 32/300 [09:13<1:28:12, 19.75s/it]

Before backend torch.Size([1, 512, 52, 78])
After backend torch.Size([1, 64, 52, 78])
After output layer:  torch.Size([1, 1, 52, 78])
Final out torch.Size([1, 1, 416, 624])


 11%|█         | 33/300 [09:21<1:11:54, 16.16s/it]

Before backend torch.Size([1, 512, 58, 44])
After backend torch.Size([1, 64, 58, 44])
After output layer:  torch.Size([1, 1, 58, 44])
Final out torch.Size([1, 1, 464, 352])


 11%|█▏        | 34/300 [09:25<55:54, 12.61s/it]  

There is a grayscale image.
Before backend torch.Size([1, 512, 80, 128])
After backend torch.Size([1, 64, 80, 128])
After output layer:  torch.Size([1, 1, 80, 128])
Final out torch.Size([1, 1, 640, 1024])


 12%|█▏        | 35/300 [09:47<1:07:53, 15.37s/it]

Before backend torch.Size([1, 512, 84, 128])
After backend torch.Size([1, 64, 84, 128])
After output layer:  torch.Size([1, 1, 84, 128])
Final out torch.Size([1, 1, 672, 1024])


 12%|█▏        | 36/300 [10:09<1:16:40, 17.43s/it]

Before backend torch.Size([1, 512, 84, 128])
After backend torch.Size([1, 64, 84, 128])
After output layer:  torch.Size([1, 1, 84, 128])
Final out torch.Size([1, 1, 672, 1024])


 12%|█▏        | 37/300 [10:31<1:22:15, 18.76s/it]

Before backend torch.Size([1, 512, 64, 128])
After backend torch.Size([1, 64, 64, 128])
After output layer:  torch.Size([1, 1, 64, 128])
Final out torch.Size([1, 1, 512, 1024])


 13%|█▎        | 38/300 [10:48<1:19:41, 18.25s/it]

Before backend torch.Size([1, 512, 32, 50])
After backend torch.Size([1, 64, 32, 50])
After output layer:  torch.Size([1, 1, 32, 50])
Final out torch.Size([1, 1, 256, 400])


 13%|█▎        | 39/300 [10:52<59:53, 13.77s/it]  

Before backend torch.Size([1, 512, 84, 128])
After backend torch.Size([1, 64, 84, 128])
After output layer:  torch.Size([1, 1, 84, 128])
Final out torch.Size([1, 1, 672, 1024])


 13%|█▎        | 40/300 [11:14<1:10:48, 16.34s/it]

Before backend torch.Size([1, 512, 82, 128])
After backend torch.Size([1, 64, 82, 128])
After output layer:  torch.Size([1, 1, 82, 128])
Final out torch.Size([1, 1, 656, 1024])


 14%|█▎        | 41/300 [11:34<1:15:20, 17.45s/it]

Before backend torch.Size([1, 512, 84, 128])
After backend torch.Size([1, 64, 84, 128])
After output layer:  torch.Size([1, 1, 84, 128])
Final out torch.Size([1, 1, 672, 1024])


 14%|█▍        | 42/300 [11:54<1:19:03, 18.38s/it]

Before backend torch.Size([1, 512, 94, 128])
After backend torch.Size([1, 64, 94, 128])
After output layer:  torch.Size([1, 1, 94, 128])
Final out torch.Size([1, 1, 752, 1024])


 14%|█▍        | 43/300 [12:19<1:26:52, 20.28s/it]

Before backend torch.Size([1, 512, 58, 46])
After backend torch.Size([1, 64, 58, 46])
After output layer:  torch.Size([1, 1, 58, 46])
Final out torch.Size([1, 1, 464, 368])


 15%|█▍        | 44/300 [12:25<1:07:22, 15.79s/it]

Before backend torch.Size([1, 512, 96, 128])
After backend torch.Size([1, 64, 96, 128])
After output layer:  torch.Size([1, 1, 96, 128])
Final out torch.Size([1, 1, 768, 1024])


 15%|█▌        | 45/300 [12:52<1:21:31, 19.18s/it]

Before backend torch.Size([1, 512, 102, 128])
After backend torch.Size([1, 64, 102, 128])
After output layer:  torch.Size([1, 1, 102, 128])
Final out torch.Size([1, 1, 816, 1024])


 15%|█▌        | 46/300 [13:19<1:31:16, 21.56s/it]

Before backend torch.Size([1, 512, 56, 36])
After backend torch.Size([1, 64, 56, 36])
After output layer:  torch.Size([1, 1, 56, 36])
Final out torch.Size([1, 1, 448, 288])


 16%|█▌        | 47/300 [13:22<1:08:24, 16.22s/it]

Before backend torch.Size([1, 512, 52, 76])
After backend torch.Size([1, 64, 52, 76])
After output layer:  torch.Size([1, 1, 52, 76])
Final out torch.Size([1, 1, 416, 608])


 16%|█▌        | 48/300 [13:30<57:16, 13.64s/it]  

Before backend torch.Size([1, 512, 84, 128])
After backend torch.Size([1, 64, 84, 128])
After output layer:  torch.Size([1, 1, 84, 128])
Final out torch.Size([1, 1, 672, 1024])


 16%|█▋        | 49/300 [13:52<1:07:54, 16.23s/it]

Before backend torch.Size([1, 512, 76, 126])
After backend torch.Size([1, 64, 76, 126])
After output layer:  torch.Size([1, 1, 76, 126])
Final out torch.Size([1, 1, 608, 1008])


 17%|█▋        | 50/300 [14:12<1:11:29, 17.16s/it]

Before backend torch.Size([1, 512, 86, 128])
After backend torch.Size([1, 64, 86, 128])
After output layer:  torch.Size([1, 1, 86, 128])
Final out torch.Size([1, 1, 688, 1024])


 17%|█▋        | 51/300 [14:35<1:18:52, 19.01s/it]

Before backend torch.Size([1, 512, 96, 128])
After backend torch.Size([1, 64, 96, 128])
After output layer:  torch.Size([1, 1, 96, 128])
Final out torch.Size([1, 1, 768, 1024])


 17%|█▋        | 52/300 [14:57<1:22:24, 19.94s/it]

Before backend torch.Size([1, 512, 46, 62])
After backend torch.Size([1, 64, 46, 62])
After output layer:  torch.Size([1, 1, 46, 62])
Final out torch.Size([1, 1, 368, 496])


 18%|█▊        | 53/300 [15:02<1:03:26, 15.41s/it]

Before backend torch.Size([1, 512, 96, 128])
After backend torch.Size([1, 64, 96, 128])
After output layer:  torch.Size([1, 1, 96, 128])
Final out torch.Size([1, 1, 768, 1024])


 18%|█▊        | 54/300 [15:23<1:09:54, 17.05s/it]

Before backend torch.Size([1, 512, 62, 94])
After backend torch.Size([1, 64, 62, 94])
After output layer:  torch.Size([1, 1, 62, 94])
Final out torch.Size([1, 1, 496, 752])


 18%|█▊        | 55/300 [15:32<1:00:07, 14.73s/it]

Before backend torch.Size([1, 512, 48, 78])
After backend torch.Size([1, 64, 48, 78])
After output layer:  torch.Size([1, 1, 48, 78])
Final out torch.Size([1, 1, 384, 624])


 19%|█▊        | 56/300 [15:38<49:15, 12.11s/it]  

Before backend torch.Size([1, 512, 32, 62])
After backend torch.Size([1, 64, 32, 62])
After output layer:  torch.Size([1, 1, 32, 62])
Final out torch.Size([1, 1, 256, 496])


 19%|█▉        | 57/300 [15:42<38:25,  9.49s/it]

Before backend torch.Size([1, 512, 82, 102])
After backend torch.Size([1, 64, 82, 102])
After output layer:  torch.Size([1, 1, 82, 102])
Final out torch.Size([1, 1, 656, 816])


 19%|█▉        | 58/300 [15:56<44:00, 10.91s/it]

Before backend torch.Size([1, 512, 68, 116])
After backend torch.Size([1, 64, 68, 116])
After output layer:  torch.Size([1, 1, 68, 116])
Final out torch.Size([1, 1, 544, 928])


 20%|█▉        | 59/300 [16:09<46:40, 11.62s/it]

Before backend torch.Size([1, 512, 22, 58])
After backend torch.Size([1, 64, 22, 58])
After output layer:  torch.Size([1, 1, 22, 58])
Final out torch.Size([1, 1, 176, 464])


 20%|██        | 60/300 [16:11<35:00,  8.75s/it]

Before backend torch.Size([1, 512, 98, 128])
After backend torch.Size([1, 64, 98, 128])
After output layer:  torch.Size([1, 1, 98, 128])
Final out torch.Size([1, 1, 784, 1024])


 20%|██        | 61/300 [16:32<49:43, 12.48s/it]

Before backend torch.Size([1, 512, 84, 128])
After backend torch.Size([1, 64, 84, 128])
After output layer:  torch.Size([1, 1, 84, 128])
Final out torch.Size([1, 1, 672, 1024])


 21%|██        | 62/300 [16:52<58:36, 14.77s/it]

Before backend torch.Size([1, 512, 50, 76])
After backend torch.Size([1, 64, 50, 76])
After output layer:  torch.Size([1, 1, 50, 76])
Final out torch.Size([1, 1, 400, 608])
