<a href="https://colab.research.google.com/github/WilliamAshbee/splineexample/blob/main/pixelcnn_playground.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [52]:
'''
Code by Hrituraj Singh
Indian Institute of Technology Roorkee
'''

from torchvision import datasets, transforms
import configparser
import os

def get_MNIST(path):
    """
    Loads the train and test MNIST data and returns both after 
    trainsforming the images to tensors. Downloads the data if 
    not on local path
    """
    assert os.path.exists(path), 'The dataloading path does not exist!'
    train_data = datasets.MNIST(root=path,
                                train=True,
                                download=True,
                                transform=transforms.ToTensor())
    test_data = datasets.MNIST(root=path,
                               train=False,
                               download=True,
                               transform=transforms.ToTensor())
    return train_data, test_data


def parse_config(filename):
    config = configparser.ConfigParser()
    config.read(filename)
    output = {}
    for section in config.sections():
        output[section] = {}
        for key in config[section]:
            val_str = str(config[section][key])
            if(len(val_str)>0):
                val = parse_value_from_string(val_str) 
            else:
                val = None
            print(section, key,val_str, val)
            output[section][key] = val
    return output



def parse_value_from_string(val_str):
    if(is_int(val_str)):
        val = int(val_str)
    elif(is_float(val_str)):
        val = float(val_str)
    elif(is_list(val_str)):
        val = parse_list(val_str)
    elif(is_bool(val_str)):
        val = parse_bool(val_str)
    else:
        val = val_str
    return val

def is_int(val_str):
    start_digit = 0
    if(val_str[0] =='-'):
        start_digit = 1
    flag = True
    for i in range(start_digit, len(val_str)):
        if(str(val_str[i]) < '0' or str(val_str[i]) > '9'):
            flag = False
            break
    return flag

def is_float(val_str):
    flag = False
    if('.' in val_str and len(val_str.split('.'))==2):
        if(is_int(val_str.split('.')[0]) and is_int(val_str.split('.')[1])):
            flag = True
        else:
            flag = False
    elif('e' in val_str and len(val_str.split('e'))==2):
        if(is_int(val_str.split('e')[0]) and is_int(val_str.split('e')[1])):
            flag = True
        else:
            flag = False       
    else:
        flag = False
    return flag 

def is_bool(var_str):
    if( var_str=='True' or var_str == 'true' or var_str =='False' or var_str=='false'):
        return True
    else:
        return False
    
def parse_bool(var_str):
    if(var_str=='True' or var_str == 'true' ):
        return True
    else:
        return False
     
def is_list(val_str):
    if(val_str[0] == '[' and val_str[-1] == ']'):
        return True
    else:
        return False
    
def parse_list(val_str):
    sub_str = val_str[1:-1]
    splits = sub_str.split(',')
    output = []
    for item in splits:
        item = item.strip()
        if(is_int(item)):
            output.append(int(item))
        elif(is_float(item)):
            output.append(float(item))
        elif(is_bool(item)):
            output.append(parse_bool(item))
        else:
            output.append(item)
    return output


In [53]:
'''
https://github.com/singh-hrituraj/PixelCNN-Pytorch/blob/master/MaskedCNN.py
Code by Hrituraj Singh
Indian Institute of Technology Roorkee
'''
from torch import nn



class MaskedCNN(nn.Conv2d):
	"""
	Implementation of Masked CNN Class as explained in A Oord et. al. 
	Taken from https://github.com/jzbontar/pixelcnn-pytorch
	"""

	def __init__(self, mask_type, *args, **kwargs):
		self.mask_type = mask_type
		assert mask_type in ['A', 'B'], "Unknown Mask Type"
		super(MaskedCNN, self).__init__(*args, **kwargs)
		self.register_buffer('mask', self.weight.data.clone())

		_, depth, height, width = self.weight.size()
		self.mask.fill_(1)
		if mask_type =='A':
			self.mask[:,:,height//2,width//2:] = 0
			self.mask[:,:,height//2+1:,:] = 0
		else:
			self.mask[:,:,height//2,width//2+1:] = 0
			self.mask[:,:,height//2+1:,:] = 0


	def forward(self, x):
		self.weight.data*=self.mask
		return super(MaskedCNN, self).forward(x)





#from MaskedCNN import MaskedCNN
import torch.nn as nn

class PixelCNN(nn.Module):
	"""
	Network of PixelCNN as described in A Oord et. al. 
	"""
	def __init__(self, no_layers=8, kernel = 7, channels=64, device=None):
		super(PixelCNN, self).__init__()
		self.no_layers = no_layers
		self.kernel = kernel
		self.channels = channels
		self.layers = {}
		self.device = device

		self.Conv2d_1 = MaskedCNN('A',1,channels, kernel, 1, kernel//2, bias=False)
		self.BatchNorm2d_1 = nn.BatchNorm2d(channels)
		self.ReLU_1= nn.ReLU(True)

		self.Conv2d_2 = MaskedCNN('B',channels,channels, kernel, 1, kernel//2, bias=False)
		self.BatchNorm2d_2 = nn.BatchNorm2d(channels)
		self.ReLU_2= nn.ReLU(True)

		self.Conv2d_3 = MaskedCNN('B',channels,channels, kernel, 1, kernel//2, bias=False)
		self.BatchNorm2d_3 = nn.BatchNorm2d(channels)
		self.ReLU_3= nn.ReLU(True)

		self.Conv2d_4 = MaskedCNN('B',channels,channels, kernel, 1, kernel//2, bias=False)
		self.BatchNorm2d_4 = nn.BatchNorm2d(channels)
		self.ReLU_4= nn.ReLU(True)

		self.Conv2d_5 = MaskedCNN('B',channels,channels, kernel, 1, kernel//2, bias=False)
		self.BatchNorm2d_5 = nn.BatchNorm2d(channels)
		self.ReLU_5= nn.ReLU(True)

		self.Conv2d_6 = MaskedCNN('B',channels,channels, kernel, 1, kernel//2, bias=False)
		self.BatchNorm2d_6 = nn.BatchNorm2d(channels)
		self.ReLU_6= nn.ReLU(True)

		self.Conv2d_7 = MaskedCNN('B',channels,channels, kernel, 1, kernel//2, bias=False)
		self.BatchNorm2d_7 = nn.BatchNorm2d(channels)
		self.ReLU_7= nn.ReLU(True)

		self.Conv2d_8 = MaskedCNN('B',channels,channels, kernel, 1, kernel//2, bias=False)
		self.BatchNorm2d_8 = nn.BatchNorm2d(channels)
		self.ReLU_8= nn.ReLU(True)

		self.out = nn.Conv2d(channels, 256, 1)

	def forward(self, x):
		x = self.Conv2d_1(x)
		x = self.BatchNorm2d_1(x)
		x = self.ReLU_1(x)

		x = self.Conv2d_2(x)
		x = self.BatchNorm2d_2(x)
		x = self.ReLU_2(x)

		x = self.Conv2d_3(x)
		x = self.BatchNorm2d_3(x)
		x = self.ReLU_3(x)

		x = self.Conv2d_4(x)
		x = self.BatchNorm2d_4(x)
		x = self.ReLU_4(x)

		x = self.Conv2d_5(x)
		x = self.BatchNorm2d_5(x)
		x = self.ReLU_5(x)

		x = self.Conv2d_6(x)
		x = self.BatchNorm2d_6(x)
		x = self.ReLU_6(x)

		x = self.Conv2d_7(x)
		x = self.BatchNorm2d_7(x)
		x = self.ReLU_7(x)

		x = self.Conv2d_8(x)
		x = self.BatchNorm2d_8(x)
		x = self.ReLU_8(x)

		return self.out(x)



In [54]:
model = PixelCNN()

In [55]:
import torch
a = torch.zeros(64,1,32,32)

In [56]:
b = model(a)
b.shape

torch.Size([64, 256, 32, 32])

In [65]:
'''
Code by Hrituraj Singh
Indian Institute of Technology Roorkee
'''

import sys
import os
import time
import torch
from torch import optim
from torch.utils import data
from torch.autograd import Variable
import torch.nn as nn
#from utils import *
#from Model import PixelCNN


def main(config_file):
	config = parse_config(config_file)
	data_ = config['data']
	network = config['network']

	path = data_.get('path', 'Data') #Path where the data after loading is to be saved
	data_name = data_.get('data_name','MNIST') #What data type is to be loaded ex - MNIST, CIFAR
	batch_size = data_.get('batch_size', 144)

	layers = network.get('no_layers', 8) #Number of layers in the network
	kernel = network.get('kernel', 7) #Kernel size
	channels = network.get('channels', 64) #Depth of the intermediate layers
	epochs = network.get('epochs', 25) #No of epochs
	save_path = network.get('save_path', 'Models') #path where the models are to be saved
	#Loading Data
	if (data_name=='MNIST'):
		train, test = get_MNIST(path)

	train = data.DataLoader(train, batch_size=batch_size, shuffle=True, num_workers =1, pin_memory = True)
	test = data.DataLoader(train, batch_size=batch_size, shuffle=False, num_workers =1, pin_memory = True)
	#Defining the model and training it on loss function
	device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
	net = PixelCNN().to(device)
	if torch.cuda.device_count() > 1: # If more than one GPU available, accelerate the training using multiple GPUs
		print("Let's use", torch.cuda.device_count(), "GPUs!")
		net = nn.DataParallel(net)
	
	optimizer = optim.Adam(net.parameters())
	criterion = nn.CrossEntropyLoss()

	loss_overall = []
	time_start = time.time()
	print('Training Started')

	for i in range(epochs):
		net.train(True)
		step = 0
		loss_= 0
		
		for images, labels in train:
			target = Variable(images[:,0,:,:]*255).long()
			images = images.to(device)
			target = target.to(device)

			optimizer.zero_grad()
			print('image.shape',images.shape)
			output = net(images)
			print('output.shape',output.shape)
			print('target.shape',target.shape)
			print(torch.sum(target))
			loss = criterion(output, target)
			loss.backward()
			optimizer.step()
			loss_+=loss
			step+=1

			if(step%100 == 0):
				print('Epoch:'+str(i)+'\t'+ str(step) +'\t Iterations Complete \t'+'loss: ', loss.item()/1000.0)
				loss_overall.append(loss_/1000.0)
				loss_=0
		print('Epoch: '+str(i)+' Over!')

		#Saving the model
		if not os.path.exists(save_path):
			os.makedirs(save_path)
		print("Saving Checkpoint!")
		if(i==epochs-1):
			torch.save(net.state_dict(), save_path+'/Model_Checkpoint_'+'Last'+'.pt')
		else:
			torch.save(net.state_dict(), save_path+'/Model_Checkpoint_'+str(i)+'.pt')
		print('Checkpoint Saved')
  

#print('Training Finished! Time Taken: ', time.time()-time_start)


#if __name__=="__main__":
config_file = 'config/config_train.txt'
assert os.path.exists(config_file), "Configuration file does not exit!"
main(config_file)




data path Data Data
network no_layers 8 8
network kernel 7 7
network channels 64 64
Training Started
image.shape torch.Size([144, 1, 28, 28])
output.shape torch.Size([144, 256, 28, 28])
target.shape torch.Size([144, 28, 28])
tensor(3940184)
image.shape torch.Size([144, 1, 28, 28])
output.shape torch.Size([144, 256, 28, 28])
target.shape torch.Size([144, 28, 28])
tensor(3821921)
image.shape torch.Size([144, 1, 28, 28])
output.shape torch.Size([144, 256, 28, 28])
target.shape torch.Size([144, 28, 28])
tensor(3788349)
image.shape torch.Size([144, 1, 28, 28])
output.shape torch.Size([144, 256, 28, 28])
target.shape torch.Size([144, 28, 28])
tensor(3888760)
image.shape torch.Size([144, 1, 28, 28])


KeyboardInterrupt: ignored

In [64]:
print((144*256*28*28-3936197)/(144*256*28*28))

0.8638058963005775
