In [None]:
import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torch.utils.data as loader
import time

In [None]:
device="cuda:0" if torch.cuda.is_available() else "cpu"
print(device)

In [None]:
transform=transforms.Compose([
                              transforms.ToTensor(),
                              transforms.Normalize((0.5,),(0.5,)),
])
dataset=torchvision.datasets.MNIST(root='./',train=True,transform=transform,download=True)
data_loader=loader.DataLoader(dataset,batch_size=100,shuffle=True)

In [None]:
# We will use a normal deterministic encoder, which is same as the one used in an ordinary autoencoder
class encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.block=nn.Sequential(
            nn.Linear(784,1000),
            nn.Dropout(p=.25),
            nn.ReLU(True),
            nn.Linear(1000,1000),
            nn.Dropout(p=.25),
            nn.ReLU(True),
            nn.Linear(1000,8),
        )

    def forward(self,x):
        bsize=x.size(0)
        x=x.view(bsize,-1)
        return self.block(x)

In [None]:
class decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.block=nn.Sequential(
            nn.Linear(8,1000),
            nn.Dropout(p=.25),
            nn.ReLU(True),
            nn.Linear(1000,1000),
            nn.Dropout(p=.25),
            nn.ReLU(True),
            nn.Linear(1000,784),
        )
    
    def forward(self,x):
        x=self.block(x)
        return F.sigmoid(x)

In [None]:
class discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.block=nn.Sequential(
            nn.Linear(8,1000),
            nn.Dropout(p=.2),
            nn.ReLU(True),
            nn.Linear(1000,1000),
            nn.Dropout(p=.2),
            nn.ReLU(True),
            nn.Linear(1000,1)
        )
    def forward(self,x):
        x=self.block(x)
        return F.sigmoid(x)

In [None]:
enc=encoder().to(device)
dec=decoder().to(device)
D_=discriminator().to(device)

In [None]:
op_enc=optim.Adam(enc.parameters(),lr=6e-4)
op_dec=optim.Adam(dec.parameters(),lr=6e-4)
op_gen=optim.Adam(enc.parameters(),lr=8e-4)
op_disc=optim.Adam(D_.parameters(),lr=8e-4)

In [None]:
# recon_loss=nn.MSELoss()
# gen_loss=nn.BCEWithLogitsLoss()
# disc_loss=nn.BCELoss()

In [None]:
num_epochs=100

In [None]:
import warnings
warnings.filterwarnings("ignore")

In [None]:
recloss=[]
dloss=[]
gloss=[]
TINY=1e-8
for epoch in range(num_epochs):
    reconst_loss=.0
    dis_loss=.0
    gent_loss=.0
    start=time.time()
    for i,data in enumerate(data_loader):
        enc.train()
        dec.train()
        D_.train()

        # Updating autoencoder network
        op_enc.zero_grad(),op_dec.zero_grad()
        data=data[0].to(device) # We only need images
        bsize=data.size(0)
        z_gen=enc(data)
        out=dec(z_gen)
        # out=out.view(bsize,1,28,28)
        # recon=recon_loss(out,data)
        recon=F.binary_cross_entropy(out.view(bsize,-1)+TINY,data.view(bsize,-1)+TINY)
        recon.backward()
        op_enc.step()
        op_dec.step()
        reconst_loss+=recon.item()

        # Updating discriminator
        enc.eval()
        op_disc.zero_grad()
        z_real=(torch.randn(bsize,8)*5).to(device).requires_grad_(True) # Sample from N(0,5)
        z_gen=enc(data)
        D_real,D_gen=D_(z_real),D_(z_gen)
        # D_loss=disc_loss(D_real,torch.ones((bsize,1)).to(device)) + disc_loss(D_gen,torch.zeros((bsize,1)).to(device))
        D_loss=-torch.mean(torch.log(D_real+TINY)+torch.log(1-D_gen+TINY))
        D_loss.backward()
        op_disc.step()
        dis_loss+=D_loss.item()

        # Updating generator (encoder)
        enc.train()
        op_gen.zero_grad()
        D_.eval()
        z_gen=enc(data)
        D_gen=D_(z_gen)
        # g_loss=gen_loss(D_gen,torch.ones((bsize,1)).to(device))
        g_loss=-torch.mean(torch.log(D_gen+TINY))
        g_loss.backward()
        op_gen.step()
        gent_loss+=g_loss.item()

    print("[%d/%d] recon_loss: %.4f dis_loss: %.4f gen_loss: %.4f time elapsed: %.4f"%(epoch+1,num_epochs,reconst_loss,dis_loss,gent_loss,time.time()-start))
    recloss.append(reconst_loss)
    dloss.append(dis_loss)
    gloss.append(gent_loss)

In [None]:
dec.eval()
samp=(torch.randn(1,8)*5).float().to(device)
plt.imshow(dec(samp).reshape(28,28).squeeze().detach().cpu().numpy())
plt.show()

In [None]:
torch.save(dec.state_dict(),'aae_decoder_2.pth')

In [None]:
from google.colab import files
try :
    files.download("aae_decoder_2.pth")
except :
    files.download("aae_decoder_2.pth")

In [None]:
plt.plot(recloss,label='recombination loss')
plt.plot(dloss,label='discriminator loss')
plt.plot(gloss,label='gen loss')
plt.legend()
plt.show()

In [None]:
z_real=(torch.randn(bsize,8)*5).to(device)
print(z_real.requires_grad_(True))

In [None]:
z_real.requires_grad=True

In [None]:
z_real.requires_grad

In [None]:
dec.load_state_dict(torch.load('aae_decoder.pth'))

In [None]:
dec.eval()
plt.set_cmap('Greys_r')
samp=(torch.randn(1,8)*5).float().to(device)
plt.imshow(dec(samp).reshape(28,28).squeeze().detach().cpu().numpy())
plt.show()

In [None]:
samp1=samp

In [None]:
samp2=samp

In [None]:
samp3=samp

In [None]:
f, axarr = plt.subplots(ncols=10)
dec.eval()
plt.set_cmap('Greens')
plt.axis('off')
m=(samp3-samp1)/10
for i in range(10):
    latz=m*(i+1)+samp1
    image=dec(latz).reshape(28,28).detach().cpu().numpy()
    axarr[i].imshow(image)
    axarr[i].axis("off")
plt.savefig('lin_intpolate 4-5 (better).png',bbox_inches='tight')
plt.show()