## Image Transformer

### Imports

In [13]:
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
import logging
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns

import argparse
import sys
import time
import os
import logging
import yaml
import shutil
import numpy as np
import tensorboardX
import torch.optim as optim
import torchvision
#from image_transformer import ImageTransformer
import matplotlib
import itertools
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchviz import make_dot
from tqdm import tqdm
import torch.nn as nn

### Configuration

In [14]:
def dict2namespace(config):
    namespace = argparse.Namespace()
    for key, value in config.items():
        if isinstance(value, dict):
            new_value = dict2namespace(value)
        else:
            new_value = value
        setattr(namespace, key, new_value)
    return namespace

def parse_args_and_config():
    """
    :return args, config: namespace objects that stores information in args and config files.
    """
    parser = argparse.ArgumentParser(description=globals()['__doc__'])

    parser.add_argument('--config', type=str, default='transformer_tiny.yml', help='Path to the config file')
    parser.add_argument('--doc', type=str, default='0', help='A string for documentation purpose')
    parser.add_argument('--verbose', type=str, default='info', help='Verbose level: info | debug | warning | critical')
    parser.add_argument('--sample', action='store_true', help='Sample at train time')

    args, unknown = parser.parse_known_args() #special modification for Jupyter notebooks. 
    #args = parser.parse_args()

    args.log = os.path.join('transformer_logs', args.doc)
    # parse config file
    with open(os.path.join('configs', args.config), 'r') as f:
        config = yaml.load(f, Loader=yaml.FullLoader)
    new_config = dict2namespace({**config, **vars(args)})

    if os.path.exists(args.log):
        shutil.rmtree(args.log)

    os.makedirs(args.log)

    with open(os.path.join(args.log, 'config.yml'), 'w') as f:
        yaml.dump(new_config, f, default_flow_style=False)

    # setup logger
    level = getattr(logging, args.verbose.upper(), None)
    if not isinstance(level, int):
        raise ValueError('level {} not supported'.format(args.verbose))

    handler1 = logging.StreamHandler()
    handler2 = logging.FileHandler(os.path.join(args.log, 'stdout.txt'))
    formatter = logging.Formatter('%(levelname)s - %(filename)s - %(asctime)s - %(message)s')
    handler1.setFormatter(formatter)
    handler2.setFormatter(formatter)
    logger = logging.getLogger()
    logger.addHandler(handler1)
    logger.addHandler(handler2)
    logger.setLevel(level)

    # add device information to args
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    logging.info("Using device: {}".format(device))
    new_config.device = device

    # set random seed
    torch.manual_seed(new_config.seed)
    torch.cuda.manual_seed_all(new_config.seed)
    np.random.seed(new_config.seed)
    logging.info("Run name: {}".format(args.doc))

    return args, new_config

def get_lr(step, config):
    warmup_steps = config.optim.warmup
    lr_base = config.optim.lr * 0.002 # for Adam correction
    ret = 5000. * config.model.hidden_size ** (-0.5) * \
          np.min([(step + 1) * warmup_steps ** (-1.5), (step + 1) ** (-0.5)])
    return ret * lr_base

In [15]:
args, config = parse_args_and_config()
tb_logger = tensorboardX.SummaryWriter(log_dir=os.path.join('transformer_logs', args.doc))

INFO - <ipython-input-14-fea14b691037> - 2022-07-17 14:59:30,431 - Using device: cpu
INFO - <ipython-input-14-fea14b691037> - 2022-07-17 14:59:30,431 - Using device: cpu
INFO - <ipython-input-14-fea14b691037> - 2022-07-17 14:59:30,433 - Run name: 0
INFO - <ipython-input-14-fea14b691037> - 2022-07-17 14:59:30,433 - Run name: 0


### Define Model

ImageTransformer

In [39]:
NUM_PIXELS = 256

def logsoftmax(x):
    m = torch.max(x, -1, keepdim=True).values
    return x - m - torch.log(torch.exp(x - m).sum(-1, keepdim=True))

def logsumexp(x):
    m = x.max(-1).values
    return m + torch.log(torch.exp(x - m[...,None]).sum(-1))

class ImageTransformer(nn.Module):
    """ImageTransformer with DMOL or categorical distribution."""
    def __init__(self, hparams):
        super().__init__()
        self.hparams = hparams

        self.layers = nn.ModuleList([DecoderLayer(hparams) for _ in range(hparams.nlayers)])

        self.input_dropout = nn.Dropout(p=hparams.dropout)

        self.embeds = nn.Embedding(NUM_PIXELS * self.hparams.channels, self.hparams.hidden_size)

        self.output_dense = nn.Linear(self.hparams.hidden_size, NUM_PIXELS, bias=True)

        
    def add_timing_signal(self, X, min_timescale=1.0, max_timescale=1.0e4):
        num_dims = len(X.shape) - 2 # 2 corresponds to batch and hidden_size dimensions
        num_timescales = self.hparams.hidden_size // (num_dims * 2)
        log_timescale_increment = np.log(max_timescale / min_timescale) / (num_timescales - 1)
        inv_timescales = min_timescale * torch.exp((torch.arange(num_timescales).float() * -log_timescale_increment))
        inv_timescales = inv_timescales.to(X.device)
        total_signal = torch.zeros_like(X) # Only for debugging purposes
        for dim in range(num_dims):
            length = X.shape[dim + 1] # add 1 to exclude batch dim
            position = torch.arange(length).float().to(X.device)
            scaled_time = position.view(-1, 1) * inv_timescales.view(1, -1)
            signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 1)
            prepad = dim * 2 * num_timescales
            postpad = self.hparams.hidden_size - (dim + 1) * 2 * num_timescales
            signal = F.pad(signal, (prepad, postpad))
            for _ in range(1 + dim):
                signal = signal.unsqueeze(0)
            for _ in range(num_dims - 1 - dim):
                signal = signal.unsqueeze(-2)
            X += signal
            total_signal += signal
        return X

    def shift_and_pad_(self, X):
        # Shift inputs over by 1 and pad
        shape = X.shape
        X = X.view(shape[0], shape[1] * shape[2], shape[3])
        X = X[:, :-1, :]
        X = F.pad(X, (0, 0, 1, 0)) # Pad second to last dimension
        X = X.view(shape)
        return X

    def forward(self, X, sampling=False):
        #original x shape  torch.Size([8, 3, 8, 8])
        #permute step:  torch.Size([8, 8, 8, 3])
        #after flattening channels to width torch.Size([8, 8, 24])
        #after channel_addition torch.Size([8, 8, 24])
        #after embedding  torch.Size([8, 8, 24, 8])
        #after shift and padding torch.Size([8, 8, 24, 8])
        #after viewing torch.Size([8, 192, 8])
        #after output dense torch.Size([8, 8, 24, 256])
        #after unpaccking channels torch.Size([8, 3, 8, 8, 256])

        # Reshape inputs
        if sampling:
            curr_infer_length = X.shape[1]
            print("first curr infer length", curr_infer_length)
            row_size = self.hparams.image_size * self.hparams.channels
            nrows = curr_infer_length // row_size + 1
            print("row_size, nrows, X.size() ", row_size, nrows, X.size())
            #row_size = 24, nrows = 1, X.size(), x.size = torch.Size([10, 3])

            X = F.pad(X, (0, nrows * row_size - curr_infer_length))
            X = X.view(X.shape[0], -1, row_size)
            print("X size after padding and viewing", X.size())
            #X size after padding and viewing torch.Size([10, 1, 24])

        else:
            X = X.permute([0, 2, 3, 1]).contiguous()
            X = X.view(X.shape[0], X.shape[1], X.shape[2] * X.shape[3]) # Flatten channels into width

        print("X before being embedded", X)

        #elif self.hparams.distr == "cat":
        # Convert to indexes, and use separate embeddings for different channels
        X = (X * (NUM_PIXELS - 1)).long()
        channel_addition = (torch.tensor([0, 1, 2]) * NUM_PIXELS).to(X.device).repeat(X.shape[2] // 3).view(1, 1, -1)
        X += channel_addition
        X = self.embeds(X) * (self.hparams.hidden_size ** 0.5)
    

        X = self.shift_and_pad_(X)
        X = self.add_timing_signal(X)
        shape = X.shape
        #print("after shift and padding", X.size()) #after shift and padding torch.Size([8, 8, 24, 8])
        X = X.view(shape[0], -1, shape[3])
        #after viewing torch.Size([8, 192, 8])

        print("After sample viewing: ", X.size())

        X = self.input_dropout(X)
        for layer in self.layers:
            X = layer(X)

        X = self.output_dense(X).view(shape[:3] + (-1,))

        if not sampling and self.hparams.distr == "cat": # Unpack the channels
            X = X.view(X.shape[0], X.shape[1], X.shape[2] // self.hparams.channels, self.hparams.channels, X.shape[3])
            X = X.permute([0, 3, 1, 2, 4])
        
        print("Returned X shape", X.size())
        print("End \n \n \n")
        return X

    def loss(self, preds, targets): # Assumes targets have been rescaled to [-1., 1.], categorical dist.
        targets = (targets * (NUM_PIXELS - 1)).long()
        ce = F.cross_entropy(preds.permute(0, 4, 1, 2, 3), targets, reduction='none')
        return ce

    def accuracy(self, preds, targets): #for Categorical Distribution
        targets = (targets * (NUM_PIXELS - 1)).long()
        argmax_preds = torch.argmax(preds, dim=-1)
        acc = torch.eq(argmax_preds, targets).float().sum() / np.prod(argmax_preds.shape)
        return acc

    def sample_from_cat(self, logits, argmax=False):
        if argmax:
            sel = torch.argmax(logits, -1, keepdim=False).float() / 255.
        else:
            gumbel_noise = -torch.log(-torch.log(torch.rand_like(logits) * (1. - 2 * 1e-5) + 1e-5))
            sel = torch.argmax(logits + gumbel_noise, -1, keepdim=False).float() / 255.
        return sel

    def sample(self, n, device, argmax=False):
        total_len = (self.hparams.image_size ** 2)
        total_len *= self.hparams.channels
        samples = torch.zeros((n, 3)).to(device)

        print("Samples size", samples.size())

        for curr_infer_length in tqdm(range(total_len)):
            outputs = self.forward(samples, sampling=True)
            outputs = outputs.view(n, -1, outputs.shape[-1])[:,curr_infer_length:curr_infer_length+1,:]
            
            x = self.sample_from_cat(outputs, argmax=argmax)
            if curr_infer_length == 0:
                samples = x
            else:
                samples = torch.cat([samples, x], 1)

        samples = samples.view(n, self.hparams.image_size, self.hparams.image_size, self.hparams.channels)
        samples = samples.permute(0, 3, 1, 2)
        return samples

    def sample_from_preds(self, preds, argmax=False):
        if self.hparams.distr == "dmol":
            samples = self.sample_from_dmol(preds)
            samples = samples.permute(0, 3, 1, 2)
        elif self.hparams.distr == "cat":
            samples = self.sample_from_cat(preds, argmax=argmax)
        return samples

model = ImageTransformer(config.model).to(config.device)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.sample(1, device= device)

  3%|▎         | 5/192 [00:00<00:05, 34.47it/s]

Samples size torch.Size([1, 3])
first curr infer length 3
row_size, nrows, X.size()  24 1 torch.Size([1, 3])
X size after padding and viewing torch.Size([1, 1, 24])
X before being embedded tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0.]]])
After sample viewing:  torch.Size([1, 24, 8])
Returned X shape torch.Size([1, 1, 24, 256])
End 
 
 

first curr infer length 1
row_size, nrows, X.size()  24 1 torch.Size([1, 1])
X size after padding and viewing torch.Size([1, 1, 24])
X before being embedded tensor([[[0.9961, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]])
After sample viewing:  torch.Size([1, 24, 8])
Returned X shape torch.Size([1, 1, 24, 256])
End 
 
 

first curr infer length 2
row_size, nrows, X.size()  24 1 torch.Size([1, 2])
X size after padding and viewin

  6%|▋         | 12/192 [00:00<00:03, 50.62it/s]

Returned X shape torch.Size([1, 1, 24, 256])
End 
 
 

first curr infer length 9
row_size, nrows, X.size()  24 1 torch.Size([1, 9])
X size after padding and viewing torch.Size([1, 1, 24])
X before being embedded tensor([[[0.9961, 0.7529, 0.2353, 0.3020, 0.3922, 0.1843, 0.2392, 0.1176,
          0.5255, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]])
After sample viewing:  torch.Size([1, 24, 8])
Returned X shape torch.Size([1, 1, 24, 256])
End 
 
 

first curr infer length 10
row_size, nrows, X.size()  24 1 torch.Size([1, 10])
X size after padding and viewing torch.Size([1, 1, 24])
X before being embedded tensor([[[0.9961, 0.7529, 0.2353, 0.3020, 0.3922, 0.1843, 0.2392, 0.1176,
          0.5255, 0.7569, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]])
After sample viewing:  torch.Size([1, 24, 8])
Returned X shape torch.Size([1, 

  9%|▉         | 18/192 [00:00<00:04, 36.21it/s]

Returned X shape torch.Size([1, 1, 24, 256])
End 
 
 

first curr infer length 15
row_size, nrows, X.size()  24 1 torch.Size([1, 15])
X size after padding and viewing torch.Size([1, 1, 24])
X before being embedded tensor([[[0.9961, 0.7529, 0.2353, 0.3020, 0.3922, 0.1843, 0.2392, 0.1176,
          0.5255, 0.7569, 0.6824, 0.2510, 0.0588, 0.7098, 0.1373, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]])
After sample viewing:  torch.Size([1, 24, 8])
Returned X shape torch.Size([1, 1, 24, 256])
End 
 
 

first curr infer length 16
row_size, nrows, X.size()  24 1 torch.Size([1, 16])
X size after padding and viewing torch.Size([1, 1, 24])
X before being embedded tensor([[[0.9961, 0.7529, 0.2353, 0.3020, 0.3922, 0.1843, 0.2392, 0.1176,
          0.5255, 0.7569, 0.6824, 0.2510, 0.0588, 0.7098, 0.1373, 0.7529,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]])
After sample viewing:  torch.Size([1, 24, 8])
Returned X shape torch.Size([1

 15%|█▌        | 29/192 [00:00<00:04, 33.94it/s]

Returned X shape torch.Size([1, 2, 24, 256])
End 
 
 

first curr infer length 25
row_size, nrows, X.size()  24 2 torch.Size([1, 25])
X size after padding and viewing torch.Size([1, 2, 24])
X before being embedded tensor([[[0.9961, 0.7529, 0.2353, 0.3020, 0.3922, 0.1843, 0.2392, 0.1176,
          0.5255, 0.7569, 0.6824, 0.2510, 0.0588, 0.7098, 0.1373, 0.7529,
          0.9647, 0.3216, 0.4196, 0.4118, 0.4353, 0.1529, 0.9608, 0.4941],
         [0.3725, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]])
After sample viewing:  torch.Size([1, 48, 8])
Returned X shape torch.Size([1, 2, 24, 256])
End 
 
 

first curr infer length 26
row_size, nrows, X.size()  24 2 torch.Size([1, 26])
X size after padding and viewing torch.Size([1, 2, 24])
X before being embedded tensor([[[0.9961, 0.7529, 0.2353, 0.3020, 0.3922, 0.1843, 0.2392, 0.1176,
    

 26%|██▌       | 49/192 [00:01<00:02, 59.86it/s]

Returned X shape torch.Size([1, 2, 24, 256])
End 
 
 

first curr infer length 36
row_size, nrows, X.size()  24 2 torch.Size([1, 36])
X size after padding and viewing torch.Size([1, 2, 24])
X before being embedded tensor([[[0.9961, 0.7529, 0.2353, 0.3020, 0.3922, 0.1843, 0.2392, 0.1176,
          0.5255, 0.7569, 0.6824, 0.2510, 0.0588, 0.7098, 0.1373, 0.7529,
          0.9647, 0.3216, 0.4196, 0.4118, 0.4353, 0.1529, 0.9608, 0.4941],
         [0.3725, 0.1137, 0.7333, 0.5725, 0.7333, 0.8196, 0.0353, 0.6314,
          0.6431, 0.8510, 0.6039, 0.8353, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]])
After sample viewing:  torch.Size([1, 48, 8])
Returned X shape torch.Size([1, 2, 24, 256])
End 
 
 

first curr infer length 37
row_size, nrows, X.size()  24 2 torch.Size([1, 37])
X size after padding and viewing torch.Size([1, 2, 24])
X before being embedded tensor([[[0.9961, 0.7529, 0.2353, 0.3020, 0.3922, 0.1843, 0.2392, 0.1176,
    

 34%|███▍      | 66/192 [00:01<00:01, 67.96it/s]

Returned X shape torch.Size([1, 3, 24, 256])
End 
 
 

first curr infer length 53
row_size, nrows, X.size()  24 3 torch.Size([1, 53])
X size after padding and viewing torch.Size([1, 3, 24])
X before being embedded tensor([[[0.9961, 0.7529, 0.2353, 0.3020, 0.3922, 0.1843, 0.2392, 0.1176,
          0.5255, 0.7569, 0.6824, 0.2510, 0.0588, 0.7098, 0.1373, 0.7529,
          0.9647, 0.3216, 0.4196, 0.4118, 0.4353, 0.1529, 0.9608, 0.4941],
         [0.3725, 0.1137, 0.7333, 0.5725, 0.7333, 0.8196, 0.0353, 0.6314,
          0.6431, 0.8510, 0.6039, 0.8353, 0.5843, 0.0275, 0.6118, 0.5765,
          0.9098, 0.7059, 0.5294, 0.5569, 0.7686, 0.3765, 0.8588, 0.3725],
         [0.9882, 0.4000, 0.1255, 0.7647, 0.2431, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]])
After sample viewing:  torch.Size([1, 72, 8])
Returned X shape torch.Size([1, 3, 24, 256])
End 
 
 

first curr inf

 43%|████▎     | 83/192 [00:01<00:01, 69.92it/s]

Returned X shape torch.Size([1, 3, 24, 256])
End 
 
 

first curr infer length 70
row_size, nrows, X.size()  24 3 torch.Size([1, 70])
X size after padding and viewing torch.Size([1, 3, 24])
X before being embedded tensor([[[0.9961, 0.7529, 0.2353, 0.3020, 0.3922, 0.1843, 0.2392, 0.1176,
          0.5255, 0.7569, 0.6824, 0.2510, 0.0588, 0.7098, 0.1373, 0.7529,
          0.9647, 0.3216, 0.4196, 0.4118, 0.4353, 0.1529, 0.9608, 0.4941],
         [0.3725, 0.1137, 0.7333, 0.5725, 0.7333, 0.8196, 0.0353, 0.6314,
          0.6431, 0.8510, 0.6039, 0.8353, 0.5843, 0.0275, 0.6118, 0.5765,
          0.9098, 0.7059, 0.5294, 0.5569, 0.7686, 0.3765, 0.8588, 0.3725],
         [0.9882, 0.4000, 0.1255, 0.7647, 0.2431, 0.9137, 0.2980, 0.5922,
          0.3843, 0.7412, 0.0039, 0.6627, 0.7176, 0.4000, 0.8353, 0.9373,
          0.8588, 0.1725, 0.7255, 0.8196, 0.6549, 0.5412, 0.0000, 0.0000]]])
After sample viewing:  torch.Size([1, 72, 8])
Returned X shape torch.Size([1, 3, 24, 256])
End 
 
 

first curr inf

 47%|████▋     | 91/192 [00:01<00:01, 65.37it/s]

Returned X shape torch.Size([1, 4, 24, 256])
End 
 
 

first curr infer length 85
row_size, nrows, X.size()  24 4 torch.Size([1, 85])
X size after padding and viewing torch.Size([1, 4, 24])
X before being embedded tensor([[[0.9961, 0.7529, 0.2353, 0.3020, 0.3922, 0.1843, 0.2392, 0.1176,
          0.5255, 0.7569, 0.6824, 0.2510, 0.0588, 0.7098, 0.1373, 0.7529,
          0.9647, 0.3216, 0.4196, 0.4118, 0.4353, 0.1529, 0.9608, 0.4941],
         [0.3725, 0.1137, 0.7333, 0.5725, 0.7333, 0.8196, 0.0353, 0.6314,
          0.6431, 0.8510, 0.6039, 0.8353, 0.5843, 0.0275, 0.6118, 0.5765,
          0.9098, 0.7059, 0.5294, 0.5569, 0.7686, 0.3765, 0.8588, 0.3725],
         [0.9882, 0.4000, 0.1255, 0.7647, 0.2431, 0.9137, 0.2980, 0.5922,
          0.3843, 0.7412, 0.0039, 0.6627, 0.7176, 0.4000, 0.8353, 0.9373,
          0.8588, 0.1725, 0.7255, 0.8196, 0.6549, 0.5412, 0.5333, 0.3451],
         [0.9647, 0.4314, 0.2980, 0.7961, 0.5529, 0.8275, 0.6863, 0.3216,
          0.4863, 0.6196, 0.9333, 0.7373, 0

 57%|█████▋    | 109/192 [00:01<00:01, 72.75it/s]

X size after padding and viewing torch.Size([1, 5, 24])
X before being embedded tensor([[[0.9961, 0.7529, 0.2353, 0.3020, 0.3922, 0.1843, 0.2392, 0.1176,
          0.5255, 0.7569, 0.6824, 0.2510, 0.0588, 0.7098, 0.1373, 0.7529,
          0.9647, 0.3216, 0.4196, 0.4118, 0.4353, 0.1529, 0.9608, 0.4941],
         [0.3725, 0.1137, 0.7333, 0.5725, 0.7333, 0.8196, 0.0353, 0.6314,
          0.6431, 0.8510, 0.6039, 0.8353, 0.5843, 0.0275, 0.6118, 0.5765,
          0.9098, 0.7059, 0.5294, 0.5569, 0.7686, 0.3765, 0.8588, 0.3725],
         [0.9882, 0.4000, 0.1255, 0.7647, 0.2431, 0.9137, 0.2980, 0.5922,
          0.3843, 0.7412, 0.0039, 0.6627, 0.7176, 0.4000, 0.8353, 0.9373,
          0.8588, 0.1725, 0.7255, 0.8196, 0.6549, 0.5412, 0.5333, 0.3451],
         [0.9647, 0.4314, 0.2980, 0.7961, 0.5529, 0.8275, 0.6863, 0.3216,
          0.4863, 0.6196, 0.9333, 0.7373, 0.9294, 0.6980, 0.8824, 0.0314,
          0.2431, 0.4471, 0.5529, 0.1765, 0.4039, 0.0745, 0.6314, 0.6000],
         [0.7137, 0.2980, 0.

 61%|██████    | 117/192 [00:02<00:01, 66.57it/s]

Returned X shape torch.Size([1, 5, 24, 256])
End 
 
 

first curr infer length 115
row_size, nrows, X.size()  24 5 torch.Size([1, 115])
X size after padding and viewing torch.Size([1, 5, 24])
X before being embedded tensor([[[0.9961, 0.7529, 0.2353, 0.3020, 0.3922, 0.1843, 0.2392, 0.1176,
          0.5255, 0.7569, 0.6824, 0.2510, 0.0588, 0.7098, 0.1373, 0.7529,
          0.9647, 0.3216, 0.4196, 0.4118, 0.4353, 0.1529, 0.9608, 0.4941],
         [0.3725, 0.1137, 0.7333, 0.5725, 0.7333, 0.8196, 0.0353, 0.6314,
          0.6431, 0.8510, 0.6039, 0.8353, 0.5843, 0.0275, 0.6118, 0.5765,
          0.9098, 0.7059, 0.5294, 0.5569, 0.7686, 0.3765, 0.8588, 0.3725],
         [0.9882, 0.4000, 0.1255, 0.7647, 0.2431, 0.9137, 0.2980, 0.5922,
          0.3843, 0.7412, 0.0039, 0.6627, 0.7176, 0.4000, 0.8353, 0.9373,
          0.8588, 0.1725, 0.7255, 0.8196, 0.6549, 0.5412, 0.5333, 0.3451],
         [0.9647, 0.4314, 0.2980, 0.7961, 0.5529, 0.8275, 0.6863, 0.3216,
          0.4863, 0.6196, 0.9333, 0.7373,

 68%|██████▊   | 130/192 [00:02<00:01, 50.02it/s]

Returned X shape torch.Size([1, 6, 24, 256])
End 
 
 

first curr infer length 121
row_size, nrows, X.size()  24 6 torch.Size([1, 121])
X size after padding and viewing torch.Size([1, 6, 24])
X before being embedded tensor([[[0.9961, 0.7529, 0.2353, 0.3020, 0.3922, 0.1843, 0.2392, 0.1176,
          0.5255, 0.7569, 0.6824, 0.2510, 0.0588, 0.7098, 0.1373, 0.7529,
          0.9647, 0.3216, 0.4196, 0.4118, 0.4353, 0.1529, 0.9608, 0.4941],
         [0.3725, 0.1137, 0.7333, 0.5725, 0.7333, 0.8196, 0.0353, 0.6314,
          0.6431, 0.8510, 0.6039, 0.8353, 0.5843, 0.0275, 0.6118, 0.5765,
          0.9098, 0.7059, 0.5294, 0.5569, 0.7686, 0.3765, 0.8588, 0.3725],
         [0.9882, 0.4000, 0.1255, 0.7647, 0.2431, 0.9137, 0.2980, 0.5922,
          0.3843, 0.7412, 0.0039, 0.6627, 0.7176, 0.4000, 0.8353, 0.9373,
          0.8588, 0.1725, 0.7255, 0.8196, 0.6549, 0.5412, 0.5333, 0.3451],
         [0.9647, 0.4314, 0.2980, 0.7961, 0.5529, 0.8275, 0.6863, 0.3216,
          0.4863, 0.6196, 0.9333, 0.7373,

 71%|███████   | 136/192 [00:02<00:01, 51.94it/s]

Returned X shape torch.Size([1, 6, 24, 256])
End 
 
 

first curr infer length 133
row_size, nrows, X.size()  24 6 torch.Size([1, 133])
X size after padding and viewing torch.Size([1, 6, 24])
X before being embedded tensor([[[0.9961, 0.7529, 0.2353, 0.3020, 0.3922, 0.1843, 0.2392, 0.1176,
          0.5255, 0.7569, 0.6824, 0.2510, 0.0588, 0.7098, 0.1373, 0.7529,
          0.9647, 0.3216, 0.4196, 0.4118, 0.4353, 0.1529, 0.9608, 0.4941],
         [0.3725, 0.1137, 0.7333, 0.5725, 0.7333, 0.8196, 0.0353, 0.6314,
          0.6431, 0.8510, 0.6039, 0.8353, 0.5843, 0.0275, 0.6118, 0.5765,
          0.9098, 0.7059, 0.5294, 0.5569, 0.7686, 0.3765, 0.8588, 0.3725],
         [0.9882, 0.4000, 0.1255, 0.7647, 0.2431, 0.9137, 0.2980, 0.5922,
          0.3843, 0.7412, 0.0039, 0.6627, 0.7176, 0.4000, 0.8353, 0.9373,
          0.8588, 0.1725, 0.7255, 0.8196, 0.6549, 0.5412, 0.5333, 0.3451],
         [0.9647, 0.4314, 0.2980, 0.7961, 0.5529, 0.8275, 0.6863, 0.3216,
          0.4863, 0.6196, 0.9333, 0.7373,

 78%|███████▊  | 149/192 [00:02<00:01, 38.41it/s]

Returned X shape torch.Size([1, 6, 24, 256])
End 
 
 

first curr infer length 142
row_size, nrows, X.size()  24 6 torch.Size([1, 142])
X size after padding and viewing torch.Size([1, 6, 24])
X before being embedded tensor([[[0.9961, 0.7529, 0.2353, 0.3020, 0.3922, 0.1843, 0.2392, 0.1176,
          0.5255, 0.7569, 0.6824, 0.2510, 0.0588, 0.7098, 0.1373, 0.7529,
          0.9647, 0.3216, 0.4196, 0.4118, 0.4353, 0.1529, 0.9608, 0.4941],
         [0.3725, 0.1137, 0.7333, 0.5725, 0.7333, 0.8196, 0.0353, 0.6314,
          0.6431, 0.8510, 0.6039, 0.8353, 0.5843, 0.0275, 0.6118, 0.5765,
          0.9098, 0.7059, 0.5294, 0.5569, 0.7686, 0.3765, 0.8588, 0.3725],
         [0.9882, 0.4000, 0.1255, 0.7647, 0.2431, 0.9137, 0.2980, 0.5922,
          0.3843, 0.7412, 0.0039, 0.6627, 0.7176, 0.4000, 0.8353, 0.9373,
          0.8588, 0.1725, 0.7255, 0.8196, 0.6549, 0.5412, 0.5333, 0.3451],
         [0.9647, 0.4314, 0.2980, 0.7961, 0.5529, 0.8275, 0.6863, 0.3216,
          0.4863, 0.6196, 0.9333, 0.7373,

 86%|████████▋ | 166/192 [00:03<00:00, 50.78it/s]

24 7 torch.Size([1, 158])
X size after padding and viewing torch.Size([1, 7, 24])
X before being embedded tensor([[[0.9961, 0.7529, 0.2353, 0.3020, 0.3922, 0.1843, 0.2392, 0.1176,
          0.5255, 0.7569, 0.6824, 0.2510, 0.0588, 0.7098, 0.1373, 0.7529,
          0.9647, 0.3216, 0.4196, 0.4118, 0.4353, 0.1529, 0.9608, 0.4941],
         [0.3725, 0.1137, 0.7333, 0.5725, 0.7333, 0.8196, 0.0353, 0.6314,
          0.6431, 0.8510, 0.6039, 0.8353, 0.5843, 0.0275, 0.6118, 0.5765,
          0.9098, 0.7059, 0.5294, 0.5569, 0.7686, 0.3765, 0.8588, 0.3725],
         [0.9882, 0.4000, 0.1255, 0.7647, 0.2431, 0.9137, 0.2980, 0.5922,
          0.3843, 0.7412, 0.0039, 0.6627, 0.7176, 0.4000, 0.8353, 0.9373,
          0.8588, 0.1725, 0.7255, 0.8196, 0.6549, 0.5412, 0.5333, 0.3451],
         [0.9647, 0.4314, 0.2980, 0.7961, 0.5529, 0.8275, 0.6863, 0.3216,
          0.4863, 0.6196, 0.9333, 0.7373, 0.9294, 0.6980, 0.8824, 0.0314,
          0.2431, 0.4471, 0.5529, 0.1765, 0.4039, 0.0745, 0.6314, 0.6000],
  

 93%|█████████▎| 179/192 [00:03<00:00, 55.74it/s]

Returned X shape torch.Size([1, 8, 24, 256])
End 
 
 

first curr infer length 170
row_size, nrows, X.size()  24 8 torch.Size([1, 170])
X size after padding and viewing torch.Size([1, 8, 24])
X before being embedded tensor([[[0.9961, 0.7529, 0.2353, 0.3020, 0.3922, 0.1843, 0.2392, 0.1176,
          0.5255, 0.7569, 0.6824, 0.2510, 0.0588, 0.7098, 0.1373, 0.7529,
          0.9647, 0.3216, 0.4196, 0.4118, 0.4353, 0.1529, 0.9608, 0.4941],
         [0.3725, 0.1137, 0.7333, 0.5725, 0.7333, 0.8196, 0.0353, 0.6314,
          0.6431, 0.8510, 0.6039, 0.8353, 0.5843, 0.0275, 0.6118, 0.5765,
          0.9098, 0.7059, 0.5294, 0.5569, 0.7686, 0.3765, 0.8588, 0.3725],
         [0.9882, 0.4000, 0.1255, 0.7647, 0.2431, 0.9137, 0.2980, 0.5922,
          0.3843, 0.7412, 0.0039, 0.6627, 0.7176, 0.4000, 0.8353, 0.9373,
          0.8588, 0.1725, 0.7255, 0.8196, 0.6549, 0.5412, 0.5333, 0.3451],
         [0.9647, 0.4314, 0.2980, 0.7961, 0.5529, 0.8275, 0.6863, 0.3216,
          0.4863, 0.6196, 0.9333, 0.7373,

100%|██████████| 192/192 [00:03<00:00, 52.16it/s]

Returned X shape torch.Size([1, 8, 24, 256])
End 
 
 

first curr infer length 183
row_size, nrows, X.size()  24 8 torch.Size([1, 183])
X size after padding and viewing torch.Size([1, 8, 24])
X before being embedded tensor([[[0.9961, 0.7529, 0.2353, 0.3020, 0.3922, 0.1843, 0.2392, 0.1176,
          0.5255, 0.7569, 0.6824, 0.2510, 0.0588, 0.7098, 0.1373, 0.7529,
          0.9647, 0.3216, 0.4196, 0.4118, 0.4353, 0.1529, 0.9608, 0.4941],
         [0.3725, 0.1137, 0.7333, 0.5725, 0.7333, 0.8196, 0.0353, 0.6314,
          0.6431, 0.8510, 0.6039, 0.8353, 0.5843, 0.0275, 0.6118, 0.5765,
          0.9098, 0.7059, 0.5294, 0.5569, 0.7686, 0.3765, 0.8588, 0.3725],
         [0.9882, 0.4000, 0.1255, 0.7647, 0.2431, 0.9137, 0.2980, 0.5922,
          0.3843, 0.7412, 0.0039, 0.6627, 0.7176, 0.4000, 0.8353, 0.9373,
          0.8588, 0.1725, 0.7255, 0.8196, 0.6549, 0.5412, 0.5333, 0.3451],
         [0.9647, 0.4314, 0.2980, 0.7961, 0.5529, 0.8275, 0.6863, 0.3216,
          0.4863, 0.6196, 0.9333, 0.7373,




tensor([[[[0.9961, 0.3020, 0.2392, 0.7569, 0.0588, 0.7529, 0.4196, 0.1529],
          [0.3725, 0.5725, 0.0353, 0.8510, 0.5843, 0.5765, 0.5294, 0.3765],
          [0.9882, 0.7647, 0.2980, 0.7412, 0.7176, 0.9373, 0.7255, 0.5412],
          [0.9647, 0.7961, 0.6863, 0.6196, 0.9294, 0.0314, 0.5529, 0.0745],
          [0.7137, 0.8549, 0.9333, 0.5176, 0.2667, 0.7882, 0.8941, 0.1882],
          [0.7412, 0.5725, 0.7843, 0.1333, 0.9176, 0.4627, 0.7686, 0.4588],
          [0.4941, 0.5961, 0.3686, 0.3765, 0.7922, 0.7569, 0.7059, 0.8431],
          [0.9373, 0.9059, 0.8588, 0.1294, 0.7059, 0.4078, 0.7569, 0.9804]],

         [[0.7529, 0.3922, 0.1176, 0.6824, 0.7098, 0.9647, 0.4118, 0.9608],
          [0.1137, 0.7333, 0.6314, 0.6039, 0.0275, 0.9098, 0.5569, 0.8588],
          [0.4000, 0.2431, 0.5922, 0.0039, 0.4000, 0.8588, 0.8196, 0.5333],
          [0.4314, 0.5529, 0.3216, 0.9333, 0.6980, 0.2431, 0.1765, 0.6314],
          [0.2980, 0.2314, 0.5882, 0.4667, 0.4784, 0.4314, 0.0157, 0.3137],
          

### Dataset

In [19]:
transform = transforms.Compose([
    transforms.Resize(config.model.image_size),
    transforms.ToTensor()
])

dataset = datasets.CIFAR10('datasets/transformer', transform=transform, download=True)
loader = DataLoader(dataset, batch_size=config.train.batch_size, shuffle=True, num_workers=4)
input_dim = config.model.image_size ** 2 * config.model.channels
model = ImageTransformer(config.model).to(config.device)
optimizer = optim.Adam(model.parameters(), lr=1., betas=(0.9, 0.98), eps=1e-9)
scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda step: get_lr(step, config))

Files already downloaded and verified


### Training Loop

In [20]:
gain = config.model.initializer_gain

for name, p in model.named_parameters():
    if "layernorm" in name:
        continue
    if p.dim() > 1:
        nn.init.xavier_uniform_(p, gain=np.sqrt(gain)) # Need sqrt for inconsistency between pytorch / TF
    else:
        a =  np.sqrt(3. * gain / p.shape[0])
        nn.init.uniform_(p, -a, a)

def revert_samples(input):
    if config.model.distr == "cat":
        return input
    elif config.model.distr == "dmol":
        return input * 0.5 + 0.5

step = 0
losses_per_dim = torch.zeros(config.model.channels, config.model.image_size, config.model.image_size).to(config.device)

for _ in range(config.train.epochs):
        for _, (imgs, l) in enumerate(loader):
            imgs = imgs.to(config.device)
            model.train()

            scheduler.step()
            optimizer.zero_grad()
            preds = model(imgs)
            loss = model.loss(preds, imgs)
            decay = 0. if step == 0 else 0.99
            if config.model.distr == "dmol":
                losses_per_dim[0,:,:] = losses_per_dim[0,:,:] * decay + (1 - decay) * loss.detach().mean(0) / np.log(2)
            else:
                losses_per_dim = losses_per_dim * decay + (1 - decay) * loss.detach().mean(0) / np.log(2)
            loss = loss.view(loss.shape[0], -1).sum(1)
            loss = loss.mean(0)

            # Show computational graph
            # dot = make_dot(loss, dict(model.named_parameters()))
            # dot.render('test.gv', view=True)

            loss.backward()

            total_norm = 0
            for p in model.parameters():
                param_norm = p.grad.data.norm(2)
                total_norm += param_norm.item() ** 2
            total_norm = (total_norm ** (1. / 2))

            if config.train.clip_grad_norm > 0.0:
                nn.utils.clip_grad_norm_(model.parameters(), config.train.clip_grad_norm)

            total_norm_post = 0
            for p in model.parameters():
                param_norm = p.grad.data.norm(2)
                total_norm_post += param_norm.item() ** 2
            total_norm_post = (total_norm_post ** (1. / 2))

            optimizer.step()
            bits_per_dim = loss / (np.log(2.) * input_dim)
            acc = model.accuracy(preds, imgs)

            if step % config.train.log_iter == 0:
                logging.info('step: {}; loss: {:.3f}; bits_per_dim: {:.3f}, acc: {:.3f}, grad norm pre: {:.3f}, post: {:.3f}'
                             .format(step, loss.item(), bits_per_dim.item(), acc.item(), total_norm, total_norm_post))
                tb_logger.add_scalar('loss', loss.item(), global_step=step)
                tb_logger.add_scalar('bits_per_dim', bits_per_dim.item(), global_step=step)
                tb_logger.add_scalar('acc', acc.item(), global_step=step)
                tb_logger.add_scalar('grad_norm', total_norm, global_step=step)

            if step % config.train.sample_iter == 0:
                logging.info("Sampling from model: {}".format(args.doc))
                if config.model.distr == "cat":
                    channels = ['r','g','b']
                    color_codes = ['Reds', "Greens", 'Blues']
                    for idx, c in enumerate(channels):
                        ax = sns.heatmap(losses_per_dim[idx,:,:].cpu().numpy(), linewidth=0.5, cmap=color_codes[idx])
                        tb_logger.add_figure("losses_per_dim/{}".format(c), ax.get_figure(), close=True, global_step=step)
                else:
                    ax = sns.heatmap(losses_per_dim[0,:,:].cpu().numpy(), linewidth=0.5, cmap='Blues')
                    tb_logger.add_figure("losses_per_dim", ax.get_figure(), close=True, global_step=step)

                model.eval()
                with torch.no_grad():
                    imgs = revert_samples(imgs)
                    imgs_grid = torchvision.utils.make_grid(imgs[:8, ...], 3)
                    tb_logger.add_image('imgs', imgs_grid, global_step=step)

                    # Evaluate model predictions for the input
                    pred_samples = revert_samples(model.sample_from_preds(preds))
                    pred_samples_grid = torchvision.utils.make_grid(pred_samples[:8, ...], 3)
                    tb_logger.add_image('pred_samples/random', pred_samples_grid, global_step=step)
                    pred_samples = revert_samples(model.sample_from_preds(preds, argmax=True))
                    pred_samples_grid = torchvision.utils.make_grid(pred_samples[:8, ...], 3)
                    tb_logger.add_image('pred_samples/argmax', pred_samples_grid, global_step=step)

                    if args.sample:
                        samples = revert_samples(model.sample(config.train.sample_size, config.device))
                        samples_grid = torchvision.utils.make_grid(samples[:8, ...], 3)
                        tb_logger.add_image('samples', samples_grid, global_step=step)

                    # Argmax samples are not useful for unconditional generation
                    # if config.model.distr == "cat":
                    #     argmax_samples = model.sample(1, config.device, argmax=True)
                    #     samples_grid = torchvision.utils.make_grid(argmax_samples[:8, ...], 3)
                    #     tb_logger.add_image('argmax_samples', samples_grid, global_step=step)
                torch.save(model.state_dict(), os.path.join('transformer_logs', args.doc, "model.pth"))
            step += 1

original x shape  torch.Size([8, 3, 8, 8])
permute step:  torch.Size([8, 8, 8, 3])
after flattening channels to width torch.Size([8, 8, 24])
after channel_addition torch.Size([8, 8, 24])
after embedding  torch.Size([8, 8, 24, 8])
after shift and padding torch.Size([8, 8, 24, 8])
after viewing torch.Size([8, 192, 8])
after output dense torch.Size([8, 8, 24, 256])
after unpaccking channels torch.Size([8, 3, 8, 8, 256])




NotImplementedError: 