In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy
from sklearn.metrics import roc_auc_score
from sklearn import mixture, preprocessing, datasets

from importlib import reload
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import patches

import torch.utils.data as data_utils

import utils.models as models
import utils.plotting as plotting
import utils.dataloaders as dl
import utils.traintest as tt
import utils.adversarial as adv
import utils.eval as ev
import model_params as params
import utils.resnet_orig as resnet
import utils.gmm_helpers as gmm_helpers

import model_paths

import utils.hybrid as hybrid



In [2]:
file = '../glow-pytorch-master/Checkpoints/glow_hybrid_dataset_CIFAR10.pth'
device = torch.device('cuda:7')

flow = hybrid.Glow(3, 8, 3)
model = hybrid.Hybrid(flow).to(device)

In [3]:
train_loader = dl.CIFAR10(train=True, augm_flag=False)

In [4]:
log_p = 0

for data, _ in train_loader:
    data = data.to(device)
    temp = model(data)[1].min().item()
    if temp < log_p:
        log_p = temp
print(log_p)

-26533.330078125


In [5]:
#reload(hybrid)
calibrated = hybrid.CalibratedHybrid(model.cpu(), log_p)

In [6]:
calibrated.cpu()
torch.save(calibrated, 'SavedModels/other/' + 'hybrid_CIFAR10' + '.pth')

In [10]:
x = torch.rand(20, 3, 32, 32)

In [11]:
calibrated.hybrid(x)

(tensor([[-5.7306e+00, -8.5042e+00, -6.0094e+00, -7.5128e+00, -9.0975e+00,
          -4.1508e+00, -7.9005e+00, -1.1309e+01, -1.0503e-01, -2.5639e+00],
         [-3.7308e+00, -4.1552e+00, -3.3178e+00, -4.2241e+00, -1.1208e-01,
          -1.3871e+01, -4.2569e+00, -6.8436e+00, -8.5557e+00, -9.5140e+00],
         [-2.4272e+00, -6.0139e+00, -7.2211e+00, -5.3245e+00, -4.4857e-01,
          -4.3815e+00, -1.5900e+01, -8.7806e+00, -8.2164e+00, -1.3775e+00],
         [-4.7295e+00, -1.3883e+01, -5.7464e-01, -5.4368e+00, -6.8198e+00,
          -1.3395e+01, -3.4353e+00, -8.2236e+00, -1.4854e+00, -1.8084e+00],
         [-9.2404e+00, -1.6964e+00, -4.5094e-01, -1.3903e+01, -2.7686e+00,
          -3.1074e+00, -1.1687e+01, -8.8535e+00, -4.1619e+00, -2.8765e+00],
         [-3.0052e+00, -1.2056e+00, -5.2009e+00, -4.8344e+00, -6.6764e+00,
          -4.7349e+00, -8.2206e+00, -9.2794e+00, -5.8569e+00, -4.7124e-01],
         [-6.8998e+00, -1.3337e+01, -3.5728e+00, -5.0965e+00, -6.1856e+00,
          -1.0169e+

In [8]:
log_p

-25301.685546875

In [22]:
a = calibrated(x)

tensor([-72127.8125, -79308.0859, -72551.8125, -75500.0625, -73640.0703,
        -72672.3516, -73314.2344, -74631.0078, -71415.0703, -77326.0391,
        -72853.0625, -64380.1406, -75419.6250, -74936.0000, -72005.6484,
        -73298.9609, -75569.8516, -72553.7109, -77355.8672, -73774.3672],
       grad_fn=<AddBackward0>)


In [23]:
a

tensor([[-72127.8125, -72127.8125, -72127.8125, -72127.8125, -72127.8125,
         -72127.8125, -72127.8125, -72127.8125, -72127.8125, -72127.8125],
        [-79308.0859, -79308.0859, -79308.0859, -79308.0859, -79308.0859,
         -79308.0859, -79308.0859, -79308.0859, -79308.0859, -79308.0859],
        [-72551.8125, -72551.8125, -72551.8125, -72551.8125, -72551.8125,
         -72551.8125, -72551.8125, -72551.8125, -72551.8125, -72551.8125],
        [-75500.0625, -75500.0625, -75500.0625, -75500.0625, -75500.0625,
         -75500.0625, -75500.0625, -75500.0625, -75500.0625, -75500.0625],
        [-73640.0703, -73640.0703, -73640.0703, -73640.0703, -73640.0703,
         -73640.0703, -73640.0703, -73640.0703, -73640.0703, -73640.0703],
        [-72672.3516, -72672.3516, -72672.3516, -72672.3516, -72672.3516,
         -72672.3516, -72672.3516, -72672.3516, -72672.3516, -72672.3516],
        [-73314.2344, -73314.2344, -73314.2344, -73314.2344, -73314.2344,
         -73314.2344, -73314.234

In [18]:
b

Parameter containing:
tensor([[-2.3026, -2.3026, -2.3026, -2.3026, -2.3026, -2.3026, -2.3026, -2.3026,
         -2.3026, -2.3026]])

In [19]:
c

tensor([[1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.]])

In [20]:
d

tensor([[-72127.8125],
        [-79308.0859],
        [-72551.8125],
        [-75500.0625],
        [-73640.0703],
        [-72672.3516],
        [-73314.2344],
        [-74631.0078],
        [-71415.0703],
        [-77326.0391],
        [-72853.0625],
        [-64380.1406],
        [-75419.6250],
        [-74936.0000],
        [-72005.6484],
        [-73298.9609],
        [-75569.8516],
        [-72553.7109],
        [-77355.8672],
        [-73774.3672]], grad_fn=<UnsqueezeBackward0>)