In [None]:
import sys
sys.path.append('..')
%env CUDA_VISIBLE_DEVICES=0

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
%matplotlib inline

import torch
import torch.nn as nn
import os
from collections import defaultdict

from torch.distributions import MultivariateNormal

from src.mrartemev_nflib.flows import NormalizingFlowModel, InvertiblePermutation, Invertible1x1Conv, ActNorm, NSF_AR
from src.mrartemev_nflib.flows import MAF, AffineHalfFlow
from src.mrartemev_nflib.nn import ARMLP, MLP

from torch.utils.data import Dataset, DataLoader, TensorDataset
from itertools import repeat

from catboost import CatBoostClassifier
from sklearn.metrics import roc_auc_score

from src.nf import CalibratedModel, neg_log_likelihood
from src.nf.classifiers import train_catboost_clf
from scipy.special import logsumexp, expit


os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

device

In [None]:
! ls dumps_20200602/GAS/SPLINE-AR_2_ind1

In [None]:
def fix_act_norm(layer):
    if isinstance(layer, ActNorm):
        layer.data_dep_init_done = True

In [None]:
def get_model(model_type, data, num_layers, dump_path):
    flows = []
    for _ in range(num_layers):
        if model_type == 'MAF':
            flows.append(MAF(dim=data.n_dims, base_network=ARMLP))
            flows.append(InvertiblePermutation(dim=data.n_dims))
        if model_type == 'SPLINE-AR':
            flows.append(ActNorm(dim=data.n_dims))
            flows.append(Invertible1x1Conv(dim=data.n_dims))
            flows.append(NSF_AR(dim=data.n_dims, K=8, B=3, hidden_features=32, depth=1, base_network=MLP))
        if model_type == 'GLOW':
            flows.append(ActNorm(dim=data.n_dims))
            flows.append(Invertible1x1Conv(dim=data.n_dims))
            flows.append(AffineHalfFlow(dim=data.n_dims, hidden_features=32, base_network=MLP))
            flows.append(InvertiblePermutation(dim=data.n_dims))
        if model_type == 'RealNVP':
            flows.append(AffineHalfFlow(dim=data.n_dims, base_network=MLP))
            flows.append(InvertiblePermutation(dim=data.n_dims))

    lr = 0.0005

    prior = MultivariateNormal(torch.zeros(data.n_dims).to(device), torch.eye(data.n_dims).to(device))
    model = NormalizingFlowModel(prior, flows).to(device)
    # optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    checkpoint = torch.load(dump_path)
    model.load_state_dict(checkpoint['model.state_dict()'])
    model.apply(fix_act_norm)
    
    return model

In [None]:
def to_device(model, device):
    model.to(device)
    model.prior = MultivariateNormal(torch.zeros(data.n_dims).to(device),
                                     torch.eye(data.n_dims).to(device))

In [None]:
def repeater(data_loader):
    for loader in repeat(data_loader):
        for data in loader:
            yield data

In [None]:
def batched_sample(model, n, batch_size=14000):
    generated = []
    for _ in range(n // batch_size):
        generated_batch = model.sample(batch_size)
        generated.append(generated_batch.cpu().detach())
    if n % batch_size != 0:
        generated_batch = model.sample(n % batch_size)
        generated.append(generated_batch.cpu().detach())
    generated = torch.cat(generated, dim=0)
    assert n == len(generated)
    return generated

In [None]:
def logloss_with_logits(y_pred_logits, y_true):
    return -np.mean(
        y_true * np.logaddexp(0, -y_pred_logits) + \
        (1 - y_true) * np.logaddexp(0, y_pred_logits)
    )

In [None]:
def train_cb(model, X_train_tensor, X_test_tensor, clips, iters):
    clf_ds_train = np.row_stack([
        np.column_stack([X_train_tensor.cpu().detach().numpy(), np.ones(len(X_train_tensor)).reshape(-1, 1)]),
        np.column_stack([model.sample_n(len(X_train_tensor)).cpu().detach().numpy(), np.zeros(len(X_train_tensor)).reshape(-1, 1)])
    ]).astype(np.float32)

    clf_ds_test = np.row_stack([
        np.column_stack([X_test_tensor.cpu().detach().numpy(), np.ones(len(X_test_tensor)).reshape(-1, 1)]),
        np.column_stack([model.sample_n(len(X_test_tensor)).cpu().detach().numpy(), np.zeros(len(X_test_tensor)).reshape(-1, 1)])
    ]).astype(np.float32)

    samples = model.sample_n(n).cpu().detach().cpu().numpy()
    

    metrics = []
    for n_iters in iters:
        clf = CatBoostClassifier(n_iters, verbose=0).fit(
            clf_ds_train[:, :-1], clf_ds_train[:, -1],
        )
        for clip in clips:
            calibrated_model = CalibratedModel(
                lambda x: np.clip(clf.predict(x, prediction_type='RawFormulaVal'), -100, clip),
                model,
                logit=True
            )
            clf_preds = np.clip(clf.predict(samples, prediction_type='RawFormulaVal'), -100, np.log(clip))
            calibration_constant = logsumexp(clf_preds) - np.log(len(clf_preds))
            logits = clf_preds - calibration_constant
            ll = -neg_log_likelihood(calibrated_model, X_test_tensor.cpu().detach()) - calibration_constant
            auc_roc = roc_auc_score(
                clf_ds_test[:, -1],
                np.clip(clf.predict(clf_ds_test[:, :-1], prediction_type='RawFormulaVal'), -100, np.log(clip)),
            )
            log_loss = logloss_with_logits(
                np.clip(clf.predict(clf_ds_test[:, :-1], prediction_type='RawFormulaVal'), -100, np.log(clip)),
                clf_ds_test[:, -1]
            )
            metrics.append({
                'clip': clip,
                'iters': n_iters,
                'll': ll,
                'auc_roc': auc_roc,
                'logloss': log_loss,
                'overhead': np.max(logits),
                'calibration_constant': calibration_constant
            })

    return clf_ds_train, clf_ds_test, metrics

In [None]:
from utils import data_utils

data_mapping = {'BSDS300': data_utils.BSDS300,
                'GAS': data_utils.GAS,
                'MINIBOONE': data_utils.MINIBOONE,
                'POWER': data_utils.POWER,
                'HEPMASS': data_utils.HEPMASS}

In [None]:
def get_best_model(model_type, data, num_layers, dumps_path):
    best_ll, best_model, best_dump = -10000000, None, None
    X_test_tensor = torch.from_numpy(data.tst.x[:100000]).to(device)
    
    for dump_path in [dumps_path + '/final_model.checkpoint'] + [
        os.path.join(dumps_path, 'checkpoints', path) for path in os.listdir(dumps_path + '/checkpoints')
    ]:
        try:
            model = get_model(model_type, data, num_layers, dump_path)
        except FileNotFoundError:
            print(f'Not found {dump_path}')
            continue
        ll = -neg_log_likelihood(model, X_test_tensor)
        if ll > best_ll:
            best_ll = ll
            best_model = model
            best_dump = dump_path
    return best_model, best_ll, best_dump

In [None]:
arr = []
for data_name in ('MINIBOONE', 'BSDS300', 'GAS', 'HEPMASS', 'POWER'):
    data = data_mapping[data_name]()
    dim = data.n_dims
    n = min(100000, data.trn.x.shape[0])
    X_train_tensor = torch.from_numpy(data.trn.x[:n]).to(device)
    X_test_tensor = torch.from_numpy(data.tst.x[:n]).to(device)

    for model_type in ('GLOW', 'MAF', 'RealNVP', 'SPLINE-AR'):
        num_layers = 2 if model_type == 'SPLINE-AR' else 5
        model_name = f"{model_type}_{num_layers}"

        dumps_path = f'dumps_20200602/{data_name}/{model_type}_{num_layers}_ind1'
        model, ll, dump_path = get_best_model(model_type, data, num_layers, dumps_path)
        
        model.eval()
        model.sample_n = lambda n: batched_sample(model, n)
        to_device(model, 'cpu')

        clf_ds_train, clf_ds_test, metrics = train_cb(model, X_train_tensor, X_test_tensor, [10000, 2, 1.5], [20, 100, 500, 1000, 5000])
        
        arr.append({
            'data_name': data_name,
            'model_type': model_type,
            'll': ll,
            'metrics': metrics,
            'dump_path': dump_path,
        })
        print(data_name, model_type, ll, [x['ll'] for x in metrics])
#         print(arr[-1])

In [None]:
[x['dump_path'] for x in arr]

In [None]:
# metrics = defaultdict(dict)
# for x in arr:
#     metrics[x['data_name']][x['model_type']] = x['ll']
#     for y in x['metrics']:
#         if clip['clip'] == 100:
#             metrics[x['data_name']][x['model_type'] + ' C ' + str(y['iters'])] = y['ll']
#         else:
#             metrics[x['data_name']][x['model_type'] + ' C ' + str(round(y['clip'], 1)) + ' ' + str(y['iters'])] = y['ll']
# pd.DataFrame(metrics)

In [70]:
metrics = defaultdict(dict)
for x in arr:
    model_name = (x['model_type'], 'normalizing flow', '', '', )
    metrics[(x['data_name'], 'll', )][model_name] = x['ll']
    metrics[(x['data_name'], 'log overhead', )][model_name] = 0
    for y in x['metrics']:
        if y['clip'] == 10000:
            model_name = (x['model_type'], 'calibrated', y['iters'], 'inf', )
            metrics[(x['data_name'], 'll', )][model_name] = y['ll']
            metrics[(x['data_name'], 'log overhead', )][model_name] = y['overhead']
            metrics[(x['data_name'], 'AUC-ROC', )][model_name] = y['auc_roc']
            metrics[(x['data_name'], 'Logloss', )][model_name] = -y['logloss']
#             metrics[(x['data_name'], 'calib const', )][(x['model_type'], 'calibrated', )] = clip['calibration_constant']
        else:
            model_name = (x['model_type'], 'calibrated', y['iters'], round(y['clip'], 2), )
            metrics[(x['data_name'], 'll', )][model_name] = y['ll']
            metrics[(x['data_name'], 'log overhead', )][model_name] = y['overhead']
            metrics[(x['data_name'], 'AUC-ROC', )][model_name] = y['auc_roc']
            metrics[(x['data_name'], 'Logloss', )][model_name] = -y['logloss']
#             metrics[(x['data_name'], 'calib const', )][(x['model_type'], ' calibrated clip ' + str(round(clip['clip'], 2)), )] = clip['calibration_constant']
metrics = pd.DataFrame(metrics)
pd.set_option('display.max_rows', metrics.shape[0] + 1)
metrics

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,MINIBOONE,MINIBOONE,MINIBOONE,MINIBOONE,BSDS300,BSDS300,BSDS300,BSDS300,GAS,GAS,GAS,GAS,HEPMASS,HEPMASS,HEPMASS,HEPMASS,POWER,POWER,POWER,POWER
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,ll,log overhead,AUC-ROC,Logloss,ll,log overhead,AUC-ROC,Logloss,ll,log overhead,AUC-ROC,Logloss,ll,log overhead,AUC-ROC,Logloss,ll,log overhead,AUC-ROC,Logloss
GLOW,normalizing flow,,,-14.054896,0.0,,,152.595673,0.0,,,9.409331,0.0,,,-18.73378,0.0,,,0.243388,0.0,,
GLOW,calibrated,20.0,inf,0.447676,10.106144,1.0,0.000476,165.3758,10.960704,0.999988,0.000797,19.549281,9.170694,0.999967,0.011921,-1.539994,11.326348,0.999995,0.000193,7.371123,8.569397,0.999462,0.026851
GLOW,calibrated,20.0,2.0,-5.183384,7.564656,0.999863,0.202996,161.384652,7.482441,0.999885,0.203002,15.760502,5.058424,0.999506,0.20938,-8.318543,9.108397,0.999975,0.202786,5.63805,4.120818,0.996467,0.21652
GLOW,calibrated,20.0,1.5,-5.543525,7.416833,0.999863,0.255651,161.082497,7.39249,0.999875,0.255659,15.302433,4.806938,0.999367,0.261782,-8.596405,9.042848,0.999975,0.255461,5.275555,3.958276,0.995795,0.268096
GLOW,calibrated,100.0,inf,3.379196,9.648936,1.0,0.000493,167.697024,10.63188,0.999989,0.000687,25.355639,11.066067,0.999999,0.00144,1.149598,11.413613,0.999995,0.000146,9.690438,9.448851,0.999804,0.014019
GLOW,calibrated,100.0,2.0,-4.7581,7.98994,0.999726,0.203094,161.567096,7.664727,0.99989,0.202955,17.892354,7.177234,0.999885,0.203485,-7.965339,9.461596,0.999975,0.202768,6.380897,4.849534,0.998027,0.209738
GLOW,calibrated,100.0,1.5,-5.120303,7.840055,0.999726,0.255724,161.270476,7.580379,0.99988,0.255613,17.4611,6.957806,0.99982,0.256125,-8.222388,9.416862,0.999975,0.255444,6.033347,4.708298,0.997701,0.261767
GLOW,calibrated,500.0,inf,1.776171,9.459289,1.0,0.0005,170.213327,10.146879,1.0,0.000525,24.349733,11.300452,0.999997,0.001676,5.409691,11.085338,1.0,7.8e-05,10.697435,9.738464,0.999833,0.012671
GLOW,calibrated,500.0,2.0,-4.590256,8.157784,0.999863,0.20303,161.814089,7.911637,0.9999,0.202916,17.678071,6.96284,0.999885,0.20359,-6.35802,11.068912,0.999995,0.202744,6.475905,4.942275,0.998144,0.20902
GLOW,calibrated,500.0,1.5,-4.98738,7.972978,0.999863,0.255685,161.478322,7.788158,0.99989,0.255576,17.233293,6.729842,0.999855,0.256242,-6.683232,10.956017,0.999995,0.255423,6.13417,4.807628,0.997837,0.261112
