In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import os
import numpy as np
import torch
torch.manual_seed(125)
import random
random.seed(125)
import torch_f as torch_f
import modelovae as mv
import meshSubplot as ms
import wandb
import networkx as nx
import matplotlib.pyplot as plt

In [2]:
use_gpu = True
device = torch.device("cuda:0" if use_gpu and torch.cuda.is_available() else "cpu")

Encoder

In [3]:
def encodeStructureFold(fold, root):
    '''Folds the tree by depth, so that nodes at the same depth can go in to the 
    encoder at the same time, reducing computational cost'''
    def encodeNode(node):
        
        if node is None:
            return
        
        if node.isLeaf():
            return fold.add('leafEncoder', node.radius)
        else:
            left = encodeNode(node.left)
            right = encodeNode(node.right)
            if left is not None and right is not None:
                return fold.add('bifurcationEncoder', node.radius, right, left)
            elif right is not None:
                return fold.add('internalEncoder', node.radius, right)
            elif left is not None:
                return fold.add('internalEncoder', node.radius, left)
        

    encoding = encodeNode(root)
    return fold.add('sampleEncoder', encoding)

def encode_structure(root, Grassencoder):
        
    def encode_node(node, Grassencoder):
          
        if node is None:
            return
        if node.isLeaf():
            return Grassencoder.leafEncoder(node.radius.reshape(-1,4))
        else :
            left = encode_node(node.left, Grassencoder)
            right = encode_node(node.right, Grassencoder)
            if left is not None and right is not None:
                return Grassencoder.bifurcationEncoder(node.radius.reshape(-1,4), right, left)
            if right is not None:
                return Grassencoder.internalEncoder(node.radius.reshape(-1,4), right)
            if left is not None:
                return Grassencoder.internalEncoder(node.radius.reshape(-1,4), left)

    encoding = encode_node(root, Grassencoder)
    return Grassencoder.sampleEncoder(encoding)

Data loader

In [4]:

def numerar_nodos(root, count):
    if root is not None:
        numerar_nodos(root.left, count)
        root.data = len(count)
        count.append(1)
        numerar_nodos(root.right, count)
        return 

In [5]:
def my_collate(batch):
    return batch


class tDataset(Dataset):
    def __init__(self, l, dir, transform=None):
        self.names = l
        self.transform = transform
        self.data = [] #lista con las strings de todos los arboles
        for file in self.names:
            self.data.append(mv.read_tree(file, dir))
        #"data" is a list of all serialized trees, "trees" is a list of the binary trees
        self.trees = []
        for tree in self.data:
            deserial = mv.deserialize(tree)
            c = []
            numerar_nodos(deserial, c)
            self.trees.append({deserial: len(c)})
            

    def __len__(self):
        return len(self.names)

    def __getitem__(self, idx):
        tree = self.trees[idx]
        return tree

batch_size = 4

Decoder

In [6]:
def decodeStructureFoldGrass(fold, v, root):
    ''' Decodes the tree in a depth first fashion, grouping nodes at the same depth
    in order to reduce computational cost'''

    def decodeNode(fold, v, node, flag):
        #multipl = np.round((node.maxlevel+1-node.level)/node.treelevel, decimals=2)
        label = fold.add('nodeClassifier', v)
      
               
        if node.childs() == 1 :
            
            right, radius = fold.add('internalDecoder', v).split(2)
            
            if node.right:
                nodoSiguiente = node.right
            else:
                nodoSiguiente = node.left
            
            child_loss = decodeNode(fold, right, nodoSiguiente, flag = 1)
            lossEstructura = fold.add('classifyLossEstimator', label, node)
            lossAtributo = fold.add('calcularLossAtributo', node, radius)
            
           
            #losse = fold.add('vectorMult', multipl, lossEstructura)
            losse = lossEstructura
            loss = fold.add('vectorAdder', losse, lossAtributo)
            loss2 = fold.add('vectorAdder', loss, child_loss)

            return loss2
        elif node.childs() == 0 : 

            radius = fold.add('featureDecoder', v)
            
            lossEstructura = fold.add('classifyLossEstimator', label, node) 
            lossAtributo = fold.add('calcularLossAtributo', node, radius)
    
            #losse = fold.add('vectorMult', multipl, lossEstructura)
            losse = lossEstructura
            loss =  fold.add('vectorAdder', losse, lossAtributo)   

            return loss
            
        
        elif node.childs() == 2 :

            left, right, radius = fold.add('bifurcationDecoder', v).split(3)
            nodoSiguienteRight = node.right
            nodoSiguienteLeft = node.left

            if nodoSiguienteRight is not None:
                right_loss = decodeNode(fold, right, nodoSiguienteRight, flag = 1)
             
            if nodoSiguienteLeft is not None:
                left_loss  = decodeNode(fold, left, nodoSiguienteLeft, flag = 1)

          
            
            lossEstructura = fold.add('classifyLossEstimator', label, node)
            lossAtributo   = fold.add('calcularLossAtributo', node, radius)
            #losse = fold.add('vectorMult', multipl, lossEstructura)
            losse = lossEstructura
            loss = fold.add('vectorAdder', losse, lossAtributo)
            loss2 = fold.add('vectorAdder', loss, right_loss)
            loss3 = fold.add('vectorAdder', loss2, left_loss)
            return loss3
            
    v1 = fold.add('sampleDecoder', v)
    dec = decodeNode (fold, v1, root, flag = 0)
    return dec


Save Model

In [7]:
class SaveBestModel:
    """
    Class to save the best model while training. If the current epoch's 
    validation loss is less than the previous least less, then save the
    model state.
    """
    def __init__(self, best_valid_loss=float('inf')):
        self.best_valid_loss = best_valid_loss
        
    def __call__(
        self, current_valid_loss, 
        epoch, encoder, decoder, optimizer
    ):  
        if epoch > 50:
            if current_valid_loss < self.best_valid_loss:
                self.best_valid_loss = current_valid_loss
                #'classifier_state_dict': classifier.state_dict(),
                torch.save({
                    'epoch': epoch+1,
                    'encoder_state_dict': encoder.state_dict(),
                    'decoder_state_dict': decoder.state_dict(),
                    'loss' : self.best_valid_loss,
                    'optimizer_state_dict': optimizer.state_dict(),
                    }, 'ablation/IntraP15eps01-best.pth')

class SaveLastModel:
    """
    Class to save the model while training. 
    """  
    def __call__( self,  epoch, encoder, decoder, optimizer):
        torch.save({
            'epoch': epoch+1,
            'encoder_state_dict': encoder.state_dict(),
            'decoder_state_dict': decoder.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            }, 'ablation/IntraP15eps01-last.pth')


In [8]:
def escalon_beta (e, corte):
    l = np.linspace(e,e,corte)
    return l

In [9]:
def set_Level(tree, n_nodes):
    max_level = 0  
    for x in range(0, n_nodes):
        level = mv.getLevel(tree, x)
        if level > max_level:
            max_level = level
        if (level):
            node = mv.searchNode(tree, x)
            node.level = mv.getLevel(tree, x)
        else:
            print(x, "is not present in tree")
    tree_level = []
    tree.getTreeLevel(tree, tree_level)
    tree_level = [max_level - nodelevel for nodelevel in tree_level]
    tree.setTreeLevel(tree, sum(tree_level))
    tree.setMaxLevel(tree, max_level)

In [10]:
def train_model(epochs, data_loader, Grassencoder, Grassdecoder, opt):
 
    save_last_model = SaveLastModel()
    save_best_model = SaveBestModel()
    train_loss_avg = []
    betas = escalon_beta(.001, 400000)

    for epoch in range(epochs):
    
        beta = betas[epoch]
        train_loss_avg.append(0)

        epochTotalLoss = 0
        epochReconLoss = 0
        epochKLDivLoss = 0
        epochKLDivLossBeta = 0

        for batch_idx, batch in enumerate(data_loader):            
            
            enc_fold = torch_f.Fold(device)
            
            enc_fold_nodes = []     
            n_nodes = []
            for tree in batch: #example es un arbolito
                example = list(tree.keys())[0]
                n = tree[example]#[0]
                n_nodes.append(n)
                enc_fold_nodes.append(encodeStructureFold(enc_fold, example))
            
            enc_fold_nodes = enc_fold.apply(Grassencoder, [enc_fold_nodes])
            
            enc_fold_nodes = torch.split(enc_fold_nodes[0], 1, 0)
            
            dec_fold = torch_f.Fold(device)
            dec_fold_nodes = []
            kld_fold_nodes = []

            for tree, fnode in zip(batch, enc_fold_nodes):
                example = list(tree.keys())[0]
                root_code, kl_div = torch.chunk(fnode, 2, 1)
                dec_fold_nodes.append(decodeStructureFoldGrass(dec_fold, root_code, example))
                kld_fold_nodes.append(kl_div)
                
            total_loss = dec_fold.apply(Grassdecoder, [dec_fold_nodes, kld_fold_nodes])
            n_nodes = torch.tensor(n_nodes, device = device)
            recon_loss = torch.div(total_loss[0], n_nodes)
            recon_loss = recon_loss.sum() / len(batch)               # avg. reconstruction loss per example
            
            kldiv_loss = []
            for element in kld_fold_nodes:
                l = torch.sum(element)
                kldiv_loss.append(l)
           
            kldiv_loss = sum(kldiv_loss) / len(batch)
           
            total_loss = recon_loss +  beta*kldiv_loss/10
           
            opt.zero_grad()
            total_loss.backward()
            opt.step()
            train_loss_avg[-1] += (total_loss.item())
            epochTotalLoss += total_loss.item()
            epochReconLoss += recon_loss.item()
            epochKLDivLoss += kldiv_loss.item()
            epochKLDivLossBeta += beta*kldiv_loss.item()

        epochTotalLoss /= len(data_loader)
        epochReconLoss /= len(data_loader)
        epochKLDivLoss /= len(data_loader)
        epochKLDivLossBeta  /= len(data_loader)
        
        
        save_best_model(total_loss, epoch, Grassencoder, Grassdecoder, opt)
        if epoch % 10 == 0: 
            wandb.log({'epoch': epoch+1, 'loss': epochTotalLoss, 'kl_div': epochKLDivLoss, 'kl_div (*beta)': epochKLDivLossBeta, 'recon_loss': epochReconLoss, 'beta': beta})
        if epoch % 100 == 0:   
            save_last_model(epoch, Grassencoder, Grassdecoder, opt)
        if epoch % 100 == 0:
            print('Epoch [%d / %d] average reconstruction error: %.10f , kl(*beta): %.10f (%.10f), reconstruction loss: %.10f' % (epoch+1, epochs, epochTotalLoss, epochKLDivLoss, epochKLDivLossBeta, epochReconLoss))
    return 


FOR LOOP


In [11]:
torch.set_printoptions(precision=10)
p = 15
eps = 1
t_list = os.listdir("data/paper/IntraP" + str(p) + "eps0" + str(eps) )[:100]
dataset = tDataset(t_list, "data/paper/IntraP" + str(p) + "eps0" + str(eps) )
data_loader = DataLoader(dataset, batch_size = batch_size, shuffle=True, collate_fn=my_collate)


mult = mv.numberNodes(data_loader, batch_size)
feature_size = 64
latent_size = feature_size
hidden_size_encoder = 512
hidden_size_decoder = 256

Grassencoder = mv.GRASSEncoder(input_size = 4, feature_size=feature_size, hidden_size=hidden_size_encoder)
Grassencoder = Grassencoder.to(device)
Grassdecoder = mv.GRASSDecoder(latent_size=latent_size, hidden_size=hidden_size_decoder, mult = mult)
Grassdecoder = Grassdecoder.to(device)

mv.setLevel(data_loader)

##loop parameters
epochs = 20000
learning_rate = 1e-4
params = list(Grassencoder.parameters()) + list(Grassdecoder.parameters()) 
opt = torch.optim.Adam(params, lr=learning_rate) 
total_paramse = sum(param.numel() for param in Grassencoder.parameters())
total_paramsd = sum(param.numel() for param in Grassdecoder.parameters())
print("total parameters encoder ", total_paramse)
print("total parameters decoder", total_paramsd)
print("total parameters", total_paramse + total_paramsd)

Grassencoder.train()
Grassdecoder.train()

config = {
"learning_rate": learning_rate,
"epochs": epochs,
"batch_size": batch_size,
"dataset": t_list,
"number of trees": len(data_loader)*batch_size,
"optim": opt,
"latent_size" : latent_size,
"params":total_paramse + total_paramsd,
"prof": p,
}
wandb.init(project="MIA", entity="paufeldman", config = config)

train_model(epochs, data_loader, Grassencoder, Grassdecoder, opt)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


total parameters encoder  626560
total parameters decoder 379911
total parameters 1006471


[34m[1mwandb[0m: Currently logged in as: [33mpaufeldman[0m. Use [1m`wandb login --relogin`[0m to force relogin


Epoch [1 / 20000] average reconstruction error: 0.1549149275 , kl(*beta): 0.0762657556 (0.0000762658), reconstruction loss: 0.1549072987
Epoch [101 / 20000] average reconstruction error: 0.0564641863 , kl(*beta): 9.4217435837 (0.0094217436), reconstruction loss: 0.0555220126
Epoch [201 / 20000] average reconstruction error: 0.0443855344 , kl(*beta): 13.1489133835 (0.0131489134), reconstruction loss: 0.0430706429
Epoch [301 / 20000] average reconstruction error: 0.0391188500 , kl(*beta): 17.3959865189 (0.0173959865), reconstruction loss: 0.0373792513
Epoch [401 / 20000] average reconstruction error: 0.0314095104 , kl(*beta): 20.3844422150 (0.0203844422), reconstruction loss: 0.0293710662
Epoch [501 / 20000] average reconstruction error: 0.0260086403 , kl(*beta): 21.7398106384 (0.0217398106), reconstruction loss: 0.0238346588
Epoch [601 / 20000] average reconstruction error: 0.0222258460 , kl(*beta): 21.4896871948 (0.0214896872), reconstruction loss: 0.0200768774
Epoch [701 / 20000] aver