# Train_vaegan

In [1]:
from __future__ import print_function
import argparse
import os
import random
import numpy as np
import torch
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
from dataset.dataset import RIODatasetSceneGraph, collate_fn_vaegan_points
from model.VAE import VAE
from model.atlasnet import AE_AtlasNet
from model.discriminators import BoxDiscriminator, ShapeAuxillary
from model.losses import bce_loss
from helpers.util import bool_flag

from model.losses import calculate_model_losses

import torch.nn.functional as F
import json

In [24]:
# standard hyperparameters
batchSize = 8
lr = 0.0001
nepoch = 101

# paths and filenames
outf = 'checkpoint' # output folder
model = '' # model path
dataset = './GT' # dataset path
dataset_3RScan = '/root/dev/G3D/3RScan' # 3RScan
label_file = 'labels.instances.align.annotated.ply' # label file name
exp = './experiments/shared_model_221119' # experiment name
path2atlas = './experiments/atlasnet/model_70.pth' # atlasnet model path

# GCN parameters
residual = True # residual connection in GCN
pooling = 'avg' # pooling method in GCN

# Dataset related
large = True # large set of shape class labels
use_splits = True # whether or not to split the data into training set and validation set
use_scene_rels = True # whether or not to connect all nodes to a root scene node
with_points = False
with_feats = True
shuffle_objs = True
num_points = 1024
rio27 = False
use_canonical = True
with_angles = True
num_box_params = 6
crop_floor = False

# Training and architecture related
workers = 4 # number of data loading workers
overfiting_debug = False
weight_D_box = 0.1 # Box discriminator
with_changes = True
with_shape_disc = True
with_manipulator = True
replace_latent = True
network_type = 'shared' # choice among 'dis', 'sln', 'mlp', 'shared'

## Prepare AtlasNet

In [36]:
print(torch.cuda.is_available())

True


In [35]:
saved_atlasnet_model = torch.load(path2atlas)
point_ae = AE_AtlasNet(num_points=1024, bottleneck_size=128, nb_primitives=25)

point_ae.load_state_dict(saved_atlasnet_model, strict=True) # load the model parameters

if torch.cuda.is_available():
    point_ae = point_ae.cuda()
point_ae.eval()

AE_AtlasNet(
  (encoder): Sequential(
    (0): PointNetfeat(
      (stn): STN3d(
        (conv1): Conv1d(3, 64, kernel_size=(1,), stride=(1,))
        (conv2): Conv1d(64, 128, kernel_size=(1,), stride=(1,))
        (conv3): Conv1d(128, 1024, kernel_size=(1,), stride=(1,))
        (fc1): Linear(in_features=1024, out_features=512, bias=True)
        (fc2): Linear(in_features=512, out_features=256, bias=True)
        (fc3): Linear(in_features=256, out_features=9, bias=True)
        (relu): ReLU()
      )
      (conv1): Conv1d(3, 64, kernel_size=(1,), stride=(1,))
      (conv2): Conv1d(64, 128, kernel_size=(1,), stride=(1,))
      (conv3): Conv1d(128, 1024, kernel_size=(1,), stride=(1,))
      (bn1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (bn2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (bn3): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): Linear(in_

## Dataloader with RIODatasetSceneGraph

In [26]:
dataset = RIODatasetSceneGraph(
        root='./GT',
        root_3rscan=dataset_3RScan,
        label_file=label_file,
        npoints=num_points,
        path2atlas=path2atlas,
        split='train_scans',
        shuffle_objs=(shuffle_objs and not overfiting_debug),
        use_points=with_points,
        use_scene_rels=use_scene_rels,
        with_changes=with_changes,
        vae_baseline=network_type == 'sln',
        with_feats=with_feats,
        large=large,
        atlas=point_ae,
        seed=False,
        use_splits=use_splits,
        use_rio27=rio27,
        use_canonical=use_canonical,
        crop_floor=crop_floor,
        center_scene_to_floor=crop_floor,
        recompute_feats=False)

collate_fn = collate_fn_vaegan_points

  0%|          | 12/3777 [00:00<00:40, 93.03it/s]

Checking for missing feats. This can be slow the first time.
This process needs to be only run once!


100%|██████████| 3777/3777 [01:01<00:00, 61.91it/s] 


In [34]:
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=batchSize,
collate_fn=collate_fn,
shuffle=(not overfiting_debug),
num_workers=int(workers))

num_classes = len(dataset.classes)
num_relationships = len(dataset.relationships)

print('number of classes: ', num_classes)
print('number of relationships: ', num_relationships)

number of classes:  161
number of relationships:  27


In [28]:
try:
    os.makedirs(outf)
except OSError:
    print('already exist!')
    pass

already exist!


## Models

### VAE Model

In [32]:
# instantiate the model
model = VAE(type=network_type, vocab=dataset.vocab, replace_latent=replace_latent,
           with_changes=with_changes, residual=residual, gconv_pooling=pooling,
           with_angles=with_angles, num_box_params=num_box_params)
if torch.cuda.is_available():
    model = model.cuda()
print(model)

VAE(
  (vae): Sg2ScVAEModel(
    (obj_embeddings_ec_box): Embedding(162, 128)
    (obj_embeddings_ec_shape): Embedding(162, 128)
    (pred_embeddings_ec_box): Embedding(27, 256)
    (pred_embeddings_ec_shape): Embedding(27, 256)
    (obj_embeddings_dc_box): Embedding(162, 256)
    (obj_embeddings_dc_man): Embedding(162, 256)
    (obj_embeddings_dc_shape): Embedding(162, 256)
    (pred_embeddings_dc_box): Embedding(27, 512)
    (pred_embeddings_dc_shape): Embedding(27, 512)
    (pred_embeddings_dc): Embedding(27, 256)
    (pred_embeddings_man_dc): Embedding(27, 768)
    (box_embeddings): Linear(in_features=6, out_features=96, bias=True)
    (shape_embeddings): Linear(in_features=128, out_features=128, bias=True)
    (angle_embeddings): Embedding(24, 32)
    (box_mean_var): Sequential(
      (0): Linear(in_features=256, out_features=512, bias=True)
      (1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Linear(in_features=5

### Relationship Discriminator
* instantiate a relationship discriminator that considers the boxes and the semantic labels
* if the loss weight is larger than zero, also create an optimizer for it

In [38]:
if weight_D_box > 0:
    boxD = BoxDiscriminator(6, num_relationships, num_classes)
    optimizerDbox = optim.Adam(filter(lambda p: p.requires_grad, boxD.parameters()), lr=lr, betas=(0.9, 0.999))
    boxD.cuda()
    boxD = boxD.train()
print(boxD)

BoxDiscriminator(
  (D): Sequential(
    (0): Linear(in_features=361, out_features=512, bias=True)
    (1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.01)
    (3): Linear(in_features=512, out_features=512, bias=True)
    (4): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): LeakyReLU(negative_slope=0.01)
    (6): Linear(in_features=512, out_features=1, bias=True)
    (7): Sigmoid()
  )
)


### Optimizer for model

In [41]:
params = filter(lambda p: p.requires_grad, list(model.parameters()))
optimizer = optim.Adam(params, lr=lr)
optimizer.step()
torch.autograd.set_detect_anomaly(True)
counter = 0

### Save parameters so that we can read them later on evaluation

In [40]:
if not os.path.join(exp + "/" + outf):
    os.makedirs(exp + "/" + outf)
    print('output folder created!')
else:
    print('output folder already exists!')

with open(os.path.join(exp, 'args.json'), 'w') as f:
    json.dump(__dict__)

output folder already exists!


## Training Loop

In [None]:
for epoch in range(0, nepoch):
    print('Epoch: {}/{}'.format(epoch, nepoch))
    for i, data in enumerate(dataloader, 0):
        # skip invalid data
        if data == -1:
            continue
        
        # get the data of encoder and decoder
        try:
            enc_objs, enc_triples, enc_tight_boxes, enc_objs_to_scene, enc_triples_to_scene = data['encoder']['objs'],\
                        data['encoder']['tripltes'], data['encoder']['boxes'], data['encoder']['obj_to_scene'], data['encoder']['tiple_to_scene']
            
            if with_feats:
                encoded_enc_points = data['encoder']['feats']
                encoded_enc_points = encoded_enc_points.cuda()
            elif with_points:
                enc_points = data['encoder']['points']
                enc_points = enc_points.cuda()
            
            dec_objs, dec_triples, dec_tight_boxes, dec_objs_to_scene, dec_triples_to_scene = data['decoder']['objs'],\
                        data['decoder']['tripltes'], data['decoder']['boxes'], data['decoder']['obj_to_scene'], data['decoder']['tiple_to_scene']
            
            if 'feats' in data['decoder']:
                encoded_dec_points = data['decoder']['feats']
                encoded_dec_points = encoded_dec_points.cuda()
            else:
                if 'points' in data['decoder']:
                    dec_points = data['decoder']['points']
                    dec_points = dec_points.cuda()
            
            # changed nodes
            missing_nodes = data['missing_nodes']
            manipulated_nodes = data['manipulated_nodes']
        
        except Exception as e:
            print('Exception', str(e))
            continue
        
        enc_objs, enc_triples, enc_tight_boxes = enc_objs.cuda(), enc_triples.cuda(), enc_tight_boxes.cuda()
        dec_objs, dec_triples, dec_tight_boxes = dec_objs.cuda(), dec_triples.cuda(), dec_tight_boxes.cuda()
        
        if with_points:
            enc_points, dec_points = enc_points.cuda(), dec_points.cuda()
        
        # mask : avoid batches with insufficient number of instances with valid shape classes
        mask = [ob in dataset.point_classes_idx for ob in dec_objs]
        if sum(mask) <= 1:
            continue
        
        # training
        optimizer.zero_grad()
        optimizerShapeAux.zero_grad()
        
        model = model.train()
        
        if weight_D_box > 0:
            optimizerDbox.zero_grad()
        
        # assume that with_feats == True and with_points == False
        
        # set all scene (dummy) nodes points encodings to zero
        enc_scene_nodes = enc_objs == 0
        dec_scene_nodes = dec_objs == 0
        encoded_enc_points[enc_scene_nodes] = torch.zeros([torch.sum(enc_scene_nodes), encoded_enc_points.shape[1]]).float().cuda()
        encoded_dec_points[dec_scene_nodes] = torch.zeros([torch.sum(dec_scene_nodes), encoded_dec_points.shape[1]]).float().cuda()
        
        # for the num_box_params
        if num_box_params == 7:
            # all parameters, including angle, procesed by the box_net
            enc_boxes = enc_tight_boxes
            dec_boxes = dec_tight_boxes
        elif num_box_params == 6:
            # no angle. this will be learned separately if with_angle is true -> corresponding to us
            enc_boxes = enc_tight_boxes[:, :6]
            dec_boxes = dec_tight_boxes[:, :6]
        elif num_box_params == 4:
            # height, centroid. assuming we want the other sizes to be estimated from the shape aspect ratio
            enc_boxes = enc_tight_boxes[:, 2:6]
            dec_boxes = dec_tight_boxes[:, 2:6]
        else:
            raise NotImplementedError
        
        # limit the angle bin range from 0 to 24
        enc_angles = enc_tight_boxes[:, 6].long() - 1
        enc_angles = torch.where(enc_angles > 0, enc_angles, torch.zeros_like(enc_angles))
        enc_angles = torch.where(enc_angles < 24, enc_angles, torch.zeros_like(enc_angles))
        dec_angles = dec_tight_boxes[:, 6].long() - 1
        dec_angles = torch.where(dec_angles > 0, dec_angles, torch.zeros_like(dec_angles))
        dec_angles = torch.where(dec_angles < 24, dec_angles, torch.zeros_like(dec_angles))
        
        # compute in the model
        attributes = None
        
        boxGloss = 0
        loss_genShape = 0
        loss_genShapeFake = 0
        loss_shape_fake_g = 0
        
        if with_manipulator:
            model_out = model.forward_mani(enc_objs, enc_triples, enc_boxes, enc_angles, encoded_enc_points, attributes, enc_objs_to_scene,
                                               dec_objs, dec_triples, dec_boxes, dec_angles, encoded_dec_points, attributes, dec_objs_to_scene,
                                               missing_nodes, manipulated_nodes)
            
            mu_box, logvar_box, mu_shape, logvar_shape, orig_gt_box, orig_gt_angle, orig_gt_shape, orig_box, orig_angle, orig_shape, \
            dec_man_enc_box_pred, dec_man_enc_angle_pred, dec_man_enc_shape_pred, keep = model_out
        else:
            model_out = model.forward_no_mani(dec_objs, dec_triples, dec_boxes, encoded_dec_points, angles=dec_angles,
                                  attributes=attributes)

            mu_box, logvar_box, mu_shape, logvar_shape, dec_man_enc_box_pred, dec_man_encd_angles_pred, \
            dec_man_enc_shape_pred = model_out
            
            # uses the decoder output directly (without manipulation)
            orig_gt_box = dec_boxes
            orig_box = dec_man_enc_box_pred

            orig_gt_shape = encoded_dec_points
            orig_shape = dec_man_enc_shape_pred

            orig_angle = dec_man_encd_angles_pred
            orig_gt_angle = dec_angles

            keep = []
            for i in range(len(dec_man_enc_box_pred)):
                keep.append(1)
            keep = torch.from_numpy(np.asarray(keep).reshape(-1, 1)).float().cuda()

        if args.with_manipulator and args.with_shape_disc and dec_man_enc_shape_pred is not None:
            shape_logits_fake_d, probs_fake_d = shapeClassifier(dec_man_enc_shape_pred[mask].detach())
            shape_logits_fake_g, probs_fake_g = shapeClassifier(dec_man_enc_shape_pred[mask])
            shape_logits_real, probs_real = shapeClassifier(encoded_dec_points[mask].detach())

            # auxiliary loss. can the discriminator predict the correct class for the generated shape?
            loss_shape_real = torch.nn.functional.cross_entropy(shape_logits_real, dec_objs[mask])
            loss_shape_fake_d = torch.nn.functional.cross_entropy(shape_logits_fake_d, dec_objs[mask])
            loss_shape_fake_g = torch.nn.functional.cross_entropy(shape_logits_fake_g, dec_objs[mask])
            # standard discriminator loss
            loss_genShapeFake = bce_loss(probs_fake_g, torch.ones_like(probs_fake_g))
            loss_dShapereal = bce_loss(probs_real, torch.ones_like(probs_real))
            loss_dShapefake = bce_loss(probs_fake_d, torch.zeros_like(probs_fake_d))

            loss_dShape = loss_dShapefake + loss_dShapereal + loss_shape_real + loss_shape_fake_d
            loss_genShape = loss_genShapeFake + loss_shape_fake_g
            loss_dShape.backward()
            optimizerShapeAux.step()

        vae_loss_box, vae_losses_box = calculate_model_losses(args,
                                                                orig_gt_box,
                                                                orig_box,
                                                                name='box', withangles=args.with_angles, angles_pred=orig_angle,
                                                                mu=mu_box, logvar=logvar_box, angles=orig_gt_angle,
                                                                KL_weight=0.1, writer=writer, counter=counter)
        if dec_man_enc_shape_pred is not None:
            vae_loss_shape, vae_losses_shape = calculate_model_losses(args,
                                                                    orig_gt_shape,
                                                                    orig_shape,
                                                                    name='shape', withangles=False,
                                                                    mu=mu_shape, logvar=logvar_shape,
                                                                    KL_weight=0.1, writer=writer, counter=counter)
        else:
            # set shape loss to 0 if we are only predicting layout
            vae_loss_shape, vae_losses_shape = 0, 0

        if args.with_manipulator and args.with_changes:
            oriented_gt_boxes = torch.cat([dec_boxes], dim=1)
            boxes_pred_in = keep * oriented_gt_boxes + (1-keep) * dec_man_enc_box_pred

            if args.weight_D_box == 0:
                # Generator loss
                boxGloss = 0
                # Discriminator loss
                gamma = 0.1
                boxDloss_real = 0
                boxDloss_fake = 0
                reg_loss = 0
            else:
                logits, _ = boxD(dec_objs, dec_triples, boxes_pred_in, keep)
                logits_fake, reg_fake = boxD(dec_objs, dec_triples, boxes_pred_in.detach(), keep, with_grad=True,
                                           is_real=False)
                logits_real, reg_real = boxD(dec_objs, dec_triples, oriented_gt_boxes, with_grad=True, is_real=True)
                # Generator loss
                boxGloss = bce_loss(logits, torch.ones_like(logits))
                # Discriminator loss
                gamma = 0.1
                boxDloss_real = bce_loss(logits_real, torch.ones_like(logits_real))
                boxDloss_fake = bce_loss(logits_fake, torch.zeros_like(logits_fake))
                # Regularization by gradient penalty
                reg_loss = torch.mean(reg_real + reg_fake)

            # gradient penalty
            # disc_reg = discriminator_regularizer(logits_real, in_real, logits_fake, in_fake)
            boxDloss = boxDloss_fake + boxDloss_real + (gamma/2.0) * reg_loss
            optimizerDbox.zero_grad()
            boxDloss.backward()
            # gradient clip
            # torch.nn.utils.clip_grad_norm_(boxD.parameters(), 5.0)
            optimizerDbox.step()

        loss = vae_loss_box + vae_loss_shape + 0.1 * loss_genShape
        if args.with_changes:
               loss = loss + args.weight_D_box * boxGloss #+ b_loss

        # optimize
        loss.backward()

        # Cap the occasional super mutant gradient spikes
        # Do now a gradient step and plot the losses
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)

        for group in optimizer.param_groups:
            for p in group['params']:
                if p.grad is not None and p.requires_grad and torch.isnan(p.grad).any():
                    print('NaN grad in step {}.'.format(counter))
                    p.grad[torch.isnan(p.grad)] = 0
        optimizer.step()
        counter += 1

        if counter % 100 == 0:
            print("loss at {}: box {:.4f}\tshape {:.4f}\tdiscr RealFake {:.4f}\t discr Classifcation "
                  "{:.4f}".format(counter, vae_loss_box, vae_loss_shape, loss_genShapeFake,
                                                          loss_shape_fake_g))