In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Imports

In [None]:
from __future__ import division
import torch as T
import torch.functional as F
import math
import PIL
import numpy as np
import pandas  as pd
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import matplotlib.gridspec as gridspec
from matplotlib.patches import Patch
from torch.nn import  Sequential , Linear , ReLU , PoissonNLLLoss, LSTM
from torch.autograd import Variable
from torch.distributions import Normal, MultivariateNormal, Poisson, kl_divergence 
from PIL import Image,ImageFont, ImageDraw
plt.style.use('dark_background')
T.set_default_tensor_type('torch.cuda.FloatTensor')

# Paths

In [None]:
root = ""
DATA_PATH           = root  + ""
SAVE_MODEL_PATH     = root  + ""
SAVE_LOG_PATH       = root  + ""
SAVE_OUTPUT_PATH    = root  + ""
CVAE_PATH           = root  + ""

# Model Architectures

In [None]:
class fcblock(T.nn.Module):
    def __init__(self, n_class):
        super(fcblock, self).__init__()
        self.seq = Sequential(
            Linear(n_class,128),
            ReLU(),
            Linear(128,128),
            ReLU(),
        )
    def forward(self,inputs):
        out = self.seq(inputs)
        return out

class Embeder(T.nn.Module):
    def __init__(self,n_class):
        super(Embeder,self).__init__()
        
        self.fcb1 = fcblock(n_class)
        self.fcb2 = fcblock(n_class)
        self.fcb3 = fcblock(n_class)
        self.fc   = Linear(128*3,128)

 
    def forward(self,inputs):
        in1,in2,in3 = inputs
        in1 = self.fcb1(in1)
        in2 = self.fcb2(in2)
        in3 = self.fcb3(in3)
        out = T.cat((in1,in2,in3),1)
        out = self.fc(out)
        return out

class Encoder(T.nn.Module):
    def __init__(self, in_dim=1 ,latent_dim=32):
        super(Encoder,self).__init__()
        self.act = ReLU()
        self.fc1 = Linear(in_dim,128)
        self.fc2 = Linear(128,128)
        self.fc3 = Linear(256,latent_dim)
        self.fc4 = Linear(latent_dim,latent_dim)
        self.fc5 = Linear(latent_dim,latent_dim)
        
    def forward(self,inputs):
        in1,in2 = inputs
        out = self.fc1(in1)
        out = self.act(out)
        out = self.fc2(out)
        out = T.cat((out,in2),1)
        out = self.fc3(out)
        out = self.act(out)
        mu  = self.fc4(out)
        logvar = self.fc5(out)
        return mu,logvar


class Prior(T.nn.Module):
    def __init__(self,latent_dim=32):
        super(Prior,self).__init__()
        
        self.act = ReLU()
        self.fc1 = Linear(128,latent_dim)
        self.fc2 = Linear(latent_dim,latent_dim)
        self.fc3 = Linear(latent_dim,latent_dim)
        
    def forward(self,inputs):
        out = inputs
        out = self.fc1(out)
        out = self.act(out)
        mu  = self.fc2(out)
        logvar = self.fc3(out)  
        return mu,logvar

class Decoder(T.nn.Module):
    def __init__(self,output_dim,latent_dim=32):
        super(Decoder,self).__init__()
        self.act = ReLU()
        self.fc1 = Linear(128+latent_dim,128)
        self.fc2 = Linear(128,64)
        self.fc3 = Linear(64,output_dim)
        
    def forward(self,inputs):
        in1,in2 = inputs
        out = T.cat((in1,in2),1)
        out = self.fc1(out)
        out = self.act(out)
        out = self.fc2(out)
        out = self.act(out)
        out = self.fc3(out)
        return out

# Loss Function for Countvae

In [None]:
class ELBOLoss(T.nn.Module):

    def __init__(self):
        super(ELBOLoss,self).__init__()
    
    def forward(self,inputs):
        mu1, logvar1, mu2, logvar2 , in1, in2 = inputs

        mask = (in2>0)+0.0
        in2 = in2-mask

        '''KL Divergence'''
        kl =   0.5 * T.sum((logvar2 - logvar1) - 1 + (logvar1.exp() + (mu2 - mu1).pow(2) )/logvar2.exp() , dim = 1).mean()
        
        '''Poisson Negative Log Likelihood'''
        pnll = PoissonNLLLoss()(in1,in2)

        loss = kl+pnll
        
        return loss, pnll , kl
 
 
class Reparamatrize(T.nn.Module):
    
    def __init__(self):
        super(Reparamatrize,self).__init__()
        
    def forward(self,inputs):
        
        mu , logvar = inputs
        '''
        mu = mean 
        logvar = log of diagonal elements of covariance matrix
        '''
        # Covarince Matrix
        covar  = T.diag_embed(T.exp(logvar/2), dim1=-2,dim2=-1)

        # Multivariate Normal Distribution
        p = MultivariateNormal(mu,covar)
        z_latent = p.rsample().float()
        return z_latent

class Sampling(T.nn.Module):

    def __init__(self,MAX_BOX):
        super(Sampling,self).__init__()
        self.max_box = MAX_BOX
    
    def forward(self,lamda):
        
        lamda   = lamda.view(-1)
        mask    = T.zeros(lamda.shape[0] , self.max_box)
        lamda   = T.t(T.t(mask) + lamda)
        mask    = mask + T.arange(0,self.max_box,1)
        e_lamda = T.exp(lamda)
        lamda_x = lamda ** mask 
        fact    = T.exp(T.lgamma(T.arange(0 , self.max_box)+1))
        
        # P = ((lambda ^ x)*e^(lamda)) / x! 
        probab = (lamda_x*e_lamda)/fact
        sample = T.argmax(probab,dim=1)

        return sample 

In [None]:
class CountVAE(T.nn.Module):
 
    def __init__(self,n_class,max_box=9):
        super(CountVAE,self).__init__()
        
        
        self.encoder = Encoder()
        self.prior   = Prior()
        self.decoder = Decoder(1)
        self.embeder = Embeder(n_class)
        self.loss    = ELBOLoss()  
        self.rep     = Reparamatrize()
        self.n_class = n_class
        self.pois    = Sampling(max_box)
                
    def forward(self, inputs, isTrain = False):
        
        '''
        isTrain(boolean) default False : defines whether data is to be treated as training data or testing
        
        if isTrain = True :
            input must be a tuple with first value corresponding to label set and second corresponding to ground Truth
            counts
        else :
            input must have label set
        
        '''
        if isTrain==True:
            
            label_set , groundtruth_counts = inputs
            Loss = 0
            LL   = 0
            KL   = 0
            previous_counts = T.zeros_like(label_set)
            
            for i in range(self.n_class):
            
                current_label = T.zeros_like(previous_counts)
                x_ = label_set[...,i]
                current_label[...,i]= x_
                z_ = groundtruth_counts[...,i].view(-1,1)
                
                # Generate Conditional Embedding
                embedding    = self.embeder([label_set, current_label, previous_counts])
                
                # Encoding To latet space
                mu1, logvar1 = self.encoder([z_,embedding])
                mu2, logvar2 = self.prior(embedding)
                
                # Reparamatrized Latent variable
                z  = self.rep([mu1,logvar1])

                # Decode from Latent space
                decoded = self.decoder([embedding,z])
                Closs, L_, kl_ = self.loss([mu1, logvar1, mu2, logvar2, decoded , z_])
                
                # Update Losses
                Loss   = Loss + Closs
                LL     = LL   + L_
                KL     = KL   + kl_
                
                decoded = T.exp(decoded)
                
                # Poisson Distributions with rate of Deoded
                # q = self.pois(decoded)
                q = Poisson(decoded).sample()
                
                # update Preivious Counts
                previous_counts = previous_counts + current_label*(q.view(-1,1) +  x_.view(-1,1))
            
            return  Loss/self.n_class, KL/self.n_class, LL/self.n_class
        
        else:
            
            label_set = inputs
            previous_counts = T.zeros_like(label_set)
            
            for i in range(self.n_class):

                current_label = T.zeros_like(previous_counts)
                x_ = label_set[...,i]
                current_label[...,i]= x_
                
                
                # Generate Conditional Embedding
                embedding = self.embeder([label_set, current_label, previous_counts])
                
                # Encoding To latet space
                mu,logvar = self.prior(embedding)
                
                # Reparamatrized Latent variable
                z = self.rep([mu,logvar])
                
                # Decode from Latent space
                decoded = self.decoder([embedding,z])
                decoded = T.exp(decoded)

                # Poisson Distributions with rate of Deoded
                # q = self.pois(decoded)
                q = Poisson(decoded).sample()
                
                 # update Preivious Counts
                previous_counts = previous_counts + current_label*(q.view(-1,1) +  x_.view(-1,1))
                
            return previous_counts


# BboxVAE Model Architecture

### Classes
1. Condtional Embedder
2. Encoder
3. Prior
4. Decoder

### Loss
1. ELBO LOSS

### Reparamatrize

In [None]:
class EmbedBbox(T.nn.Module):
    
    def __init__(self,n_class):
        super(EmbedBbox,self).__init__()
       
        self.fcb1 = fcblock(n_class)
        self.fcb2 = fcblock(n_class)
        self.seq1 = Sequential(
            Linear(128,128),
            ReLU()
        )
        
        self.n_class = n_class
        self.fc   = Linear(128*3,128)
        self.lstm = LSTM(n_class+4, hidden_size=128)

    def forward(self,inputs):
        
        in1,in2,in3 = inputs

        _ , (h_0 , c_0 ) = self.lstm(in3)
        hn  = h_0.view(-1, 128)
        
        in1 = self.fcb1(in1)
        in2 = self.fcb2(in2)
        in3 = self.seq1(hn)
        
        out = T.cat((in1,in2,in3),1)
        out = self.fc(out)
        
        return out

# Loss Function and Reparamatrization for BboxVAE

## KL Divergence
 Same as CountVAE

## MSE
 as reconstruction loss

In [None]:
class ELBOLoss_Bbox(T.nn.Module):
    
    def __init__(self):
        super(ELBOLoss_Bbox,self).__init__()
    
    def forward(self,inputs):
        mu1,logvar1,mu2,logvar2, xp , yp = inputs
        
        ''' KL Divergence '''
        kl =   0.5 * T.sum((logvar2 - logvar1) - 1 + (logvar1.exp() + (mu2 - mu1).pow(2) )/logvar2.exp() , dim = -1 ).mean()
        
        ''' Multivariate Guassian Likelihood '''
        mse = T.nn.MSELoss()(xp,yp)
        loss = mse + kl
        
        return loss, kl,mse


class Reparamatrize(T.nn.Module):
    
    def __init__(self):
        super(Reparamatrize,self).__init__()
    
    def forward(self,inputs):
        
        mu , logvar = inputs
        std = T.exp(logvar/2)
        eps = T.rand_like(std)

        return eps*std + mu
        

class ReparamatrizeMulti(T.nn.Module):
    
    def __init__(self):
        super(ReparamatrizeMulti,self).__init__()
    
    def forward(self,inputs):
       
        mu  = inputs
        std = (T.ones_like(mu)*0.02)
        eps = T.rand_like(std)
        
        return eps*std + mu

In [None]:
class BboxVAE(T.nn.Module):
    def __init__(self,n_class,n_dim,max_box,latent_dim=32):

        super(BboxVAE,self).__init__()
        
        self.embeder   = EmbedBbox(n_class)
        self.encoder = Encoder(n_dim,latent_dim=latent_dim)
        self.decoder = Decoder(n_dim,latent_dim=latent_dim)
        self.prior   = Prior(latent_dim=latent_dim)
        self.loss    = ELBOLoss_Bbox()
        self.rep     = Reparamatrize()
        self.n_dim   = n_dim
        self.n_class = n_class
        self.rep_mul = ReparamatrizeMulti()
        self.max_box = max_box


    def forward(self,inputs,isTrain=True):
        if isTrain==True :
            BoxCounts, GTBBox , BoxLabel= inputs
            los = 0
            kl1 = 0
            ll1 = 0
            for i in range(self.max_box):
                if i==0:
                    PrevLabel = T.zeros((1 , *BoxLabel[... ,i,:].shape)) 
                    PrevBox = T.zeros((1 , *GTBBox[...,i,:].shape))
                    

                GroundTruth = GTBBox[... , i ,:].view(-1,self.n_dim)
    
                CurrentLabel = BoxLabel[... , i ,:].view(-1,self.n_class)
    
                Embedding = self.embeder([BoxCounts,CurrentLabel,T.cat([PrevLabel,PrevBox] , dim = 2)])

                mu1 , logvar1 = self.encoder([GroundTruth,Embedding])
                mu2 , logvar2 = self.prior(Embedding)
                z1  = self.rep([mu1,logvar1])
                z2  = self.rep([mu2,logvar2])
                
                Mu   = self.decoder([Embedding,z1])
                BBox   = self.rep_mul(Mu)
                CLoss, kl_tot , ll_tot = self.loss([mu1,logvar1,mu2,logvar2, BBox , GroundTruth])

                los = los + CLoss/self.max_box
                kl1 = kl1 + kl_tot/self.max_box
                ll1 = ll1 + ll_tot/self.max_box
                
                PrevBox = T.cat([PrevBox ,T.unsqueeze(GroundTruth,0)])
                PrevLabel = T.cat([PrevLabel , T.unsqueeze(CurrentLabel,0)])


            return los , kl1 , ll1
        else:
            BoxCounts, BoxLabel= inputs
            BBoxes = []
            for i in range(self.max_box):
                if i==0:
                    PrevLabel = T.zeros((1 , *BoxLabel[... ,i,:].shape)) 
                    PrevBox = T.zeros((1 , BoxLabel.shape[0] , 4))

                CurrentLabel = BoxLabel[... , i ,:].view(-1,self.n_class)
                Embedding = self.embeder([BoxCounts,CurrentLabel,T.cat([PrevLabel,PrevBox] , dim = 2)])
                
                mu , logvar = self.prior(Embedding)
                
                z  = self.rep([mu,logvar])
                
                Mu  = self.decoder([Embedding,z])
                
                BBox  = self.rep_mul(Mu)
                
                PrevBox = T.cat([PrevBox ,T.unsqueeze(BBox,0)])
                PrevLabel = T.cat([PrevLabel , T.unsqueeze(CurrentLabel,0)])
                BBoxes.append(BBox.t())
            BBoxes =T.stack(BBoxes)
            return BBoxes

# Layout VAE

In [None]:
class LayoutVAE(T.nn.Module):

        def __init__(self, n_class = 6, max_box = 9,bboxvae_latent_dim = 32,bboxvae_lr=1e-4,countvae_lr=1e-6):
            '''
            ** Layout VAE **
            * https://arxiv.org/abs/1907.10719
            '''
            super(LayoutVAE,self).__init__()

            self.max_box    = max_box
            self.n_class    = n_class
            self.lr_bvae    = bboxvae_lr
            self.lr_cvae    = countvae_lr
            self.countvae   = CountVAE(n_class)
            self.bboxvae    = BboxVAE(n_class,4,max_box,bboxvae_latent_dim)
            self.is_cvae_trained = 0
            self.is_bvae_trained = 0

        def forward(self,input):
            '''
            Takes only Labels Set as input
            Label Set : it is a vector of size n_class and contains 1 if correspinding class is present
            '''
            if self.is_cvae_trained == 0:
                print("[Warning] Count VAE is Not Trained !!")

            if self.is_bvae_trained == 0:
                print("[Warning] Bbox VAE is Not Trained !!")

            label_set   = input
            pred_class_counts = self.countvae(label_set , isTrain=False)
            
            # Normalize classiction between [0 , max_box]
            pred_class_counts = T.floor ( self.max_box*(pred_class_counts / T.sum(pred_class_counts , dim = 1 ).view(-1,1)) )
            
            # Extra boxes which are not be predicted
            # Their counts are set in first class
            for class_count in pred_class_counts:
                if(T.sum(class_count) < self.max_box):
                    class_count[0] = self.max_box - T.sum(class_count)

            class_labels = T.zeros(len(label_set) , self.max_box, self.n_class)

            for i in range(len(pred_class_counts)):
                l = 0
                for j in range(self.n_class):
                    for k in range(int(pred_class_counts[i][self.n_class-j-1])):
                        class_labels[i][l][self.n_class-j-1] = 1;
                        l+=1

            pred_box = self.bboxvae([ pred_class_counts, class_labels], isTrain=False)
            pred_box = pred_box.permute(2,0,1)
            class_info = T.unsqueeze(T.argmax(class_labels ,dim=2),dim=2)
            predictions = T.cat([class_info,pred_box],dim = 2)

            for i in range(len(predictions)):
                for j in range(len(predictions[i])):
                    if predictions[i][j][0]==0:
                        predictions[i][j]*=0
            
            self.predictions  = predictions
            self.pred_class_counts = pred_class_counts

            return predictions

        def load_data(self, path, frac = 0.5, train_test_split = 0.1):
            '''
            Loads data from npy file
            path string containig path to data
            frac defines the fraction of data to load

            '''
            try : 
                Data = np.load(DATA_PATH)
                # Sortind Data in proper order
                order = np.argsort(Data[:,:,0])
                for i in range(len(Data)):
                    Data[i] = Data[i][order[i][::-1]]
                np.random.shuffle(Data)

                data_size = int(frac*len(Data))
                test_size = int(train_test_split*data_size)
                Data      = T.tensor(Data[0:data_size]).float()
                test_data = Data[0:test_size]
                Data      = Data[test_size:]

                # Prepare Data
                self.class_labels = Data[...,4:]
                self.class_counts = T.sum(Data[...,4:], dim = 1)
                self.b_boxes      = Data[...,0:4]
                self.label_set    = (self.class_counts !=0) + 0.0

                # Test Data
                self.test_class_labels = test_data[...,4:]
                self.test_class_counts = T.sum(test_data[...,4:], dim = 1)
                self.test_b_boxes      = test_data[...,0:4]
                self.test_label_set    = (self.test_class_counts !=0) + 0.0

                print("[Success] Data Loaded Succesfully")

            except:    
                print("[Failed] Data Loading Failed\n please check path")
       
        def train(self, optim, train_mode = 'bboxvae', epochs = 100, bsize = 256 , validation_split = 0.1):
            '''
            * train_mode (str , default bboxvae) : Two optons
                1. if train_mode is bboxvae, BBoxVAE model will be trained and data 
                will be loaded accordingly
                2. if train_mode is countvae, CountVAE model will be trained and data 
                will be loaded accordingly
            * epochs (int , default 100 ) : number of epochs training should run
            * bsize(int default 256) : Batch Size
            * validation_split(float default 0.1) : should be between between 0 and 1
                1 . it defines the size of validation data 

            '''
            # Create validation Split
            total_examples   = len(self.class_counts)
            val_size         = int(total_examples*validation_split)
            
            losses = dict()
            train_data = []
            if train_mode == 'countvae':
                model = self.countvae
                train_data = [self.label_set, self.class_counts]
            else :
                model = self.bboxvae
                train_data = [self.class_counts, self.b_boxes, self.class_labels]

            # Validation Data 
            val_data = []
            for x in train_data:
                val_data.append(x[:val_size])

            # Train data
            for i in range(len(train_data)):
                train_data[i] = train_data[i][val_size:]


            # find the number of batches
            batches = len(train_data[0])//bsize
            second_loss = 'mse'
            if train_mode == 'countvae':
                second_loss = 'poisson_nll'
            
            # Dictionary to keep track of model statistics
            losses = {'epoch':-1, 
                    'batch':0,
                    'lr' : 0,
                    'loss':0,
                    'kl_div_loss':0,
                    second_loss+'_loss':0,
                    'val_loss':0,
                    'val_kl_div_loss':0,
                    'val_'+second_loss+'_loss':0
                    }

            history  = pd.DataFrame(losses ,index = [0])
            index = 1

            for ep in range(epochs):

                # if train_mode=='countvae':
                #     self.countvae_pred_grpah(epoch = ep,path = CVAE_PATH)

                print(f'Epoch[{ep+1}/{epochs}]')
                for batch in range(batches):

                    # Get Current batch
                    b = []
                    for x in train_data:
                        b.append(x[batch*bsize : (batch+1)*bsize])

                    optim.zero_grad()

                    # Train Step
                    loss, kl_, l_ = model(b,isTrain = True)
                    
                    # Validation Step
                    val_loss, val_kl_, val_l_ = model(val_data, isTrain = True)


                    # Save Statistics
                    losses['epoch'] = ep
                    losses['batch'] = batch
                    losses['lr']    = optim.param_groups[0]['lr']

                    loss_list = [loss, kl_, l_ , val_loss , val_kl_ , val_l_]

                    for i in range(6):
                        losses[list(losses.keys())[3+i]] = loss_list[i].cpu().clone().detach().numpy()
                        pass
                    
                    losses_df = pd.DataFrame(losses , index=[index])
                    history   = pd.concat([history,losses_df])
                    index+=1

                    # Backpropogation step and updating weights
                    loss.backward()
                    optim.step()
                    print('\r Batch: {}/{} - loss : {} - val_loss : {} - val_{} : {}'.format(batch+1,batches,
                                                                            losses_df['loss'][index-1],
                                                                            losses_df['val_loss'][index-1],
                                                                            second_loss,
                                                                            losses_df['val_'+second_loss+'_loss'][index-1]),
                        end="")
                print("\n")
            print('[Success] Finished Training')
            return history

        def load_countvae_weights(self,path):
            try :
                self.countvae = T.load(path)
                self.is_cvae_trained=1
                print('[Success] Loaded Successfully')
            except:
                print('[Failed] Load Failed')

        def load_bboxvae_weights(self,path):
            try :
                self.bboxvae = T.load(path)
                self.is_bvae_trained=1
                print('[Success] Loaded Successfully')
            except:
                print('[Failed] Load Failed')

        def train_bboxvae(self,epochs=30, bsize=256, validation_split=0.1, optim=None): 
            if optim == None:
                optim = T.optim.Adam(self.bboxvae.parameters(),lr=self.lr_bvae)

            # Start Training
            history = self.train(optim      = optim,
                            train_mode = 'bboxvae',
                            epochs     = epochs,
                            bsize      = bsize,
                            validation_split = validation_split
                        )
            self.is_bvae_trained = 1
            self.bvae_history = history[history.columns][1:]
            return self.bvae_history

        def train_countvae(self,epochs=30, bsize=256, validation_split=0.1, optim=None):
            
            if optim == None:
                optim = T.optim.Adam(self.countvae.parameters(),lr=self.lr_cvae)

            # Start Training
            history = self.train(optim      = optim,
                            train_mode = 'countvae',
                            epochs     = epochs,
                            bsize      = bsize,
                            validation_split = validation_split
                        )
            self.is_cvae_trained = 1
            self.cvae_history = history[history.columns][1:]
            return self.cvae_history

        def pred_countvae(self,data=None):
            '''
            * Functions is used for for predcting from CountVAE
              given label_set
            * if data is None than label set from loaded data 
              are used for predictions.
            '''

            if self.is_cvae_trained == 0:
                print("[Warning] Count VAE is Not Trained !!")
            if data == None :
                data = self.test_label_set
            return self.countvae(data , isTrain=False)

        def pred_bboxvae(self,Data=None):

            '''
            * Functions is used for for predcting from BboxVAE
              given class_counts and class labels
            * if data is None than class counts and class labels from loaded data 
              are used for predictions.
            '''

            if self.is_bvae_trained == 0:
                print("[Warning] Bbox VAE is Not Trained !!")

            if Data == None :
                Data = [self.test_class_counts,self.test_class_labels]

            batches = len(Data[0])//64

            for b in range(batches):

                # Get data in batch
                data = [self.test_class_counts[b*64 : (b+1)*64],
                        self.test_class_labels[b*64 : (b+1)*64]]   

                # Predict
                pred = self.bboxvae(data, isTrain=False)
                pred = pred.permute(2,0,1)

                # cxywh format
                class_info = T.unsqueeze(T.argmax(data[1] ,dim=2),dim=2)
                pred = T.cat([class_info,pred],dim = 2)


                for i in range(len(pred)):
                    for j in range(len(pred[i])):
                        if pred[i][j][0]==0:
                            pred[i][j] *= 0

                if b > 0:
                    predictions = T.cat([predictions,pred],dim=0)
                else:
                    predictions = pred
            class_info =T.argmax(self.test_class_labels[0:64*batches] ,dim=2)
            class_info = T.unsqueeze(class_info,dim=2)
            gt = T.cat([class_info,self.test_b_boxes[0:64*batches]],dim = 2)
            return predictions, gt
            
        def countvae_pred_grpah(self,path,epoch = 0):
            pred_cvae = self.pred_countvae()
            pred_cvae = T.sum(pred_cvae,dim=0)
            pred_cvae = pred_cvae/T.sum(pred_cvae)
            pred_cvae = pred_cvae.to('cpu').clone().detach().numpy()

            gt_cvae = T.sum(self.class_counts,dim=0)
            gt_cvae = gt_cvae/T.sum(gt_cvae)
            gt_cvae = gt_cvae.to('cpu').clone().detach().numpy()

            fig   = plt.figure(figsize=(5 ,4), dpi=100 ,facecolor=(0,0,0))
            ax = fig.add_subplot()
            ax.plot(gt_cvae  , 'red',marker = 'o', label = 'Ground Truth',linewidth=4)
            ax.plot(pred_cvae,'blue',marker ='o',label = "Predicted" ,linewidth=4)
            ax.legend()
            ax.set_title('Ground Truth vs Predicted Distribution\n Epoch = '+str(epoch))
            ax.set_xlabel('Classes')
            ax.set_xticks([0,1,2,3,4,5])
            ax.set_xticklabels(class_names)

            plt.savefig(path+"cvae-train-ep-"+str(epoch)+".png",facecolor=(0,0,0))
            plt.close()


        def convert_to_cxywh(self,data):
            
            bboxes = data[...,0:4]
            labels = data[...,4: ]
            class_info = T.unsqueeze(T.argmax(labels ,dim=2),dim=2)
            cxywh = T.cat([class_info,bboxes],dim = 2)
            return cxywh

        def save_model(self,path):
            
            T.save(self.countvae,path+'countvae.h5')
            T.save(self.bboxvae,path+'bboxvae.h5')
            T.save(self,path+'selef.h5')
            print('[Success] Saved Successfully')

        def save_history(self,path):

            self.cvae_history.to_csv(path+'cvae-history.csv',index=False)
            self.bvae_history.to_csv(path+'bvae-history.csv',index=False)
            print('[Success] Saved Successfully')

        


# Plotting Functions

In [None]:
def plot_history(history,title = 'Training Statistics', path =""):
    height = 12
    width  = 9
    fig          = plt.figure(figsize=(width,height), dpi=100 ,facecolor=(0,0,0))
    height_ratio = [0.25,1,1,1]
    grid         = plt.GridSpec(4,2,
                        hspace=0.3,wspace=0.2,
                        height_ratios =height_ratio,
                        left=0.02,right=0.98,top=0.98,bottom=0.02
                    )
    index = 0
    ax = fig.add_subplot(grid[index : index+2])
    index+=2
    ax.text(x = 0.3 ,y = 0.5 ,s = title,fontsize=30)
    ax.invert_yaxis()
    ax.axis('off')
    colors = ['red','blue','green']
    for i in range(3):

        ax = fig.add_subplot(grid[index])
        ax.plot(history[history.columns[i+3]],colors[i])
        index+=1
        ax.set_facecolor((0,0,0))
        ax.set_title(history.columns[i+3])
        ax = fig.add_subplot(grid[index])
        ax.plot(history[history.columns[i+6]],colors[i])
        ax.set_title(history.columns[i+6])
        index+=1
        ax.set_facecolor((0,0,0))
    plt.savefig(path, facecolor=(0,0,0))

In [None]:
def generate_colors(class_names = None,n_class=6):
    cmap = ["","#dc143c","#ffff00","#00ff00","#ff00ff","#1e90ff","#fff5ee",
            "#00ffff","#8b008b","#ff4500","#8b4513","#808000","#483d8b",
            "#008000","#000080","#9acd32","#ffa500","#ba55d3","#00fa9a",
            "#dc143c","#0000ff","#f08080","#f0e68c","#dda0dd","#ff1493"]
            
    colors = dict()

    if class_names == None:
        class_names = []
        for i in range(n_class):
            class_names.append('class'+str(i+1))
    
    for i in range(n_class):
        colors[class_names[i]] = cmap[i]

    return colors

def plot_layouts(data,colors,class_names,title="Random Predictions", path=""):
    '''
    data in cxywh format
    '''
    height = 15
    width  = 9
    fig          = plt.figure(figsize=(width,height), dpi=100 ,facecolor=(0,0,0))
    height_ratio = [0.5,0.25,1,1,1,1]
    grid         = plt.GridSpec(6,4,
                        hspace=0.05,wspace=0.05,
                        height_ratios =height_ratio,
                        left=0.02,right=0.98,top=0.98,bottom=0.02
                    )
    index = 0


    ax = fig.add_subplot(grid[index : index+4])
    index+=4
    ax.text(x = 0.2 ,y = 0.5 ,s = title,fontsize=30)
    ax.axis('off')
    legend = []
    ax = fig.add_subplot(grid[index : index+4])
    index += 4
    
    for i in range(1,6):
        legend.append(Patch(facecolor=colors[class_names[i]]+"40",
                            edgecolor=colors[class_names[i]],
                            label= class_names[i]))
        
    ax.legend(handles=legend, ncol=3,loc=8, fontsize=25, facecolor=(0,0,0))
    ax.axis('off')

    for i in range(16):
        ax   = fig.add_subplot(grid[index])
        index += 1
        
        data = pred[i]
        rect1 = patches.Rectangle((0,0),180,240)
        rect1.set_color((0,0,0,1))
        ax.add_patch(rect1)
        for box in data:

            c,x,y,w,h = box
            if c==0:
                continue
            x = x*180
            y = y*240
            w = w*180
            h = h*240
            rect = patches.Rectangle((x,y),w,h,linewidth=2)
            rect.set_color(colors[class_names[int(c)]]+"72")
            rect.set_linestyle('-')
            rect.set_edgecolor(colors[class_names[int(c)]])
            ax.add_patch(rect)
        ax.plot()
        ax.set_facecolor((0,0,0))
        for spine in ax.spines.values():
            spine.set_edgecolor('green')
            spine.set_linewidth(2)
        ax.invert_yaxis()
        ax.set_xticks([])
        ax.set_yticks([])
    plt.savefig(path, facecolor=(0,0,0))
    plt.show()
    plt.close()

# Training



In [None]:
layoutvae = LayoutVAE()
layoutvae.load_data(DATA_PATH, frac = 0.5)

## Countvae

In [None]:
# layoutvae.load_countvae_weights(path = SAVE_MODEL_PATH + "countvae.h5")
layoutvae.train_countvae(bsize = 512, epochs=100,validation_split=0.1)
plot_history(layoutvae.cvae_history,
             title="CountVAE Training",
             path = SAVE_LOG_PATH+"Cvae-train.svg"
             )

def countvae_pred_grpah(self,path):
            pred_cvae = self.pred_countvae()
            pred_cvae = T.sum(pred_cvae,dim=0)
            pred_cvae = pred_cvae/T.sum(pred_cvae)
            pred_cvae = pred_cvae.to('cpu').clone().detach().numpy()

            gt_cvae = T.sum(self.class_counts,dim=0)
            gt_cvae = gt_cvae/T.sum(gt_cvae)
            gt_cvae = gt_cvae.to('cpu').clone().detach().numpy()

            fig   = plt.figure(figsize=(5 ,4), dpi=100 ,facecolor=(0,0,0))
            ax = fig.add_subplot()
            ax.plot(gt_cvae  , 'red',marker = 'o', label = 'Ground Truth',linewidth=4)
            ax.plot(pred_cvae,'blue',marker ='o',label = "Predicted" ,linewidth=4)
            ax.legend()
            ax.set_title('Ground Truth vs Predicted Distribution')
            ax.set_xlabel('Classes')
            ax.set_xticks([0,1,2,3,4,5])
            ax.set_xticklabels(class_names)

            plt.savefig(path+"cvae-train.png",facecolor=(0,0,0))
            plt.close()

## Bbox VAE

In [None]:
# layoutvae.load_bboxvae_weights(path = SAVE_MODEL_PATH + "bboxvae.h5")
history_df = layoutvae.train_bboxvae(bsize = 256, epochs = 150, validation_split = 0.1)
preds,gt = layoutvae.pred_bboxvae()
plot_history(layoutvae.vae_history,
             title="BBoxVAE Training",
             path = SAVE_LOG_PATH+"Bvae-train.svg"
             )

In [None]:
class_names = ['None' , 'Text' , 'Title' , 'List' , 'Table' ,'Figure']
colors = generate_colors(n_class=6 , class_names=class_names)

preds = layoutvae.pred_bboxvae()
for i in range(2):
    plot_layouts(data = predd[i*16:(i+1)*16],
                 colors=colors,
                 class_names=class_names,
                 path=SAVE_OUTPUT_PATH+"/bboxvae-preds-"+str(i)+".png"
                 )

# Save Model and Train History

In [None]:
layoutvae.save_model(SAVE_MODEL_PATH)
layoutvae.save_history(SAVE_LOG_PATH)

# Complete Model

In [None]:
predd = layoutvae(layoutvae.test_label_set)
for i in range(2):
    plot_layouts(data = predd[i*16:(i+1)*16],
                 colors=colors,
                 class_names=class_names,
                 path=SAVE_OUTPUT_PATH+"/random-preds2-"+str(i)+".png"
                 )