In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import torch
import random
device = 'cuda' if torch.cuda.is_available() else 'cpu'
import os,sys
opj = os.path.join
from tqdm import tqdm
import acd
from copy import deepcopy
import torchvision.utils as vutils
import models
from visualize import *
from data import *
sys.path.append('../trim')
from transforms_torch import transform_bandpass, tensor_t_augment, batch_fftshift2d, batch_ifftshift2d
from trim import *
from util import *
from attributions import *
from captum.attr import *
from functools import partial
import warnings
warnings.filterwarnings("ignore")
data_path = './cosmo'
# invertible nn
sys.path.append('../../invertible-resnet')
sys.path.append('../../invertible-resnet/models')
from conv_iResNet import conv_iResNet as iResNet

# load dataset and model

In [2]:
# params
img_size = 256
class_num = 1

# cosmo dataset
transformer = transforms.Compose([ToTensor()])
mnu_dataset = MassMapsDataset(opj(data_path, 'cosmological_parameters.txt'),  
                              opj(data_path, 'z1_256'),
                              transform=transformer)

# dataloader
data_loader = torch.utils.data.DataLoader(mnu_dataset, batch_size=32, shuffle=False, num_workers=4)

# load model
model = models.load_model(model_name='resnet18', device=device, inplace=False, data_path=data_path).to(device)
model = model.eval()
# freeze layers
for param in model.parameters():
    param.requires_grad = False

In [3]:
# test im
batches = []
seen = 0
for data in data_loader:
    inputs, params = data['image'], data['params']
    batches.append(inputs)
    seen += inputs.size(0)
    if seen >= 10:
        break
init_batch = torch.cat(batches)

# output
with torch.no_grad():
    model = model.to(device)
    inputs = inputs.to(device)
    outputs = model(inputs)[:,1]

## iResNet

In [74]:
t = InvertibleResnetConv(1,32, list_num_blocks=(2,2,2)).to(device)

NameError: name 'InvertibleResnetConv' is not defined

In [72]:
t

NameError: name 't' is not defined

In [4]:
class InvertibleResnetConv(nn.Module):
    def __init__(self, dim, hidden_dim = 32, magnitude=0.7, reverse_iterations=10, bias=False, n_power_iterations=10):
        super(InvertibleResnetConv, self).__init__()
        
        self.dim = dim
        self.reverse_iterations = reverse_iterations        
        self.nets = nn.ModuleList()        
        
        l = nn.ModuleList()
        net = nn.Sequential(spectral_norm(nn.Conv2d(dim, hidden_dim, kernel_size=7, stride=2, padding=3, bias=bias),
                                          n_power_iterations=n_power_iterations, magnitude=magnitude),
                            nn.ReLU(),
                            spectral_norm(nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1, padding=1, bias=bias), 
                                          n_power_iterations=n_power_iterations, magnitude=magnitude),
                            nn.ReLU(),
                           )           
        b = block(net, reverse_iterations=reverse_iterations)       
        
        
        for num_blocks in list_num_blocks:    
            l = nn.ModuleList()
            for i in range(0,num_blocks):
                net = nn.Sequential(nn.ReLU(),
                                    spectral_norm(nn.Conv2d(dim, hidden_dim, kernel_size=3, bias=bias), n_power_iterations=n_power_iterations, magnitude=magnitude),
                                    nn.ReLU(),
                                    spectral_norm(nn.Conv2d(hidden_dim, hidden_dim, kernel_size=1, bias=bias), n_power_iterations=n_power_iterations, magnitude=magnitude),
                                    nn.ReLU(),
                                    spectral_norm(nn.Conv2d(hidden_dim, dim, kernel_size=3, bias=bias), n_power_iterations=n_power_iterations, magnitude=magnitude),
                                   )        
                b = block(net, reverse_iterations=reverse_iterations)            
                l.append(ActNorm(dim))
                l.append(b) 
            dim *= 2            
            self.nets.append(nn.Sequential(*l))
        
    def forward(self, x_list, reverse=False, reverse_iterations=None):
        if reverse:
            for i, net in enumerate(self.nets[::-1]):
                if i == 0:
                    x = x_list[len(self.nets)-1-i]
                else:
                    x = torch.cat([x, x_list[len(self.nets)-1-i]], dim=1)                
                for module in net[::-1]:
                    x = apply_module_reverse(module, x, reverse_iterations)
            return x
        else:
            y_list = []
            x = x_list
            
            for i, net in enumerate(self.nets):                            
                x = net(x)  
                if i < len(self.nets)-1: 
                    y_list.append(x[:,x.shape[1]//2:])
                    x = x[:,:x.shape[1]//2] 
            y_list.append(x)
            return y_list

In [23]:
def apply_module_reverse(module, x, reverse_iterations=None):
    if 'block' in str(module.__class__):
        return module(x, reverse=True, reverse_iterations=reverse_iterations)
    else:
        return module(x, reverse=True)

In [None]:
dim = 1
hidden_dim = 32
bias = False
n_power_iterations = 10
magnitude = 0.7
net = nn.Sequential(nn.ReLU(),
                    spectral_norm(nn.Conv2d(dim, hidden_dim, 3, padding=1, bias=bias), n_power_iterations=n_power_iterations, magnitude=magnitude),
                    nn.ReLU(),
                    spectral_norm(nn.Conv2d(hidden_dim, hidden_dim, 1, bias=bias), n_power_iterations=n_power_iterations, magnitude=magnitude),
                    nn.ReLU(),
                    spectral_norm(nn.Conv2d(hidden_dim, dim, 3, padding=1, bias=bias), n_power_iterations=n_power_iterations, magnitude=magnitude),
                   )   

In [18]:
a = SqueezeLayer()
a(init_batch).shape

torch.Size([32, 4, 128, 128])

In [20]:
init_batch.shape

torch.Size([32, 1, 256, 256])

In [6]:
class block(nn.Module):
    def __init__(self, net, reverse_iterations=40):
        super(block, self).__init__()
        self.reverse_iterations = reverse_iterations
        self.net = net # residual neural network 
        self.normalize(self.net) # normalize weight
    def normalize(self, net):
        for n in net.modules():
            for k,hook in n._forward_pre_hooks.items():
                if isinstance(hook,SpectralNorm):
                    hook(n, None)
    def calcG(self, x):
        return self.net(x)
    def forward(self, x, reverse=False, reverse_iterations=None):
        if reverse:
            y = x
            for count in range(reverse_iterations if reverse_iterations else self.reverse_iterations):
                x = y - self.calcG(x)
            return x
        else:            
            y = self.calcG(x) + x
            return y

In [7]:
class ActNorm(nn.Module):
    def __init__(self, in_channel):
        super().__init__()
    
        self.loc = nn.Parameter(torch.zeros(1, in_channel, 1, 1))
        self.scale = nn.Parameter(torch.ones(1, in_channel, 1, 1))
        self.initialized = False

    def initialize(self, input):
        with torch.no_grad():
            if len(input.shape) == 2: # linear
                flatten = input.permute(1, 0).contiguous().view(input.shape[1], -1)
                mean = (
                    flatten.mean(1)
                    .unsqueeze(1)
                    .permute(1, 0,)
                )
                std = (
                    flatten.std(1)
                    .unsqueeze(1)
                    .permute(1, 0)
                )
                self.loc.data.copy_(-mean.view_as(self.loc))
                self.scale.data.copy_(1 / (std.view_as(self.scale) + 1e-6))
            elif len(input.shape) == 4: # conv
                flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
                mean = (
                    flatten.mean(1)
                    .unsqueeze(1)
                    .unsqueeze(2)
                    .unsqueeze(3)
                    .permute(1, 0, 2, 3)
                )
                std = (
                    flatten.std(1)
                    .unsqueeze(1)
                    .unsqueeze(2)
                    .unsqueeze(3)
                    .permute(1, 0, 2, 3)
                )

                self.loc.data.copy_(-mean)
                self.scale.data.copy_(1 / (std + 1e-6))
            else:
                raise 'Input shape not supported {}'.format(input.shape)
    
    def forward(self, input, reverse=False):
        
        scale = self.scale if len(input.shape) == 4 else self.scale.view(1, -1)
        loc = self.loc if len(input.shape) == 4 else self.loc.view(1, -1)
        
        if reverse:
            return input / scale - loc

        if not self.initialized:
            self.initialize(input)
            self.initialized = True
        
        return scale * (input + loc)

In [8]:
def spectral_norm(module, name='weight', n_power_iterations=1, magnitude=1.0, eps=1e-12):
    r"""Applies spectral normalization to a parameter in the given module.
    .. math::
         \mathbf{W} = \dfrac{\mathbf{W}}{\sigma(\mathbf{W})} \\
         \sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2}
    Spectral normalization stabilizes the training of discriminators (critics)
    in Generaive Adversarial Networks (GANs) by rescaling the weight tensor
    with spectral norm :math:`\sigma` of the weight matrix calculated using
    power iteration method. If the dimension of the weight tensor is greater
    than 2, it is reshaped to 2D in power iteration method to get spectral
    norm. This is implemented via a hook that calculates spectral norm and
    rescales weight before every :meth:`~Module.forward` call.
    See `Spectral Normalization for Generative Adversarial Networks`_ .
    .. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957
    Args:
        module (nn.Module): containing module
        name (str, optional): name of weight parameter
        n_power_iterations (int, optional): number of power iterations to
            calculate spectal norm
        eps (float, optional): epsilon for numerical stability in
            calculating norms
        dim (int, optional): dimension corresponding to number of outputs,
            the default is 0, except for modules that are instances of
            ConvTranspose1/2/3d, when it is 1
    Returns:
        The original module with the spectal norm hook
    Example::
        >>> m = spectral_norm(nn.Linear(20, 40))
        Linear (20 -> 40)
        >>> m.weight_u.size()
        torch.Size([20])
    """    
    SpectralNorm.apply(module, name, n_power_iterations, magnitude, eps)
    return module

In [None]:
mod = spectral_norm(nn.Linear(20, 40))
mod = nn.Linear(20, 40)
# mod = nn.Conv2d(10,10,5)
weight = mod._parameters['weight']
mod.forward_function = lambda inp,weight=weight: F.linear(inp, weight)
mod.iteration_function = lambda inp,weight=weight: F.linear(F.linear(inp, weight), weight.transpose(1,0))

In [None]:
with torch.no_grad():
    shape = (1,weight.shape[1])
    u = torch.randn(shape).to(weight.device)  
    
with torch.no_grad():
    for _ in range(5):
        u = mod.iteration_function(u, weight=weight)
    u = u.clone()
    sv = torch.sqrt((mod.forward_function(u, weight=weight)**2).sum()) / torch.sqrt((u**2).sum())
    sigma = F.relu(sv / 1.0 - 1.0) + 1.0

In [None]:
sigma

In [None]:
np.linalg.svd(weight.data.numpy())[1][0]


In [9]:
class SpectralNorm(object):
    # Invariant before and after each forward call:
    #   u = normalize(W @ v)
    # NB: At initialization, this invariant is not enforced

    _version = 2
    # At version 2:
    #   used Gouk 2018 method.
    #   will only normalize if largest singular value > magnitude    

    def __init__(self, name='weight', n_power_iterations=1, magnitude=1.0, eps=1e-12):
        self.name = name
        self.magnitude = magnitude
        if n_power_iterations <= 0:
            raise ValueError('Expected n_power_iterations to be positive, but '
                             'got n_power_iterations={}'.format(n_power_iterations))
        self.n_power_iterations = n_power_iterations
        self.eps = eps
    def l2norm(self, t):
        return torch.sqrt((t ** 2).sum())
    def compute_weight(self, module, do_power_iteration, num_iter=0):      
        weight = getattr(module, self.name + '_orig')
        u = getattr(module, self.name + '_u')
        
        if do_power_iteration:
            with torch.no_grad():
                for _ in range(max(self.n_power_iterations, num_iter)):
                    u = module.iteration_function(u, weight=weight)
                if self.n_power_iterations > 0:
                    # See above on why we need to clone
                    u = u.clone()
                sv = self.l2norm(module.forward_function(u, weight=weight)) / self.l2norm(u)      
                sigma = F.relu(sv / self.magnitude - 1.0) + 1.0
                module.sigma = sigma
        else:
            sigma = module.sigma
        
        return weight / sigma

    def remove(self, module):
        with torch.no_grad():
            weight = self.compute_weight(module, do_power_iteration=False)
        delattr(module, self.name)
        delattr(module, self.name + '_u')
        delattr(module, self.name + '_sigma')
        delattr(module, self.name + '_orig')
        module.register_parameter(self.name, torch.nn.Parameter(weight.detach()))

    def __call__(self, module, inputs, n_power_iterations=0):
        setattr(module, self.name, self.compute_weight(module, do_power_iteration=module.training, num_iter=n_power_iterations))

    @staticmethod
    def apply(module, name, n_power_iterations, magnitude, eps):
        for k, hook in module._forward_pre_hooks.items():
            if isinstance(hook, SpectralNorm) and hook.name == name:
                raise RuntimeError("Cannot register two spectral_norm hooks on "
                                   "the same parameter {}".format(name))

        fn = SpectralNorm(name, n_power_iterations, magnitude, eps)
        weight = module._parameters[name]
        
        functions_dict = {torch.nn.Conv1d : (F.conv1d, F.conv_transpose1d),
             torch.nn.Conv2d : (F.conv2d, F.conv_transpose2d),
             torch.nn.Conv3d : (F.conv3d, F.conv_transpose3d),
             torch.nn.ConvTranspose1d : (F.conv_transpose1d, F.conv1d),
             torch.nn.ConvTranspose2d : (F.conv_transpose2d, F.conv2d),
             torch.nn.ConvTranspose3d : (F.conv_transpose3d, F.conv3d),            
            }
        
        if isinstance(module, torch.nn.Linear):  
            module.forward_function = lambda inp,weight=weight: F.linear(inp, weight)
            module.iteration_function = lambda inp,weight=weight: F.linear(F.linear(inp, weight), weight.transpose(1,0))
        elif isinstance(module, (torch.nn.ConvTranspose1d,
                               torch.nn.ConvTranspose2d,
                               torch.nn.ConvTranspose3d,
                               torch.nn.Conv1d,
                               torch.nn.Conv2d,
                               torch.nn.Conv3d,)):
            k = weight.shape[2:]
            s = module.stride
            g = module.groups
            d = module.dilation
            p = module.padding
            functions = functions_dict[module.__class__ ]
            module.forward_function = lambda inp,weight=weight,s=s,g=g,d=d,p=p: functions[0](inp, weight, stride=s, padding=p, dilation=d, groups=g)            
            module.iteration_function = lambda inp,weight=weight,s=s,g=g,d=d,p=p: functions[1](functions[0](inp, weight, stride=s, padding=p, dilation=d, groups=g), 
                                                                                             weight, stride=s, padding=p, dilation=d, groups=g)
            
            
        with torch.no_grad():
            shape = (1,weight.shape[1])
            for i in range(0,len(weight.shape)-2):
                shape += (max(k[i]*d[i],1),)
            u = torch.randn(shape).to(weight.device)
            

        delattr(module, fn.name)
        module.register_parameter(fn.name + "_orig", weight)
        # We still need to assign weight back as fn.name because all sorts of
        # things may assume that it exists, e.g., when initializing weights.
        # However, we can't directly assign as it could be an nn.Parameter and
        # gets added as a parameter. Instead, we register weight.data as a plain
        # attribute.
        setattr(module, fn.name, weight.data)
        module.register_buffer(fn.name + "_u", u)
        sigma = torch.tensor(1).to(weight.device)
        module.register_buffer(fn.name + "_sigma", sigma)

        module.register_forward_pre_hook(fn)

        module._register_state_dict_hook(SpectralNormStateDictHook(fn))
        module._register_load_state_dict_pre_hook(SpectralNormLoadStateDictPreHook(fn))
        return fn
    
    
class SpectralNormStateDictHook(object):
    # See docstring of SpectralNorm._version on the changes to spectral_norm.
    def __init__(self, fn):
        self.fn = fn

    def __call__(self, module, state_dict, prefix, local_metadata):
        pass    

# This is a top level class because Py2 pickle doesn't like inner class nor an
# instancemethod.
class SpectralNormLoadStateDictPreHook(object):
    # See docstring of SpectralNorm._version on the changes to spectral_norm.
    def __init__(self, fn):
        self.fn = fn

    # For state_dict with version None, (assuming that it has gone through at
    # least one training forward), we have
    #
    #    u = normalize(W_orig @ v)
    #    W = W_orig / sigma, where sigma = u @ W_orig @ v
    #
    # To compute `v`, we solve `W_orig @ x = u`, and let
    #    v = x / (u @ W_orig @ x) * (W / W_orig).
    def __call__(self, state_dict, prefix, local_metadata, strict,
                 missing_keys, unexpected_keys, error_msgs):
        pass

In [None]:
model_t = InvertibleResnetConv(dim=5)

In [None]:
model_t

In [None]:
torch.manual_seed(1234)
torch.cuda.manual_seed(1234)
np.random.seed(1234)
random.seed(1234)
torch.backends.cudnn.deterministic = True

model_t = iResNet(nBlocks=[4,4,4], nStrides=[1,2,2],
                nChannels=[1,64,256], nClasses=10,
                init_ds=2,
                inj_pad=0,
                in_shape=init_batch.shape[1:],
                coeff=.9,
                numTraceSamples=1,
                numSeriesTerms=1,
                n_power_iter=5,
                density_estimation=False,
                actnorm=True,
                learn_prior=True,
                nonlin="relu")

In [None]:
Squeeze

In [None]:
class conv_iResNet(nn.Module):
    def __init__(self, in_shape, nBlocks, nStrides, nChannels, init_ds=2, inj_pad=0,
                 coeff=.9, density_estimation=False, nClasses=None,
                 numTraceSamples=1, numSeriesTerms=1,
                 n_power_iter=5,
                 block=conv_iresnet_block,
                 actnorm=True, learn_prior=True,
                 nonlin="relu"):
        super(conv_iResNet, self).__init__()
        assert len(nBlocks) == len(nStrides) == len(nChannels)
        assert init_ds in (1, 2), "can only squeeze by 2"
        self.init_ds = init_ds
        self.ipad = inj_pad
        self.nBlocks = nBlocks
        self.density_estimation = density_estimation
        self.nClasses = nClasses
        # parameters for trace estimation
        self.numTraceSamples = numTraceSamples if density_estimation else 0
        self.numSeriesTerms = numSeriesTerms if density_estimation else 0
        self.n_power_iter = n_power_iter

        print('')
        print(' == Building iResNet %d == ' % (sum(nBlocks) * 3 + 1))
        self.init_squeeze = Squeeze(self.init_ds)
        self.inj_pad = injective_pad(inj_pad)
        if self.init_ds == 2:
           in_shape = downsample_shape(in_shape)
        in_shape = (in_shape[0] + inj_pad, in_shape[1], in_shape[2])  # adjust channels

        self.stack, self.in_shapes, self.final_shape = self._make_stack(nChannels, nBlocks, nStrides,
                                                                        in_shape, coeff, block,
                                                                        actnorm, n_power_iter, nonlin)

        # make prior distribution
        self._make_prior(learn_prior)
        # make classifier
        self._make_classifier(self.final_shape, nClasses)
        assert (nClasses is not None or density_estimation), "Must be either classifier or density estimator"

    def _make_prior(self, learn_prior):
        dim = np.prod(self.in_shapes[0])
        self.prior_mu = nn.Parameter(torch.zeros((dim,)).float(), requires_grad=learn_prior)
        self.prior_logstd = nn.Parameter(torch.zeros((dim,)).float(), requires_grad=learn_prior)

    def _make_classifier(self, final_shape, nClasses):
        if nClasses is None:
            self.logits = None
        else:
            self.bn1 = nn.BatchNorm2d(final_shape[0], momentum=0.9)
            self.logits = nn.Linear(final_shape[0], nClasses)

    def classifier(self, z):
        out = F.relu(self.bn1(z))
        out = F.avg_pool2d(out, out.size(2))
        out = out.view(out.size(0), out.size(1))
        return self.logits(out)

    def prior(self):
        return distributions.Normal(self.prior_mu, torch.exp(self.prior_logstd))

    def logpz(self, z):
        return self.prior().log_prob(z.view(z.size(0), -1)).sum(dim=1)

    def _make_stack(self, nChannels, nBlocks, nStrides, in_shape, coeff, block,
                    actnorm, n_power_iter, nonlin):
        """ Create stack of iresnet blocks """
        block_list = nn.ModuleList()
        in_shapes = []
        for i, (int_dim, stride, blocks) in enumerate(zip(nChannels, nStrides, nBlocks)):
            for j in range(blocks):
                in_shapes.append(in_shape)
                block_list.append(block(in_shape, int_dim,
                                        numTraceSamples=self.numTraceSamples,
                                        numSeriesTerms=self.numSeriesTerms,
                                        stride=(stride if j == 0 else 1),  # use stride if first layer in block else 1
                                        input_nonlin=(i + j > 0),  # add nonlinearity to input for all but fist layer
                                        coeff=coeff,
                                        actnorm=actnorm,
                                        n_power_iter=n_power_iter,
                                        nonlin=nonlin))
                if stride == 2 and j == 0:
                    in_shape = downsample_shape(in_shape)

        return block_list, in_shapes, in_shape

    def get_in_shapes(self):
        return self.in_shapes
    
    def inspect_singular_values(self):
        i = 0
        j = 0
        params = [v for v in self.state_dict().keys()
                  if "bottleneck" in v and "weight_orig" in v
                  and not "weight_u" in v
                  and not "bn1" in v
                  and not "linear" in v]
        print(len(params))
        print(len(self.in_shapes))
        svs = [] 
        for param in params:
          input_shape = tuple(self.in_shapes[j])
          # get unscaled parameters from state dict
          convKernel_unscaled = self.state_dict()[param].cpu().numpy()
          # get scaling by spectral norm
          sigma = self.state_dict()[param[:-5] + '_sigma'].cpu().numpy()
          convKernel = convKernel_unscaled / sigma
          # compute singular values
          input_shape = input_shape[1:]
          fft_coeff = np.fft.fft2(convKernel, input_shape, axes=[2, 3])
          t_fft_coeff = np.transpose(fft_coeff)
          D = np.linalg.svd(t_fft_coeff, compute_uv=False, full_matrices=False)
          Dflat = np.sort(D.flatten())[::-1] 
          print("Layer "+str(j)+" Singular Value "+str(Dflat[0]))
          svs.append(Dflat[0])
          if i == 2:
            i = 0
            j+= 1
          else:
            i+=1
        return svs

    def forward(self, x, ignore_logdet=False):
        """ iresnet forward """
        if self.init_ds == 2:
            x = self.init_squeeze.forward(x)

        if self.ipad != 0:
            x = self.inj_pad.forward(x)

        z = x
        traces = []
        for block in self.stack:
            z, trace = block(z, ignore_logdet=ignore_logdet)
            traces.append(trace)

        # no classification head
        if self.density_estimation:
            # add logdets
            tmp_trace = torch.zeros_like(traces[0])
            for k in range(len(traces)):
                tmp_trace += traces[k]

            logpz = self.logpz(z)
            return z, logpz, tmp_trace

        # classification head
        else:
            logits = self.classifier(z)
            return logits, z

    def inverse(self, z, max_iter=10):
        """ iresnet inverse """
        with torch.no_grad():
            x = z
            for i in range(len(self.stack)):
                x = self.stack[-1 - i].inverse(x, maxIter=max_iter)

            if self.ipad != 0:
                x = self.inj_pad.inverse(x)

            if self.init_ds == 2:
                x = self.init_squeeze.inverse(x)
        return x

    def sample(self, batch_size, max_iter=10):
        """sample from prior and invert"""
        with torch.no_grad():
            # only send batch_size to prior, prior has final_shape as attribute
            samples = self.prior().rsample((batch_size,))
            samples = samples.view((batch_size,) + self.final_shape)
            return self.inverse(samples, max_iter=max_iter)

    def set_num_terms(self, n_terms):
        for block in self.stack:
            for layer in block.stack:
                layer.numSeriesTerms = n_terms

In [None]:
print("initializing actnorm parameters...")
with torch.no_grad():
    model_t(init_batch, ignore_logdet=True)
print("initialized")

model_t = model_t.to(device)
optimizer = optim.Adam(model_t.parameters(), lr=.1, weight_decay=5e-4)

num_epochs = 30
for epoch in range(1, 1+num_epochs):
    start_time = time.time()
    train(args, model, optimizer, epoch, trainloader, trainset, viz, use_cuda, train_log)
    epoch_time = time.time() - start_time
    elapsed_time += epoch_time
    print('| Elapsed time : %d:%02d:%02d' % (get_hms(elapsed_time)))

print('Testing model')
test_log = open(os.path.join(args.save_dir, "test_log.txt"), 'w')
test_objective = test(test_objective, args, model, epoch, testloader, viz, use_cuda, test_log)
print('* Test results : objective = %.2f%%' % (test_objective))
with open(os.path.join(args.save_dir, 'final.txt'), 'w') as f:
    f.write(str(test_objective))

In [None]:
def train(args, model, optimizer, epoch, trainloader, trainset, viz, use_cuda, train_log):
    model.train()
    correct = 0
    total = 0

    # update lr for this epoch (for classification only)
    if not args.densityEstimation:
        lr = learning_rate(args.lr, epoch)
        update_lr(optimizer, lr)
    else:
        lr = args.lr

    params = sum([np.prod(p.size()) for p in model.parameters()])
    print('|  Number of Trainable Parameters: ' + str(params))
    print('\n=> Training Epoch #%d, LR=%.4f' % (epoch, lr))
          
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        cur_iter = (epoch - 1) * len(trainloader) + batch_idx
        # if first epoch use warmup
        if epoch - 1 <= args.warmup_epochs:
            this_lr = args.lr * float(cur_iter) / (args.warmup_epochs * len(trainloader))
            update_lr(optimizer, this_lr)

        if use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda()  # GPU settings
        optimizer.zero_grad()
        
        inputs, targets = Variable(inputs, requires_grad=True), Variable(targets)
        

        if args.densityEstimation: # density estimation
            _, logpz, trace = model(inputs)  # Forward Propagation
            # compute loss
            logpx = logpz + trace
            loss = bits_per_dim(logpx, inputs).mean()
        else: # classification
            out, _ = model(inputs)
            loss = criterion(out, targets) # Loss
        
        # logging for sigmas. NOTE: needs to be done before backward-call
        if args.densityEstimation and args.log_verbose:
            if batch_idx % args.log_every == 0:
                sigmas = []
                for k in model.state_dict().keys():
                    if 'bottleneck' and 'weight_orig' in k:               
                        sigma = model.state_dict()[k[:-5] + '_sigma']
                        sigmas.append(sigma.item())
                sigmas = np.array(sigmas)
                line_plot(viz, "sigma all layers", cur_iter, sigmas)
                
        loss.backward()  # Backward Propagation
        optimizer.step()  # Optimizer update
                
        if args.densityEstimation: # logging for density estimation
            if batch_idx % args.log_every == 0:
                mean_trace = trace.mean().item()
                mean_logpz = logpz.mean().item()
                sys.stdout.write('\r')
                sys.stdout.write('| Epoch [%3d/%3d] Iter[%3d/%3d]\t\t%% bits/dim: %.3f Trace: %.3f  logp(z) %.3f'
                                 % (epoch, args.epochs, batch_idx+1,
                                    (len(trainset)//args.batch)+1, loss,  mean_trace, mean_logpz))
                sys.stdout.flush()
                line_plot(viz, "bits/dim", cur_iter, loss.item())
                line_plot(viz, "logp(z)", cur_iter, mean_logpz)
                line_plot(viz, "log|df/dz|", cur_iter, mean_trace)
                # file logging
                log_dict = {"iter": cur_iter, "loss": loss.item(), "logpz": mean_logpz, "logdet": mean_trace, "epoch": epoch}
                train_log.write("{}\n".format(json.dumps(log_dict)))
                train_log.flush()

                if args.log_verbose:
                    # grad_norm_2 = sum((p.grad.norm()**2).item() for p in model.parameters() if p.grad is not None)
                    grad_norm_inf = max(p.grad.data.abs().max().item() for p in model.parameters() if p.grad is not None)
                    # line_plot(viz, "grad_norm_2", cur_iter, grad_norm_2)
                    line_plot(viz, "grad_norm_inf", cur_iter, grad_norm_inf)
                    # log actnorm scaling
                    if not args.noActnorm:
                        actnorm_scales = []
                        actnorm_scales_min = []
                        actnorm_l2 = []
                        for k in model.state_dict().keys():
                            if 'actnorm' and '_log_scale' in k:
                                scale = torch.max(model.state_dict()[k])
                                scale_min = torch.min(model.state_dict()[k])
                                l2 = torch.norm(model.state_dict()[k])
                                actnorm_scales.append(scale.item())
                                actnorm_scales_min.append(scale_min.item())
                                actnorm_l2.append(l2.item())
                        actnorm_scales = np.array(actnorm_scales)
                        actnorm_scales_min = np.array(actnorm_scales_min)
                        actnorm_l2 = np.array(actnorm_l2)
                        line_plot(viz, "max actnorm scale per layer", cur_iter, actnorm_scales)
                        line_plot(viz, "min actnorm scale per layer", cur_iter, actnorm_scales_min)
                        line_plot(viz, "l2 norm of actnorm scale per layer", cur_iter, actnorm_l2)  
                    # learned prior logging
                    if not args.fixedPrior:
                        prior_scales_max = torch.max(model.state_dict()['module.prior_logstd'])
                        prior_scales_min = torch.min(model.state_dict()['module.prior_logstd'])
                        line_plot(viz, "max prior scale", cur_iter, prior_scales_max.item())
                        line_plot(viz, "min prior scale", cur_iter, prior_scales_min.item())  

        else: # logging for classification
            _, predicted = torch.max(out.data, 1)
            total += targets.size(0)
            correct += predicted.eq(targets.data).cpu().sum()    
            if batch_idx % 1 == 0:
                sys.stdout.write('\r')
                sys.stdout.write('| Epoch [%3d/%3d] Iter[%3d/%3d]\t\tLoss: %.4f Acc@1: %.3f'
                                 % (epoch, args.epochs, batch_idx+1,
                                    (len(trainset)//args.batch)+1, loss.data.item(),
                                    100.*correct.type(torch.FloatTensor)/float(total)))
                sys.stdout.flush()

In [None]:
def learning_rate(init, epoch):
    optim_factor = 0
    if epoch > 160:
        optim_factor = 3
    elif epoch > 120:
        optim_factor = 2
    elif epoch > 60:
        optim_factor = 1
    return init*math.pow(0.2, optim_factor)

def zero_grad(model):
    for p in model.parameters():
        if p.grad is not None:
            p.grad.detach_()
            p.grad.zero_()
            
def update_lr(optimizer, lr):
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr            

In [None]:
python CIFAR_main.py --nBlocks 16 16 16 --nStrides 1 2 2 --nChannels 512 512 512 --coeff 0.9 -densityEstimation -multiScale --lr 0.003 --weight_decay 0. 
--numSeriesTerms 5 --dataset cifar10 --batch 128 --warmup_epochs 1 --save_dir ./results/dens_est_cifar --vis_server your.server.local --vis_port your_port_nr

# load dataset and model

In [None]:
# params
img_size = 256
class_num = 1

# cosmo dataset
transformer = transforms.Compose([ToTensor()])
mnu_dataset = MassMapsDataset(opj(data_path, 'cosmological_parameters.txt'),  
                              opj(data_path, 'z1_256'),
                              transform=transformer)

# dataloader
data_loader = torch.utils.data.DataLoader(mnu_dataset, batch_size=32, shuffle=False, num_workers=4)

# load model
model = models.load_model(model_name='resnet18', device=device, inplace=False, data_path=data_path).to(device)
model = model.eval()
# freeze layers
for param in model.parameters():
    param.requires_grad = False

In [None]:
class Transform(nn.Module):
    def __init__(self):
        super(Transform, self).__init__()
        self.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=2, padding=3, bias=False)
        self.relu1 = nn.ReLU()
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        return x
    
    
class Transform_i(nn.Module):
    def __init__(self):
        super(Transform_i, self).__init__()
        self.convt1 = nn.ConvTranspose2d(64, 1, kernel_size=7, stride=2, padding=3, output_padding=1, bias=False)

    def forward(self, x):
        return self.convt1(x)
    
    
def gradient_pen(gen_frames, alpha=2):

    def gradient(x):
        # idea from tf.image.image_gradients(image)
        # https://github.com/tensorflow/tensorflow/blob/r2.1/tensorflow/python/ops/image_ops_impl.py#L3441-L3512
        # x: (b,c,h,w), float32 or float64
        # dx, dy: (b,c,h,w)

        h_x = x.size()[-2]
        w_x = x.size()[-1]
        # gradient step=1
        left = x
        right = F.pad(x, [0, 1, 0, 0])[:, :, :, 1:]
        top = x
        bottom = F.pad(x, [0, 0, 0, 1])[:, :, 1:, :]

        # dx, dy = torch.abs(right - left), torch.abs(bottom - top)
        dx, dy = right - left, bottom - top 
        # dx will always have zeros in the last column, right-left
        # dy will always have zeros in the last row,    bottom-top
        dx[:, :, :, -1] = 0
        dy[:, :, -1, :] = 0

        return dx, dy
    
    # gradient
    gen_dx, gen_dy = gradient(gen_frames)
    
    # condense into one tensor and avg
    return torch.mean(gen_dx ** alpha + gen_dy ** alpha)             

In [None]:
# test im
X = iter(data_loader).next()['image'][0:1].to(device)
X.requires_grad = True

# output
with torch.no_grad():
    output = model(X).flatten()[1]

optimize over mask 

In [None]:
# test im
X = iter(data_loader).next()['image'][0:1].to(device)
X.requires_grad = True

# output
with torch.no_grad():
    output = model(X).flatten()[1]

In [None]:
class Mask(nn.Module):
    def __init__(self, img_size=256):
        super(Mask, self).__init__()
        self.mask = nn.Parameter(torch.ones(img_size, img_size))
#         self.mask = nn.Parameter(torch.clamp(abs(torch.randn(img_size, img_size)), 0, 1))
        
    def forward(self, x):
        return torch.mul(self.mask, x)

In [None]:
# mask
mask = Mask().to(device)

# criterion
criterion = nn.MSELoss()

# l1-loss
l1loss = nn.L1Loss()

# Setup Adam optimizer
optimizer = optim.Adam(mask.parameters(), lr=0.05)

In [None]:
# Training Loop
# Lists to keep track of progress
losses = []
num_epochs = 1000

lamb_l1 = 5.0

print("Starting Training Loop...")
# For each epoch
for epoch in range(num_epochs):
    im_mask = mask(X)
    output_ = model(im_mask).flatten()[1] 
    # loss
    loss = -output_ + lamb_l1 * l1loss(mask.mask, torch.zeros_like(mask.mask))
    # zero grad
    optimizer.zero_grad()
    # backward
    loss.backward()
    # Update G
    optimizer.step()
    # projection
    mask.mask.data = torch.clamp(mask.mask.data, 0, 1)

    # Output training stats
    print('\rTrain Epoch: {}/{}'.format(epoch, num_epochs), end='')

    # Save Losses for plotting later
    losses.append(loss.item())


In [None]:
plt.plot(losses)
plt.show()

In [None]:
cshow(mask(X).data.cpu().squeeze())

In [None]:
cshow(X.data.cpu().squeeze())