## Model

In [1]:
import matplotlib.pyplot as plt
import torch
import torch.nn as nn

In [2]:
from torch.utils import model_zoo
from torch.nn import functional as F

class VGG19(nn.Module):
    def __init__(self, ftrs):
        super(VGG19, self).__init__()
        self.ftrs = ftrs
        self.layer2 = nn.Sequential(
            nn.Conv2d(512, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
        )
        self.layer3 = nn.Sequential(nn.Conv2d(128, 1, 1), nn.ReLU())

    def forward(self, x):
        x = self.ftrs(x)
        x = torch.nn.functional.upsample_bilinear(x, scale_factor=2)
        x = self.layer2(x)
        mu = self.layer3(x)
        B, C, H, W = mu.size()
        mu_sum = mu.view([B, -1]).sum(1).unsqueeze(1).unsqueeze(2).unsqueeze(3)
        mu_norm = mu / (mu_sum + 1000000)
        return mu, mu_norm

def make_model(config):
    lyrs = []
    in_ch = 3
    for v in config:
        if v == 'M':
            lyrs += [nn.MaxPool2d(kernel_size=2, stride=2)]
        else:
            conv2d = nn.Conv2d(in_ch, v, kernel_size=3, padding=1)
            if batch_norm:
                lyrs += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
            else:
                lyrs += [conv2d, nn.ReLU(inplace=True)]
            in_ch = v
    return nn.Sequential(*lyrs)

config = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512]

def vgg19():
    model = VGG19(make_model(config))
    model.load_state_dict(model_zoo.load_url('https://download.pytorch.org/models/vgg19-dcbb9e9d.pth'), strict=False)
    return model

## Dataset

In [3]:
import numpy as np
import scipy.io as sio
from torchvision import transforms
import torch.utils.data as data
import os
from glob import glob
import torchvision.transforms.functional as F
import random
from PIL import Image
import torchvision

In [5]:
def cropper(h, w, ch, cw): 
    out_w = w - cw
    out_h = h - ch
    i = random.randint(0, out_h)
    j = random.randint(0, out_w)
    return i, j, ch, cw


def gen_discrete_map(h, w, points):
    discrete_map = np.zeros([h, w], dtype=np.float32)
    no_gt = points.shape[0]
    if no_gt == 0:
        return discrete_map
    points_np = np.array(points).round().astype(int)
    p_h = np.minimum(points_np[:, 1], np.array([h-1]*no_gt).astype('int64'))
    p_w = np.minimum(points_np[:, 0], np.array([w-1]*no_gt).astype('int64'))
    p_index = torch.from_numpy(p_h* im_width + p_w)

    discrete_map = torch.zeros(w * h).scatter_add_(0, index=p_index, src=torch.ones(w*h)).view(h, w).numpy()
    assert np.sum(discrete_map) == no_gt
    return discrete_map
    
class Custom_dataloader():
    def __init__(self, path, csize, downsample_ratio=8, method='train'):
        self.path = path
        self.csize = csize
        self.d_ratio = downsample_ratio
        assert self.csize % self.d_ratio == 0
        self.dc_size = self.csize // self.d_ratio
        self.trans = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])

        self.method = method

        self.images = sorted(glob(os.path.join(self.path, 'images', '*.jpg')))
        print(f'Total images: {len(self.images)}')

    def __len__(self):
        return len(self.images)

    def __getitem__(self, item):
        img_path = self.images[item]
        name = os.path.basename(img_path).split('.')[0]
        gd_path = os.path.join(self.path, 'ground-truth', 'GT_{}.mat'.format(name))
        img = Image.open(img_path).convert('RGB')
        keypoints = sio.loadmat(gd_path)['image_info'][0][0][0][0][0]

        if self.method == 'train':
            return self.train_transform(img, keypoints)
        elif self.method == 'val':
            img = self.trans(img)
            img.resize_((3, 225, 225))
            return img, len(keypoints), name

    def train_transform(self, img, keypoints):
        wd, ht = img.size
        st_size = 1.0 * min(wd, ht)
        if st_size < self.csize:
            rr = 1.0 * self.csize / st_size
            wd = round(wd * rr)
            ht = round(ht * rr)
            st_size = 1.0 * min(wd, ht)
            img = img.resize((wd, ht), Image.BICUBIC)
            keypoints = keypoints * rr
        assert st_size >= self.csize, print(wd, ht)
        assert len(keypoints) >= 0
        i, j, h, w = cropper(ht, wd, self.csize, self.csize)
        img = torchvision.transforms.functional.crop(img, i, j, h, w)
        if len(keypoints) > 0:
            keypoints = keypoints - [j, i]
            idx_mask = (keypoints[:, 0] >= 0) * (keypoints[:, 0] <= w) * \
                       (keypoints[:, 1] >= 0) * (keypoints[:, 1] <= h)
            keypoints = keypoints[idx_mask]
        else:
            keypoints = np.empty([0, 2])

        gt_discrete = gen_discrete_map(h, w, keypoints)
        down_w = w // self.d_ratio
        down_h = h // self.d_ratio
        gt_discrete = gt_discrete.reshape([down_h, self.d_ratio, down_w, self.d_ratio]).sum(axis=(1, 3))
        assert np.sum(gt_discrete) == len(keypoints)

        if len(keypoints) > 0:
            if random.random() > 0.5:
                img = torchvision.transforms.functional.hflip(img)
                gt_discrete = np.fliplr(gt_discrete)
                keypoints[:, 0] = w - keypoints[:, 0] - 1
        else:
            if random.random() > 0.5:
                img = torchvision.transforms.functional.hflip(img)
                gt_discrete = np.fliplr(gt_discrete)
        gt_discrete = np.expand_dims(gt_discrete, 0)
        

        return self.trans(img), torch.from_numpy(keypoints.copy()).float(), torch.from_numpy(
            gt_discrete.copy()).float()

## Train

In [6]:
from torch.utils.data.dataloader import default_collate
from torch.utils.data import DataLoader
from torch import optim

import time
from datetime import datetime

###  Helper modules

In [7]:
import logging


def get_logger(log_file):
    logger = logging.getLogger(log_file)
    logger.setLevel(logging.DEBUG)
    fh = logging.FileHandler(log_file)
    fh.setLevel(logging.DEBUG)
    ch = logging.StreamHandler()
    ch.setLevel(logging.INFO)
    formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
    ch.setFormatter(formatter)
    fh.setFormatter(formatter)
    logger.addHandler(ch)
    logger.addHandler(fh)
    return logger


def print_config(config, logger):
    for k, v in config.items():
        logger.info("{}:\t{}".format(k.ljust(15), v))

In [8]:
def adjust_learning_rate(optimizer, epoch, initial_lr=0.001, decay_epoch=10):

    lr = max(initial_lr * (0.1 ** (epoch // decay_epoch)), 1e-6)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


class Save_Handle(object):
    def __init__(self, max_num):
        self.save_list = []
        self.max_num = max_num

    def append(self, save_path):
        if len(self.save_list) < self.max_num:
            self.save_list.append(save_path)
        else:
            remove_path = self.save_list[0]
            del self.save_list[0]
            self.save_list.append(save_path)
            if os.path.exists(remove_path):
                os.remove(remove_path)


class AverageMeter(object):

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = 1.0 * self.sum / self.count

    def get_avg(self):
        return self.avg

    def get_count(self):
        return self.count


def set_trainable(model, requires_grad):
    for param in model.parameters():
        param.requires_grad = requires_grad

def get_num_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

## Losses

In [9]:
from torch.nn import Module

M_EPS = 1e-16

def sinkhorn(a, b, C, reg=1e-1, maxIter=1000, stopThr=1e-9,
                   verbose=False, log=False, warm_start=None, eval_freq=10, print_freq=200, **kwargs):

    device = a.device
    na, nb = C.shape

    assert na >= 1 and nb >= 1, 'C needs to be 2d'
    assert na == a.shape[0] and nb == b.shape[0], "Shape of a or b does't match that of C"
    assert reg > 0, 'reg should be greater than 0'
    assert a.min() >= 0. and b.min() >= 0., 'Elements in a or b less than 0'

    if log:
        log = {'err': []}

    if warm_start is not None:
        u = warm_start['u']
        v = warm_start['v']
    else:
        u = torch.ones(na, dtype=a.dtype).to(device) / na
        v = torch.ones(nb, dtype=b.dtype).to(device) / nb

    K = torch.empty(C.shape, dtype=C.dtype).to(device)
    torch.div(C, -reg, out=K)
    torch.exp(K, out=K)

    b_hat = torch.empty(b.shape, dtype=C.dtype).to(device)

    it = 1
    err = 1

    KTu = torch.empty(v.shape, dtype=v.dtype).to(device)
    Kv = torch.empty(u.shape, dtype=u.dtype).to(device)

    while (err > stopThr and it <= maxIter):
        upre, vpre = u, v
        torch.matmul(u, K, out=KTu)
        v = torch.div(b, KTu + M_EPS)
        torch.matmul(K, v, out=Kv)
        u = torch.div(a, Kv + M_EPS)

        if torch.any(torch.isnan(u)) or torch.any(torch.isnan(v)) or \
                torch.any(torch.isinf(u)) or torch.any(torch.isinf(v)):
            print('Warning: numerical errors at iteration', it)
            u, v = upre, vpre
            break

        if log and it % eval_freq == 0:
            b_hat = torch.matmul(u, K) * v
            err = (b - b_hat).pow(2).sum().item()
            log['err'].append(err)

        if verbose and it % print_freq == 0:
            print('iteration {:5d}, constraint error {:5e}'.format(it, err))

        it += 1

    if log:
        log['u'] = u
        log['v'] = v
        log['alpha'] = reg * torch.log(u + M_EPS)
        log['beta'] = reg * torch.log(v + M_EPS)

    # transport plan
    P = u.reshape(-1, 1) * K * v.reshape(1, -1)
    if log:
        return P, log
    else:
        return P

class OT_Loss(Module):
    def __init__(self, csize, stride, norm_cood, device, num_of_iter_in_ot=100, reg=10.0):
        super(OT_Loss, self).__init__()
        assert csize % stride == 0

        self.csize = csize
        self.device = device
        self.norm_cood = norm_cood
        self.num_of_iter_in_ot = num_of_iter_in_ot
        self.reg = reg


        self.cood = torch.arange(0, csize, step=stride,
                                 dtype=torch.float32, device=device) + stride / 2
        self.density_size = self.cood.size(0)
        self.cood.unsqueeze_(0) 
        if self.norm_cood:
            self.cood = self.cood / csize * 2 - 1 
        self.output_size = self.cood.size(1)


    def forward(self, normed_density, unnormed_density, points):
        batch_size = normed_density.size(0)
        assert len(points) == batch_size
        assert self.output_size == normed_density.size(2)
        loss = torch.zeros([1]).to(self.device)
        ot_obj_values = torch.zeros([1]).to(self.device)
        wd = 0 
        for idx, im_points in enumerate(points):
            if len(im_points) > 0:
                if self.norm_cood:
                    im_points = im_points / self.csize * 2 - 1 
                x = im_points[:, 0].unsqueeze_(1)  
                y = im_points[:, 1].unsqueeze_(1)
                x_dis = -2 * torch.matmul(x, self.cood) + x * x + self.cood * self.cood 
                y_dis = -2 * torch.matmul(y, self.cood) + y * y + self.cood * self.cood
                y_dis.unsqueeze_(2)
                x_dis.unsqueeze_(1)
                dis = y_dis + x_dis
                dis = dis.view((dis.size(0), -1)) 
                source_prob = normed_density[idx][0].view([-1]).detach()
                target_prob = (torch.ones([len(im_points)]) / len(im_points)).to(self.device)

                P, log = sinkhorn(target_prob, source_prob, dis, self.reg, maxIter=self.num_of_iter_in_ot, log=True)
                beta = log['beta'] 
                ot_obj_values += torch.sum(normed_density[idx] * beta.view([1, self.output_size, self.output_size]))

                source_density = unnormed_density[idx][0].view([-1]).detach()
                source_count = source_density.sum()
                im_grad_1 = (source_count) / (source_count * source_count+1e-8) * beta 
                im_grad_2 = (source_density * beta).sum() / (source_count * source_count + 1e-8) 
                im_grad = im_grad_1 - im_grad_2
                im_grad = im_grad.detach().view([1, self.output_size, self.output_size])

                loss += torch.sum(unnormed_density[idx] * im_grad)
                wd += torch.sum(dis * P).item()

        return loss, wd, ot_obj_values

In [11]:
def train_collate(batch):
    transposed_batch = list(zip(*batch))
    images = torch.stack(transposed_batch[0], 0)
    points = transposed_batch[1]
    gt_discretes = torch.stack(transposed_batch[2], 0)
    return images, points, gt_discretes


class Trainer():
    def __init__(self):
        self.crop_size = 512
        self.data_dir = '../ShanghaiTech/part_A/'
        self.lr = 1e-5
        self.weight_decay = 1e-4
#         self.resume = './Vgglogs/137_ckpt.tar'
        self.resume = None

        self.max_epoch = 500
        self.val_epoch_i = 5
        self.val_start = 50
        self.sigma = 0.00001
        self.num_workers = 0
        self.batch_size = 10
        self.downsample_ratio = 8
        self.norm_cood = 0
        self.num_of_iter_in_ot = 100
        self.reg = 10
        self.wot = 0.1
        self.wtv = 0.01

    def setup(self):

        self.save_dir = './Vgglogs'
        if not os.path.exists(self.save_dir):
            os.makedirs(self.save_dir)

        time_str = datetime.strftime(datetime.now(), '%m%d-%H%M%S')
        self.logger = get_logger(os.path.join(self.save_dir, 'train-{:s}.log'.format(time_str)))
#         print_config(vars(args), self.logger)

        if torch.cuda.is_available():
            self.device = torch.device("cuda")
        else:
            raise Exception("no gpu")

        self.datasets = {'train': Custom_dataloader(os.path.join(self.data_dir, 'train_data'),
                                           self.crop_size, self.downsample_ratio, 'train'),
                         'val': Custom_dataloader(os.path.join(self.data_dir, 'test_data'),
                                         self.crop_size, self.downsample_ratio, 'val'),
                         }

        self.dataloaders = {x: DataLoader(self.datasets[x],
                                          collate_fn=(train_collate
                                                      if x == 'train' else default_collate),
                                          batch_size=(self.batch_size
                                                      if x == 'train' else 1),
                                          shuffle=(True if x == 'train' else False),
                                          num_workers=self.num_workers ,
                                          pin_memory=(True if x == 'train' else False))
                            for x in ['train', 'val']}
        self.model = vgg19()
        self.model.to(self.device)
        self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay)

        self.start_epoch = 0
        if self.resume:
            self.logger.info('loading pretrained model from ' + self.resume)
            suf = self.resume.rsplit('.', 1)[-1]
#             print('suf: ', suf)
            if suf == 'tar':
#                 print("yo")
                checkpoint = torch.load(self.resume, self.device)
                self.model.load_state_dict(checkpoint['model_state_dict'])
                self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                self.start_epoch = checkpoint['epoch'] + 1
#                 print("yo again")

            elif suf == 'pth':
                self.model.load_state_dict(torch.load(self.resume, self.device))
        else:
            self.logger.info('random initialization')

        self.ot_loss = OT_Loss(self.crop_size, self.downsample_ratio, self.norm_cood, self.device, self.num_of_iter_in_ot,
                               self.reg)
        self.tv_loss = nn.L1Loss(reduction='none').to(self.device)
        self.mse = nn.MSELoss().to(self.device)
        self.mae = nn.L1Loss().to(self.device)
        self.save_list = Save_Handle(max_num=1)
        self.best_mae = np.inf
        self.best_mse = np.inf
        self.best_count = 0

    def train(self):

        for epoch in range(self.start_epoch, self.max_epoch + 1):
            self.logger.info('-' * 5 + 'Epoch {}/{}'.format(epoch, self.max_epoch) + '-' * 5)
            self.epoch = epoch
            self.train_eopch()
            if epoch % self.val_epoch_i == 0 and epoch >= self.val_start:
                self.val_epoch()

    def train_eopch(self):
        epoch_ot_loss = AverageMeter()
        epoch_ot_obj_value = AverageMeter()
        epoch_wd = AverageMeter()
        epoch_count_loss = AverageMeter()
        epoch_tv_loss = AverageMeter()
        epoch_loss = AverageMeter()
        epoch_mae = AverageMeter()
        epoch_mse = AverageMeter()
        epoch_start = time.time()
        self.model.train()  

        for step, (inputs,points , gt_discrete) in enumerate(self.dataloaders['train']):
            inputs = inputs.to(self.device)
            gd_count = np.array([len(p) for p in points], dtype=np.float32)
            points = [p.to(self.device) for p in points]
            gt_discrete = gt_discrete.to(self.device)
            N = inputs.size(0)

            with torch.set_grad_enabled(True):
                outputs, outputs_normed = self.model(inputs)
                
                ot_loss, wd, ot_obj_value = self.ot_loss(outputs_normed, outputs, points)
                ot_loss = ot_loss * self.wot
                ot_obj_value = ot_obj_value * self.wot
                epoch_ot_loss.update(ot_loss.item(), N)
                epoch_ot_obj_value.update(ot_obj_value.item(), N)
                epoch_wd.update(wd, N)

                count_loss = self.mae(outputs.sum(1).sum(1).sum(1),
                                      torch.from_numpy(gd_count).float().to(self.device))
                epoch_count_loss.update(count_loss.item(), N)
                gd_count_tensor = torch.from_numpy(gd_count).float().to(self.device).unsqueeze(1).unsqueeze(
                    2).unsqueeze(3)
                gt_discrete_normed = gt_discrete / (gd_count_tensor + 1e-6)
                tv_loss = (self.tv_loss(outputs_normed, gt_discrete_normed).sum(1).sum(1).sum(
                    1) * torch.from_numpy(gd_count).float().to(self.device)).mean(0) * self.wtv
                epoch_tv_loss.update(tv_loss.item(), N)

                loss = ot_loss + count_loss + tv_loss

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                pred_count = torch.sum(outputs.view(N, -1), dim=1).detach().cpu().numpy()
                pred_err = pred_count - gd_count
                epoch_loss.update(loss.item(), N)
                epoch_mse.update(np.mean(pred_err * pred_err), N)
                epoch_mae.update(np.mean(abs(pred_err)), N)

        self.logger.info(
            'Epoch {} Train, Loss: {:.2f}, OT Loss: {:.2e}, Wass Distance: {:.2f}, OT obj value: {:.2f}, '
            'Count Loss: {:.2f}, TV Loss: {:.2f}, MSE: {:.2f} MAE: {:.2f}, Cost {:.1f} sec'
                .format(self.epoch, epoch_loss.get_avg(), epoch_ot_loss.get_avg(), epoch_wd.get_avg(),
                        epoch_ot_obj_value.get_avg(), epoch_count_loss.get_avg(), epoch_tv_loss.get_avg(),
                        np.sqrt(epoch_mse.get_avg()), epoch_mae.get_avg(),
                        time.time() - epoch_start))
        model_state_dic = self.model.state_dict()
        save_path = os.path.join(self.save_dir, '{}_ckpt.tar'.format(self.epoch))
        torch.save({
            'epoch': self.epoch,
            'optimizer_state_dict': self.optimizer.state_dict(),
            'model_state_dict': model_state_dic
        }, save_path)
        self.save_list.append(save_path)

    def val_epoch(self):
        epoch_start = time.time()
        self.model.eval()  # Set model to evaluate mode
        epoch_res = []
        for inputs, count, name in self.dataloaders['val']:
            inputs = inputs.to(self.device)
            assert inputs.size(0) == 1, 'the batch size should equal to 1 in validation mode'
            with torch.set_grad_enabled(False):
                _, outputs = self.model(inputs)
                res = count[0].item() - torch.sum(outputs).item()
                epoch_res.append(res)

        epoch_res = np.array(epoch_res)
        mse = np.sqrt(np.mean(np.square(epoch_res)))
        mae = np.mean(np.abs(epoch_res))
        self.logger.info('Epoch {} Val, MSE: {:.2f} MAE: {:.2f}, Cost {:.1f} sec'
                         .format(self.epoch, mse, mae, time.time() - epoch_start))

        model_state_dic = self.model.state_dict()
        if (2.0 * mse + mae) < (2.0 * self.best_mse + self.best_mae):
            self.best_mse = mse
            self.best_mae = mae
            self.logger.info("save best mse {:.2f} mae {:.2f} model epoch {}".format(self.best_mse,
                                                                                     self.best_mae,
                                                                                     self.epoch))
            torch.save(model_state_dic, os.path.join(self.save_dir, 'best_model_{}.pth'.format(self.best_count)))
            self.best_count += 1

In [12]:
del trainer

NameError: name 'trainer' is not defined

In [None]:
trainer = Trainer()
trainer.setup()
trainer.train()