In [None]:
# %% Deep learning - Section 16.155
#    AEs for occlusion

# This code pertains a deep learning course provided by Mike X. Cohen on Udemy:
#   > https://www.udemy.com/course/deeplearning_x
# The "base" code in this repository is adapted (with very minor modifications)
# from code developed by the course instructor (Mike X. Cohen), while the
# "exercises" and the "code challenges" contain more original solutions and
# creative input from my side. If you are interested in DL (and if you are
# reading this statement, chances are that you are), go check out the course, it
# is singularly good.


In [142]:
# %% Libraries and modules
import numpy               as np
import matplotlib.pyplot   as plt
import torch
import torch.nn            as nn
import seaborn             as sns
import copy
import torch.nn.functional as F
import pandas              as pd
import scipy.stats         as stats
import sklearn.metrics     as skm
import time
import sys

from torch.utils.data                 import DataLoader,TensorDataset
from sklearn.model_selection          import train_test_split
from google.colab                     import files
from torchsummary                     import summary
from scipy.stats                      import zscore
from IPython                          import display
from matplotlib_inline.backend_inline import set_matplotlib_formats
set_matplotlib_formats('svg')
plt.style.use('default')


In [124]:
# %% Data

# Load data
data = np.loadtxt(open('sample_data/mnist_train_small.csv','rb'),delimiter=',')

# Split labels from data
labels = data[:,0]
data   = data[:,1:]

# Normalise data (original range is (0,255))
data_norm = data / np.max(data)

# Convert to tensor
data_tensor = torch.tensor(data_norm).float()
labels_tensor = torch.tensor(labels).long()


In [None]:
# %% Demonstrate occlusion

img = data_tensor[12345,:].view(28,28)

occluded_img = copy.deepcopy(img)
occluded_img[10:13,:] = 1

phi = (1 + np.sqrt(5)) / 2
fig,ax = plt.subplots(1,2,figsize=(phi*5,5))

ax[0].imshow(img,cmap='gray')
ax[0].set_title('Original image')
ax[0].axis('off')

ax[1].imshow(occluded_img,cmap='gray')
ax[1].set_title('Occluded image')
ax[1].axis('off')

plt.savefig('figure29_autoencoders_occlusion.png')
plt.show()
files.download('figure29_autoencoders_occlusion.png')


In [126]:
# %% Model class

# No need to create train and test datasets!

def gen_model():

    class mnist_AE(nn.Module):
        def __init__(self):
            super().__init__()

            # Architecture
            self.input  = nn.Linear(784,128)
            self.encode = nn.Linear(128, 50)
            self.mid    = nn.Linear( 50,128)
            self.decode = nn.Linear(128,784)

        # Forward propagation (sigmoid to scale between 0 and 1)
        def forward(self,x):

            x = F.relu(self.input(x))
            x = F.relu(self.encode(x))
            x = F.relu(self.mid(x))
            x = torch.sigmoid(self.decode(x))

            return x

    # Generate model instance
    ANN = mnist_AE()

    # Loss function
    loss_fun = nn.MSELoss()

    # Optimizer
    optimizer = torch.optim.Adam(ANN.parameters(),lr=0.001)

    return ANN,loss_fun,optimizer


In [127]:
# %% Function to train the model

def train_model(ANN,loss_fun,optimizer):

    # Parameters, inizialise vars
    num_epochs = 20
    batch_size = 32
    n_samples  = data_tensor.shape[0]
    losses     = []

    # Loop over epochs
    for epoch_i in range(num_epochs):

        batch_losses = []
        batch_sizes  = []

        # Select a random subset of images
        rand_idx = np.random.permutation(data_tensor.shape[0]).astype(int)

        for i in range(0,n_samples,batch_size):

            # Pick a sample
            sample = rand_idx[i:i+batch_size]
            X      = data_tensor[sample,:]

            # Forward propagation and loss (pass data themselves to loss_fun)
            yHat = ANN(X)
            loss = loss_fun(yHat,X)

            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Batch mean loss and actual batch size (last sample might be <32)
            batch_losses.append(loss.item())
            batch_sizes.append(X.shape[0])

        # Current epoch loss
        losses.append(np.average(batch_losses,weights=batch_sizes))

    return losses,ANN


In [128]:
# %% Train and fit

ANN,loss_fun,optimizer = gen_model()
losses,ANN             = train_model(ANN,loss_fun,optimizer)


In [None]:
# %% Plotting

phi = (1 + np.sqrt(5)) / 2
fig = plt.figure(figsize=(phi*5,5))

plt.plot(losses,'-')
plt.xlabel('Epochs')
plt.ylabel('Model loss')
plt.title('Model loss over epochs')

plt.savefig('figure30_autoencoders_occlusion.png')
plt.show()
files.download('figure30_autoencoders_occlusion.png')


In [None]:
# %% Occlude some images

# Get some images
X = copy.deepcopy( data_tensor[:10,:] )

# Reshape and occlude random rows or cols (if even, horizontal occlusion; if
# odd, vertical)
for i in range(X.shape[0]):

    img       = X[i,:].view(28,28)
    start_loc = np.random.choice(range(10,21))

    if i%2==0:
        img[start_loc:start_loc+1,:] = 1
    else:
        img[:,start_loc:start_loc+1] = 1

# Pass occluded data to trained model
reconstructed = ANN(X)

# Plotting
phi = (1 + np.sqrt(5)) / 2
fig,axs = plt.subplots(3,10,figsize=(1.5*phi*5,5))

for i in range(10):
    axs[0,i].imshow(data_tensor[i,:].view(28,28).detach() ,cmap='gray')
    axs[1,i].imshow(X[i,:].view(28,28).detach() ,cmap='gray')
    axs[2,i].imshow(reconstructed[i,:].view(28,28).detach() ,cmap='gray')
    axs[0,i].set_xticks([]), axs[0,i].set_yticks([])
    axs[1,i].set_xticks([]), axs[1,i].set_yticks([])
    axs[2,i].set_xticks([]), axs[2,i].set_yticks([])

plt.suptitle('Original, occluded and reconstructed images')

plt.savefig('figure31_autoencoders_occlusion.png')
plt.show()
files.download('figure31_autoencoders_occlusion.png')


In [None]:
# Quantify the performance (correlate reconstructed with original)

corr = np.corrcoef(data_tensor[9,:].detach(),reconstructed[9,:].detach())

# Plotting
phi = (1 + np.sqrt(5)) / 2
fig = plt.figure(figsize=(phi*5,5))
plt.plot(data_tensor[9,:].detach(),reconstructed[9,:].detach(),'.',markersize=8)
plt.xlabel('Original pixel values')
plt.ylabel('Reconstructed pixel values')
plt.title(f'Correlation r={corr[0,1] :.3f}')

plt.savefig('figure32_autoencoders_occlusion.png')
plt.show()
files.download('figure32_autoencoders_occlusion.png')


In [None]:
# Quantify the performance (correlate reconstructed with original, exclude zero pixels)

# Variables for convenience
orig  = data_tensor[9,:].detach()
recon = reconstructed[9,:].detach()

# boolean for pixels > 0
tol      = 1e-4
non_zero = (orig>tol) & (recon>tol)

# Recompute correlation
corr_no_zero = np.corrcoef(orig[non_zero],recon[non_zero])

# Plotting
phi = (1 + np.sqrt(5)) / 2
fig = plt.figure(figsize=(phi*5,5))
plt.plot(orig[non_zero],recon[non_zero],'.',markersize=8)
plt.xlabel('Original pixel values')
plt.ylabel('Reconstructed pixel values')
plt.title(f'Correlation r={corr_no_zero[0,1] :.3f}')

plt.savefig('figure33_autoencoders_occlusion.png')
plt.show()
files.download('figure33_autoencoders_occlusion.png')


In [None]:
# %% Test reconstructed occluded data against non-occluded reconstructed data (more 'fair')

# Reconstructed with no occlusion
no_occlusion = ANN(data_tensor[:10,:])

# Compare reconstructed with and without occlusion
r = np.zeros((10,2))
for i in range(reconstructed.shape[0]):

    tol      = 1e-4
    non_zero = (data_tensor[i,:]>tol) & (no_occlusion[i,:]>tol) & (reconstructed[i,:]>tol)

    r[i,0] = np.corrcoef(data_tensor[i,non_zero].detach(),no_occlusion[i,non_zero].detach())[0,1]
    r[i,1] = np.corrcoef(data_tensor[i,non_zero].detach(),reconstructed[i,non_zero].detach())[0,1]


# plot the correlation coefficients
phi = (1 + np.sqrt(5)) / 2
fig = plt.figure(figsize=(phi*5,5))

plt.plot(r,'o-',markersize=8)
plt.legend(['No occlusion','Occlusion'])
plt.xlabel('Sample number')
plt.ylabel('Correlation with original')
plt.title('Correlation with original (occluded and non-occluded images)')

plt.savefig('figure34_autoencoders_occlusion.png')
plt.show()
files.download('figure34_autoencoders_occlusion.png')


In [None]:
# %% Exercise 1
#    Does occlusion affect some numbers more than others? Run the entire dataset through the autoencoder with occluded
#    images. Compute the image correlations for each sample. Then compute the average correlation for each number (image
#    label). Show the results in a plot. (Bonus: Also compute the standard deviation across correlations and use those
#    to draw error bars.) What do the results tell you about the difficulty of fixing occlusions in images?

# Indeed the occlusion seems to be worse for some numbers rather than others

# Occlude all images
data_tensor_occluded = copy.deepcopy( data_tensor )
for i in range(data_tensor_occluded.shape[0]):

    img       = data_tensor_occluded[i,:].view(28,28)
    start_loc = np.random.choice(range(10,21))

    if i%2==0:
        img[start_loc:start_loc+1,:] = 1
    else:
        img[:,start_loc:start_loc+1] = 1

# Pass to model
reconstructed_all = ANN(data_tensor_occluded)

# Correlations
r = np.zeros((reconstructed_all.shape[0]))
for i in range(reconstructed_all.shape[0]):

    tol = 1e-4
    non_zero = (data_tensor[i,:]>tol) & (reconstructed_all[i,:]>tol)
    n_valid  = non_zero.sum().item()

    r[i] = np.corrcoef(data_tensor[i,non_zero].detach(),reconstructed_all[i,non_zero].detach())[0,1]

print(f"Warning: {np.isnan(r).sum()} correlations are nans.")

# Averages by digit (a few rs might be nans, possibly because no variance in
# either vector used to compute the correlation)
mean_corr = np.zeros(10)
std_corr  = np.zeros(10)

for digit in range(10):

    digit_corr       = r[labels == digit]
    mean_corr[digit] = np.nanmean(digit_corr)
    std_corr[digit]  = np.nanstd(digit_corr)

# Plotting
cmap = plt.cm.plasma(np.linspace(0.2,0.9,len(mean_corr)))

phi = (1 + np.sqrt(5)) / 2
fig = plt.figure(figsize=(phi*5,5))

for digit in range(10):

    plt.errorbar(
        digit,
        mean_corr[digit],
        yerr=std_corr[digit],
        fmt='o',
        color=cmap[digit],
        capsize=5 )

plt.xticks(range(10))
plt.xlabel('Digit')
plt.ylabel('Correlation (mean Â± std)')
plt.title('Reconstruction correlation by digit\n(occluded data)')
plt.grid(alpha=0.3)

plt.savefig('figure35_autoencoders_occlusion_extra1.png')
plt.show()
files.download('figure35_autoencoders_occlusion_extra1.png')


In [None]:
# %% Exercise 1
#    Continue ...

import statsmodels.stats.multicomp as mc

# Remove nans
mask         = ~np.isnan(r)
r_clean      = r[mask]
labels_clean = labels[mask]

# Fisher transform to spread out tails (corrs are bounded [-1,1])
z = np.arctanh(r_clean)

# Groups
groups = [z[labels_clean == d] for d in range(10)]

# Assume normality and homoscedasticity

# One-way ANOVA
f,p = stats.f_oneway(*groups)
print(f"ANOVA F = {f}, p = {p:.5f}"),print()

# Tukey post-hoc
z_all        = np.concatenate(groups)
digit_labels = np.concatenate([[d]*len(groups[d]) for d in range(10)])

comp  = mc.MultiComparison(z_all,digit_labels)
tukey = comp.tukeyhsd()

print(tukey)

# Plotting
md     = tukey.meandiffs
groups = tukey.groupsunique
matrix = np.zeros((len(groups),len(groups)))

phi = (1 + np.sqrt(5)) / 2
fig = plt.figure(figsize=(phi*5,5))

k = 0
for i in range(len(groups)):
    for j in range(i+1, len(groups)):
        matrix[i,j] = md[k]
        matrix[j,i] = -md[k]
        k += 1

sns.heatmap(matrix,annot=True,fmt=".2f",cmap="plasma",center=0,xticklabels=groups,yticklabels=groups)
plt.title("Tukey post-hoc\n(Mean differences in avg correlations between digits)")

plt.savefig('figure36_autoencoders_occlusion_extra1.png')
plt.show()
files.download('figure36_autoencoders_occlusion_extra1.png')


In [None]:
# %% Exercise 2
#    Perhaps a correlation coefficient isn't really the best performance metric. Try this: Binarize the images like we
#    did in the video "CodeChallenge: Binarized MNIST images" (section FFN). Then compute the number of pixels in the
#    original and reconstructed images that overlap (hint: try summing them). Make sure your new metric has a possible
#    range of 0 (absolutely no overlap) to 1 (perfect overlap). Does this metric seem more consistent with your visual
#    intuition?

# Seems very similar in pattern to the correlation approach

# Occlude all images
data_tensor_occluded = copy.deepcopy( data_tensor )
for i in range(data_tensor_occluded.shape[0]):

    img       = data_tensor_occluded[i,:].view(28,28)
    start_loc = np.random.choice(range(10,21))

    if i%2==0:
        img[start_loc:start_loc+1,:] = 1
    else:
        img[:,start_loc:start_loc+1] = 1

# Pass to model
reconstructed_all_obstructed = ANN(data_tensor_occluded)

# Binarise
split_val         = 0.5
reconstructed_bin = np.where(reconstructed_all_obstructed.detach().numpy()>split_val, 1,0)
original_bin      = np.where(data_tensor.detach().numpy()>split_val, 1,0)

# Compute overlap (pixels 1 in both / pixels 1 in either; i.e., intersection / union)
overlap = np.zeros(original_bin.shape[0])

for i in range(original_bin.shape[0]):
    orig  = original_bin[i]
    recon = reconstructed_bin[i]

    intersection = np.sum(orig & recon)
    union        = np.sum(orig | recon)

    if union == 0:
        overlap[i] = 1.0
    else:
        overlap[i] = intersection / union

# Average per digit
mean_overlap = np.zeros(10)
std_overlap  = np.zeros(10)

for d in range(10):
    digit_idx       = labels == d
    mean_overlap[d] = np.mean(overlap[digit_idx])
    std_overlap[d]  = np.std(overlap[digit_idx])

# Plotting
phi = (1 + np.sqrt(5)) / 2
fig = plt.figure(figsize=(phi*5,5))

for digit in range(10):

    plt.errorbar(
        digit,
        mean_overlap[digit],
        yerr=std_overlap[digit],
        fmt='o',
        color=cmap[digit],
        capsize=5 )

plt.xticks(range(10))
plt.xlabel('Digit')
plt.ylabel('Pixel overlap')
plt.title('Binarized overlap between reconstructed original and obstructed images')
plt.grid(alpha=0.3)

plt.savefig('figure37_autoencoders_occlusion_extra2.png')
plt.show()
files.download('figure37_autoencoders_occlusion_extra2.png')


In [None]:
# %% Exercise 3
#    But wait a minute, don't we already have a quantitative measure of the similarity between the AE input and output?
#    Of course we do -- it's the loss function! Mean-squared error already accounts for zeros because those get ignored
#    [zero-valued pixels have MSE=(0-0)**2 ]. In fact, question #2 is kindof a "rough MSE." Take a moment to write down
#    the formulas for MSE and correlation, and see whether they are related (hint: the relationship isn't linear because of
#    the squared term). Finally, compute MSE on our example occlusion images and compare MSE to correlation empirically
#    by making a scatter plot. (Hint 1: Use more than 10 examples to see trends. Hint 2: Consider the signs (+/-).)

# Correlation and MSE are strictly related, they bot measure similarity between
# vectors. MSE gets the absolute difference and is not scale-free (square of
# differences), while correlation is a scale-free measure of similarity
# (normalised covariance)

# Get some images
subsample = copy.deepcopy( data_tensor[:400,:] )

for i in range(subsample.shape[0]):

    img       = subsample[i,:].view(28,28)
    start_loc = np.random.choice(range(10,21))

    if i%2==0:
        img[start_loc:start_loc+1,:] = 1
    else:
        img[:,start_loc:start_loc+1] = 1

# Pass occluded data to trained model
reconstructed_occluded = ANN(subsample)

# Correlations
orig = subsample.detach().numpy()
recon = reconstructed_occluded.detach().numpy()

r = np.zeros((recon.shape[0]))
for i in range(recon.shape[0]):

    tol = 1e-4
    non_zero = (orig[i,:]>tol) & (recon[i,:]>tol)
    n_valid  = non_zero.sum().item()

    r[i] = np.corrcoef(orig[i,non_zero],recon[i,non_zero])[0,1]

print(f"Warning: {np.isnan(r).sum()} correlations are nans.")

# MSE
mse = np.mean((orig - recon)**2, axis=1)


In [None]:
# %% Exercise 3
#    Continue ...

phi = (1 + np.sqrt(5)) / 2
fig = plt.figure(figsize=(phi*5,5))

plt.scatter(r,mse,c='tab:blue',alpha=0.6)
plt.xlabel("Pixel-wise correlation")
plt.ylabel("Pixel-wise MSE")
plt.title("Reconstruction quality\n(correlation vs MSE)")
plt.grid(True)

plt.savefig('figure38_autoencoders_occlusion_extra3.png')
plt.show()
files.download('figure38_autoencoders_occlusion_extra3.png')
