In [None]:
# Imports
import torch
import torch.nn as nn
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

In [None]:
# Loading Binarized MNIST dataset
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST(root = './data', train = True, download = True , transform = transform)
test_dataset = datasets.MNIST(root = './data', train = False, download = True, transform = transform)
train_loader = DataLoader(train_dataset, batch_size = 64, shuffle = True)
test_loader = DataLoader(test_dataset, batch_size = 64, shuffle = True)

In [None]:
print(len(train_loader))

In [None]:
device = torch.device("cuda")
print(device)

In [None]:
class NADE(nn.Module):
    
    def custom_weight_input_layer(self,module):
        torch.nn.init.normal_(module.weight)
        #Weight Matrix is a Lower Triangular Matrix so that Hidden_Layer_i receives inputs only from Input_layers<i 
        module.weight.data = torch.tril(module.weight.data ,diagonal = -1)
        
        
    def custom_weight_hidden_layer(self,module):
        torch.nn.init.normal_(module.weight)
        #Weight Matrix is a diagonal Matrix such that Output_i is only affected by Hidden_Layer_i
        module.weight.data = torch.triu(torch.tril(module.weight.data))
    
        #Model Constructor
    def __init__(self,input_dim):
        super(NADE, self).__init__()
        
        #Input layer with "input_dim" units
        self.input_layer = nn.Linear(input_dim, input_dim)
        #Manually Initializing Weight Matrix between Input and Hidden Layer
        self.input_layer.apply(self.custom_weight_input_layer)
        #Hidden Layer 
        self.hidden_layer = nn.Linear(input_dim, input_dim)
        #Manually Initializing Weight Matrix between Hidden Layer and Outputs
        self.hidden_layer.apply(self.custom_weight_hidden_layer)
        
    #Zero-ing out unnecesaryily acquired gradients during backpropagation
    def zero_triu_gradient(self,):
        self.input_layer.weight.data = torch.tril(self.input_layer.weight.data, diagonal = -1)
        
    def zero_off_diagonal_gradient(self,):
        self.hidden_layer.weight.data = torch.triu(torch.tril(self.hidden_layer.weight.data))
        
    #Forward Pass
    def forward(self,x):
        
        out = self.input_layer(x)
        out = torch.sigmoid(out)
        out = self.hidden_layer(out)
        out = torch.sigmoid(out)
        
        return out
    
    def sample(self, x, output_dim):
        
        sampled_image = torch.zeros(1,output_dim).to(device)
        sampled_image[0][0] = x
        
        for pixel in range(1,784):
            
            new_pixel = model(sampled_image).to(device)
            sampled_image[0][pixel] = new_pixel[0][pixel]
            
        return sampled_image
        

In [None]:
model = NADE(784)
model.to(device)

In [None]:
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [None]:
num_epochs = 100

for epoch in range(1,num_epochs + 1):
    total_loss = 0.0
    for images, _ in train_loader:
        
        optimizer.zero_grad()
        images = images.view(images.shape[0], -1).to(device)
        outputs = model(images).to(device)
        loss = criterion(outputs, images)
        loss.backward()
        model.zero_triu_gradient()
        model.zero_off_diagonal_gradient()
        optimizer.step()
        
        total_loss += loss.item()
    
    total_loss /= 64
    print("Epoch [{}/{}], Loss {:.4f}".format(epoch, num_epochs, total_loss))
        

In [None]:
import numpy as np
new_image = model.sample(0.8,784)

In [None]:
import matplotlib.pyplot as plt

In [None]:
image_2d = new_image.view(28,28)
img = image_2d.cpu().detach().numpy()
plt.figure(figsize=(1, 1))
plt.imshow(img, cmap = 'gray')
plt.axis('off')  # Turn off axis
plt.show()
  