In [0]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:

import os
import argparse
from time import gmtime, strftime

import torch.nn as nn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
import torchvision
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter

import numpy as np
import cv2
import torch
import torch.nn.functional as F
import torchvision.utils as utils

import logging
from typing import List
from termcolor import colored

os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

#ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
ROOT_DIR = '/content/drive/My Drive/Colab Notebooks'
#ROOT_DIR = '/content/sample_data/LTPA'
vgg_in_channels = 3

# define layers for VGG, 'M' for max pooling
vgg_config = {
    'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
    'VGGAttention': [64, 64, 128, 128, 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M', 512, 'M', 512, 'M'],
}

class ProjectLogger:
    def __init__(self,
                 log_file: str = None,
                 level: int = logging.DEBUG,
                 printing: bool = True, attrs: List[str] = None,
                 name: str = 'project_logger',
                 ):
        """ Basic logger that can write to a file on disk or to sterr.
        :param log_file: name of the file to log to
        :param level: logging verbosity level
        :param printing: flag for whether to log to sterr
        """
        root_logger = logging.getLogger(name)
        root_logger.setLevel(level)
        self.printing = printing
        self.attrs = attrs

        # Set up writing to a file
        if log_file:
            file_handler = logging.FileHandler(log_file, mode='a')
            file_formatter = logging.Formatter(
                '%(levelname)s: %(asctime)s %(message)s',
                datefmt='%m/%d/%Y %image:%M:%S %p'
            )
            file_handler.setFormatter(file_formatter)
            root_logger.addHandler(file_handler)

        # Set up printing to stderr
        def check_if_sterr(hdlr: logging.Handler):
            return isinstance(hdlr, logging.StreamHandler) \
                   and not isinstance(hdlr, logging.FileHandler)

        if printing and not list(filter(check_if_sterr, root_logger.handlers)):
            console_handler = logging.StreamHandler()
            console_handler.setFormatter(logging.Formatter("%(message)s"))
            root_logger.addHandler(console_handler)

        self.log = root_logger

    def debug(self, msg, color='grey', attrs: List[str] = None):
        self.log.debug(colored(msg, color, attrs=attrs or self.attrs))

    def info(self, msg, color='green', attrs: List[str] = None):
        self.log.info(colored(msg, color, attrs=attrs or self.attrs))

    def warning(self, msg, color='blue', attrs: List[str] = None):
        self.log.warning(colored(msg, color, attrs=attrs or self.attrs))

    def error(self, msg, color='magenta', attrs: List[str] = None):
        self.log.error(colored(msg, color, attrs=attrs or self.attrs))

    def critical(self, msg, color='red', attrs: List[str] = None):
        self.log.critical(colored(msg, color, attrs=attrs or self.attrs))


def plot_attention(image, attention_estimator, up_factor, nrow):
    """plot attention maps based on attention estimators"""
    img = image.permute((1, 2, 0)).cpu().numpy()
    N, C, W, H = attention_estimator.size()
    comp_score = torch.softmax(attention_estimator.view(N, C, -1), dim=2).view(
        N, C, W, H)
    comp_score = F.interpolate(comp_score, scale_factor=up_factor,
                               mode='bilinear',
                               align_corners=False)
    attention_img = utils.make_grid(comp_score, nrow=nrow, normalize=True,
                                    scale_each=True)
    attention_img = attention_img.permute((1, 2, 0)).mul(
        255).byte().cpu().numpy()
    attention_img = cv2.applyColorMap(attention_img, cv2.COLORMAP_JET)
    attention_img = cv2.cvtColor(attention_img, cv2.COLOR_BGR2RGB)
    attention_img = np.float32(attention_img) / 255
    vis = 0.6 * img + 0.4 * attention_img
    return torch.from_numpy(vis).permute(2, 0, 1)


class VGGAttention(nn.Module):
    def __init__(self, mode: str = 'pc'):
        """
        :param mode:
        dp for dot product for matching the global and local descriptors
        pc for the use of parametrised compatibility
        """
        super(VGGAttention, self).__init__()
        self.mode = mode

        # features through VGG
        self.features = self._make_layers()

        # right before the 8th, 11th, and 14th layers
        self.l1 = nn.Sequential(*list(self.features)[:22])
        self.l2 = nn.Sequential(*list(self.features)[22:32])
        self.l3 = nn.Sequential(*list(self.features)[32:42])

        # remaining layers before fully-connected
        self.conv_remain = nn.Sequential(*list(self.features)[42:50])

        # 1st fully-connected back to attention estimator layers
        self.fc1 = nn.Linear(512, 512)
        self.ga1 = nn.Linear(512, 256)
        self.ga2 = nn.Linear(512, 512)
        self.ga3 = nn.Linear(512, 512)

        # last fully-connected after weight combinations
        self.fc2 = nn.Linear(256 + 512 + 512, 10)

        if mode == 'pc':
            self.u1 = nn.Conv2d(256, 1, 1)
            self.u2 = nn.Conv2d(512, 1, 1)
            self.u3 = nn.Conv2d(512, 1, 1)

    def forward(self, x):
        l1 = self.l1(x)
        l2 = self.l2(l1)
        l3 = self.l3(l2)
        conv_remain = self.conv_remain(l3)

        fc1 = self.fc1(conv_remain.view(conv_remain.size(0), -1))
        ga1 = self.ga1(fc1)
        ga2 = self.ga2(fc1)
        ga3 = self.ga3(fc1)

        ae1 = self._get_compatibility_score(l1, ga1, level=1)
        ae2 = self._get_compatibility_score(l2, ga2, level=2)
        ae3 = self._get_compatibility_score(l3, ga3, level=3)

        g1 = self._get_weighted_combination(l1, ae1)
        g2 = self._get_weighted_combination(l2, ae2)
        g3 = self._get_weighted_combination(l3, ae3)

        g = torch.cat((g1, g2, g3), dim=1)
        out = self.fc2(g)

        # need the attention estimators for the image plots
        return [out, ae1, ae2, ae3]

    @staticmethod
    def _make_layers():
        """the making of convolutional layers for any VGG architecture"""
        layers = []
        in_channels = vgg_in_channels
        for x in vgg_config['VGGAttention']:
            if x == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1),
                           nn.BatchNorm2d(x),
                           nn.ReLU(inplace=True)]
                in_channels = x
        layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
        return nn.Sequential(*layers)

    def _get_compatibility_score(self, l, g, level):
        """secret sauce from the paper"""
        if self.mode == 'dp':
            ae = l * g.unsqueeze(2).unsqueeze(3)
            ae = ae.sum(1).unsqueeze(1)
            size = ae.size()
            ae = ae.view(ae.size(0), ae.size(1), -1)
            ae = torch.softmax(ae, dim=2)
            ae = ae.view(size)

        elif self.mode == 'pc':
            ae = l + g.unsqueeze(2).unsqueeze(3)
            if level == 1:
                u = self.u1
            elif level == 2:
                u = self.u2
            elif level == 3:
                u = self.u3
            ae = u(ae)
            size = ae.size()
            ae = ae.view(ae.size(0), ae.size(1), -1)
            ae = F.softmax(ae, dim=2)
            ae = ae.view(size)
        return ae

    @staticmethod
    def _get_weighted_combination(l, ae):
        g = l * ae
        return g.view(g.size(0), g.size(1), -1).sum(2)


class AttentionNetwork:
    def __init__(self,
                 opt: argparse.ArgumentParser,
                 model: torch.nn.Module,
                 criterion: torch.nn,
                 optimizer: optim.Optimizer,
                 scheduler: lr_scheduler.LambdaLR,
                 device: torch.device,
                 early_stop: int = 3,
                 loglevel: int = 20,
                 ):
        self.opt = opt
        self.params = vars(self).copy()
        self.model = model
        self.criterion = criterion
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.device = device
        self.writer = SummaryWriter(
            log_dir=f'{ROOT_DIR}/{opt.logs_path}/tensorboard/'
                    f'{strftime("%Y-%m-%d", gmtime())}/'
                    f'{str(self.params)}_{strftime("%Y-%m-%d %H-%M-%S", gmtime())}')
        self.logger = ProjectLogger(level=loglevel)
        self.images = []
        self.image_dim = int(5 * 5)
        self.step = 0
        self.min_up_factor = 2
        self.early_stop = early_stop

    def train_validate(self,
                       train_loader: torch.utils.data.DataLoader,
                       test_loader: torch.utils.data.DataLoader):
        best_test_loss = float('inf')
        for epoch in range(self.opt.epochs):
            epoch_train_loss = self.train(train_loader=train_loader, epoch=epoch)
            epoch_test_loss = self.test(test_loader=test_loader, epoch=epoch)
            """
            if best_test_loss > epoch_train_loss:
                best_test_loss = epoch_train_loss
                patience = 1
            else:
                patience += 1
            
            if patience > self.early_stop:
                break
            """          
        tb_layout = {
            'Training': {
                'Losses': ['Multiline',
                           ['epoch_train_loss', 'epoch_test_loss']],
                'Accuracy': ['Multiline',
                             ['epoch_train_acc', 'epoch_test_acc']],
            }
        }
        self.writer.add_custom_scalars(tb_layout)
        self.writer.close()
        
    def train(self,
              train_loader: torch.utils.data.DataLoader,
              epoch: int):
        self.writer.add_scalar('train_learning_rate',
                               self.optimizer.param_groups[0]['lr'], epoch)
        self.logger.info(f'epoch {epoch} completed')
        self.scheduler.step()
        epoch_loss, epoch_acc = [], []
        for batch_idx, (inputs, targets) in enumerate(train_loader, 0):
            inputs, targets = inputs.to(self.device), targets.to(
                self.device)
            self.model.train()
            self.model.zero_grad()
            self.optimizer.zero_grad()
            if batch_idx == 0:
                self.images.append(inputs[0:self.image_dim, :, :, :])
            pred, _, _, _ = self.model(inputs)
            loss = self.criterion(pred, targets)
            loss.backward()
            self.optimizer.step()
            predict = torch.argmax(pred, 1)
            total = targets.size(0)
            correct = torch.eq(predict, targets).cpu().sum().item()
            acc = correct / total
            if batch_idx % 10 == 0:
                self.logger.info(
                    f"[epoch {epoch}][batch_idx {batch_idx}]"
                    f"loss: {round(loss.item(), 4)} "
                    f"accuracy: {100 * acc}% "
                )
            epoch_loss += [loss.item()]
            epoch_acc += [acc]
            self.step += 1
            self.writer.add_scalar('train_loss', loss.item(), self.step)
            self.writer.add_scalar('train_acc', acc, self.step)

        # log/add on tensorboard
        epoch_train_loss = np.mean(epoch_loss, axis=0)
        epoch_train_acc = np.mean(epoch_acc, axis=0)
        self.writer.add_scalar('epoch_train_loss', epoch_train_loss, epoch)
        self.writer.add_scalar('epoch_train_acc', epoch_train_acc, epoch)
        self.logger.info(f"[epoch {epoch}] train_acc: "
                         f"{100 * epoch_train_acc}%")
        self.logger.info(f"[epoch {epoch}] train_loss: "
                         f"{epoch_train_loss}")
        self.inputs = inputs

        # save model params
        os.makedirs(f'{ROOT_DIR}/{opt.logs_path}/model_states', exist_ok=True)
        torch.save(self.model.state_dict(),
                   f'{ROOT_DIR}/{opt.logs_path}/model_states/net_epoch_{epoch}.pth')
        return epoch_train_loss

    def test(self,
             test_loader: torch.utils.data.DataLoader,
             epoch: int):
        epoch_loss, epoch_acc = [], []
        self.model.eval()
        with torch.no_grad():
            for batch_idx, (inputs, targets) in enumerate(test_loader, 0):
                inputs, targets = inputs.to(
                    self.device), targets.to(self.device)
                if batch_idx == 0:
                    self.images.append(
                        self.inputs[0:self.image_dim, :, :, :])
                pred, _, _, _ = self.model(inputs)
                loss = self.criterion(pred, targets)
                predict = torch.argmax(pred, 1)
                total = targets.size(0)
                correct = torch.eq(predict, targets).cpu().sum().item()
                acc = correct / total
                epoch_loss += [loss.item()]
                epoch_acc += [acc]

            # log/add on tensorboard
            epoch_test_loss = np.mean(epoch_loss, axis=0)
            epoch_test_acc = np.mean(epoch_acc, axis=0)
            self.writer.add_scalar('epoch_test_loss', epoch_test_loss, epoch)
            self.writer.add_scalar('epoch_test_acc', epoch_test_acc, epoch)
            self.logger.info(f"[epoch {epoch}] test_acc: {100 * epoch_test_acc}%")
            self.logger.info(f"[epoch {epoch}] test_loss: {epoch_test_loss}")

            # initial image..
            if epoch == 0:
                self.train_image = utils.make_grid(self.images[0],
                                              nrow=int(np.sqrt(self.image_dim)),
                                              normalize=True, scale_each=True)
                self.test_image = utils.make_grid(self.images[1],
                                                  nrow=int(
                                                      np.sqrt(self.image_dim)),
                                                  normalize=True,
                                                  scale_each=True)
                self.writer.add_image('train_image', self.train_image, epoch)
                self.writer.add_image('test_image', self.test_image, epoch)

            # training image sets
            __, ae1, ae2, ae3 = self.model(self.images[0])
            attn1 = plot_attention(self.train_image, ae1,
                                   up_factor=self.min_up_factor,
                                   nrow=int(np.sqrt(self.image_dim)))
            self.writer.add_image('train_attention_map_1', attn1,
                                  epoch)

            attn2 = plot_attention(self.train_image, ae2,
                                   up_factor=self.min_up_factor * 2,
                                   nrow=int(np.sqrt(self.image_dim)))
            self.writer.add_image('train_attention_map_2', attn2,
                                  epoch)

            attn3 = plot_attention(self.train_image, ae3,
                                   up_factor=self.min_up_factor * 4,
                                   nrow=int(np.sqrt(self.image_dim)))
            self.writer.add_image('train_attention_map_3', attn3,
                                  epoch)

            # validation image sets
            __, ae1, ae2, ae3 = self.model(self.images[1])
            attn1 = plot_attention(self.test_image, ae1,
                                   up_factor=self.min_up_factor,
                                   nrow=int(np.sqrt(self.image_dim)))
            self.writer.add_image('test_attention_map_1', attn1,
                                  epoch)
            attn2 = plot_attention(self.test_image, ae2,
                                   up_factor=self.min_up_factor * 2,
                                   nrow=int(np.sqrt(self.image_dim)))
            self.writer.add_image('test_attention_map_2', attn2,
                                  epoch)
            attn3 = plot_attention(self.test_image, ae3,
                                   up_factor=self.min_up_factor * 4,
                                   nrow=int(np.sqrt(self.image_dim)))
            self.writer.add_image('test_attention_map_3', attn3,
                                  epoch)
            return epoch_test_loss


def get_data_loader(opt, im_size=32) -> torch.utils.data.DataLoader:
    transform_train = transforms.Compose([
        transforms.RandomCrop(im_size, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])
    train_data = torchvision.datasets.CIFAR10(root=f'{ROOT_DIR}/data/CIFAR10',
                                              train=True,
                                              download=True,
                                              transform=transform_train)
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=opt.batch_size,
                                               shuffle=True,
                                               num_workers=os.cpu_count())
    test_data = torchvision.datasets.CIFAR10(root=f'{ROOT_DIR}/data/CIFAR10',
                                             train=False,
                                             download=True,
                                             transform=transform_test)
    test_loader = torch.utils.data.DataLoader(test_data,
                                              batch_size=opt.batch_size,
                                              shuffle=False,
                                              num_workers=os.cpu_count())
    return train_loader, test_loader


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="LTPA")
    parser.add_argument("--attention_mode", type=str, default="pc",
                        help="mode for running the attention model [pc or dp]")
    parser.add_argument("--epochs", type=int, default=300,
                        help="number of epochs")
    parser.add_argument("--batch_size", type=int, default=128,
                        help="batch size")
    parser.add_argument("--lr", type=float, default=0.1,
                        help="initial learning rate")
    parser.add_argument("--logs_path", type=str, default="logs",
                        help='path of log files')
    opt = parser.parse_args('')

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(999)
    else:
        torch.manual_seed(999)
        torch.set_num_threads(os.cpu_count())
        print(f'Using {device}: {torch.get_num_threads()} threads')

    # load data
    train_loader, test_loader = get_data_loader(opt, im_size=32)

    # model + loss function + optimizer + scheduler
    net = VGGAttention(mode=opt.attention_mode)
    criterion = nn.CrossEntropyLoss()
    if torch.cuda.is_available():
        model = nn.DataParallel(net, device_ids=list(
            range(torch.cuda.device_count()))).to(device)
    else:
        model = net.to(device)
    criterion.to(device)
    optimizer = optim.SGD(model.parameters(), lr=opt.lr, momentum=0.9,
                          weight_decay=5e-4)
    scheduler = lr_scheduler.LambdaLR(optimizer,
                                      lr_lambda=lambda epoch: np.power(0.5, int(
                                          epoch / 25)))

    # time to train/validate
    obj = AttentionNetwork(opt=opt,
                           model=model,
                           criterion=criterion,
                           optimizer=optimizer,
                           scheduler=scheduler,
                           device=device)
    obj.train_validate(train_loader=train_loader, test_loader=test_loader)
