In [1]:
import os
import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, models, transforms
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
from sklearn.decomposition import PCA

import cv2
from PIL import Image as im
from sklearn.metrics import jaccard_score

import collections
from typing import DefaultDict, Tuple, List, Dict
from functools import partial

In [112]:
import sys
sys.path.append('..\..\early-stopping-pytorch')
from pytorchtools import EarlyStopping

# Adjust printing view dimensions
np.set_printoptions(threshold=sys.maxsize, linewidth=300)
torch.set_printoptions(threshold=sys.maxsize, linewidth=300, profile='full')

In [3]:
class AutoEncoder(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()
        self.input_shape = kwargs["input_shape"]
        # number of hidden units in first hidden layer
        self.n_units = kwargs["n_units"]
        # number of hidden units in latent space
        self.latent_units = kwargs["latent_units"]
        
        self.encoder = torch.nn.Sequential(
            nn.Linear(in_features=self.input_shape, out_features=self.n_units),
            torch.nn.ReLU()
        )
        # Bottleneck is actually in the encoder, but it must be isolated in order to calculate sparsity
        self.bottleneck = torch.nn.Sequential(
            nn.Linear(in_features=self.n_units, out_features=self.latent_units),
            torch.nn.ReLU()
        )
        self.decoder = torch.nn.Sequential(
            nn.Linear(in_features=self.latent_units, out_features=self.n_units),
            torch.nn.ReLU(),
            nn.Linear(in_features=self.n_units, out_features=self.input_shape),
            torch.nn.Sigmoid()
        )
    
    # X denotes features
    def forward(self, X):
        encoded = self.encoder(X)
        bottleneck = self.bottleneck(encoded)
        decoded = self.decoder(bottleneck)
        return bottleneck, decoded

In [198]:
# Load the dataset
# 10,000 samples, 30x30 matrices
is_pca = False
data_count = 20
data = np.ndarray(shape=(data_count,30,30))
n_features = data.shape[1] * data.shape[2]


for i in range(data_count):
    path = f'data/jet_matrices/sample{i+1}.dat'
    sample = np.loadtxt(path, unpack = False)
    data[i] = sample

print("Done loading data.")

Done loading data.


In [199]:
# Load parameters corresponding to the 4 variables input into 
# the Helmholtz Resonator function, where output is each sample in dataset.
params = np.ndarray(shape=(data_count,4))

path = r'data/param_lhs.dat'
with open(path) as f:
    lines = f.readlines()
    for i, line in enumerate(lines):
        if i >= params.shape[0]:
            break
        param = np.fromstring(line, dtype=float, sep=',')
        params[i] = param

print("Done loading parameters.")

Done loading parameters.


In [200]:
# Flatten data and convert to Torch Tensor

# 10,000 samples, 900 features
X = np.ndarray(shape=(data_count, n_features))
for i, sample in enumerate(data):
    if i >= X.shape[0]:
        break
    flat = sample.flatten()
    X[i] = flat

# Convert from numpy array to Pytorch tensor
X = torch.from_numpy(X)
# Convert all scalars to floats. May affect training behavior (ie. reconstructions made of non-binary scalar values)
X = X.float()

In [201]:
# Pair data samples with their corresponding parameter
# in order to keep organized during random splitting.
X_with_params = []
for i in range(data_count):
    pair = [X[i], params[i]]
    X_with_params.append(pair)

# PCA

In [None]:
def de_correlate_data(X):
    X_pert = np.copy(X)
    i = 0
    for col in X.T:
        #print(col)
        X_pert[:,i] = np.random.permutation(col)
        #print(X_pert[:,i])
        i += 1
        
    return X_pert

In [None]:
# Plot cumulative explained variance w.r.t. number of components

def pca_run(X):
    pca = PCA(n_components=0.95).fit(X)

    #% matplotlib inline
    import matplotlib.pyplot as plt
    plt.rcParams["figure.figsize"] = (12,6)

    fig, ax = plt.subplots()
    y = np.cumsum(pca.explained_variance_ratio_)
    # n_components = number of components needed to reach cum. variance threshold
    n_components = y.size
    xi = np.arange(1, n_components+1, step=1)

    plt.ylim(0.0,1.1)
    plt.plot(xi, y, marker='o', linestyle='--', color='b')

    plt.xlabel('Number of Components')
    #change from 0-based array index to 1-based human-readable label
    plt.xticks(np.arange(0, n_components+1, step=1))
    plt.ylabel('Cumulative variance (%)')
    plt.title('The Number of Components Needed to Explain Variance')

    plt.axhline(y=0.95, color='r', linestyle='-')
    plt.axhline(y=0.8, color='g', linestyle='-')
    plt.axhline(y=0.9, color='b', linestyle='-')
    plt.text(0, 0.915, '95% cut-off threshold', color = 'red', fontsize=13)
    plt.text(24, 0.85, '90% cut-off threshold', color = 'blue', fontsize=13)
    plt.text(12, 0.75, '80% cut-off threshold', color = 'green', fontsize=13)

    ax.grid(axis='x')
    plt.show()

# Run with original data.
pca_run(X.numpy())

# Run with permutated data.
# De-correlates features, so performing worse than original data indicates
# existence of correlation in the original data's features.
X_pert = de_correlate_data(X)
pca_run(X_pert)

In [None]:
plt.rcParams["figure.figsize"] = (12,6)
fig, ax = plt.subplots()
plt.bar(xi, pca.explained_variance_ratio_, width=0.4)
plt.ylabel("Percent of Total Variance")
plt.xlabel("Principal Component")
plt.title("Significance of Each Principal Component Towards Variance ")

In [None]:
# PCA

# Toggle to indicate to training that PCA is in use
is_pca = True
# -- DEFINE NUMBER OF COMPONENTS HERE --
n_components = 5

pca = PCA(n_components=n_components).fit(X.numpy())

print(X)
# If fails, re-run "Flatten data..." cell
X_pca = pca.fit_transform(X)
X_pca = torch.from_numpy(X_pca)
# Convert all scalars to floats. May affect training behavior (ie. reconstructions made of non-binary scalar values)
X_pca = X_pca.float()
# Replace former n_features with number of components
n_features = X_pca.shape[1]

# Training & Validation

In [203]:
# Hyperparameters

# Changes X based on whether PCA was used
if is_pca:
    X_2 = X_pca
else:
    X_2 = X

batch_size = 32
# 70/15/15 split
train_size = int(0.7 * len(X_2))
val_test_size = len(X_2) - train_size
test_size = val_test_size // 2
    
val_size = val_test_size - test_size

In [204]:
# Initate data loaders

train, val = torch.utils.data.random_split(X_with_params, [train_size, val_test_size], generator=torch.Generator().manual_seed(5))
val, test = torch.utils.data.random_split(val, [val_size, test_size], generator=torch.Generator().manual_seed(5))

train_loader = torch.utils.data.DataLoader(
    train, batch_size=1, shuffle=True, num_workers=0, pin_memory=True
)

val_loader = torch.utils.data.DataLoader(
    val, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True
)

test_loader = torch.utils.data.DataLoader(
    test, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True
)

# Same as test_loader but with stochastic batch size
test_loader_stoch = torch.utils.data.DataLoader(
    test, batch_size=1, shuffle=False, num_workers=0, pin_memory=True
)

# Use gpu if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [205]:
#############################    
#   TRAINING & VALIDATION   #
#############################

class ExceededRangeError(Exception):
    """Raised when values outside range [0.0, 1.0] are found in BCE loss"""
    pass

def KL(p, q):
    # "from q to p"
    # p and q are np array frequency distributions
    
    sum = 0.0
    for i in range(p.size(dim=1)):
        div = torch.div(p[0,i], q[0,i])
        sum += p[0,i] * torch.log(div)
    return sum

def demo():
    print("\nBegin Kullback-Liebler from scratch demo ")
    np.set_printoptions(precision=4, suppress=True)

    p = np.array([9.0/25.0, 12.0/25.0, 4.0/25.0], dtype=np.float32)
    q = np.array([1.0/3.0, 1.0/3.0, 1.0/3.0], dtype=np.float32)

    print("\nThe P distribution is: ")
    print(p)
    print("The Q distribution is: ")
    print(q)

    kl_pq = KL(p,q)
    kl_qp = KL(q, p)

    print("\nKL(P,Q) = %0.6f " % kl_pq)
    print("KL(Q,P) = %0.6f " % kl_qp)

    print("\nEnd demo ")
    
    
    

class TrainedModel():
    def __init__(self, model, avg_train_loss, avg_val_loss, epochs):
        self.model = model
        self.avg_train_loss = avg_train_loss
        self.avg_val_loss = avg_val_loss
        self.epochs = epochs

# Training and Validation are combined in order to allow for early stopping
def train_validate(model, epochs, lr, is_early_stopping=False, is_pca=False, is_sparse=False, patience=None, beta=None, rho=None):
    # Define Adam optimizer
    optimizer = optim.Adam(model.parameters(), lr=lr)

    # Binary Cross Entropy Loss
    criterion = nn.BCELoss()
    # See pytorch docs on why reduction is batch mean.
    kl_divergence = nn.KLDivLoss(reduction="batchmean", log_target=False)

    # Reset model state if previously trained
    torch.manual_seed(1)
    def weights_init(m):
        if isinstance(m, torch.nn.Linear):
            nn.init.xavier_uniform_(m.weight.data)
            print("existing instance")
                                 
    model.apply(weights_init)

    # Toggle Early Stopping (if using).
    if is_early_stopping:
        early_stopping = EarlyStopping(patience=patience, verbose=True)
        print("Using Early Stopping")
    if is_pca:
        print("Using PCA")

    print("Training...")
    for epoch in range(epochs):
        
        #############################    
        #          TRAINING         #
        #############################
        
        loss = 0
        # Prepare model for training
        model.eval()
        train_losses = []
        for i, batch in enumerate(train_loader, 0):
            # remove params, keep data
            batch = batch[0]
            # reshape mini-batch data from [batch_size, 30, 30] to [batch_size, 900]
            # load it to the active device
            batch = batch.view(-1, n_features).to(device)

            # reset the gradients back to zero
            # PyTorch accumulates gradients on subsequent backward passes
            optimizer.zero_grad()

            # compute reconstructions
            # also retrieve bottleneck weights for computing sparsity penalty
            bottleneck, decoded = model(batch)

            # Exception handler for when BCE loss has values outside range [0.0, 1.0]
            try:
                # compute training reconstruction loss
                train_loss = criterion(decoded, batch)
            except RuntimeError:
                print('Runtime Error during loss calculation. BCE loss has values outside range [0.0, 1.0]')
                for k, sample in enumerate(decoded):
                    print(k)
                    print(sample)
                    
            # add sparsity penalty to loss, if toggled
            if is_sparse:
                rho_hat = torch.sum(bottleneck, dim=0, keepdim=True)
#                 # the input KL Divergence must be in log space and in a distribution, hence log & softmax.
#                 rho_hat = torch.nn.functional.log_softmax(rho_hat, dim=1)
                
#                 # Converting target to softmax in our case is unecessary because rho is the same for all values,
#                 # thus no change occurs. However, we do need to convert target to log-space due to a PyTorch bug.
                
#                 # rho_soft = torch.nn.functional.log_softmax(rho, dim=1)
#                 rho_soft = torch.nn.functional.softmax(rho, dim=1)
                
#                 s = kl_divergence(rho_hat, rho_soft)
                
                rho_hat = torch.nn.functional.log_softmax(rho_hat, dim=1)
                rho_soft = torch.nn.functional.log_softmax(rho, dim=1)
                s = KL(rho, rho_hat)
                
                sparsity_penalty = beta * s
                #print(sparsity_penalty)
                train_loss = train_loss + sparsity_penalty
                # Check whether KL divergence is behaving correctly (ie. should be nonnegative).
                if torch.all(sparsity_penalty < 0):
                    print('Error: sparsity penalty is negative.')
                    print(f'rho: {rho}')
                    print(f'rho_hat: {rho_hat}')
                    print(f'sparsity: {sparsity_penalty}')
                    print(f'training loss: {train_loss}')

            # compute accumulated gradients
            train_loss.backward()

            # perform parameter update based on current gradients
            optimizer.step()

            # add the mini-batch training loss to epoch loss
            train_losses.append(train_loss.item())

        # compute the epoch training loss
        avg_train_loss = np.average(train_losses)

        #############################    
        #         VALIDATION        #
        #############################

        # Decoupled into three lists due to issue with placing torch tensors into multidimensional lists
        batches = []
        recons = []
        val_losses = []

        # Prepare model for evaluation
        model.eval()

        # since we're not training, we don't need to calculate the gradients for our outputs
        with torch.no_grad():
            for i, batch in enumerate(val_loader, 0):
                # remove params, keep data 
                batch = batch[0]
                batch = batch.view(-1, n_features).to(device)
                bottleneck, reconstructions = model(batch)
                # Reconstruction loss
                val_loss = criterion(reconstructions, batch)
                # Store samples, predictions, and loss for visualization purposes
                batches.append(batch)
                recons.append(reconstructions)
                val_losses.append(val_loss.item())
                #print(f'Batch {i}: {val_loss.item()}')

        avg_val_loss = np.average(val_losses)
        
        # display the epoch training loss and validation loss
        print("Epoch : {}/{}, Training Loss = {:.6f}, Validation Loss = {:.6f}".format(epoch + 1, epochs, avg_train_loss, avg_val_loss))
        
        opt_epochs = epochs
        
        if is_early_stopping:
            early_stopping(avg_val_loss, model)
            if early_stopping.early_stop:
                opt_epochs = epoch + 1
                print("Early stopping...")
                # Exit training loop
                break
        else:
            torch.save(model.state_dict(), 'checkpoint.pt')
    
    # load the last checkpoint with the best model
    model.load_state_dict(torch.load('checkpoint.pt'))
    print(f"Epochs: {opt_epochs}, Training Loss: {avg_train_loss}, Validation Loss: {avg_val_loss}")
    trained_model = TrainedModel(model, avg_train_loss, avg_val_loss, epochs)
    
    return trained_model

In [206]:
# BASIC AUTOENCODER Execute training & validating

lr = 1e-3
epochs = 5
# number of hidden units in encoder hidden layer
n_units = 50
# number of hidden units in latent space
latent_units = 4
# Boolean for whether to use Early Stopping
is_early_stopping = False
# early stopping patience; how long to wait after last time validation loss improved.
patience = 20

basic_model = AutoEncoder(input_shape=n_features,
                    n_units=n_units,
                    latent_units=latent_units
                   ).to(device)

basic_trained = train_validate(model=basic_model,
                            epochs=epochs,
                            lr=lr,
                            is_early_stopping=is_early_stopping, 
                            is_pca=is_pca,
                            patience=patience)


existing instance
existing instance
existing instance
existing instance
Training...
Epoch : 1/5, Training Loss = 0.610824, Validation Loss = 0.461002
Epoch : 2/5, Training Loss = 0.254193, Validation Loss = 0.180619
Epoch : 3/5, Training Loss = 0.099973, Validation Loss = 0.167758
Epoch : 4/5, Training Loss = 0.090283, Validation Loss = 0.145618
Epoch : 5/5, Training Loss = 0.085692, Validation Loss = 0.140694
Epochs: 5, Training Loss: 0.08569211486194815, Validation Loss: 0.1406935453414917


In [None]:
# SPARSE AUTOENCODER Execute training & validating
lr = 1e-3
epochs = 200
n_units = 50
latent_units = 20
is_early_stopping = True
# early stopping patience; how long to wait after last time validation loss improved.
patience = 5

is_sparse = True
beta = 2
rho = 0.1
rho_tensor = torch.FloatTensor([rho for _ in range(latent_units)]).unsqueeze(0)
rho_tensor = rho_tensor.to(device)

sparse_model = AutoEncoder(input_shape=n_features,
                    n_units=n_units,
                    latent_units=latent_units
                   ).to(device)

trained_sparse = train_validate(model=sparse_model,
                                epochs=epochs,
                                lr=lr,
                                is_early_stopping=is_early_stopping, 
                                is_pca=is_pca,
                                is_sparse=is_sparse,
                                patience=patience,
                                beta=beta,
                                rho=rho_tensor)

# Parameter - Latent Weights Correlation

In [207]:
# Get average latent layer weights
avg_weights = []
for neuron in basic_model.state_dict()['bottleneck.0.weight']:
    avg = torch.mean(neuron)
    avg_weights.append(avg)
avg_weights

[tensor(0.0349, device='cuda:0'),
 tensor(0.0194, device='cuda:0'),
 tensor(0.0032, device='cuda:0'),
 tensor(0.0271, device='cuda:0')]

In [255]:
def save_activations(
        activations: DefaultDict,
        name: str,
        module: nn.Module,
        inp: Tuple,
        out: torch.Tensor
) -> None:
    """PyTorch Forward hook to save outputs at each forward
    pass. Mutates specified dict objects with each fwd pass.
    """
    activations[name].append(out.detach().cpu())
    

def register_activation_hooks(
        model: nn.Module,
        layers_to_save: List[str]
) -> DefaultDict[List, torch.Tensor]:
    """Registers forward hooks in specified layers.
    Parameters
    ----------
    model:
        PyTorch model
    layers_to_save:
        Module names within ``model`` whose activations we want to save.

    Returns
    -------
    activations_dict:
        dict of lists containing activations of specified layers in
        ``layers_to_save``.
    """
    activations_dict = collections.defaultdict(list)

    for name, module in model.named_modules():
        if name in layers_to_save:
            module.register_forward_hook(
                partial(save_activations, activations_dict, name)
            )
    return activations_dict


# Save activations per layer per sample
def get_activations(model):
    # Enter which layers to retrieve activations from
    to_save = ['bottleneck.0']

    # register fwd hooks in specified layers
    saved_activations = register_activation_hooks(model, layers_to_save=to_save)
    activations_with_params = []

    with torch.no_grad():
        # Evaluate one sample at a time
        for i, sample in enumerate(test_loader_stoch, 0):
            # Remove params, keep data
            params = sample[1]
            sample = sample[0]
            # move to device
            sample = sample.to(device)
            bottleneck, reconstruction = model(sample)
            
            # keep track of which activations correspond with which parameters
            pair = [saved_activations['bottleneck.0'][i], params]
            activations_with_params.append(pair)
            
    return activations_with_params


def plot_correlations(activations_with_params):
    for pair in activations_with_params:
        activations = pair[0]
        params = pair[1]
        for i in range(activations.shape[1]):
            print(activations[0][i])
            for j in range(params.shape[1]):
                print(params[0][j])

In [256]:
activations_with_params = get_activations(basic_model)
# print(activations_with_params)

plot_correlations(activations_with_params)

tensor(10.9300)
tensor(6353.6500, dtype=torch.float64)
tensor(43.9109, dtype=torch.float64)
tensor(0.0196, dtype=torch.float64)
tensor(1156.8900, dtype=torch.float64)
tensor(-4.2650)
tensor(6353.6500, dtype=torch.float64)
tensor(43.9109, dtype=torch.float64)
tensor(0.0196, dtype=torch.float64)
tensor(1156.8900, dtype=torch.float64)
tensor(40.1753)
tensor(6353.6500, dtype=torch.float64)
tensor(43.9109, dtype=torch.float64)
tensor(0.0196, dtype=torch.float64)
tensor(1156.8900, dtype=torch.float64)
tensor(-3.4796)
tensor(6353.6500, dtype=torch.float64)
tensor(43.9109, dtype=torch.float64)
tensor(0.0196, dtype=torch.float64)
tensor(1156.8900, dtype=torch.float64)
tensor(10.2434)
tensor(1228.1500, dtype=torch.float64)
tensor(37.3265, dtype=torch.float64)
tensor(0.0479, dtype=torch.float64)
tensor(1409.3700, dtype=torch.float64)
tensor(-3.7820)
tensor(1228.1500, dtype=torch.float64)
tensor(37.3265, dtype=torch.float64)
tensor(0.0479, dtype=torch.float64)
tensor(1409.3700, dtype=torch.float64

# Testing

In [12]:
def test(model, n_features):
    # Decoupled into three lists due to issue with placing torch tensors into multidimensional lists
    batches = []
    recons = []
    test_losses = []
    criterion = nn.BCELoss()

    # Prepare model for evaluation
    model.eval()

    # since we're not training, we don't need to calculate the gradients for our outputs
    with torch.no_grad():
        for i, batch in enumerate(test_loader, 0):
            # remove params, keep data
            batch = batch[0]
            # move data to device
            batch = batch.to(device)
            bottleneck, reconstructions = model(batch)
            # Reconstruction loss
            test_loss = criterion(reconstructions, batch)
            # Store samples, predictions, and loss for visualization purposes
            batches.append(batch)
            recons.append(reconstructions)
            test_losses.append(test_loss.item())
            print(f'Batch {i}: {test_loss.item()}')

    avg_test_loss = np.average(test_losses)
    print(f"Average Test Reconstruction Loss: {avg_test_loss}")
    
    return batches, recons, test_losses

In [66]:
# BASIC MODEL TEST
batches, recons, test_losses = test(basic_model, n_features)

Batch 0: 0.009814923629164696
Batch 1: 0.01434624195098877
Batch 2: 0.011674277484416962
Batch 3: 0.008035548962652683
Batch 4: 0.011244864203035831
Batch 5: 0.01551626343280077
Batch 6: 0.010843966156244278
Batch 7: 0.009246154688298702
Batch 8: 0.009967377409338951
Batch 9: 0.012493015266954899
Batch 10: 0.007324742618948221
Batch 11: 0.00857092346996069
Batch 12: 0.010051090270280838
Batch 13: 0.009157484397292137
Batch 14: 0.010780120268464088
Batch 15: 0.009737378917634487
Batch 16: 0.011055518873035908
Batch 17: 0.012654055841267109
Batch 18: 0.010263180360198021
Batch 19: 0.0109745804220438
Batch 20: 0.010743393562734127
Batch 21: 0.010629914700984955
Batch 22: 0.011865353211760521
Batch 23: 0.01139965932816267
Batch 24: 0.008230835199356079
Batch 25: 0.013583836145699024
Batch 26: 0.011747307144105434
Batch 27: 0.008694911375641823
Batch 28: 0.011957325972616673
Batch 29: 0.011084092780947685
Batch 30: 0.013173844665288925
Batch 31: 0.008395952172577381
Batch 32: 0.010515286587

In [None]:
# SPARSE MODEL TEST
batches, recons, test_losses = test(sparse_model, n_features)

# Visualization

In [70]:
# Calculate Intersection over Union
def iou(cur_img, original, recon):
    path = r'data\iou'
    
    for i, x in enumerate(recon):
        for j, y in enumerate(x):
            if y.data < 0.5:
                recon[i,j] = 0.0
            else:
                recon[i,j] = 1.0
                
    original_flat  = original.flatten().numpy().astype(int)
    recon_flat = recon.flatten().numpy().astype(int)
    
    # Jaccard Scores of positive and negative classes
    score = jaccard_score(original_flat, recon_flat, average=None)
    # Average Jaccard Score between both classes
    score_micro = jaccard_score(original_flat, recon_flat, average='micro')
    return score, score_micro
    
#     # convert arrays to grayscale
#     original = np.array(original * 255, dtype = np.uint8)
#     recon = np.array(recon * 255, dtype = np.uint8)
    
#     path1 = f'{path}\original{cur_img}.png'
#     path2 = f'{path}\\recon{cur_img}.png'
    
#     cv2.imwrite(path1, original)
#     cv2.imwrite(path2, recon)
    
#     original_img = cv2.imread(path1, 0)
#     recon_img = cv2.imread(path2, 0)
    
#     intersect = cv2.bitwise_and(original_img, recon_img)
#     union = cv2.bitwise_or(original_img, recon_img)

#     plt.imshow(intersect, cmap='gray', vmin=0, vmax=255)
#     plt.axis('off')
#     title1 = f'original{cur_img}'
#     plt.title(title1)
#     plt.show()
#     title2 = f'recon{cur_img}'

    
# Plot original image alongside its reconstruction
def plot(cur_batch, tot_batches, original, recon, loss):
    fig = plt.figure(figsize=(8, 8))
    plt.title("Batch : {}/{}, Batch Reconstruction Loss = {:.6f}".format(cur_batch+1, tot_batches, loss))
    plt.axis('off')
    # display original
    fig.add_subplot(1, 2, 1)
    plt.imshow(original)
    plt.axis('off')
    plt.title("original")
    plt.gray()

    # fig.get_xaxis().set_visible(False)
    # fig.get_yaxis().set_visible(False)

    # display reconstruction
    fig.add_subplot(1, 2, 2)
    plt.imshow(recon)
    plt.axis('off')
    plt.title("reconstructed")
    plt.gray()
    # fig.get_xaxis().set_visible(False)
    # fig.get_yaxis().set_visible(False)
    plt.show()
    

# Main function for visualization
def visualize(n, batches, recons, test_losses, is_compare=False, is_iou=False):
    count = 0
    scores = []
    for i in range(n):
        loss = test_losses[i]
        batch = batches[i]
        reconstructions = recons[i]
        # Iterate through all examples in ith batch
        for j in range(len(batch)):
            # If n plots have been printed, exit
            if count >= n:
                return
            # Reshape original example for plotting back into 30x30
            # or keep as vector of components if using PCA.
            if is_pca:
                original = batch[j].reshape(1, n_features)
            else:
                original = batch[j].reshape(data.shape[1], data.shape[2])
            original = original.cpu()
            # Reshape reconstructed example for plotting
            # or keep as vector of components if using PCA.
            if is_pca:
                recon = reconstructions[j].reshape(1, n_features)
            else:
                recon = reconstructions[j].reshape(data.shape[1], data.shape[2])
            recon = recon.cpu()
            
            if is_iou:
                score, score_micro = iou(count, original, recon)
                scores.append(score_micro)
                print(f'Jaccard Similarity (Pos. & Neg.): {score}')
                print(f"Jaccard Similarity (Both avg'd): {score_micro}")
            if is_compare:
                print(original)
                plot(i, len(recons), original, recon, loss)

            count += 1
            
visualize(n=5, batches=batches, recons=recons, test_losses=test_losses, is_compare=False, is_iou=True)

Jaccard Similarity (Pos. & Neg.): [0.99188641 0.99026764]
Jaccard Similarity (Both avg'd): 0.9911504424778761
Jaccard Similarity (Pos. & Neg.): [0.99788136 0.997669  ]
Jaccard Similarity (Both avg'd): 0.9977802441731409
Jaccard Similarity (Pos. & Neg.): [0.99628942 0.99449036]
Jaccard Similarity (Both avg'd): 0.9955654101995566
Jaccard Similarity (Pos. & Neg.): [0.996      0.99502488]
Jaccard Similarity (Both avg'd): 0.9955654101995566
Jaccard Similarity (Pos. & Neg.): [0.9916388  0.98371336]
Jaccard Similarity (Both avg'd): 0.988950276243094
