In [1]:
WANDB_DESC = 'Setting up BNAF.'

In [2]:
%%capture
! pip3 install torch==1.9.0+cu102 torchvision==0.10.0+cu102 -f https://download.pytorch.org/whl/torch_stable.html
! pip install wandb

In [3]:
import torch
assert torch.__version__ == '1.9.0+cu102'
! pip list | grep torch

In [4]:
! wandb login

In [5]:
### COLAB ###

# Mount Drive
from google.colab import drive
drive.mount('/content/drive')

# Install missing packages
! apt-get install tree >/dev/null

# Download dataset
# ! ./get_dataset.sh

In [6]:
### COLAB ###

# Mount Drive
from google.colab import drive
drive.mount('/content/drive')

# Install missing packages
! apt-get install tree >/dev/null

# Download dataset
# ! ./get_dataset.sh

In [7]:
# Input format of files in $ROOT/data
INPUT_FORMAT = 'Boosted_Jets_Sample-{}.snappy.parquet'

In [8]:
%matplotlib inline

import os
import cv2
import wandb
import numpy as np
import pandas as pd
import pickle as pkl
import matplotlib.pyplot as plt
import pyarrow.parquet as pq
from tqdm.auto import tqdm, trange
import torch.nn.functional as F

# Set Numpy Print Options
np.set_printoptions(precision=3)
np.set_printoptions(suppress=True)

In [9]:
''' CREATE AND SET DATA/CACHE DIRECTORIES '''

ROOT = '/content/drive/My Drive/_GSoC/Normalizing-Flows/'
DATA_ROOT = ROOT + "data/"
CACHE_ROOT = ROOT + "cache/"
LOGS_ROOT = ROOT + "logs/"

os.chdir(ROOT)
os.makedirs(DATA_ROOT, exist_ok=True)
os.makedirs(CACHE_ROOT, exist_ok=True)
os.makedirs(LOGS_ROOT, exist_ok=True)

# Initialize scratch space on /content for faster read-write
SCRATCH_ROOT = '/content/scratch/'   
os.makedirs(SCRATCH_ROOT, exist_ok=True) 

# print('Directory Structure [Excluding Log/Temp Files]:')
! tree -I 'model*|temp__*|wandb*|run_'

In [10]:
import torch
from torch.utils.data import *
CHANNELS = ['Pt', 'ECal', 'HCal']

def parse_img(track_img, reduce=False):
    '''
    Returns non-zero hits from the single-channel input jet image - `track_img`. 
    '''
    track_img = torch.Tensor(track_img)        
    x_pos, y_pos = torch.nonzero(track_img, as_tuple=True)
    val = track_img[x_pos, y_pos]
    out = torch.stack((x_pos,y_pos,val),dim=1)
    return out

class ParquetDataset(Dataset):
    def __init__(self, filename, channels=[1], max_instances=768, min_instances=384):
        self.parquet = pq.ParquetFile(filename)
        self.cur_idx = 0
        self.total_len = self.parquet.num_row_groups
        self.cols = None 
        self.verbose = False                # False by default
        self.max_instances = max_instances  # Number of max hits to force in each jet.
        self.min_instances = min_instances  # Number of min hits to force in each jet.
        self.allowed_range = range(min_instances, max_instances+1)

        self.pt_range = [25,140]
        self.return_channels = ['ECal']
        self.supress_val = True
        self.cached_pts = None
        
        
    def __getitem__(self, index):
        raise NotImplementedError('Not needed. Using `get_next_valid_instance` instead.')

    def __len__(self):
        raise NotImplementedError('Not needed. Using `get_next_valid_instance` instead.')

    def get_next_instance(self):
        '''
        Returns the next valid sample, with it's true index.
        Keeps looping until a valid sample is found.
        '''

        while True:
            for idx in range(self.cur_idx, self.total_len):
                raw_jets = self.get_raw_instance(idx)
                parsed_jets = [parse_img(j) for j in raw_jets]
                gen_pts = [j[:,2] for j in parsed_jets]
                self.cached_pts = gen_pts
                
                if self.supress_val:
                    parsed_jets = [j[:,:2] for j in parsed_jets]
                
                ### Temporary hack for ECal ###
                ecal_idx = self.return_channels.index('ECal')
                parsed_jet = parsed_jets[ecal_idx]      
                gen_pt = gen_pts[ecal_idx].sum() 
                if self.verbose:
                    print('-- Output Shape: {}'.format(parsed_jet.shape))
                    print('-- Gen-Pt Val: {}'.format(gen_pt))

                # print(gen_pt, self.pt_range)
                if (parsed_jet.shape[0] in self.allowed_range) and \
                   (gen_pt >= self.pt_range[0] and gen_pt <= self.pt_range[1]):
                    if self.verbose:
                        print('-- Returning instance at idx={}'.format(idx))

                    padded_jet = self.pad_instance(parsed_jet)
                    self.cur_idx = idx + 1
                    return padded_jet, idx     # Exit after finding a valid instance
                else:
                    if self.verbose:
                        print('-- Skipped instance at idx={}, shape={}, pt={}'.format(idx, parsed_jet.shape, gen_pt))

            # End of dataset, loop back.
            self.cur_idx = 0

    def get_gen_pt(self, raw_jets):
        '''
        Does not actually returns gen-pt, but an approximation equal to 
        the sum of ECal energies over all the hits.
        '''
        
        return gen_pts

    def get_raw_instance(self, index):

        c_idx = []
        for c in self.return_channels:
            assert c in CHANNELS
            c_idx.append(CHANNELS.index(c))

        data = self.parquet.read_row_group(index, columns=self.cols).to_pydict()
        data['X_jets'] = np.float32(data['X_jets'][0]) 
        data['X_jets'] = data['X_jets'][0:]
        data['X_jets'][data['X_jets'] < 1.e-3] = 0.     # Zero-Suppression
        raw_jet = dict(data)['X_jets'][c_idx]           # Temporary Hack for Ecal
        return raw_jet

    def pad_instance(self, instance):
        assert instance.shape[0] <= self.max_instances
        assert instance.shape[0] >= self.min_instances
        pad_len = self.max_instances - instance.shape[0]
        instance = F.pad(instance, pad=(0, 0, 0, pad_len), mode='constant', value=0)
        return instance

In [11]:
def vis(arr, is_parsed=True, title=None, scale=1000, cmap='gist_heat', reduce=False):  
    '''
    Visualise a jet instance.
    '''

    if not is_parsed:   
        arr = parse_img(arr, reduce)

    if arr.shape[1] == 3:
        x_pos, y_pos, val = arr[:,0], arr[:,1], arr[:,2]
    elif arr.shape[1] == 2:
        x_pos, y_pos = arr[:,0], arr[:,1]
        val = torch.ones_like(x_pos)
        scale = None

    if scale:
        sz = np.array(np.abs(val)) * scale
    else:
        sz = np.ones_like(val) * 10
        
    plt.figure(figsize=[10,6], facecolor='#f0f0f0')
    cm = plt.cm.get_cmap(cmap)     # 'gist_heat' / 'YlOrRd'
    sc = plt.scatter(x_pos, y_pos, c=val, s=sz, cmap=cm, alpha=0.5, edgecolors='k')
    plt.colorbar(sc)
    plt.xlim(0, 126)
    plt.ylim(0, 126)
    plt.xticks(range(0,125,25))
    plt.yticks(range(0,125,25))
    plt.grid()
    if title:
        plt.title(title)
    plt.show()

In [12]:
dataset_file = DATA_ROOT + INPUT_FORMAT.format(0)
dataset = ParquetDataset(dataset_file)
dataset.verbose = True
print('Max Length of Dataset: ', dataset.total_len)

In [13]:
data_sample, true_idx = dataset.get_next_instance()

In [14]:
data_sample.shape

torch.Size([768, 2])

In [15]:
# print(dataset.allowed_range)
# print()

# for i in range(10):
#     print('[{}]'.format(i))
#     data_sample, true_idx = dataset.get_next_instance()
#     print(data_sample.shape, true_idx)
#     print(dataset.cur_idx)
#     print('\n')

In [16]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as D
from torch.utils.data import DataLoader, TensorDataset

import math
import os
import time
import argparse
import pprint
from functools import partial

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from tqdm import tqdm

In [17]:
def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--train', action='store_true', help='Train a flow.')
    parser.add_argument('--plot', action='store_false', help='Plot a flow and target density.')
    parser.add_argument('--restore_file', type=str, help='Path to model to restore.')
    parser.add_argument('--output_dir', default='./results/run_')
    parser.add_argument('--cuda', type=int, help='Which GPU to run on.')
    parser.add_argument('--seed', type=int, default=0, help='Random seed.')
    # model parameters
    parser.add_argument('--data_dim', type=int, default=2, help='Dimension of the data.')
    parser.add_argument('--hidden_dim', type=int, default=100, help='Dimensions of hidden layers.')
    parser.add_argument('--n_hidden', type=int, default=3, help='Number of hidden layers.')
    # training parameters
    parser.add_argument('--step', type=int, default=0, help='Current step of training (number of minibatches processed).')
    parser.add_argument('--n_steps', type=int, default=1, help='Number of steps to train.')
    parser.add_argument('--batch_size', type=int, default=200, help='Training batch size.')
    parser.add_argument('--lr', type=float, default=1e-1, help='Initial learning rate.')
    parser.add_argument('--lr_decay', type=float, default=0.5, help='Learning rate decay.')
    parser.add_argument('--lr_patience', type=float, default=2000, help='Number of steps before decaying learning rate.')
    parser.add_argument('--log_interval', type=int, default=50, help='How often to save model and samples.')

    args = parser.parse_args([])
    args.output_dir = os.path.join('./results/run_', time.strftime('%Y-%m-%d_%H-%M-%S', time.gmtime()))
    if not os.path.isdir(args.output_dir): 
        os.makedirs(args.output_dir)
    args.device = torch.device('cuda:{}'.format(args.cuda) if args.cuda is not None and torch.cuda.is_available() else 'cpu')

    return args

In [18]:
''' Model components ''' 

class MaskedLinear(nn.Module):
    def __init__(self, in_features, out_features, data_dim):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.data_dim = data_dim
        # print(self.in_features, self.out_features, self.data_dim)

        # Notation:
        # BNAF weight calculation for (eq 8): W = g(W) * M_d + W * M_o
        #   where W is block lower triangular so model is autoregressive,
        #         g = exp function; M_d is block diagonal mask; M_o is block off-diagonal mask.
        # Weight Normalization (Salimans & Kingma, eq 2): w = g * v / ||v||
        #   where g is scalar, v is k-dim vector, ||v|| is Euclidean norm
        # ------
        # Here: pre-weight norm matrix is v; then: v = exp(weight) * mask_d + weight * mask_o
        #       weight-norm scalar is g: out_features dimensional vector (here logg is used instead to avoid taking logs in the logdet calc.
        #       then weight-normed weight matrix is w = g * v / ||v||
        #
        #       log det jacobian of block lower triangular is taking block diagonal mask of
        #           log(g*v/||v||) = log(g) + log(v) - log(||v||)
        #                          = log(g) + weight - log(||v||) since v = exp(weight) * mask_d + weight * mask_o

        weight = torch.zeros(out_features, in_features)
        mask_d = torch.zeros_like(weight)
        mask_o = torch.zeros_like(weight)
        for i in range(data_dim):
            # select block slices
            h     = slice(i * out_features // data_dim, (i+1) * out_features // data_dim)
            w     = slice(i * in_features // data_dim,  (i+1) * in_features // data_dim)
            w_row = slice(0,                            (i+1) * in_features // data_dim)
            # initialize block-lower-triangular weight and construct block diagonal mask_d and lower triangular mask_o
            nn.init.kaiming_uniform_(weight[h,w_row], a=math.sqrt(5))  # default nn.Linear weight init only block-wise
            mask_d[h,w] = 1
            mask_o[h,w_row] = 1

        mask_o = mask_o - mask_d  # remove diagonal so mask_o is lower triangular 1-off the diagonal

        self.weight = nn.Parameter(weight)                          # pre-mask, pre-weight-norm
        self.logg = nn.Parameter(torch.rand(out_features, 1).log()) # weight-norm parameter
        self.bias = nn.Parameter(nn.init.uniform_(torch.rand(out_features), -1/math.sqrt(in_features), 1/math.sqrt(in_features)))  # default nn.Linear bias init
        self.register_buffer('mask_d', mask_d)
        self.register_buffer('mask_o', mask_o)

    def forward(self, x, sum_logdets):
        # 1. compute BNAF masked weight eq 8
        v = self.weight.exp() * self.mask_d + self.weight * self.mask_o
        # 2. weight normalization
        v_norm = v.norm(p=2, dim=1, keepdim=True)
        w = self.logg.exp() * v / v_norm
        # 3. compute output and logdet of the layer
        # print(x.dtype)
        # print(w.dtype)
        # print(x)
        # print(w)
        x = x.type(torch.float32)
        # print(x.dtype)
        out = F.linear(x, w, self.bias)
        # print('Out:', out.shape)


        logdet = self.logg + self.weight - 0.5 * v_norm.pow(2).log()
        logdet = logdet[self.mask_d.byte()]
        logdet = logdet.view(1, self.data_dim, out.shape[1]//self.data_dim, x.shape[1]//self.data_dim) \
                       .expand(x.shape[0],-1,-1,-1)  # output (B, data_dim, out_dim // data_dim, in_dim // data_dim)

        # 4. sum with sum_logdets from layers before (BNAF section 3.3)
        # Compute log det jacobian of the flow (eq 9, 10, 11) using log-matrix multiplication of the different layers.
        # Specifically for two successive MaskedLinear layers A -> B with logdets A and B of shapes
        #  logdet A is (B, data_dim, outA_dim, inA_dim)
        #  logdet B is (B, data_dim, outB_dim, inB_dim) where outA_dim = inB_dim
        #
        #  Note -- in the first layer, inA_dim = in_features//data_dim = 1 since in_features == data_dim.
        #            thus logdet A is (B, data_dim, outA_dim, 1)
        #
        #  Then:
        #  logsumexp(A.transpose(2,3) + B) = logsumexp( (B, data_dim, 1, outA_dim) + (B, data_dim, outB_dim, inB_dim) , dim=-1)
        #                                  = logsumexp( (B, data_dim, 1, outA_dim) + (B, data_dim, outB_dim, outA_dim), dim=-1)
        #                                  = logsumexp( (B, data_dim, outB_dim, outA_dim), dim=-1) where dim2 of tensor1 is broadcasted
        #                                  = (B, data_dim, outB_dim, 1)

        sum_logdets = torch.logsumexp(sum_logdets.transpose(2,3) + logdet, dim=-1, keepdim=True)

        return out, sum_logdets


    def extra_repr(self):
        return 'in_features={}, out_features={}, bias={}'.format(
            self.in_features, self.out_features, self.bias is not None
        )

class Tanh(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x, sum_logdets):
        # derivation of logdet:
        # d/dx tanh = 1 / cosh^2; cosh = (1 + exp(-2x)) / (2*exp(-x))
        # log d/dx tanh = - 2 * log cosh = -2 * (x - log 2 + log(1 + exp(-2x)))
        logdet = -2 * (x - math.log(2) + F.softplus(-2*x))
        sum_logdets = sum_logdets + logdet.view_as(sum_logdets)
        return x.tanh(), sum_logdets

class FlowSequential(nn.Sequential):
    """ Container for layers of a normalizing flow """
    def forward(self, x):
        sum_logdets = torch.zeros(1, x.shape[1], 1, 1, device=x.device)
        for module in self:
            x, sum_logdets = module(x, sum_logdets)
        return x, sum_logdets.squeeze()

In [19]:
''' Model '''

class BNAF(nn.Module):
    def __init__(self, data_dim, n_hidden, hidden_dim):
        super().__init__()

        # base distribution for calculation of log prob under the model
        self.register_buffer('base_dist_mean', torch.zeros(data_dim))
        self.register_buffer('base_dist_var', torch.ones(data_dim))

        # construct model
        modules = []
        modules += [MaskedLinear(data_dim, hidden_dim, data_dim), Tanh()]
        for _ in range(n_hidden):
            modules += [MaskedLinear(hidden_dim, hidden_dim, data_dim), Tanh()]
        modules += [MaskedLinear(hidden_dim, data_dim, data_dim)]
        self.net = FlowSequential(*modules)

        # TODO --   add permutation
        #           add residual gate
        #           add stack of flows

    @property
    def base_dist(self):
        return D.Normal(self.base_dist_mean, self.base_dist_var)

    def forward(self, x):
        return self.net(x)

def compute_kl_pq_loss(model, input_data):
    input_data = input_data.to(model.base_dist.loc.device)
    z, logdet = model(input_data)
    # print(z)
    # print(logdet)
    return - torch.sum(model.base_dist.log_prob(z) + logdet, dim=1)

In [20]:
''' Training ''' 

def train_flow(model, dataset, loss_fn, optimizer, scheduler, args):
    model.train()

    init_steps = args.step
    for _ in range(args.n_steps):
        args.step += 1

        input_data, true_index = dataset.get_next_instance()
        input_data = input_data.reshape(1,-1)
        loss = loss_fn(model, input_data).mean(0)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step(loss)

        print("Step: {:03d}/{:03d} | Loss: {:14.2f}".format(args.step, init_steps+args.n_steps, loss.item()))
        wandb.log({'Step': args.step, 
                   'Loss': loss.item(), 
                   'Learning_Rate': args.lr,
                   'True_Index': true_index})

        if args.step % args.log_interval == 0:
            # save model
            torch.save({'step': args.step,
                        'state_dict': model.state_dict()},
                        os.path.join(args.output_dir, 'checkpoint.pt'))
            torch.save({'optimizer': optimizer.state_dict(),
                        'scheduler': scheduler.state_dict()},
                        os.path.join(args.output_dir, 'optim_checkpoint.pt'))

In [21]:
# Config to initialize wandb.
DEFAULT_CFG = {
    'model': 'BNAF_Valid-Samples',
    'root_dir': ROOT,
}

In [22]:
import warnings
warnings.filterwarnings('ignore')

# Get Args
args = get_args()

# Set custom args
args.data_dim = dataset.max_instances * 2     # 768 * 2
args.hidden_dim = dataset.max_instances * 4   # 768 * 4
args.n_steps = 1000
args.log_interval = 200
args.cuda = 0
args.lr = 0.1
args.step = 0
args.pt_range = [50,55]

# Set Seeds
torch.manual_seed(args.seed)
if args.device.type == 'cuda': 
    torch.cuda.manual_seed(args.seed)

# Get Model
model = BNAF(args.data_dim, args.n_hidden, args.hidden_dim).to(args.device)
if args.restore_file:
    model_checkpoint = torch.load(args.restore_file, map_location=args.device)
    model.load_state_dict(model_checkpoint['state_dict'])
    args.step = model_checkpoint['step']

# Save Config
config = 'Parsed args:\n{}\n\n'.format(pprint.pformat(args.__dict__)) + \
            'Num trainable params: {:,.0f}\n\n'.format(sum(p.numel() for p in model.parameters())) + \
            'Model:\n{}'.format(model)

config_path = os.path.join(args.output_dir, 'config.txt')
if not os.path.exists(config_path):
    with open(config_path, 'a') as f:
        print(config, file=f)

# Get Optimizer + Scheduler
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=args.lr_decay, patience=args.lr_patience, verbose=True)
if args.restore_file:
    optim_checkpoint = torch.load(os.path.dirname(args.restore_file) + '/optim_checkpoint.pt', map_location=args.device)
    optimizer.load_state_dict(optim_checkpoint['optimizer'])
    scheduler.load_state_dict(optim_checkpoint['scheduler'])

# Initialize Dataset 
dataset_file = DATA_ROOT + INPUT_FORMAT.format(0)
dataset = ParquetDataset(dataset_file)
dataset.pt_range = args.pt_range

# Define Loss
loss_fn = compute_kl_pq_loss

In [23]:
# Init Wandb
os.environ["WANDB_SILENT"] = "true"
os.environ["WANDB_NOTES"] = WANDB_DESC
run = wandb.init(project='gnf', config=DEFAULT_CFG, dir=LOGS_ROOT)
wandb.config.update(args)

# Save files for later
! pip freeze > requirements.txt
wandb.save(ROOT + 'requirements.txt')
# wandb.save(config_path)       # TODO: Fix permission issue on Colab.

['/content/drive/My Drive/_GSoC/Normalizing-Flows/logs/wandb/run-20210714_095112-vgu4i3pi/files/requirements.txt',
 '/content/drive/My Drive/_GSoC/Normalizing-Flows/logs/wandb/run-20210714_095112-vgu4i3pi/files/requirements.txt']

In [24]:
k = torch.array([10,5])
k.shape

In [25]:
k = torch.Tensor([10,5])
k.shape

torch.Size([2])

In [26]:
k = torch.ones([1,2,3,4,5,6,7,8,9]).rehsape(1,-1)
k.shape

In [27]:
k = torch.ones([1,2,3,4,5,6,7,8,9]).reshape(1,-1)
k.shape

torch.Size([1, 362880])

In [28]:
k = torch.Tensor([1,2,3,4,5,6,7,8,9]).reshape(1,-1)
k.shape

torch.Size([1, 9])

In [29]:
torch.stack([k,k,k]).shape

torch.Size([3, 1, 9])

In [30]:
k = torch.Tensor([1,2,3,4,5,6,7,8,9])#.reshape(1,-1)
k.shape

torch.Size([9])

In [31]:
torch.stack([k,k,k]).shape

torch.Size([3, 9])

In [32]:
 input_data, true_index = dataset.get_next_instance()

In [33]:
input_data.shape

torch.Size([768, 2])

In [34]:
torch.stack([input_data,input_data,input_data]).shape

torch.Size([3, 768, 2])

In [35]:
k = input_data.reshape(-1)

In [36]:
torch.stack([k,k,k]).shape

torch.Size([3, 1536])

In [37]:
''' Training ''' 

def train_flow(model, dataset, loss_fn, optimizer, scheduler, args):
    model.train()

    init_steps = args.step
    for _ in range(args.n_steps):
        args.step += 1

        input_batch = []
        for batch_idx in args.batch_size:
            input_data, true_index = dataset.get_next_instance()
            input_data = input_data.reshape(-1)
            input_batch.append(input_data)
        input_batch = torch.stack(input_batch)
        print(input_batch.shape)
        loss = loss_fn(model, input_batch).mean(0)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step(loss)

        print("Step: {:03d}/{:03d} | Loss: {:14.2f}".format(args.step, init_steps+args.n_steps, loss.item()))
        wandb.log({'Step': args.step, 
                   'Loss': loss.item(), 
                   'Learning_Rate': args.lr,
                   'True_Index': true_index})

        if args.step % args.log_interval == 0:
            # save model
            torch.save({'step': args.step,
                        'state_dict': model.state_dict()},
                        os.path.join(args.output_dir, 'checkpoint.pt'))
            torch.save({'optimizer': optimizer.state_dict(),
                        'scheduler': scheduler.state_dict()},
                        os.path.join(args.output_dir, 'optim_checkpoint.pt'))

In [38]:
# Config to initialize wandb.
DEFAULT_CFG = {
    'model': 'BNAF_Valid-Samples',
    'root_dir': ROOT,
}

In [39]:
import warnings
warnings.filterwarnings('ignore')

# Get Args
args = get_args()

# Set custom args
args.data_dim = dataset.max_instances * 2     # 768 * 2
args.hidden_dim = dataset.max_instances * 4   # 768 * 4
args.n_steps = 1000
args.log_interval = 200
args.cuda = 0
args.lr = 0.1
args.step = 0
args.pt_range = [50,55]

# Set Seeds
torch.manual_seed(args.seed)
if args.device.type == 'cuda': 
    torch.cuda.manual_seed(args.seed)

# Get Model
model = BNAF(args.data_dim, args.n_hidden, args.hidden_dim).to(args.device)
if args.restore_file:
    model_checkpoint = torch.load(args.restore_file, map_location=args.device)
    model.load_state_dict(model_checkpoint['state_dict'])
    args.step = model_checkpoint['step']

# Save Config
config = 'Parsed args:\n{}\n\n'.format(pprint.pformat(args.__dict__)) + \
            'Num trainable params: {:,.0f}\n\n'.format(sum(p.numel() for p in model.parameters())) + \
            'Model:\n{}'.format(model)

config_path = os.path.join(args.output_dir, 'config.txt')
if not os.path.exists(config_path):
    with open(config_path, 'a') as f:
        print(config, file=f)

# Get Optimizer + Scheduler
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=args.lr_decay, patience=args.lr_patience, verbose=True)
if args.restore_file:
    optim_checkpoint = torch.load(os.path.dirname(args.restore_file) + '/optim_checkpoint.pt', map_location=args.device)
    optimizer.load_state_dict(optim_checkpoint['optimizer'])
    scheduler.load_state_dict(optim_checkpoint['scheduler'])

# Initialize Dataset 
dataset_file = DATA_ROOT + INPUT_FORMAT.format(0)
dataset = ParquetDataset(dataset_file)
dataset.pt_range = args.pt_range

# Define Loss
loss_fn = compute_kl_pq_loss

In [40]:
import warnings
warnings.filterwarnings('ignore')

# Get Args
args = get_args()

# Set custom args
args.data_dim = dataset.max_instances * 2     # 768 * 2
args.hidden_dim = dataset.max_instances * 4   # 768 * 4
args.n_steps = 1000
args.log_interval = 200
args.cuda = 0
args.lr = 0.1
args.step = 0
args.pt_range = [50,55]
args.batch_size = 16

# Set Seeds
torch.manual_seed(args.seed)
if args.device.type == 'cuda': 
    torch.cuda.manual_seed(args.seed)

# Get Model
model = BNAF(args.data_dim, args.n_hidden, args.hidden_dim).to(args.device)
if args.restore_file:
    model_checkpoint = torch.load(args.restore_file, map_location=args.device)
    model.load_state_dict(model_checkpoint['state_dict'])
    args.step = model_checkpoint['step']

# Save Config
config = 'Parsed args:\n{}\n\n'.format(pprint.pformat(args.__dict__)) + \
            'Num trainable params: {:,.0f}\n\n'.format(sum(p.numel() for p in model.parameters())) + \
            'Model:\n{}'.format(model)

config_path = os.path.join(args.output_dir, 'config.txt')
if not os.path.exists(config_path):
    with open(config_path, 'a') as f:
        print(config, file=f)

# Get Optimizer + Scheduler
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=args.lr_decay, patience=args.lr_patience, verbose=True)
if args.restore_file:
    optim_checkpoint = torch.load(os.path.dirname(args.restore_file) + '/optim_checkpoint.pt', map_location=args.device)
    optimizer.load_state_dict(optim_checkpoint['optimizer'])
    scheduler.load_state_dict(optim_checkpoint['scheduler'])

# Initialize Dataset 
dataset_file = DATA_ROOT + INPUT_FORMAT.format(0)
dataset = ParquetDataset(dataset_file)
dataset.pt_range = args.pt_range

# Define Loss
loss_fn = compute_kl_pq_loss

In [41]:
# Init Wandb
os.environ["WANDB_SILENT"] = "true"
os.environ["WANDB_NOTES"] = WANDB_DESC
run = wandb.init(project='gnf', config=DEFAULT_CFG, dir=LOGS_ROOT)
wandb.config.update(args)

# Save files for later
! pip freeze > requirements.txt
wandb.save(ROOT + 'requirements.txt')
# wandb.save(config_path)       # TODO: Fix permission issue on Colab.

['/content/drive/My Drive/_GSoC/Normalizing-Flows/logs/wandb/run-20210714_100332-3gnvetlu/files/requirements.txt',
 '/content/drive/My Drive/_GSoC/Normalizing-Flows/logs/wandb/run-20210714_100332-3gnvetlu/files/requirements.txt']

In [42]:
# Init Wandb
os.environ["WANDB_SILENT"] = "true"
os.environ["WANDB_NOTES"] = WANDB_DESC
run = wandb.init(project='gnf', config=DEFAULT_CFG, dir=LOGS_ROOT)
wandb.config.update(args)

# Save files for later
! pip freeze > requirements.txt
wandb.save(ROOT + 'requirements.txt')
# wandb.save(config_path)       # TODO: Fix permission issue on Colab.

['/content/drive/My Drive/_GSoC/Normalizing-Flows/logs/wandb/run-20210714_100551-36ppoz1d/files/requirements.txt',
 '/content/drive/My Drive/_GSoC/Normalizing-Flows/logs/wandb/run-20210714_100551-36ppoz1d/files/requirements.txt']

In [43]:
 # Train The Model
train_flow(model, dataset, loss_fn, optimizer, scheduler, args)

In [44]:
''' Training ''' 

def train_flow(model, dataset, loss_fn, optimizer, scheduler, args):
    model.train()

    init_steps = args.step
    for _ in range(args.n_steps):
        args.step += 1

        input_batch = []
        for batch_idx in range(args.batch_size):
            input_data, true_index = dataset.get_next_instance()
            input_data = input_data.reshape(-1)
            input_batch.append(input_data)
        input_batch = torch.stack(input_batch)
        print(input_batch.shape)
        loss = loss_fn(model, input_batch).mean(0)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step(loss)

        print("Step: {:03d}/{:03d} | Loss: {:14.2f}".format(args.step, init_steps+args.n_steps, loss.item()))
        wandb.log({'Step': args.step, 
                   'Loss': loss.item(), 
                   'Learning_Rate': args.lr,
                   'True_Index': true_index})

        if args.step % args.log_interval == 0:
            # save model
            torch.save({'step': args.step,
                        'state_dict': model.state_dict()},
                        os.path.join(args.output_dir, 'checkpoint.pt'))
            torch.save({'optimizer': optimizer.state_dict(),
                        'scheduler': scheduler.state_dict()},
                        os.path.join(args.output_dir, 'optim_checkpoint.pt'))

In [45]:
# Config to initialize wandb.
DEFAULT_CFG = {
    'model': 'BNAF_Valid-Samples',
    'root_dir': ROOT,
}

In [46]:
import warnings
warnings.filterwarnings('ignore')

# Get Args
args = get_args()

# Set custom args
args.data_dim = dataset.max_instances * 2     # 768 * 2
args.hidden_dim = dataset.max_instances * 4   # 768 * 4
args.n_steps = 1000
args.log_interval = 200
args.cuda = 0
args.lr = 0.1
args.step = 0
args.pt_range = [50,55]
args.batch_size = 16

# Set Seeds
torch.manual_seed(args.seed)
if args.device.type == 'cuda': 
    torch.cuda.manual_seed(args.seed)

# Get Model
model = BNAF(args.data_dim, args.n_hidden, args.hidden_dim).to(args.device)
if args.restore_file:
    model_checkpoint = torch.load(args.restore_file, map_location=args.device)
    model.load_state_dict(model_checkpoint['state_dict'])
    args.step = model_checkpoint['step']

# Save Config
config = 'Parsed args:\n{}\n\n'.format(pprint.pformat(args.__dict__)) + \
            'Num trainable params: {:,.0f}\n\n'.format(sum(p.numel() for p in model.parameters())) + \
            'Model:\n{}'.format(model)

config_path = os.path.join(args.output_dir, 'config.txt')
if not os.path.exists(config_path):
    with open(config_path, 'a') as f:
        print(config, file=f)

# Get Optimizer + Scheduler
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=args.lr_decay, patience=args.lr_patience, verbose=True)
if args.restore_file:
    optim_checkpoint = torch.load(os.path.dirname(args.restore_file) + '/optim_checkpoint.pt', map_location=args.device)
    optimizer.load_state_dict(optim_checkpoint['optimizer'])
    scheduler.load_state_dict(optim_checkpoint['scheduler'])

# Initialize Dataset 
dataset_file = DATA_ROOT + INPUT_FORMAT.format(0)
dataset = ParquetDataset(dataset_file)
dataset.pt_range = args.pt_range

# Define Loss
loss_fn = compute_kl_pq_loss

In [47]:
# Init Wandb
os.environ["WANDB_SILENT"] = "true"
os.environ["WANDB_NOTES"] = WANDB_DESC
run = wandb.init(project='gnf', config=DEFAULT_CFG, dir=LOGS_ROOT)
wandb.config.update(args)

# Save files for later
! pip freeze > requirements.txt
wandb.save(ROOT + 'requirements.txt')
# wandb.save(config_path)       # TODO: Fix permission issue on Colab.