In [1]:
#Importing Modules
import torch
import torch.nn as nn
from torchvision.utils import save_image
import torch.optim as optim
from torchvision import datasets,transforms
import torchvision
from torch.utils.data import DataLoader

In [2]:
#Creating the model
class VAEmodel(nn.Module):
  def __init__(self):
   super().__init__()

   self.common_fc=nn.Sequential(
       nn.Linear(784,196),
       nn.Tanh(),
       nn.Linear(196,48),
       nn.Tanh()
   )
   self.mean_fc=nn.Sequential(  #For the mean part of VAE
       nn.Linear(48,16),
       nn.Tanh(),
       nn.Linear(16,2)
   )
   self.log_fc=nn.Sequential(  #Logarithm of variance
       nn.Linear(48,16),
       nn.Tanh(),
       nn.Linear(16,2)
   )
   self.decoder_fc=nn.Sequential(
       nn.Linear(2,16),
       nn.Tanh(),
       nn.Linear(16,48),
       nn.Tanh(),
       nn.Linear(48,196),
       nn.Tanh(),
       nn.Linear(196,784),
       nn.Tanh(),
   )

  def forward(self,x):
    mean, log_var=self.encode(x)
    z=self.sample(mean,log_var)
    out= self.decode(z)
    return mean,log_var,out

  def encode(self,x): #Encoding part
    out=self.common_fc(torch.flatten(x,start_dim=1))
    mean=self.mean_fc(out)
    log_var=self.log_fc(out)
    return mean, log_var

  def sample(self,mean,log_var): #Generating data using mean and standard deviation
    std=torch.exp(0.5*log_var)
    z=torch.randn_like(std)
    z=z*std+mean
    return z

  def decode(self,z): #Decoding to generate an image
    out=self.decoder_fc(z)
    out=out.reshape(z.size(0),28,28)
    return out






In [3]:

def train_vae():
  transform=transforms.ToTensor()
  mnist_tr=datasets.MNIST(root='\data',transform=transform,download=True,train=True) #Getting training data

  mnist_ts=datasets.MNIST(root='\data',transform=transform,download=True,train=False) #Getting testing data

  train_data=DataLoader(mnist_tr.data,batch_size=64)

  test_data=(mnist_ts.data-127.5)/127.5 #Scaling from -1 to 1

  n_epochs=8
  model=VAEmodel()
  opti=optim.Adam(model.parameters(),lr=1e-3) #Defining optimizer function
  criteria=nn.MSELoss()

  recon=[] #Used to store the Mean Square error
  kl=[]    #Storing the KL Dvergence error
  loss=[]  #Storing the total error (MSE and KL Divergence error)

  for epoch in range(n_epochs):
   for id, image in enumerate(train_data):
    image=image.float()
    image=(image-127.5)/127.5
    mean,log_var,out=model(image) #Getting the mean,log variance and output
    kl_loss=torch.mean(0.5*torch.sum(torch.exp(log_var)+mean**2-1-log_var,dim=-1)) #Calculating KL loss
    r_loss=criteria(out,image)
    l=r_loss+0.00001*kl_loss #Actual loss function
    recon.append(r_loss.item())
    kl.append(kl_loss.item())
    loss.append(l.item())
    l.backward() #Calculating Gradients
    opti.step()  #Gradient Descent
    opti.zero_grad() #Restoring gradient's to 0

   print(f'Epoch{epoch} done ')

  idxs=torch.randint(0,len(test_data)-1,(100,))
  ims=(test_data[idxs]).float() #Getting test image
  _,_,gen=model(ims)

  ims=(ims+1)/2
  gray=ims.unsqueeze(1)
  gen=1-(gen+1)/2
  print(ims.shape)

  save_image(gray[:],"saved.png",nrow=10) #Storing the outcomes of 100 test images

train_vae()




Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to \data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:01<00:00, 8924362.04it/s] 


Extracting \data/MNIST/raw/train-images-idx3-ubyte.gz to \data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to \data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 2266548.67it/s]

Extracting \data/MNIST/raw/train-labels-idx1-ubyte.gz to \data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz





Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to \data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 12515343.84it/s]


Extracting \data/MNIST/raw/t10k-images-idx3-ubyte.gz to \data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to \data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 5418239.13it/s]

Extracting \data/MNIST/raw/t10k-labels-idx1-ubyte.gz to \data/MNIST/raw






Epoch0 done 
Epoch1 done 
Epoch2 done 
Epoch3 done 
Epoch4 done 
Epoch5 done 
Epoch6 done 
Epoch7 done 
torch.Size([100, 28, 28])
