# METADATA
- Notebook implementing MAF - Masked Autoregressive Flow for Density Estimation. 
    - Only consider X and Y, ignoring particle energy for now.
    - Scaling the X and Y predictions to [0, 125].
    - Considering only the jets with 'valid' number of hits, i.e. in the range [384, 768].
    - Binning the energies between [50, 55].
- Author: Aditya Ahuja

In [None]:
# Add a description for the Weights and Bisases project. 
WANDB_DESC = 'Setting up MAF.'

# PRELIMINARY

- We need PyTorch 1.9.0+cu102.
- Ideally: 

```
>>> pip list | grep torch
torch                         1.9.0+cu102
torchsummary                  1.5.1
torchtext                     0.10.0
torchvision                   0.10.0+cu102
```



In [None]:
%%capture
! pip3 install torch==1.9.0+cu102 torchvision==0.10.0+cu102 -f https://download.pytorch.org/whl/torch_stable.html
! pip install wandb

In [None]:
import torch
assert torch.__version__ == '1.9.0+cu102'
! pip list | grep torch

torch                         1.9.0+cu102
torchsummary                  1.5.1
torchtext                     0.10.0
torchvision                   0.10.0+cu102


In [None]:
# Log into Weights and Biases for Logging
! wandb login

[34m[1mwandb[0m: Currently logged in as: [33madiah80[0m (use `wandb login --relogin` to force relogin)


## Preliminary Setup

In [None]:
''' COLAB'''

# Mount Drive
from google.colab import drive
drive.mount('/content/drive')

# Install missing packages
! apt-get install tree >/dev/null

# Download dataset
# ! ./get_dataset.sh

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
# Set Input format of files in $ROOT/data
INPUT_FORMAT = 'Boosted_Jets_Sample-{}.snappy.parquet'

In [None]:
''' IMPORT INITIAL PACKAGES ''' 

import os
import cv2
import wandb
import numpy as np
import pandas as pd
import pickle as pkl
import matplotlib.pyplot as plt
import pyarrow.parquet as pq
from tqdm.auto import tqdm, trange
import torch.nn.functional as F

# Set Numpy Print Options
np.set_printoptions(precision=3)
np.set_printoptions(suppress=True)

In [None]:
%matplotlib inline

In [None]:
''' CREATE AND SET DATA/CACHE DIRECTORIES '''

ROOT = '/content/drive/My Drive/_GSoC/Normalizing-Flows/'
DATA_ROOT = ROOT + "data/"      # Store Datasets
CACHE_ROOT = ROOT + "cache/"    # Store intermediate cache files
LOGS_ROOT = ROOT + "logs/"      # Store Logs

# Ensure the above directories are present
os.chdir(ROOT)
os.makedirs(DATA_ROOT, exist_ok=True)
os.makedirs(CACHE_ROOT, exist_ok=True)
os.makedirs(LOGS_ROOT, exist_ok=True)

# Initialize scratch space on /content for faster read-write
# Mainly used for temporary file storage
SCRATCH_ROOT = '/content/scratch/'   
os.makedirs(SCRATCH_ROOT, exist_ok=True) 

# print('Directory Structure [Excluding Log/Temp Files]:')
! tree -I 'model*|temp__*|wandb*|run_'

.
├── cache
│   └── X_dict.pkl
├── data
│   ├── Boosted_Jets_Sample-0.snappy.parquet
│   ├── Boosted_Jets_Sample-1.snappy.parquet
│   ├── Boosted_Jets_Sample-2.snappy.parquet
│   ├── Boosted_Jets_Sample-3.snappy.parquet
│   └── Boosted_Jets_Sample-4.snappy.parquet
├── get_dataset.sh
├── logs
├── nbs
│   ├── nb3_OLD.ipynb
│   ├── Starter.ipynb
│   ├── Week_1.ipynb
│   ├── Week_2.ipynb
│   ├── Week_3.ipynb
│   ├── Week_4.ipynb
│   ├── Week_5.ipynb
│   └── Week_6.ipynb
├── README.md
├── requirements.txt
└── results

5 directories, 17 files


# LOAD DATA

### Dataset Class

- We create a Dataset class to feed data to our model.
- As we have restrictions on the jets we want to consider (based on number of hits in them), we modify the defauly PyTorch Dataset class.
- Additional functions added to the class - 
    * `get_next_instance`.
    * `get_raw_instance`. 
    * `pad_instance`. 

In [1]:
''' DEFINE THE DATA CLASS '''

import torch
from torch.utils.data import *
CHANNELS = ['Pt', 'ECal', 'HCal']

def parse_img(track_img, reduce=False):
    '''
    Returns non-zero hits from the single-channel input jet image - `track_img`. 
    '''
    track_img = torch.Tensor(track_img)        
    x_pos, y_pos = torch.nonzero(track_img, as_tuple=True)
    val = track_img[x_pos, y_pos]
    out = torch.stack((x_pos,y_pos,val),dim=1)
    return out

class ParquetDataset(Dataset):
    def __init__(self, filename, channels=[1], max_instances=768, min_instances=384):
        self.parquet = pq.ParquetFile(filename)
        self.cur_idx = 0
        self.total_len = self.parquet.num_row_groups
        self.cols = None 
        self.verbose = False                # False by default
        self.max_instances = max_instances  # Number of max hits to force in each jet.
        self.min_instances = min_instances  # Number of min hits to force in each jet.
        self.allowed_range = range(min_instances, max_instances+1)

        self.pt_range = [25,140]
        self.return_channels = ['ECal']
        self.supress_val = True
        self.cached_pts = None
        
        
    def __getitem__(self, index):
        raise NotImplementedError('Not needed. Using `get_next_instance` instead.')

    def __len__(self):
        raise NotImplementedError('Not needed.')

    def get_next_instance(self):
        '''
        - Returns the next valid sample, with it's true index.
        - Keeps looping until a valid sample is found.
        - A samples validity is based on its `pt_range` (energy range)
        as set by the `self.pt_range` attribute.
        '''

        while True:
            for idx in range(self.cur_idx, self.total_len):
                raw_jets = self.get_raw_instance(idx)
                parsed_jets = [parse_img(j) for j in raw_jets]
                gen_pts = [j[:,2] for j in parsed_jets]
                self.cached_pts = gen_pts
                
                if self.supress_val:
                    parsed_jets = [j[:,:2] for j in parsed_jets]
                
                ### Temporary hack for ECal ###
                ecal_idx = self.return_channels.index('ECal')
                parsed_jet = parsed_jets[ecal_idx]      
                gen_pt = gen_pts[ecal_idx].sum() 
                if self.verbose:
                    print('-- Output Shape: {}'.format(parsed_jet.shape))
                    print('-- Gen-Pt Val: {}'.format(gen_pt))

                # print(gen_pt, self.pt_range)
                if (parsed_jet.shape[0] in self.allowed_range) and \
                   (gen_pt >= self.pt_range[0] and gen_pt <= self.pt_range[1]):
                    if self.verbose:
                        print('-- Returning instance at idx={}'.format(idx))

                    padded_jet = self.pad_instance(parsed_jet)
                    self.cur_idx = idx + 1
                    return padded_jet, idx     # Exit after finding a valid instance
                else:
                    if self.verbose:
                        print('-- Skipped instance at idx={}, shape={}, pt={}'.format(idx, parsed_jet.shape, gen_pt))

            # End of dataset, loop back.
            self.cur_idx = 0

    def get_raw_instance(self, index):
        '''
        Parses the Parquet dataset to return raw data at a certain index.
        '''

        c_idx = []
        for c in self.return_channels:
            assert c in CHANNELS
            c_idx.append(CHANNELS.index(c))

        data = self.parquet.read_row_group(index, columns=self.cols).to_pydict()
        data['X_jets'] = np.float32(data['X_jets'][0]) 
        data['X_jets'] = data['X_jets'][0:]
        data['X_jets'][data['X_jets'] < 1.e-3] = 0.     # Zero-Suppression
        raw_jet = dict(data)['X_jets'][c_idx]           # Temporary Hack for Ecal
        return raw_jet

    def pad_instance(self, instance):
        '''
        - Pads a data instance (input: `instance`) with additional zeros. 
        - Shape of the padded instance is made equal to `self.max_instances`. 
        - Number of hits in the instance is assumed to be in the range: 
        [self.min_instances, self.max_instances] range.
        '''
        assert instance.shape[0] <= self.max_instances
        assert instance.shape[0] >= self.min_instances
        pad_len = self.max_instances - instance.shape[0]
        instance = F.pad(instance, pad=(0, 0, 0, pad_len), mode='constant', value=0)
        return instance

In [2]:

def vis(arr, is_parsed=True, title=None, scale=1000, cmap='gist_heat', reduce=False, onlySave=False, savePath=None):  
    '''
    Visualise a jet instance.
    '''

    if not is_parsed:   
        arr = parse_img(arr, reduce)

    if arr.shape[1] == 3:
        x_pos, y_pos, val = arr[:,0], arr[:,1], arr[:,2]
    elif arr.shape[1] == 2:
        x_pos, y_pos = arr[:,0], arr[:,1]
        val = torch.ones_like(x_pos)
        scale = None
    else:
        raise Exception("Wrong array dimensions.")

    if scale:
        sz = np.array(np.abs(val)) * scale
    else:
        sz = np.ones_like(val) * 10
        
    plt.figure(figsize=[10,6], facecolor='#f0f0f0')
    cm = plt.cm.get_cmap(cmap)     # 'gist_heat' / 'YlOrRd'
    sc = plt.scatter(x_pos, y_pos, c=val, s=sz, cmap=cm, alpha=0.5, edgecolors='k')
    plt.colorbar(sc)
    plt.xlim(0, 126)
    plt.ylim(0, 126)
    plt.xticks(range(0,125,25))
    plt.yticks(range(0,125,25))
    plt.grid()
    if title:
        plt.title(title)

    if onlySave:
        print('Saving to:', savePath)
        plt.savefig(savePath, dpi=150)
    else:
        plt.show()

### Try using the Dataset Class

In [None]:
''' INSTANTIATE A DATASET OBJECT  '''

dataset_file = DATA_ROOT + INPUT_FORMAT.format(0)  # Load the first file
dataset = ParquetDataset(dataset_file)             # Instantiate the dataset 
dataset.verbose = True                             # Set the Verbose flag
print('Max Length of Dataset: ', dataset.total_len)

In [None]:
data_sample, true_idx = dataset.get_next_instance()   # Load an instance
data_sample.shape

In [None]:
print(dataset.allowed_range)
print()

# Load a few instances
for i in range(10):
    print('[{}]'.format(i))
    data_sample, true_idx = dataset.get_next_instance()
    print(data_sample.shape, true_idx)
    print(dataset.cur_idx)
    print('\n')

# MASKED AUTOREGRESSIVE FLOWs

Adapted from - https://github.com/kamenbliznashki/normalizing_flows

### Initial Setup

In [None]:
''' IMPORTS '''

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as D
import torchvision.transforms as T
from torchvision.utils import save_image
from torch.utils.data import DataLoader, TensorDataset

import os
import math
import copy
import time
import argparse
import pprint
from functools import partial

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from tqdm import tqdm

In [None]:
''' DEFINE THE ARG PARSER '''

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--train', action='store_true', help='Train a flow.')
    parser.add_argument('--evaluate', action='store_true', help='Evaluate a flow.')
    parser.add_argument('--restore_file', type=str, help='Path to model to restore.')
    parser.add_argument('--generate', action='store_true', help='Generate samples from a model.')
    parser.add_argument('--data_dir', default='./data/', help='Location of datasets.')
    parser.add_argument('--output_dir', default='./results/run_')
    parser.add_argument('--results_file', default='results.txt', help='Filename where to store settings and test results.')
    parser.add_argument('--no_cuda', action='store_true', help='Do not use cuda.')
    
    # data
    parser.add_argument('--dataset', default='toy', help='Which dataset to use.')
    parser.add_argument('--flip_toy_var_order', action='store_true', help='Whether to flip the toy dataset variable order to (x2, x1).')
    parser.add_argument('--seed', type=int, default=1, help='Random seed to use.')
    
    # model
    parser.add_argument('--model', default='maf', help='Which model to use: made, maf.')
    
    # made parameters
    parser.add_argument('--n_blocks', type=int, default=5, help='Number of blocks to stack in a model (MADE in MAF; Coupling+BN in RealNVP).')
    parser.add_argument('--n_components', type=int, default=1, help='Number of Gaussian clusters for mixture of gaussians models.')
    parser.add_argument('--hidden_size', type=int, default=100, help='Hidden layer size for MADE (and each MADE block in an MAF).')
    parser.add_argument('--n_hidden', type=int, default=1, help='Number of hidden layers in each MADE.')
    parser.add_argument('--activation_fn', type=str, default='relu', help='What activation function to use in the MADEs.')
    parser.add_argument('--input_order', type=str, default='sequential', help='What input order to use (sequential | random).')
    parser.add_argument('--conditional', default=False, action='store_true', help='Whether to use a conditional model.')
    parser.add_argument('--no_batch_norm', action='store_true')
    
    # training params
    parser.add_argument('--batch_size', type=int, default=100)
    parser.add_argument('--n_epochs', type=int, default=50)
    parser.add_argument('--start_epoch', default=0, help='Starting epoch (for logging; to be overwritten when restoring file.')
    parser.add_argument('--lr', type=float, default=1e-4, help='Learning rate.')
    parser.add_argument('--log_interval', type=int, default=1000, help='How often to show loss statistics and save samples.')

    args = parser.parse_args([])

    # Automatically create a new log directory for every run.
    args.output_dir = os.path.join('./results/run_', time.strftime('%Y-%m-%d_%H-%M-%S', time.gmtime()))
    if not os.path.isdir(args.output_dir):
        os.makedirs(args.output_dir)

    # Use Cuda if it's available
    args.device = torch.device('cuda:0' if torch.cuda.is_available() and not args.no_cuda else 'cpu')
    args.conditional = False
    args.cond_label_size = None

    return args

### Building the MAF model

In [None]:
''' MODEL CONPONENTS ''' 

def create_masks(input_size, hidden_size, n_hidden, input_order='sequential', input_degrees=None):
    # MADE paper sec 4:
    # degrees of connections between layers -- ensure at most in_degree - 1 connections
    degrees = []

    # set input degrees to what is provided in args (the flipped order of the previous layer in a stack of mades);
    # else init input degrees based on strategy in input_order (sequential or random)
    if input_order == 'sequential':
        degrees += [torch.arange(input_size)] if input_degrees is None else [input_degrees]
        for _ in range(n_hidden + 1):
            degrees += [torch.arange(hidden_size) % (input_size - 1)]
        degrees += [torch.arange(input_size) % input_size - 1] if input_degrees is None else [input_degrees % input_size - 1]

    elif input_order == 'random':
        degrees += [torch.randperm(input_size)] if input_degrees is None else [input_degrees]
        for _ in range(n_hidden + 1):
            min_prev_degree = min(degrees[-1].min().item(), input_size - 1)
            degrees += [torch.randint(min_prev_degree, input_size, (hidden_size,))]
        min_prev_degree = min(degrees[-1].min().item(), input_size - 1)
        degrees += [torch.randint(min_prev_degree, input_size, (input_size,)) - 1] if input_degrees is None else [input_degrees - 1]

    # construct masks
    masks = []
    for (d0, d1) in zip(degrees[:-1], degrees[1:]):
        masks += [(d1.unsqueeze(-1) >= d0.unsqueeze(0)).float()]

    return masks, degrees[0]


class MaskedLinear(nn.Linear):
    """ MADE building block layer """
    def __init__(self, input_size, n_outputs, mask, cond_label_size=None):
        super().__init__(input_size, n_outputs)

        self.register_buffer('mask', mask)

        self.cond_label_size = cond_label_size
        if cond_label_size is not None:
            self.cond_weight = nn.Parameter(torch.rand(n_outputs, cond_label_size) / math.sqrt(cond_label_size))

    def forward(self, x, y=None):
        out = F.linear(x, self.weight * self.mask, self.bias)
        if y is not None:
            out = out + F.linear(y, self.cond_weight)
        return out

    def extra_repr(self):
        return 'in_features={}, out_features={}, bias={}'.format(
            self.in_features, self.out_features, self.bias is not None
        ) + (self.cond_label_size != None) * ', cond_features={}'.format(self.cond_label_size)


class LinearMaskedCoupling(nn.Module):
    """ Modified RealNVP Coupling Layers per the MAF paper """
    def __init__(self, input_size, hidden_size, n_hidden, mask, cond_label_size=None):
        super().__init__()

        self.register_buffer('mask', mask)

        # scale function
        s_net = [nn.Linear(input_size + (cond_label_size if cond_label_size is not None else 0), hidden_size)]
        for _ in range(n_hidden):
            s_net += [nn.Tanh(), nn.Linear(hidden_size, hidden_size)]
        s_net += [nn.Tanh(), nn.Linear(hidden_size, input_size)]
        self.s_net = nn.Sequential(*s_net)

        # translation function
        self.t_net = copy.deepcopy(self.s_net)
        # replace Tanh with ReLU's per MAF paper
        for i in range(len(self.t_net)):
            if not isinstance(self.t_net[i], nn.Linear): self.t_net[i] = nn.ReLU()

    def forward(self, x, y=None):
        # apply mask
        mx = x * self.mask

        # run through model
        s = self.s_net(mx if y is None else torch.cat([y, mx], dim=1))
        t = self.t_net(mx if y is None else torch.cat([y, mx], dim=1))
        u = mx + (1 - self.mask) * (x - t) * torch.exp(-s)  # cf RealNVP eq 8 where u corresponds to x (here we're modeling u)

        log_abs_det_jacobian = - (1 - self.mask) * s  # log det du/dx; cf RealNVP 8 and 6; note, sum over input_size done at model log_prob

        return u, log_abs_det_jacobian

    def inverse(self, u, y=None):
        # apply mask
        mu = u * self.mask

        # run through model
        s = self.s_net(mu if y is None else torch.cat([y, mu], dim=1))
        t = self.t_net(mu if y is None else torch.cat([y, mu], dim=1))
        x = mu + (1 - self.mask) * (u * s.exp() + t)  # cf RealNVP eq 7

        log_abs_det_jacobian = (1 - self.mask) * s  # log det dx/du

        return x, log_abs_det_jacobian
        
class BatchNorm(nn.Module):
    """ RealNVP BatchNorm layer """
    def __init__(self, input_size, momentum=0.9, eps=1e-5):
        super().__init__()
        self.momentum = momentum
        self.eps = eps

        self.log_gamma = nn.Parameter(torch.zeros(input_size))
        self.beta = nn.Parameter(torch.zeros(input_size))

        self.register_buffer('running_mean', torch.zeros(input_size))
        self.register_buffer('running_var', torch.ones(input_size))

    def forward(self, x, cond_y=None):
        if self.training:
            self.batch_mean = x.mean(0)
            self.batch_var = x.var(0) # note MAF paper uses biased variance estimate; ie x.var(0, unbiased=False)

            # update running mean
            self.running_mean.mul_(self.momentum).add_(self.batch_mean.data * (1 - self.momentum))
            self.running_var.mul_(self.momentum).add_(self.batch_var.data * (1 - self.momentum))

            mean = self.batch_mean
            var = self.batch_var
        else:
            mean = self.running_mean
            var = self.running_var

        # compute normalized input (cf original batch norm paper algo 1)
        x_hat = (x - mean) / torch.sqrt(var + self.eps)
        y = self.log_gamma.exp() * x_hat + self.beta

        # compute log_abs_det_jacobian (cf RealNVP paper)
        log_abs_det_jacobian = self.log_gamma - 0.5 * torch.log(var + self.eps)
#        print('in sum log var {:6.3f} ; out sum log var {:6.3f}; sum log det {:8.3f}; mean log_gamma {:5.3f}; mean beta {:5.3f}'.format(
#            (var + self.eps).log().sum().data.numpy(), y.var(0).log().sum().data.numpy(), log_abs_det_jacobian.mean(0).item(), self.log_gamma.mean(), self.beta.mean()))
        return y, log_abs_det_jacobian.expand_as(x)

    def inverse(self, y, cond_y=None):
        if self.training:
            mean = self.batch_mean
            var = self.batch_var
        else:
            mean = self.running_mean
            var = self.running_var

        x_hat = (y - self.beta) * torch.exp(-self.log_gamma)
        x = x_hat * torch.sqrt(var + self.eps) + mean

        log_abs_det_jacobian = 0.5 * torch.log(var + self.eps) - self.log_gamma

        return x, log_abs_det_jacobian.expand_as(x)


class FlowSequential(nn.Sequential):
    """ Container for layers of a normalizing flow """
    def forward(self, x, y):

        sum_log_abs_det_jacobians = 0
        for module in self:
            x, log_abs_det_jacobian = module(x, y)
            sum_log_abs_det_jacobians = sum_log_abs_det_jacobians + log_abs_det_jacobian
            # print("sum_log_abs_det_jacobians: ", sum_log_abs_det_jacobians)
        return x, sum_log_abs_det_jacobians

    def inverse(self, u, y):
        sum_log_abs_det_jacobians = 0
        for module in reversed(self):
            u, log_abs_det_jacobian = module.inverse(u, y)
            sum_log_abs_det_jacobians = sum_log_abs_det_jacobians + log_abs_det_jacobian
        return u, sum_log_abs_det_jacobians
        


In [None]:
''' DEFINING THE MODEL '''

class MADE(nn.Module):
    def __init__(self, input_size, hidden_size, n_hidden, cond_label_size=None, activation='relu', input_order='sequential', input_degrees=None):
        """
        Args:
            input_size -- scalar; dim of inputs
            hidden_size -- scalar; dim of hidden layers
            n_hidden -- scalar; number of hidden layers
            activation -- str; activation function to use
            input_order -- str or tensor; variable order for creating the autoregressive masks (sequential|random)
                            or the order flipped from the previous layer in a stack of mades
            conditional -- bool; whether model is conditional
        """
        super().__init__()
        # base distribution for calculation of log prob under the model
        self.register_buffer('base_dist_mean', torch.zeros(input_size))
        self.register_buffer('base_dist_var', torch.ones(input_size))

        # create masks
        masks, self.input_degrees = create_masks(input_size, hidden_size, n_hidden, input_order, input_degrees)

        # setup activation
        if activation == 'relu':
            activation_fn = nn.ReLU()
        elif activation == 'tanh':
            activation_fn = nn.Tanh()
        else:
            raise ValueError('Check activation function.')

        # construct model
        self.net_input = MaskedLinear(input_size, hidden_size, masks[0], cond_label_size)
        self.net = []
        for m in masks[1:-1]:
            self.net += [activation_fn, MaskedLinear(hidden_size, hidden_size, m)]
        self.net += [activation_fn, MaskedLinear(hidden_size, 2 * input_size, masks[-1].repeat(2,1))]
        self.net = nn.Sequential(*self.net)

    @property
    def base_dist(self):
        return D.Normal(self.base_dist_mean, self.base_dist_var)

    def forward(self, x, y=None):
        # MAF eq 4 -- return mean and log std
        m, loga = self.net(self.net_input(x, y)).chunk(chunks=2, dim=1)
        u = (x - m) * torch.exp(-loga)
        # MAF eq 5
        log_abs_det_jacobian = - loga
        # print(u, log_abs_det_jacobian)
        # print("-----------\n")
        return u, log_abs_det_jacobian

    def inverse(self, u, y=None, sum_log_abs_det_jacobians=None):
        # MAF eq 3
        D = u.shape[1]
        x = torch.zeros_like(u)
        # run through reverse model
        for i in self.input_degrees:
            m, loga = self.net(self.net_input(x, y)).chunk(chunks=2, dim=1)
            x[:,i] = u[:,i] * torch.exp(loga[:,i]) + m[:,i]
        log_abs_det_jacobian = loga
        return x, log_abs_det_jacobian

    def log_prob(self, x, y=None):
        u, log_abs_det_jacobian = self.forward(x, y)
        return torch.sum(self.base_dist.log_prob(u) + log_abs_det_jacobian, dim=1)


class MAF(nn.Module):
    def __init__(self, n_blocks, input_size, hidden_size, n_hidden, cond_label_size=None, activation='relu', input_order='sequential', batch_norm=True):
        super().__init__()
        # base distribution for calculation of log prob under the model
        self.register_buffer('base_dist_mean', torch.zeros(input_size))
        self.register_buffer('base_dist_var', torch.ones(input_size))

        # construct model
        modules = []
        self.input_degrees = None
        for i in range(n_blocks):
            modules += [MADE(input_size, hidden_size, n_hidden, cond_label_size, activation, input_order, self.input_degrees)]
            self.input_degrees = modules[-1].input_degrees.flip(0)
            modules += batch_norm * [BatchNorm(input_size)]

        # print(modules)
        self.net = FlowSequential(*modules)

    @property
    def base_dist(self):
        return D.Normal(self.base_dist_mean, self.base_dist_var)

    def forward(self, x, y=None):
        return self.net(x, y)

    def inverse(self, u, y=None):
        return self.net.inverse(u, y)

    def log_prob(self, x, y=None):
        # This function breaks up during training and might require more work.
        # The try-except statement has been added to continue training in that case.
        try:
            u, sum_log_abs_det_jacobians = self.forward(x, y)
            # print("[Log Prob] u: ", u.shape)
            # print("[Log Prob] Sum Log Abs Det J: ", sum_log_abs_det_jacobians.shape)
            v =  torch.sum(self.base_dist.log_prob(u) + sum_log_abs_det_jacobians, dim=1)
        except:
            print("--- [ Encountered NaN ]")
            
            v = torch.sum(torch.ones(u.shape, requires_grad=True) * 1e-2)
        # print("[Log Prob] v: ", v.shape)
        return v

In [None]:
''' DEFINING THE TRAINING FUNCTIONS ''' 

def get_batch(dataset, bs, return_indices=False):
    '''
        Get a new batch of input data from the `dataset` object of batch size `bs`. 
        
        Also returns the actual indices of the data instances in the batch 
        (corresponding to the Parquet files) if `return_indices` is set to true. 
    '''
    input_batch = []
    true_indices = []
    for batch_idx in range(bs):
        input_data, true_index = dataset.get_next_instance()
        input_data = input_data.reshape(-1)
        input_batch.append(input_data)
        true_indices.append(true_index)
    input_batch = torch.stack(input_batch)
    # print(input_batch.shape)

    if return_indices:
        return input_batch, true_indices
    else:
        return input_batch

def train(model, dataset, optimizer, args):
    '''
        Trains the input `model` on the given `dataset`, using the provided `optimizer`.

        The `args` parameter should contain the parsed training arguments from 
        the `get_args()` function.

        Note: It is recommended to instead use the `train_evaluate()` function 
        as it contains additional testing and logging functionality.  
    '''
    init_steps = args.step
    for i in range(args.n_steps):
        model.train()
        x = get_batch(dataset, args.batch_size).to(args.device)
        # print(x.shape)
        loss = - model.log_prob(x, None).mean(0)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        wandb.log({'Step': args.step, 
            'Loss': loss.item(), 
            'Learning_Rate': args.lr})
        
        if i % args.log_interval == 0:
            print('Iter {:4d} / {:4d} | loss {:.4f}'.format(init_steps + i, init_steps + args.n_steps, loss.item())) 


def train_batch(model, x_batch, optimizer):
    '''
        Train a single batch. Designed for use from within the 
        `train_evaluate()` function.
    '''

    model.train()
    # print(x_batch.shape)
    loss = - model.log_prob(x_batch, None).mean(0)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    return loss

@torch.no_grad()
def evaluate_batch(model, x_batch, iter, args):
    '''
        Evaluate a single batch. Designed for use from within the 
        `train_evaluate()` function.
    '''
    model.eval()
    # print(x_batch.shape)
    x_batch = x_batch.to(args.device)  ## To edit
    logprobs = model.log_prob(x_batch)
    # print(logprobs)
    logprobs = logprobs.to(args.device)
    logprob_mean, logprob_std = logprobs.mean(0), 2 * logprobs.var(0).sqrt() / math.sqrt(args.batch_size)
    return logprob_mean, logprob_std

# @torch.no_grad()
# def generate_samples(model, args, iter='?', toSave=False):
#     '''
#         Generate a new samples from a trained model. 
#     '''

#     model.eval()
#     u = model.base_dist.sample((1, args.n_components)).squeeze()
#     samples, _ = model.inverse(u)
#     log_probs = model.log_prob(samples).sort(0)[1].flip(0)  # sort by log_prob; take argsort idxs; flip high to low
#     samples = samples[log_probs]

#     if toSave:
#         print(samples.shape)
#         # samples = samples.view(samples.shape[0], *args.input_dims)
#         # # samples = (torch.sigmoid(samples) - dataset_lam) / (1 - 2 * dataset_lam)
#         # filename = 'generated_samples' + '_epoch_{}'.format(iter) + '.png'
#         # save_image(samples, os.path.join(args.output_dir, filename), nrow=1, normalize=True)

#     return samples

def save_model(model, optimizer, iter, args, save_name='model_state.pt'):
    '''
        Cache the training state consisting of the `model`, `optimizer` and current 
        training interation - `iter`. 

        Saved model can be loaded using the `load_model()` function. 
    '''
    save_dict = {'iteration': iter,
                'model_state': model.state_dict(),
                'optimizer_state': optimizer.state_dict()}
    save_path = os.path.join(args.output_dir, save_name)
    torch.save(save_dict, save_path)

def load_model(model, optimizer, args, restore_path):
    '''
        Load a model cached using the `save_model()` function. 
        Argument `restore_path` is expected to be a `.pt` file.
    '''
    state = torch.load(restore_path, map_location=args.device)
    model.load_state_dict(state['model_state'])
    optimizer.load_state_dict(state['optimizer_state'])
    args.step = state['iteration'] + 1
    return model, optimizer, args

def train_evaluate(model, optimizer, train_dataset, test_dataset, args):
    '''
        Root function to Train and Evaluate the MAF model. 

        Supports functionality to log training and testing on the console and 
        onto Weights and Biases (wandb) as well as checkpoint models based on 
        test set perfomance. 
    '''

    best_eval_logprob = float('-inf')
    init_steps = args.step
    
    for i in range(args.n_steps):
        training_batch = get_batch(train_dataset, args.batch_size).to(args.device)
        loss = train_batch(model, training_batch, optimizer)
        
        testing_batch = get_batch(test_dataset, args.batch_size).to(args.device)
        logprob_mean, logprob_std = evaluate_batch(model, testing_batch, iter, args)
        # print('=====================================')
        # print(logprob_mean)
        # print('=====================================')
        # print(logprob_std)
        # print('=====================================')

        wandb.log({'Step': args.step, 
                   'Loss': loss.item(), 
                   'Learning_Rate': args.lr,
                   'LogProb_Mean': logprob_mean, 
                   'LogProb_Std': logprob_std})

        if i % args.log_interval == 0:
            print_output = '[Iter: {:4d}/{:4d}] '.format(init_steps + i, init_steps + args.n_steps)
            print_output += 'Loss: {:.3f}'.format(loss.item())
            print_output += '  |  LogP(x): {:.3f} +/- {:.3f}'.format(logprob_mean, logprob_std)                       
            print(print_output)

            save_model(model, optimizer, i, args, save_name='model_state.pt')
            if logprob_mean > best_eval_logprob:
                best_eval_logprob = logprob_mean
                save_model(model, optimizer, i, args, save_name='best_model_checkpoint.pt')

### Setting up the Model for Training

In [None]:
''' INITIAL WANDB CONFIG '''

DEFAULT_CFG = {
    'model': 'MAF_Valid-Samples',
    'root_dir': ROOT,
}

In [None]:
''' LOAD TEST/TRAIN DATASETS '''

def load_dataset(index, pt_range=[50,55], data_root=DATA_ROOT, input_format=INPUT_FORMAT):
    '''
        Loads the `index` dataset file and returns a `ParquetDataset` object 
        containing the instances in that file.
    '''
    dataset_file = data_root + input_format.format(index)
    dataset = ParquetDataset(dataset_file)
    dataset.pt_range = pt_range
    return dataset

# Initialize Datasets
PT_RANGE = [50,55]
train_dataset = load_dataset(0, pt_range=PT_RANGE)
test_dataset = load_dataset(1, pt_range=PT_RANGE)

In [None]:
''' TRAINING SETUP '''

# Filter training warnings.
import warnings
warnings.filterwarnings('ignore')

# Get Args
args = get_args()

# Set custom args
args.input_dims = None
args.input_size = train_dataset.max_instances * 2    # Can be tuned
args.hidden_size = 300   # Can be tuned
args.n_blocks = 8
args.n_components = 1
args.n_hidden = 1
args.n_steps = 5000       # Iterations to train for.
args.step = 0             # Initial training iteration.
args.lr = 1e-4            # Learning Rate
args.batch_size = 32      # Batch Size
args.cuda = 0             # GPU index
args.log_interval = 1     # Logging Interval (based on iteration frequency)
args.pt_range = PT_RANGE  # Range of energy to filter jets
args.no_batch_norm = False

# Set Seeds
torch.manual_seed(args.seed)
if args.device.type == 'cuda': 
    torch.cuda.manual_seed(args.seed)

# Get the MAF Model
model = MAF(args.n_blocks, args.input_size, args.hidden_size, args.n_hidden, args.cond_label_size,
                    args.activation_fn, args.input_order, batch_norm=not args.no_batch_norm)
model = model.to(args.device)

# Save Config
config = 'Parsed args:\n{}\n\n'.format(pprint.pformat(args.__dict__)) + \
            'Num trainable params: {:,.0f}\n\n'.format(sum(p.numel() for p in model.parameters())) + \
            'Model:\n{}'.format(model)
config_path = os.path.join(args.output_dir, 'config.txt')
if not os.path.exists(config_path):
    with open(config_path, 'a') as f:
        print(config, file=f)

# Get Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-6)

In [None]:
''' INITIALIZE A WANDB RUN '''

# Init Wandb
os.environ["WANDB_SILENT"] = "true"
os.environ["WANDB_NOTES"] = WANDB_DESC
run = wandb.init(project='gnf', config=DEFAULT_CFG, dir=LOGS_ROOT)
wandb.config.update(args)

# Save files for later
! pip freeze > requirements.txt
# wandb.save(ROOT + 'requirements.txt')       # TODO: Fix permission issue on Colab.
# wandb.save(config_path)                     # TODO: Fix permission issue on Colab.

### Model Training

In [None]:
''' TRAIN + EVALUATE THE MODEL '''

train_evaluate(model, optimizer, train_dataset, test_dataset, args, toGenerate=False)

--- [ Encountered NaN ]
[Iter:    0/5000] Loss: 7335.718  |  LogP(x): 491.520 +/- nan
[Iter:    1/5000] Loss: 7206.878  |  LogP(x): -20063.887 +/- 1499.554
[Iter:    2/5000] Loss: 7201.562  |  LogP(x): -14998.738 +/- 3175.619
[Iter:    3/5000] Loss: 6909.108  |  LogP(x): -12103.184 +/- 1325.232
[Iter:    4/5000] Loss: 7181.210  |  LogP(x): -10597.162 +/- 545.477
[Iter:    5/5000] Loss: 7169.951  |  LogP(x): -9496.451 +/- 375.776
[Iter:    6/5000] Loss: 6763.291  |  LogP(x): -9188.175 +/- 477.719
[Iter:    7/5000] Loss: 6955.591  |  LogP(x): -8744.181 +/- 386.556
[Iter:    8/5000] Loss: 7095.173  |  LogP(x): -8299.588 +/- 232.408
[Iter:    9/5000] Loss: 6568.482  |  LogP(x): -8039.018 +/- 170.622
[Iter:   10/5000] Loss: 7084.449  |  LogP(x): -7799.437 +/- 190.473
[Iter:   11/5000] Loss: 6847.776  |  LogP(x): -7980.133 +/- 232.192
[Iter:   12/5000] Loss: 6975.272  |  LogP(x): -7869.228 +/- 166.644
[Iter:   13/5000] Loss: 6709.409  |  LogP(x): -7738.040 +/- 215.357
[Iter:   14/5000] Loss:

### Saving and Loading

In [None]:
''' MANUALLY SAVE THE MODEL '''

# save_model(model, optimizer, i=-1, args=args, save_name='model_state.pt')

In [None]:
''' LOAD A MODEL '''

# RESTORE_PATH = '/content/drive/MyDrive/_GSoC/Normalizing-Flows/results/run_/2021-08-20_08-22-09/best_model_checkpoint.pt'
# RESTORE_PATH = '/content/drive/MyDrive/_GSoC/Normalizing-Flows/results/run_/2021-08-20_08-22-09/model_state.pt'
RESTORE_PATH = '/content/drive/MyDrive/_GSoC/Normalizing-Flows/results/run_/2021-08-20_11-19-50/best_model_checkpoint.pt'

# model, optimizer, args = load_model(model, optimizer, args, RESTORE_PATH)

### Generating New Instances

In [None]:
from sklearn.preprocessing import MinMaxScaler

def normalize(instance):
    '''
        Normalize the input instance between [0,125]
    '''
    scaler = MinMaxScaler()
    instance = scaler.fit_transform(instance) * 125
    return torch.Tensor(instance)

@torch.no_grad()
def generate_sample(model, args, 
                    sample_size=4, 
                    save_size=1, 
                    toNormalize=True,
                    iter=-1, toSave=False):
    '''
        Generate new samples from a trained model. `sample_size` defines the 
        number of samples to generate out of which `save_size` number of samples
        are saved into the current runs logging folder (Samples with the highest
        Log Probability are saved). 
    '''

    args.batch_size = 1     # Needed to avoid CUDA out of memory.
    model.eval()
    with torch.no_grad():
        u = model.base_dist.sample((sample_size, args.n_components)).reshape(sample_size, -1)
        samples, _ = model.inverse(u)
        log_probs_raw = model.log_prob(samples)
        order = log_probs_raw.sort(0)[1].flip(0)  # sort by log_prob; take argsort idxs; flip high to low
        print('Sample Probability: ', log_probs_raw[order])
        samples = samples.detach().cpu()
        samples = samples[order][:save_size]
        samples = ([s.reshape(args.input_size//2, 2) for s in samples])

        # Normalize the samples? 
        if toNormalize:
            samples = ([normalize(s) for s in samples])

        # Save the samples to Disk? 
        if toSave:
            for i in range(len(samples)):
                file_name = 'generated_samples_iter={}_sample={}_.png'.format(iter, i)
                file_path = os.path.join(args.output_dir, file_name)
                title = 'Pred_Idx={}'.format(i)
                vis(samples[i], title=title, onlySave=True, savePath=file_path)

    return samples

In [None]:
preds = generate_sample(model, args, sample_size=1024, save_size=4, toNormalize=False, toSave=True)

# for i in range(len(preds)):
#     vis(preds[i], title='Pred_Idx={}'.format(i))

Sample Probablities:  tensor([-5463.1377, -5463.2837, -5468.6416,  ..., -5644.5547, -5644.9629,
        -5675.0122], device='cuda:0')
Saving to: ./results/run_/2021-08-20_11-19-50/generated_samples_iter=-1_sample=0_.png
Saving to: ./results/run_/2021-08-20_11-19-50/generated_samples_iter=-1_sample=1_.png
Saving to: ./results/run_/2021-08-20_11-19-50/generated_samples_iter=-1_sample=2_.png
Saving to: ./results/run_/2021-08-20_11-19-50/generated_samples_iter=-1_sample=3_.png
