In [3]:
import torch
from torch import nn
import einops
from tqdm import tqdm
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset, Subset
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import os
from autoencoder import *

In [4]:
def VAELoss(input,mu,logsigma,target,beta_kl):   
    reconstruction_loss = nn.MSELoss()(target, input)
    kl_div_loss = (0.5 * (mu ** 2 + torch.exp(2 * logsigma) - 1) - logsigma).mean() * beta_kl
    return reconstruction_loss + kl_div_loss

In [5]:
def get_dataset(train: bool = True) -> Dataset:
    img_size = 28
    transform = transforms.Compose([
        transforms.Resize(img_size),
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    dataset = datasets.MNIST(
        root = "/data",
        train = train,
        transform = transform,
        download = True,
    )

    return dataset

In [6]:
trainset_mnist = get_dataset()
testset_mnist = get_dataset(train=False)

sub_trainset_size=5000
sub_trainset_mnist=Subset(trainset_mnist,np.random.choice(len(trainset_mnist),size=sub_trainset_size))
single_loader=DataLoader(sub_trainset_mnist,batch_size=1)

sub_trainset_mnist=torch.zeros((len(sub_trainset_mnist),28*28+1))
for k,batch in enumerate(single_loader):
    sub_trainset_mnist[k,:28*28]=batch[0].reshape(28*28)
    sub_trainset_mnist[k,-1]=batch[1].item()

In [7]:
def random_noise_corruption_dataset(dataset,amplitude):
    dataset_corrupted=torch.zeros((len(dataset),28*28+1))
    for k in range(dataset_corrupted.shape[0]):
        dataset_corrupted[k,:28*28]=dataset[k,:28*28]+amplitude*(2*torch.rand(28*28)-1)
    return dataset_corrupted

def trailing_corruption(dataset,length,p=0.5):
    dataset_corrupted=torch.zeros((len(dataset),28*28+1))
    for k in range(dataset_corrupted.shape[0]):
        img=dataset[k,:28*28].reshape(28,28)
        for row in range(28):
            col=0
            while col<28:
                if col>=length:
                    if (img[row,col-length:col]<0).sum()==length and (img[row,col]>0):
                        if np.random.random()>p:
                            for l in range(length+1):
                                dataset_corrupted[k,28*row+col-l]=img[row,col]
                        else:
                            dataset_corrupted[k,28*row+col]=img[row,col]
                    else:
                        dataset_corrupted[k,28*row+col]=img[row,col]
                else:
                    dataset_corrupted[k,28*row+col]=img[row,col]
                col+=1
    return dataset_corrupted

In [8]:
def initial_mse(dataset,dataset_corrupted):
    initial_mse=0
    for n in range(dataset.shape[0]):
        initial_mse+=torch.norm(dataset[k,:28*28]-dataset_corrupted[k,:28*28],p=2)
    initial_mse=initial_mse/dataset.shape[0]
    return initial_mse

In [9]:
def plot_initial_samples(dataset,dataset_corrupted,num_sample_indices,directory,file_prefix):

    sample_indicies=np.random.randint(low=0,high=dataset.shape[0],size=num_sample_indices)

    mnist_sample=dataset[sample_indicies]
    mnist_sample_corrupted=dataset_corrupted[sample_indicies]

    fig,axs=plt.subplots(nrows=2,ncols=num_sample_indices,layout="constrained")

    for k in range(num_sample_indices):
        axs[0][k].imshow(mnist_sample[k,:28*28].reshape(28,28),cmap='gray')
        axs[0][k].axis('off')

    for k in range(num_sample_indices):
        axs[1][k].imshow(mnist_sample_corrupted[k,:28*28].reshape(28,28).detach().numpy(),cmap='gray')
        axs[1][k].axis('off')
    plt.savefig(f'{directory}/{file_prefix}_initial.png')

In [10]:
def train_ae(dataset,model,opt,criterion,train_args,progress_bar_text,ground_truth=False):
    dataset_size=dataset.shape[0]
    vector_size=dataset.shape[1]-1
    progress_bar=tqdm(range(train_args["epochs"]))
    for epoch in progress_bar:
        perm=torch.randperm(dataset_size)
        epoch_dataset=dataset[perm,:]
        epoch_ground_truth=train_args["ground_truth_set"][perm,:]
        epoch_cycles=dataset_size//train_args["batch_size"]
        count=0
        epoch_loss=0
        for k in range(epoch_cycles+1):
            if k==epoch_cycles:
                n=dataset_size-k*train_args["batch_size"]
                img=epoch_dataset[-n:,:vector_size].reshape((n,1,28,28))
                img_reconstructed = model(img)[0]
                if ground_truth:
                    loss = criterion(epoch_ground_truth[-n:,:vector_size].reshape((n,1,28,28)), img_reconstructed)
                else:
                    loss = criterion(img, img_reconstructed)
                loss.backward()
                opt.step()
                opt.zero_grad()
                epoch_loss+=loss
                count+=n
            else:
                img=epoch_dataset[k*train_args["batch_size"]:(k+1)*train_args["batch_size"],:vector_size].reshape((train_args["batch_size"],1,28,28))
                img_reconstructed = model(img)[0]
                if ground_truth:
                    loss = criterion(epoch_ground_truth[k*train_args["batch_size"]:(k+1)*train_args["batch_size"],:vector_size].reshape((train_args["batch_size"],1,28,28)), img_reconstructed)
                else:
                    loss = criterion(img, img_reconstructed)
                loss.backward()
                opt.step()
                opt.zero_grad()
                epoch_loss+=loss
                count+=train_args["batch_size"]
        progress_bar.set_description(f"{progress_bar_text} epoch={epoch+1}, loss={epoch_loss/count:.4f}")
    return epoch_loss/count

def train_vae(dataset,model,opt,train_args,progress_bar_text,ground_truth=False):
    dataset_size=dataset.shape[0]
    vector_size=dataset.shape[1]-1
    progress_bar=tqdm(range(train_args["epochs"]))
    for epoch in progress_bar:
        perm=torch.randperm(dataset_size)
        epoch_dataset=dataset[perm,:]
        epoch_ground_truth=train_args["ground_truth_set"][perm,:]
        epoch_cycles=dataset_size//train_args["batch_size"]
        count=0
        epoch_loss=0
        for k in range(epoch_cycles+1):
            if k==epoch_cycles:
                n=dataset_size-k*train_args["batch_size"]
                img=epoch_dataset[-n:,:vector_size].reshape((n,1,28,28))
                img_reconstructed,mu,logsigma,z = model(img)
                if ground_truth:
                    loss = VAELoss(epoch_ground_truth[-n:,:vector_size].reshape((n,1,28,28)),mu,logsigma, img_reconstructed,train_args["beta_kl"])
                else:
                    loss = VAELoss(epoch_ground_truth[-n:,:vector_size].reshape((n,1,28,28)),mu,logsigma, img_reconstructed,train_args["beta_kl"])
                loss.backward()
                opt.step()
                opt.zero_grad()
                epoch_loss+=loss
                count+=n
            else:
                img=epoch_dataset[k*train_args["batch_size"]:(k+1)*train_args["batch_size"],:vector_size].reshape((train_args["batch_size"],1,28,28))
                img_reconstructed,mu,logsigma,z = model(img)
                if ground_truth:
                    loss = VAELoss(epoch_ground_truth[k*train_args["batch_size"]:(k+1)*train_args["batch_size"],:vector_size].reshape((train_args["batch_size"],1,28,28)),mu,logsigma, img_reconstructed,train_args["beta_kl"])
                else:
                    loss = VAELoss(epoch_ground_truth[k*train_args["batch_size"]:(k+1)*train_args["batch_size"],:vector_size].reshape((train_args["batch_size"],1,28,28)),mu,logsigma, img_reconstructed,train_args["beta_kl"])
                loss.backward()
                opt.step()
                opt.zero_grad()
                epoch_loss+=loss
                count+=train_args["batch_size"]
        progress_bar.set_description(f"{progress_bar_text} epoch={epoch+1}, loss={epoch_loss/count:.4f}")
    return epoch_loss/count

In [11]:
def update_dataset(model,original_dataset,num_to_replace,num_samples,return_sample=True):
    with torch.no_grad():
        random_indices_sample=np.random.choice(num_to_replace,size=num_samples)
        N=original_dataset.shape[0]
        replaced_dataset=torch.zeros_like(original_dataset)
        replaced_dataset[num_to_replace:,:]=original_dataset[num_to_replace:,:]
        original_tensors_sample=torch.zeros((5,28*28))
        updated_tensors_sample=torch.zeros((5,28*28))
        count=0
        for k in range(num_to_replace):
            batch=original_dataset[k,:]
            if k in random_indices_sample:
                original_tensors_sample[count,:]=batch[:28*28]
            replaced_dataset[k,:28*28]=model(batch[:28*28].reshape(1,1,28,28))[0].reshape(28*28)
            replaced_dataset[k,-1]=batch[-1].item()
            if k in random_indices_sample:
                updated_tensors_sample[count,:]=replaced_dataset[N-num_to_replace+k,:28*28]
                count+=1
        if return_sample:
            return replaced_dataset,original_tensors_sample,updated_tensors_sample
        else:
            return replaced_dataset

In [12]:
def iterate_autoencoder(dataset,ground_truth_set,num_interations,num_replacement,num_samples,progress_bar_text,model_args={"latent_dim_size":32,"hidden_dim_size":128},train_args={"batch_size":32,"epochs":10,"lr":1e-3,"betas":(0.5, 0.999)},retrain=True):
    samples=[]
    losses=[]
    train_args["ground_truth_set"]=ground_truth_set
    for k in range(num_interations):
        if k==0:
            iter_progress_bar_text=f'Iteration {k}, '+progress_bar_text
            model=Autoencoder(model_args["latent_dim_size"],model_args["hidden_dim_size"])
            opt=torch.optim.Adam(model.parameters(),lr=train_args["lr"],betas=train_args["betas"])
            criterion=nn.MSELoss()
            losses.append(train_ae(dataset,model,opt,criterion,train_args,iter_progress_bar_text,ground_truth=True).item())
            dataset,original_sample,_=update_dataset(model,dataset,num_replacement,num_samples)
            samples.append(original_sample)
        else:
            iter_progress_bar_text=f'Iteration {k}, '+progress_bar_text
            if retrain:
                model=Autoencoder(model_args["latent_dim_size"],model_args["hidden_dim_size"])
                opt=torch.optim.Adam(model.parameters(),lr=train_args["lr"],betas=train_args["betas"])
                criterion=nn.MSELoss()
            losses.append(train_ae(dataset,model,opt,criterion,train_args,iter_progress_bar_text,ground_truth=True).item())
            dataset,_,ending_sample=update_dataset(model,dataset,num_replacement,num_samples)
            samples.append(ending_sample)
    return samples,losses

def iterate_vautoencoder(dataset,ground_truth_set,num_interations,num_replacement,num_samples,progress_bar_text,model_args={"latent_dim_size":32,"hidden_dim_size":128},train_args={"batch_size":32,"epochs":10,"lr":1e-3,"betas":(0.5, 0.999),"beta_kl":0.1},retrain=True):
    samples=[]
    losses=[]
    train_args["ground_truth_set"]=ground_truth_set
    for k in range(num_interations):
        if k==0:
            iter_progress_bar_text=f'Iteration {k}, '+progress_bar_text
            model=VAE(model_args["latent_dim_size"],model_args["hidden_dim_size"])
            opt=torch.optim.Adam(model.parameters(),lr=train_args["lr"],betas=train_args["betas"])
            losses.append(train_vae(dataset,model,opt,train_args,iter_progress_bar_text,ground_truth=True).item())
            dataset,original_sample,_=update_dataset(model,dataset,num_replacement,num_samples)
            samples.append(original_sample)
        else:
            iter_progress_bar_text=f'Iteration {k}, '+progress_bar_text
            if retrain:
                model=VAE(model_args["latent_dim_size"],model_args["hidden_dim_size"])
                opt=torch.optim.Adam(model.parameters(),lr=train_args["lr"],betas=train_args["betas"])
            losses.append(train_vae(dataset,model,opt,train_args,iter_progress_bar_text,ground_truth=True).item())
            dataset,_,ending_sample=update_dataset(model,dataset,num_replacement,num_samples)
            samples.append(ending_sample)
    return samples,losses

In [18]:
def corruption_ae(corruption,amplitude,dataset,num_sample_indices,num_iteriations,num_replacement,iterations_to_plot,retrain):

    if retrain:
        retrain_text='retrain'
    else:
        retrain_text='no_retrain'
    
    amplitude_text='0'+str(amplitude).split('.')[-1]
    if not(f'{retrain_text}_ae_{corruption}_corruption_{amplitude_text}' in os.listdir()):
        os.mkdir(f'{retrain_text}_ae_{corruption}_corruption_{amplitude_text}')
    directory_name=f'{retrain_text}_ae_{corruption}_corruption_{amplitude_text}'
    file_prefix=f'{retrain_text}_{corruption}_corruption_{amplitude_text}'

    if corruption=='random_noise':
        dataset_corrupted=random_noise_corruption_dataset(dataset,amplitude)
    elif corruption=='trailing':
        dataset_corrupted=trailing_corruption(dataset,amplitude)
    else:
        raise ValueError('Specify valid corruption method')

    plot_initial_samples(dataset,dataset_corrupted,num_sample_indices,directory_name,file_prefix)
    init_mse=initial_mse(dataset,dataset_corrupted)
    file=open(f'{directory_name}/{file_prefix}.txt','w')
    file.write(f'Initial MSE = {init_mse}\n')
    progress_bar_text=f'Amplitude {amplitude} {corruption} corruption:'
    samples,losses=iterate_autoencoder(dataset_corrupted,dataset,num_iteriations,num_replacement,num_sample_indices,progress_bar_text,retrain=retrain)
    file.write(f'Losses = {losses}')
    file.close()

    fig,ax=plt.subplots(nrows=1,ncols=1)

    fig.set_figheight(6)
    fig.set_figwidth(6)


    ax.plot(range(1,len(losses)+1),losses)
    ax.set_xticks(np.arange(0,len(losses)+1,step=2))
    plt.savefig(f'{directory_name}/{file_prefix}_losses.png')

    fig,axs=plt.subplots(nrows=len(iterations_to_plot),ncols=num_sample_indices,layout="tight")

    fig.set_figheight(10)
    fig.set_figwidth(5)

    for row,iteration in enumerate(iterations_to_plot):
        for col in range(num_sample_indices):
            axs[row][col].imshow(samples[iteration][col,:].reshape(28,28),cmap='gray')
            axs[row][col].axis('off')
            if col==0:
                axs[row][col].set_title(str(iteration))
    plt.savefig(f'{directory_name}/{file_prefix}_training_samples.png')

In [14]:
def corruption_vae(corruption,amplitude,dataset,num_sample_indices,num_iteriations,num_replacement,iterations_to_plot,retrain):

    if retrain:
        retrain_text='retrain'
    else:
        retrain_text='no_retrain'

    amplitude_text='0'+str(amplitude).split('.')[-1]
    if not(f'{retrain_text}_vae_{corruption}_corruption_{amplitude_text}' in os.listdir()):
        os.mkdir(f'{retrain_text}_vae_{corruption}_corruption_{amplitude_text}')
    directory_name=f'{retrain_text}_vae_{corruption}_corruption_{amplitude_text}'
    file_prefix=f'{retrain_text}_{corruption}_corruption_{amplitude_text}'

    if corruption=='random_noise':
        dataset_corrupted=random_noise_corruption_dataset(dataset,amplitude)
    elif corruption=='trailing':
        dataset_corrupted=trailing_corruption(dataset,amplitude)
    else:
        raise ValueError('Specify valid corruption method')

    plot_initial_samples(dataset,dataset_corrupted,num_sample_indices,directory_name,file_prefix)
    init_mse=initial_mse(dataset,dataset_corrupted)
    file=open(f'{directory_name}/{file_prefix}.txt','w')
    file.write(f'Initial MSE = {init_mse}\n')
    progress_bar_text=f'Amplitude {amplitude} {corruption} corruption:'
    samples,losses=iterate_vautoencoder(dataset_corrupted,dataset,num_iteriations,num_replacement,num_sample_indices,progress_bar_text,retrain=retrain)
    file.write(f'Losses = {losses}')
    file.close()

    fig,ax=plt.subplots(nrows=1,ncols=1)

    fig.set_figheight(6)
    fig.set_figwidth(6)


    ax.plot(range(1,len(losses)+1),losses)
    ax.set_xticks(np.arange(0,len(losses)+1,step=2))
    plt.savefig(f'{directory_name}/{file_prefix}_losses.png')

    fig,axs=plt.subplots(nrows=len(iterations_to_plot),ncols=num_sample_indices,layout="tight")

    fig.set_figheight(10)
    fig.set_figwidth(5)

    for row,iteration in enumerate(iterations_to_plot):
        for col in range(num_sample_indices):
            axs[row][col].imshow(samples[iteration][col,:].reshape(28,28),cmap='gray')
            axs[row][col].axis('off')
            if col==0:
                axs[row][col].set_title(str(iteration))
    plt.savefig(f'{directory_name}/{file_prefix}_training_samples.png')

In [20]:
corruption_ae('random_noise',0,sub_trainset_mnist,5,20,5000,[0,9,19],retrain=False)
corruption_vae('random_noise',0,sub_trainset_mnist,5,20,5000,[0,9,19],retrain=False)
corruption_ae('random_noise',0.3,sub_trainset_mnist,5,20,5000,[0,9,19],retrain=False)
corruption_vae('trailing',4,sub_trainset_mnist,5,20,5000,[0,9,19],retrain=False)

Iteration 0, Amplitude 0 random_noise corruption: epoch=10, loss=0.0037: 100%|██████████| 10/10 [00:46<00:00,  4.66s/it]
Iteration 1, Amplitude 0 random_noise corruption: epoch=10, loss=0.0034: 100%|██████████| 10/10 [00:48<00:00,  4.88s/it]
Iteration 2, Amplitude 0 random_noise corruption: epoch=10, loss=0.0032: 100%|██████████| 10/10 [00:46<00:00,  4.68s/it]
Iteration 3, Amplitude 0 random_noise corruption: epoch=10, loss=0.0031: 100%|██████████| 10/10 [00:44<00:00,  4.50s/it]
Iteration 4, Amplitude 0 random_noise corruption: epoch=10, loss=0.0030: 100%|██████████| 10/10 [00:43<00:00,  4.36s/it]
Iteration 5, Amplitude 0 random_noise corruption: epoch=10, loss=0.0029: 100%|██████████| 10/10 [00:42<00:00,  4.27s/it]
Iteration 6, Amplitude 0 random_noise corruption: epoch=10, loss=0.0027: 100%|██████████| 10/10 [00:42<00:00,  4.26s/it]
Iteration 7, Amplitude 0 random_noise corruption: epoch=10, loss=0.0026: 100%|██████████| 10/10 [00:48<00:00,  4.85s/it]
Iteration 8, Amplitude 0 random_