In [None]:
""" Experiment on MNIST data - 
The task is to reconstruct ground truth pixels using pixels+noise
If prior knowledge of the required topology is provided then we will observe the same data
being predicted as a 0, 9, 8 etc. as different priors are enforced. 
"""

import copy
import gudhi as gd
import matplotlib.pyplot as plt
import numpy as np
import numpy.fft as ft
import os

%matplotlib inline

import torch
import torchvision.datasets as datasets
import torchsummary
from torch import nn, optim
from torch.nn import functional as F
from models import Segmenter_Unet, MNIST_classifier

from topologylayer.nn import LevelSetLayer2D, TopKBarcodeLengths
device = torch.device("cuda")

In [None]:
mnist_trainset = datasets.MNIST(root='./MNIST_data', train=True, download=True, transform=None)
mnist_testset = datasets.MNIST(root='./MNIST_data', train=False, download=True, transform=None)
img_dim = 28
print("Size of training set is {}".format(len(mnist_trainset)))
print("Size of test set is {}".format(len(mnist_testset)))

In [None]:
""" Split training set """

X_train = np.array([np.array(x[0]) for x in mnist_trainset])
Y_train = np.array([x[1] for x in mnist_trainset])
X_test = np.array([np.array(x[0]) for x in mnist_testset])
Y_test = np.array([x[1] for x in mnist_testset])

# use some of the training set to train denoising network
# use some of the training set to train digit classifier (that can measure how well digits are denoised)
N_denoise = 10000
N_classifier = 50000

X_denoise = X_train[:N_denoise]
Y_denoise = Y_train[:N_denoise]
X_classifier = X_train[N_denoise:N_denoise+N_classifier]
Y_classifier = Y_train[N_denoise:N_denoise+N_classifier]

def norm(X):
    return X.astype(np.float) / np.max(X, axis=(1,2)).reshape((-1, 1, 1))

X_denoise = norm(X_denoise)
X_classifier = norm(X_classifier)

print(X_denoise.shape)
print(Y_denoise.shape)
print(X_classifier.shape)
print(Y_classifier.shape)

In [None]:
classifier_net = MNIST_classifier(img_dim).to(device)
classifier_optimizer = optim.Adam(classifier_net.parameters(), lr=1e-4)

X_classifier_torch = torch.tensor(X_classifier.reshape(-1, 1, img_dim, img_dim)).float().to(device)
Y_classifier_torch = torch.tensor(Y_classifier).to(device)

N_classifier_val = 10000
N_classifier_train = N_classifier - N_classifier_val

def train_classifier(model, optimizer, X, Y, X_v, Y_v, batch_size=50, num_epochs=1, verbose=False):
    """ Train the classification model

    """
    model.train()
    N = X.shape[0]
    if Y.shape[0] != N:
        raise ValueError('ERROR: Number of labels ({}) != Number of images ({})!'.format(Y.shape[0], N))

    num_batches = N // batch_size
    for e in range(num_epochs):
        train_loss = 0.
        batch_indices = np.arange(N, dtype=np.int)
        np.random.shuffle(batch_indices)

        for b in range(num_batches):
            this_batch_indices = batch_indices[b*batch_size:(b+1)*batch_size]
            X_batch = X[this_batch_indices]
            Y_batch = Y[this_batch_indices]

            optimizer.zero_grad()

            predict_batch = model(X_batch)
            ce_loss = torch.nn.CrossEntropyLoss()(predict_batch, Y_batch)
            train_loss += ce_loss.item()
            ce_loss.backward()
            
            optimizer.step()

        if ((e+1) % 5) == 0:
            # check validation loss as well
            model.eval()
            predict_val = model(X_v)
            validation_loss = torch.nn.CrossEntropyLoss()(predict_val, Y_v)
            validation_accuracy = torch.mean((Y_v == torch.argmax(predict_val, dim=1)).float()) * 100.
            
            if verbose:
                print('Epoch: {0:5d} \t Training Loss: {1:5g} \t Val Loss: {2:5g} \t Val Acc: {3:4g}%'.format(e+1,
                                                                                     train_loss / num_batches,
                                                                                     validation_loss,
                                                                                     validation_accuracy))
            # set model back into training mode
            model.train()

    return model

torchsummary.summary(classifier_net, (1, img_dim, img_dim))

batch_size = 1000
num_epochs = 250

try:
    classifier_net = torch.load('./MNIST_classifier.pt')
except:
    classifier_net = train_classifier(classifier_net,
                                      classifier_optimizer,
                                      X_classifier_torch[:N_classifier_train],
                                      Y_classifier_torch[:N_classifier_train],
                                      X_classifier_torch[N_classifier_train:N_classifier_train+N_classifier_val],
                                      Y_classifier_torch[N_classifier_train:N_classifier_train+N_classifier_val],
                                      batch_size,
                                      num_epochs,
                                      verbose=True)
    torch.cuda.empty_cache()
    torch.save(classifier_net, './MNIST_classifier.pt')


In [None]:
""" Add noise to MNIST digits in the Fourier domain """
def add_noise(X, num_lines_removed):
    N = X.shape[0]
    K = ft.fftshift(ft.fft2(X), axes=(1,2))
    num_img = K.shape[0]
    img_dim = K.shape[-1]
    K_degraded = K.copy()
    
    for n in range(num_img):
        lines = np.arange(img_dim)
        
        np.random.shuffle(lines)
        for l in lines[:num_lines_removed]:
            K_degraded[n, l] = 0
        
        np.random.shuffle(lines)
        for l in lines[:num_lines_removed]:
            K_degraded[n, :, l] = 0
        
    X_recon = np.abs(ft.ifft2(K_degraded))
    # min already 0 due to np.abs
    X_recon = X_recon / np.max(X_recon, axis=(1,2)).reshape((N, 1, 1))
    return X_recon

In [None]:
""" Train network to get original images back - then train one with digit-specific topological priors """

def train_model_supervised(model, optimizer, X, Y, X_v, Y_v, batch_size=50, num_epochs=1, verbose=False):
    """ Train the segmentation model

    Parameters
    ----------
    model - Pytorch model
    optimizer - Pytorch optimizer
    X - training images
    Y - training labels
    X_v - validation images
    Y_v - validation labels
    batch_size - int - batch size for training
    num_epochs - int - number of full epochs to train for
    verbose - bool - if True, print training information

    Returns
    -------
    model - trained Pytorch model

    Notes
    -----

    """
    model.train()
    N = X.shape[0]
    if Y.shape[0] != N:
        raise ValueError('ERROR: Number of labels ({}) != Number of images ({})!'.format(Y.shape[0], N))

    num_batches = N // batch_size
    for e in range(num_epochs):
        train_loss = 0.
        batch_indices = np.arange(N, dtype=np.int)
        np.random.shuffle(batch_indices)

        for b in range(num_batches):
            this_batch_indices = batch_indices[b*batch_size:(b+1)*batch_size]
            X_batch = X[this_batch_indices]
            Y_batch = Y[this_batch_indices]

            optimizer.zero_grad()

            predict_batch = model(X_batch)
            bce_loss = F.binary_cross_entropy(predict_batch, Y_batch)
            train_loss += bce_loss.item()
            bce_loss.backward()

            optimizer.step()

        if ((e+1) % 10) == 0:
            # check validation loss as well
            model.eval()
            predict_val = model(X_v)
            validation_loss = nn.MSELoss()(predict_val, Y_v)
            
            if verbose:
                print('Epoch: {0:5d} \t Training Loss: {1:5g} \t Validation Loss: {2:5g}'.format(e+1,
                                                                                     train_loss / num_batches,
                                                                                     validation_loss))
            model.train()

    return model

In [None]:
""" Create noisy images"""
l = 8
X_noise = add_noise(X_denoise, l)   
X_noise_torch = torch.tensor(X_noise.reshape(-1, 1, img_dim, img_dim)).float().to(device)
X_denoise_torch = torch.tensor(X_denoise.reshape(-1, 1, img_dim, img_dim)).float().to(device)

X_test_noise = add_noise(X_test, l)
X_test_noise_torch = torch.tensor(X_test_noise.reshape(-1, 1, img_dim, img_dim)).float().to(device)
X_test_torch = torch.tensor(X_test.reshape(-1, 1, img_dim, img_dim)).float().to(device)

In [None]:
""" Train U-net to denoise the MNIST images """
N_denoise_train = 100
N_denoise_val = 100

model = Segmenter_Unet(img_dim=img_dim,
                     num_filters=16)

model = model.to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-4)

model = train_model_supervised(model, optimizer,
                X_noise_torch[:N_denoise_train],
                X_denoise_torch[:N_denoise_train],
                X_noise_torch[N_denoise_train:N_denoise_train+N_denoise_val],
                X_denoise_torch[N_denoise_train:N_denoise_train+N_denoise_val],
                batch_size=min(1000, N_denoise_train), num_epochs=1000, verbose=True)

In [None]:
""" Predict on test set """
model.eval()
with torch.no_grad():
    Z_predicted = model(X_test_noise_torch)
Z_predicted_np = Z_predicted.cpu().numpy()[:,0]
print(Z_predicted_np.shape)
print(X_test_noise_torch.shape)

In [None]:
""" Assess quality of reconstructed images by passing them through pre-trained MNIST classifier
if the classifier can correctly tell what they are then they look good, since it is ~98% accurate on real digits """
Y_test_torch = torch.tensor(Y_test)
Z_digit_prediction = classifier_net(Z_predicted)
X_noise_digit_prediction = classifier_net(X_test_noise_torch)
X_digit_prediction = classifier_net(X_test_torch)
print(Z_digit_prediction.shape)
print(X_noise_digit_prediction.shape)
print(X_digit_prediction.shape)

In [None]:
print(torch.mean((torch.argmax(X_noise_digit_prediction, dim=1).cpu() == Y_test_torch).float()))
print(torch.mean((torch.argmax(X_digit_prediction, dim=1).cpu() == Y_test_torch).float()))
print(torch.mean((torch.argmax(Z_digit_prediction, dim=1).cpu() == Y_test_torch).float()))

In [None]:
"""Optimise topological loss on a single case to get some nice pictures for the paper
train network with some specific set of parameters, then observe change in output reconstruction
when topological priors are applied to the output and the network's weights adjusted """ 

H_1 = {0:1, 1:0} # 1, 2, 3, 4, 5, 7
H_0 = {0:1, 1:1} # 0, 6, 9
H_8 = {0:1, 1:2} # 8

# correct topology for each digit
H_dict = {0:H_0,
          1:H_1,
          2:H_1, # note this will close the loop on the 2 - some interesting cases here?
          3:H_1,
          4:H_1,
          5:H_1,
          6:H_0,
          7:H_1,
          8:H_8,
          9:H_0}

dgminfo = LevelSetLayer2D(size=(28,28), sublevel=False, maxdim=1)
l2_loss = nn.MSELoss()

In [None]:
original_network_correct = torch.argmax(Z_digit_prediction, dim=1).cpu() == Y_test_torch
print(original_network_correct[:20])

In [None]:
i = 3
f = plt.figure(figsize=(15,5))
(ax1, ax2, ax3) = f.subplots(1,3)
ax1.imshow(X_test_noise[i])
ax1.set_xticks([])
ax1.set_yticks([])
ax2.imshow(X_test[i])
ax2.set_xticks([])
ax2.set_yticks([])
ax3.imshow(Z_predicted_np[i])
ax3.set_xticks([])
ax3.set_yticks([])
print('Ground truth: {}'.format(Y_test[i]))
print('Predicted as: {}'.format(torch.argmax(Z_digit_prediction[i]).item()))
print('Logits:')
print(Z_digit_prediction[i])

In [None]:
model_topo = copy.deepcopy(model)
optimizer = torch.optim.Adam(model_topo.parameters(), lr=1e-5)
num_iter_topo = 100
digit_i = Y_test[i]
H_i = H_dict[digit_i]

print(digit_i)
print(H_i)

original_model_output = model(X_test_noise_torch[i:i+1]).cpu().detach() # detach to avoid second pass error

L_sqdiff_weight = 10 # hyper-parameter
max_k = 20 # only consider this many bars - most will be 0-length anyway

L_list = []
for t in range(num_iter_topo):
    optimizer.zero_grad()
    Z_cuda = model_topo(X_test_noise_torch[i:i+1])
    Z_cpu = Z_cuda.cpu()
    a = dgminfo(Z_cpu)

    L0 = (TopKBarcodeLengths(dim=0, k=max_k)(a)**2).sum()
    dim_1_sq_bars = TopKBarcodeLengths(dim=1, k=max_k)(a)**2
    bar_signs = torch.ones(max_k)
    bar_signs[:H_i[1]] = -1
    L1 = (dim_1_sq_bars * bar_signs).sum()

    L_sqdiff = l2_loss(original_model_output, Z_cpu) * L_sqdiff_weight
    L = L0 + L1 + L_sqdiff
    L.backward()
    L_list.append(L.item())
    optimizer.step()

    ground_truth_mask = X_test_torch[i:i+1][0,0].cpu().detach()
    original_predicted_mask = original_model_output[0,0]
    topo_predicted_mask = Z_cpu[0,0].detach()

In [None]:
f = plt.figure(figsize=(20,5))
(ax1, ax2, ax3, ax4) = f.subplots(1,4)
ax1.imshow(X_test_noise[i])
ax1.set_xticks([])
ax1.set_yticks([])
ax2.imshow(X_test[i])
ax2.set_xticks([])
ax2.set_yticks([])
ax3.imshow(Z_predicted_np[i])
ax3.set_xticks([])
ax3.set_yticks([])
ax4.imshow(Z_cpu[0,0].detach().numpy(), cmap='gray')
ax4.set_xticks([])
ax4.set_yticks([])

Z_topo_digit_prediction_i = classifier_net(Z_cuda)

print(Y_test[i])
print(torch.argmax(Z_topo_digit_prediction_i).item())
print(Z_topo_digit_prediction_i)

In [None]:
def diag_tidy(diag, eps=1e-1):
    new_diag = []
    for _, x in diag:
        if np.abs(x[0] - x[1]) > eps:
            new_diag.append((_, x))
    return new_diag

plt.figure(figsize=(3,3))
plt.imshow(Z_predicted_np[i], cmap='gray')
plt.xticks([])
plt.yticks([])
plt.colorbar()
plt.show()

cc = gd.CubicalComplex(dimensions=(img_dim, img_dim),
               top_dimensional_cells=1-Z_predicted_np[i].flatten())

diag = cc.persistence()
plt.figure(figsize=(3,3))
diag_clean = diag_tidy(diag, 1e-3)
gd.plot_persistence_barcode(diag_clean)
plt.ylim(-1, len(diag_clean))
plt.xticks(ticks=np.linspace(0, 1, 6), labels=np.round(np.linspace(1, 0, 6), 2))
plt.yticks([])
plt.show()

In [None]:
plt.figure(figsize=(3,3))
plt.imshow(Z_cpu[0,0].detach().numpy(), cmap='gray')
plt.xticks([])
plt.yticks([])
plt.colorbar()
plt.show()

cc = gd.CubicalComplex(dimensions=(img_dim, img_dim),
               top_dimensional_cells=1-Z_cpu[0,0].detach().numpy().flatten())
diag = cc.persistence()

plt.figure(figsize=(3,3))
diag_clean = diag_tidy(diag, 1e-3)
gd.plot_persistence_barcode(diag_clean)
plt.ylim(-1, len(diag_clean))
plt.xticks(ticks=np.linspace(0, 1, 6), labels=np.round(np.linspace(1, 0, 6), 2))
plt.yticks([])