In [48]:
import os
import yaml
import argparse
import numpy as np
from pathlib import Path
from models import *
# from experiment import VAEXperiment
# import torch.backends.cudnn as cudnn
# from pytorch_lightning import Trainer
# from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.utilities.seed import seed_everything
# from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from dataset import VAEDataset
# from pytorch_lightning.plugins import DDPPlugin
import torchvision.utils as vutils

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# import wandb
# wandb.init(project="AESPN")
# wandb.config = config

## All the parameters are here
config_file = "ae.yaml"
with open(config_file, 'r') as file:
    try:
        config = yaml.safe_load(file)
    except yaml.YAMLError as exc:
        print(exc)

# For reproducibility
seed_everything(config['exp_params']['manual_seed'], True)




Global seed set to 1265


1265

In [49]:
####
# Dataset
####
data = VAEDataset(**config["data_params"])
data.setup()
train_loader = data.train_dataloader()
test_loader = data.test_dataloader()

# Quick Dataset 

In [36]:
### I want to only grab the celebrities 50-50 split smiling and frowning
attributes = "5_o_Clock_Shadow Arched_Eyebrows Attractive Bags_Under_Eyes Bald Bangs Big_Lips Big_Nose Black_Hair Blond_Hair Blurry Brown_Hair Bushy_Eyebrows Chubby Double_Chin Eyeglasses Goatee Gray_Hair Heavy_Makeup High_Cheekbones Male Mouth_Slightly_Open Mustache Narrow_Eyes No_Beard Oval_Face Pale_Skin Pointy_Nose Receding_Hairline Rosy_Cheeks Sideburns Smiling Straight_Hair Wavy_Hair Wearing_Earrings Wearing_Hat Wearing_Lipstick Wearing_Necklace Wearing_Necktie Young"
attr = attributes.split()
# len(attr)
print("Smiling Attribute Index: ", attr.index("Smiling"))

len(data.train_dataset)

data.train_dataset

my_subset = torch.utils.data.Subset(data.train_dataset, [1,3,0])
# loader = DataLoader(my_subset)
len(my_subset)

Smiling Attribute Index:  31


3

In [43]:
import pandas as pd

attr_df = pd.read_csv("Data/celeba/attr_records.txt", delim_whitespace=True)

print("Smiling Attribute Index: ", attr.index("Smiling"))
print("num smiling", len(attr_df[attr_df["Smiling"] == 1]))
print("num not smiling", len(attr_df[attr_df["Smiling"] == -1]))

print("Attractive Attribute Index: ", attr.index("Attractive"))
print("num attractive", len(attr_df[attr_df["Attractive"] == 1]))
print("num not attractive", len(attr_df[attr_df["Attractive"] == -1]))


print("Smiling Attribute Index: ", attr.index("Male"))
print("num Male", len(attr_df[attr_df["Male"] == 1]))
print("num not Male", len(attr_df[attr_df["Male"] == -1]))

Smiling Attribute Index:  31
num smiling 97669
num not smiling 104930
Attractive Attribute Index:  2
num attractive 103833
num not attractive 98766
Smiling Attribute Index:  20
num Male 84434
num not Male 118165


# Model + Training

In [2]:
###
# AE Model
###

import torch
from models import BaseVAE
from torch import nn
from torch.nn import functional as F
# from .types_ import *

class NormalAE(BaseVAE):

    def __init__(self,
                 in_channels: int,
                 latent_dim: int,
                 hidden_dims: List = None,
                 **kwargs) -> None:
        super(NormalAE, self).__init__()

        self.latent_dim = latent_dim

        modules = []
        if hidden_dims is None:
            hidden_dims = [32, 64, 128, 256, 512]

        # Build Encoder
        for h_dim in hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Conv2d(in_channels, out_channels=h_dim,
                              kernel_size= 3, stride= 2, padding  = 1),
                    nn.BatchNorm2d(h_dim),
                    nn.LeakyReLU())
            )
            in_channels = h_dim

        self.encoder = nn.Sequential(*modules)
        self.fc_last = nn.Linear(hidden_dims[-1]*4, latent_dim)
        # self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim)

        # Build Decoder
        modules = []

        ATTR_DIM = 2 ## depends highly on the setup
        self.decoder_input = nn.Linear(latent_dim + ATTR_DIM, hidden_dims[-1] * 4)

        hidden_dims.reverse()

        for i in range(len(hidden_dims) - 1):
            modules.append(
                nn.Sequential(
                    nn.ConvTranspose2d(hidden_dims[i],
                                       hidden_dims[i + 1],
                                       kernel_size=3,
                                       stride = 2,
                                       padding=1,
                                       output_padding=1),
                    nn.BatchNorm2d(hidden_dims[i + 1]),
                    nn.LeakyReLU())
            )
        self.decoder = nn.Sequential(*modules)

        self.final_layer = nn.Sequential(
                            nn.ConvTranspose2d(hidden_dims[-1],
                                               hidden_dims[-1],
                                               kernel_size=3,
                                               stride=2,
                                               padding=1,
                                               output_padding=1),
                            nn.BatchNorm2d(hidden_dims[-1]),
                            nn.LeakyReLU(),
                            nn.Conv2d(hidden_dims[-1], out_channels= 3,
                                      kernel_size= 3, padding= 1),
                            nn.Tanh())

    def encode(self, input: Tensor, attr: Tensor) -> List[Tensor]:
        """
        Encodes the input by passing through the encoder network
        and returns the latent codes.
        :param input: (Tensor) Input tensor to encoder [N x C x H x W]
        :return: (Tensor) List of latent codes
        """
        result = self.encoder(input)
        result = torch.flatten(result, start_dim=1)
        result = self.fc_last(result)
        
        ## concatenate the attribute tensor to the result tensor
        result = torch.cat([result, attr], dim=1)
        
        return result

    def decode(self, z: Tensor) -> Tensor:
        """
        Maps the given latent codes
        onto the image space.
        :param z: (Tensor) [B x D]
        :return: (Tensor) [B x C x H x W]
        """
        result = self.decoder_input(z)
        result = result.view(-1, 512, 2, 2)
        result = self.decoder(result)
        result = self.final_layer(result)
        return result

    
    def forward(self, input: Tensor, attr: Tensor ,**kwargs) -> List[Tensor]:
        z = self.encode(input, attr)
        return  [self.decode(z), input, z]

    def generate(self, x: Tensor, **kwargs) -> Tensor:
        """
        Given an input image x, returns the reconstructed image
        :param x: (Tensor) [B x C x H x W]
        :return: (Tensor) [B x C x H x W]
        """
        return self.forward(x)[0]


model_params = config['model_params']
model = NormalAE(model_params["in_channels"], model_params["latent_dim"]).to(device)

# result = model(a[0], a[1])
# result[0].shape

####
# Optimizers
####
MSELoss = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=config["exp_params"]["LR"])


# # **config['model_params']
# model_params = config['model_params']
# in_channels = model_params["in_channels"]
# latent_dim = model_params["latent_dim"]
# model = vae_models["VanillaVAE"](model_params["in_channels"], model_params["latent_dim"])
# model(a[0])[0].shape


In [12]:
####
# Training the model
####


from tqdm import tqdm

attr_indices = [20, 31]

print("Train Loader Len: ", str(len(train_loader)))

num_epochs = 50

for epoch in range(num_epochs):
    model.train()
    
    for i, batch in enumerate(train_loader):
        batch_X, batch_attr = batch
        batch_X = batch_X.to(device)
        batch_attr = batch_attr.to(device)
        
        optimizer.zero_grad()
        result = model(batch_X, batch_attr[:, attr_indices])
        recons = result[0]
        
        loss = MSELoss(recons, batch_X)
        print( "==> Iter: ", i, " Loss: ", loss.item(), " Epoch: ", epoch)
        
        loss.backward()
        optimizer.step()

    model_save_path = "saved_AE_models/" + 'April19_' + str(epoch) + '_SmileGender.pth' #wandb.run.name 
    torch.save(model.state_dict(), model_save_path)
    
    ### Testing
    test_samples = recons[0:144]
    vutils.save_image(test_samples.cpu().data,
                              "AE_test/" + str(epoch) + "_test.png",
                              normalize=True,
                              nrow=12)
    

torch.save(model.state_dict(), model_save_path)
# torch.save(model.state_dict(), "model.pth")

Train Loader Len:  159
==> Iter:  0  Loss:  0.3931291699409485  Epoch:  0
==> Iter:  1  Loss:  0.2030644416809082  Epoch:  0
==> Iter:  2  Loss:  0.0970953032374382  Epoch:  0
==> Iter:  3  Loss:  0.10531716793775558  Epoch:  0
==> Iter:  4  Loss:  0.09014032781124115  Epoch:  0
==> Iter:  5  Loss:  0.07975980639457703  Epoch:  0
==> Iter:  6  Loss:  0.07991255819797516  Epoch:  0
==> Iter:  7  Loss:  0.07909481972455978  Epoch:  0
==> Iter:  8  Loss:  0.08054643124341965  Epoch:  0
==> Iter:  9  Loss:  0.08062172681093216  Epoch:  0
==> Iter:  10  Loss:  0.07805599272251129  Epoch:  0
==> Iter:  11  Loss:  0.08061002194881439  Epoch:  0
==> Iter:  12  Loss:  0.07810728251934052  Epoch:  0
==> Iter:  13  Loss:  0.07722081989049911  Epoch:  0
==> Iter:  14  Loss:  0.07695186138153076  Epoch:  0
==> Iter:  15  Loss:  0.07625767588615417  Epoch:  0
==> Iter:  16  Loss:  0.07477015256881714  Epoch:  0
==> Iter:  17  Loss:  0.0769302174448967  Epoch:  0
==> Iter:  18  Loss:  0.0768935903906

KeyboardInterrupt: 

In [13]:
###
# Code Snippet for Loading an existing model
###

# NewModel = NormalAE(model_params["in_channels"], model_params["latent_dim"])
# NewModel.load_state_dict(torch.load("model.pth"))

# result = model(batch_X, batch_attr)
# recons = result[0]
# vutils.save_image(recons[0:144].cpu().data,
#                               "test.png",
#                               normalize=True,
#                               nrow=12)



# Getting the latents dataset to train the EinSum Network from

In [46]:
import os
import numpy as np
import torch
from EinsumNetwork import Graph, EinsumNetwork
import ES_datasets
import utils

###
# Can I learn an SPN from a dataset of 168 x 160,000
## I think I'll sample around 50,000 images and try generating it from that?
###

model_version = "20"
model_save_path = "saved_AE_models/" + 'April19_' + model_version +'_SmileGender.pth' #wandb.run.name + 
EvalModel = NormalAE(model_params["in_channels"], model_params["latent_dim"]).to(device)
EvalModel.load_state_dict(torch.load(model_save_path))


<All keys matched successfully>

In [51]:
"""
Now I need to get a concatenated matrix of all the latent vectors distributions. 
It should take around 5 minutes~ on a GPU with model.eval()

"""
from torch.utils.data import DataLoader, Dataset

data.train_dataset

spn_dataloader = DataLoader(
            data.train_dataset,
            batch_size=1024,
            shuffle=False,
            )


attr_indices = [20, 31]

recons = None

EvalModel.eval()
for i, batch in enumerate(spn_dataloader):
    print( "==> Iter: ", i, " / ", str(len(spn_dataloader)))

    batch_X, batch_attr = batch
    batch_X = batch_X.to(device)
    batch_attr = batch_attr.to(device)

    result = EvalModel(batch_X, batch_attr[:, attr_indices])

    if(i == 0):
        recons = result[2].detach().cpu()
    else:
        recons = torch.cat((recons, result[2].detach().cpu()), 0)


==> Iter:  0  /  159
==> Iter:  1  /  159
==> Iter:  2  /  159
==> Iter:  3  /  159
==> Iter:  4  /  159
==> Iter:  5  /  159
==> Iter:  6  /  159
==> Iter:  7  /  159
==> Iter:  8  /  159
==> Iter:  9  /  159
==> Iter:  10  /  159
==> Iter:  11  /  159
==> Iter:  12  /  159
==> Iter:  13  /  159
==> Iter:  14  /  159
==> Iter:  15  /  159
==> Iter:  16  /  159
==> Iter:  17  /  159
==> Iter:  18  /  159
==> Iter:  19  /  159
==> Iter:  20  /  159
==> Iter:  21  /  159
==> Iter:  22  /  159
==> Iter:  23  /  159
==> Iter:  24  /  159
==> Iter:  25  /  159
==> Iter:  26  /  159
==> Iter:  27  /  159
==> Iter:  28  /  159
==> Iter:  29  /  159
==> Iter:  30  /  159
==> Iter:  31  /  159
==> Iter:  32  /  159
==> Iter:  33  /  159
==> Iter:  34  /  159
==> Iter:  35  /  159
==> Iter:  36  /  159
==> Iter:  37  /  159
==> Iter:  38  /  159
==> Iter:  39  /  159
==> Iter:  40  /  159
==> Iter:  41  /  159
==> Iter:  42  /  159
==> Iter:  43  /  159
==> Iter:  44  /  159
==> Iter:  45  /  15

In [53]:
"""
Saving the reconstructed latent samples into a pytorch folder
"""

latents_filepath = "saved_AE_latents/" + 'April19_' + model_version +'_SmileGender_latents.pth' #wandb.run.name + 
torch.save(recons, latents_filepath)
# z = torch.load(latents_filepath)



# Training the Einsum Network

In [54]:
latents_filepath = "saved_AE_latents/" + 'April19_' + model_version +'_SmileGender_latents.pth'
train_x = torch.load(latents_filepath)
train_x.shape

torch.Size([162770, 130])

In [55]:
### Training the EinSum

import os
import numpy as np
import torch
from EinsumNetwork import Graph, EinsumNetwork
# import datasets
import utils

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# exponential_family = EinsumNetwork.BinomialArray
# exponential_family = EinsumNetwork.CategoricalArray
exponential_family = EinsumNetwork.NormalArray

# classes = [7]
# classes = [2, 3, 5, 7]
# classes = [0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]
# classes = None

K = 10

structure = 'poon-domingos'
# structure = 'binary-trees'

# 'poon-domingos'
pd_num_pieces = [4] # [4]
# pd_num_pieces = [7]
# pd_num_pieces = [7, 28]
width = 10 #28
height = 13 #28

# 'binary-trees'
depth = 3
num_repetitions = 20

num_epochs = 5
batch_size = 1024
online_em_frequency = 1
online_em_stepsize = 0.05


In [57]:
############################################################################

exponential_family_args = None
if exponential_family == EinsumNetwork.BinomialArray:
    exponential_family_args = {'N': 255}
if exponential_family == EinsumNetwork.CategoricalArray:
    exponential_family_args = {'K': 256}
if exponential_family == EinsumNetwork.NormalArray:
    exponential_family_args = {'min_var': 1e-6, 'max_var': 0.1}

# if not exponential_family != EinsumNetwork.NormalArray:
#     train_x /= 255.
#     test_x /= 255.
#     train_x -= .5
#     test_x -= .5

pd_delta = [[height / d, width / d] for d in pd_num_pieces]
graph = Graph.poon_domingos_structure(shape=(height, width), delta=pd_delta)

# # Make EinsumNetwork
# ######################################
# if structure == 'poon-domingos':
#     pd_delta = [[height / d, width / d] for d in pd_num_pieces]
#     graph = Graph.poon_domingos_structure(shape=(height, width), delta=pd_delta)
# elif structure == 'binary-trees':
#     graph = Graph.random_binary_trees(num_var=train_x.shape[1], depth=depth, num_repetitions=num_repetitions)
# else:
#     raise AssertionError("Unknown Structure")

args = EinsumNetwork.Args(
        num_var=train_x.shape[1],
        num_dims=1,
        num_classes=1,
        num_sums=K,
        num_input_distributions=K,
        exponential_family=exponential_family,
        exponential_family_args=exponential_family_args,
        online_em_frequency=online_em_frequency,
        online_em_stepsize=online_em_stepsize)

einet = EinsumNetwork.EinsumNetwork(graph, args)
einet.initialize()
einet.to(device)
print(einet)


EinsumNetwork(
  (einet_layers): ModuleList(
    (0): FactorizedLeafLayer(
      (ef_array): NormalArray()
    )
    (1-2): 2 x EinsumLayer()
    (3): EinsumMixingLayer()
    (4): EinsumLayer()
    (5): EinsumMixingLayer()
    (6): EinsumLayer()
    (7): EinsumMixingLayer()
    (8): EinsumLayer()
    (9): EinsumMixingLayer()
    (10): EinsumLayer()
    (11): EinsumMixingLayer()
  )
)


In [58]:
from tqdm import tqdm

# Train
######################################

train_x = train_x.to(device)
train_N = train_x.shape[0]
# valid_N = valid_x.shape[0]
# test_N = test_x.shape[0]

num_epochs = 100
for epoch_count in tqdm(range(num_epochs)):

    ##### evaluate
    einet.eval()
    train_ll = EinsumNetwork.eval_loglikelihood_batched(einet, train_x, batch_size=batch_size)
    # valid_ll = EinsumNetwork.eval_loglikelihood_batched(einet, valid_x, batch_size=batch_size)
    # test_ll = EinsumNetwork.eval_loglikelihood_batched(einet, test_x, batch_size=batch_size)
    # print("[{}]   train LL {}   valid LL {}   test LL {}".format(
    #     epoch_count,
    #     train_ll / train_N,
    #     valid_ll / valid_N,
    #     test_ll / test_N))
    # print("Epoch: ", epoch_count, " LL ", train_ll / train_N)
    einet.train()
    #####

    idx_batches = torch.randperm(train_N, device=device).split(batch_size)
    # print("num batches: ", len(idx_batches))
    total_ll = 0.0
    for idx in idx_batches:
        batch_x = train_x[idx, :]
        outputs = einet.forward(batch_x)
        ll_sample = EinsumNetwork.log_likelihoods(outputs)
        log_likelihood = ll_sample.sum()
        log_likelihood.backward()

        einet.em_process_batch()
        total_ll += log_likelihood.detach().item()

    einet.em_update()

samples_dir = 'PC_samples_' + model_version
model_dir = 'PC_models_' + model_version

utils.mkdir_p(model_dir)
utils.mkdir_p(samples_dir)

100%|██████████| 100/100 [06:10<00:00,  3.71s/it]


In [59]:
# save model
graph_file = os.path.join(model_dir, "einet.pc")
Graph.write_gpickle(graph, graph_file)
print("Saved PC graph to {}".format(graph_file))
model_file = os.path.join(model_dir, "einet.mdl")
torch.save(einet, model_file)
print("Saved model to {}".format(model_file))

Saved PC graph to PC_models_20/einet.pc
Saved model to PC_models_20/einet.mdl


In [11]:
del einet

# reload model
einet = torch.load(model_file)
print("Loaded model from {}".format(model_file))

Loaded model from PC_models/einet.mdl


In [61]:
"""
Sampling and reconstructing
"""
einet.eval()

z_samples = einet.sample(num_samples=25)# .cpu().numpy()

model_version = "20"
model_save_path = "saved_AE_models/" + 'April19_' + model_version +'_SmileGender.pth' #wandb.run.name + 
EvalModel = NormalAE(model_params["in_channels"], model_params["latent_dim"]).to(device)
EvalModel.load_state_dict(torch.load(model_save_path))

samples_dir = 'PC_samples_' + model_version

## How to reconstruct an image
a = EvalModel.decode(z_samples)
vutils.save_image(a, samples_dir + "/samples_"+ model_version +".png", normalize=True, nrow=12)
print(a.shape)

torch.Size([25, 3, 64, 64])


In [62]:
## Random samples
z_samples = torch.normal(0, 1, size=(25, 130)).to(device)

## How to reconstruct an image
a = EvalModel.decode(z_samples)
vutils.save_image(a, samples_dir + "/AE_gaussian_"+ model_version +".png", normalize=True, nrow=12)
print(a.shape)

torch.Size([25, 3, 64, 64])


In [31]:
import numpy as np

scope = list(np.array(range(130)))

## smiles
marginalize_idx = scope.remove(129)
keep_idx = [129]

# ## other thing
# marginalize_idx = scope.remove(128)
# keep_idx = [128]

## setting the marginalizing idx
einet.set_marginalization_idx(marginalize_idx)



In [67]:
####
# smile not smile
###

scope = list(np.array(range(130)))
scope.remove(129)
marginalize_idx = scope
keep_idx = [129]
einet.set_marginalization_idx(marginalize_idx)

z_samples = torch.normal(0, 1, size=(25, 130)).to(device)
z_samples[:, 129] = 1 ## making smiles

mpe_reconstruction = einet.mpe(x=z_samples)
a = EvalModel.decode(mpe_reconstruction)
vutils.save_image(a, samples_dir + "/mpe_smiles_"+ model_version +".png", normalize=True, nrow=12)


z_samples[:, 129] = 0 ## making no_smiles
mpe_reconstruction = einet.mpe(x=z_samples)
a = EvalModel.decode(mpe_reconstruction)
vutils.save_image(a, samples_dir + "/mpe_no_smiles_"+ model_version +".png", normalize=True, nrow=12)

In [68]:
####
# male female
###

scope = list(np.array(range(130)))
scope.remove(128)
marginalize_idx = scope
keep_idx = [128]
einet.set_marginalization_idx(marginalize_idx)

z_samples = torch.normal(0, 1, size=(25, 130)).to(device)
z_samples[:, 128] = 1 ## making smiles

mpe_reconstruction = einet.mpe(x=z_samples)
a = EvalModel.decode(mpe_reconstruction)
vutils.save_image(a, samples_dir + "/mpe_male_"+ model_version +".png", normalize=True, nrow=12)

z_samples[:, 128] = 0 ## making no_smiles
mpe_reconstruction = einet.mpe(x=z_samples)
a = EvalModel.decode(mpe_reconstruction)
vutils.save_image(a, samples_dir + "/mpe_female_"+ model_version +".png", normalize=True, nrow=12)

[0,
 1,
 2,
 3,
 4,
 5,
 6,
 7,
 8,
 9,
 10,
 11,
 12,
 13,
 14,
 15,
 16,
 17,
 18,
 19,
 20,
 21,
 22,
 23,
 24,
 25,
 26,
 27,
 28,
 29,
 30,
 31,
 32,
 33,
 34,
 35,
 36,
 37,
 38,
 39,
 40,
 41,
 42,
 43,
 44,
 45,
 46,
 47,
 48,
 49,
 50,
 51,
 52,
 53,
 54,
 55,
 56,
 57,
 58,
 59,
 60,
 61,
 62,
 63,
 64,
 65,
 66,
 67,
 68,
 69,
 70,
 71,
 72,
 73,
 74,
 75,
 76,
 77,
 78,
 79,
 80,
 81,
 82,
 83,
 84,
 85,
 86,
 87,
 88,
 89,
 90,
 91,
 92,
 93,
 94,
 95,
 96,
 97,
 98,
 99,
 100,
 101,
 102,
 103,
 104,
 105,
 106,
 107,
 108,
 109,
 110,
 111,
 112,
 113,
 114,
 115,
 116,
 117,
 118,
 119,
 120,
 121,
 122,
 123,
 124,
 125,
 126,
 127,
 129]

In [65]:
z_samples = torch.normal(0, 1, size=(25, 130)).to(device)

idx = 128 ## male
scope = list(np.array(range(130)))
scope.remove(idx)
marginalize_idx = scope
keep_idx = [idx]
einet.set_marginalization_idx(marginalize_idx)


z_samples[:, idx] = 1 ## true or false
num_samples = 10
samples = None
for k in range(num_samples):
    if samples is None:
        samples = einet.sample(x=z_samples)
    else:
        samples += einet.sample(x=z_samples)

## taing the average
samples /= num_samples
samples = samples.squeeze()
print(samples.shape)

## How to reconstruct an image
a = EvalModel.decode(samples)
vutils.save_image(a, samples_dir + "/sample_male_"+ model_version +".png", normalize=True, nrow=12)



z_samples[:, idx] = 0 ## true or false
num_samples = 10
samples = None
for k in range(num_samples):
    if samples is None:
        samples = einet.sample(x=z_samples)
    else:
        samples += einet.sample(x=z_samples)

## taing the average
samples /= num_samples
samples = samples.squeeze()
print(samples.shape)

## How to reconstruct an image
a = EvalModel.decode(samples)
vutils.save_image(a, samples_dir + "/sample_female_"+ model_version +".png", normalize=True, nrow=12)

torch.Size([25, 130])
torch.Size([25, 130])


In [66]:
z_samples = torch.normal(0, 1, size=(25, 130)).to(device)

idx = 129 ## smiles
scope = list(np.array(range(130)))
scope.remove(idx)
marginalize_idx = scope
keep_idx = [idx]
einet.set_marginalization_idx(marginalize_idx)


z_samples[:, idx] = 1 ## true or false
num_samples = 10
samples = None
for k in range(num_samples):
    if samples is None:
        samples = einet.sample(x=z_samples)
    else:
        samples += einet.sample(x=z_samples)

## taing the average
samples /= num_samples
samples = samples.squeeze()
print(samples.shape)

## How to reconstruct an image
a = EvalModel.decode(samples)
vutils.save_image(a, samples_dir + "/sample_smiles_"+ model_version +".png", normalize=True, nrow=12)



z_samples[:, idx] = 0 ## true or false
num_samples = 10
samples = None
for k in range(num_samples):
    if samples is None:
        samples = einet.sample(x=z_samples)
    else:
        samples += einet.sample(x=z_samples)

## taing the average
samples /= num_samples
samples = samples.squeeze()
print(samples.shape)

## How to reconstruct an image
a = EvalModel.decode(samples)
vutils.save_image(a, samples_dir + "/sample_no_smiles_"+ model_version +".png", normalize=True, nrow=12)

torch.Size([25, 130])
torch.Size([25, 130])
