In [1]:
%load_ext autoreload
%autoreload 2

import os
import sys
from datetime import datetime

sys.path.insert(1, '/home/cem/Documents/imps/src')

import numpy as np
import matplotlib.pyplot as plt
import open3d as o3d
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from torch.utils.tensorboard import SummaryWriter
from sklearn.metrics import jaccard_score
from tqdm import tqdm

from imps.sqn.model import SQN
from imps.sqn.data_utils import prepare_input
from imps.metrics import compute_iou
from imps.data.synth_scene import SynthSceneDataset, N_CLASS, CLASS_NAMES

def warn(*args, **kwargs):
    pass
import warnings
warnings.warn = warn

DEVICE = 'cuda'
DATA_ROOT = '/mnt/data.nas/shareddata/6G-futurelab/synthetic_room_dataset/rooms_04'
POS_EMBEDDING = True
LOGDIR = './logs'
LOG = True

# Classes taken from: https://github.com/autonomousvision/convolutional_occupancy_networks/blob/master/scripts/dataset_synthetic_room/build_dataset.py#L29
# https://gist.github.com/tejaskhot/15ae62827d6e43b91a4b0c5c850c168e

def get_occupancy_loss(logits, labels, pos_weight):
    n_batch = logits.shape[0]
    logits = logits.reshape(n_batch, -1)
    labels = labels.reshape(n_batch, -1)

    criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight, reduction='none')
    output_loss = criterion(logits, labels)
    output_loss = output_loss.mean()
    
    return output_loss

def get_semantic_loss(logits, labels, class_weights):
    # is_occupied is 0 for free space so free space won't be included in the loss
    # Walls are -1 and free space is > 4
    
    semantic_points = get_semantic_filter(labels)
    # This selection flattens the batch
    sem_logits = logits[semantic_points]
    sem_labels = labels[semantic_points]

    criterion = torch.nn.CrossEntropyLoss(weight=class_weights, reduction='none')
    output_loss = criterion(sem_logits, sem_labels)
    output_loss = output_loss.mean()
    
    return output_loss

def get_semantic_filter(labels):
    return ~((labels == -1) | (labels > 4))

def get_semantic_iou(logits, labels, n_class):
    preds = np.argmax(logits, axis=-1)
    ious = []
    
    for c in range(n_class):
        iou = jaccard_score((labels==c).astype(int), (preds==c).astype(int), pos_label=1)
        ious.append(iou)
        
    return np.array(ious)

def evaluate(sqn, eval_scene_data):
    with torch.no_grad():
        sqn.eval()

        eval_query = eval_scene_data['query'][0].unsqueeze(0)
        eval_occ_labels = eval_scene_data['query_occ'][0].unsqueeze(0)
        eval_sem_labels = eval_scene_data['query_semantics'][0].unsqueeze(0)

        occ_logits, sem_logits = sqn.forward(eval_scene_data['features'], eval_scene_data['input_points'], 
                                             eval_scene_data['input_neighbors'], eval_scene_data['input_pools'],
                                             eval_query)

        occ_pred = torch.nn.Sigmoid()(occ_logits).detach().cpu().numpy().squeeze()
        occ_pred = (occ_pred > 0.5).astype(np.int)
        occ_gold = eval_occ_labels.detach().cpu().numpy().squeeze()
        occ_iou = jaccard_score(occ_gold, occ_pred)

        semantic_filter = get_semantic_filter(eval_sem_labels)
        sem_pred = sem_logits[semantic_filter].squeeze().detach().cpu().numpy()
        sem_gold = eval_sem_labels[semantic_filter].squeeze().detach().cpu().numpy()

        sem_iou = get_semantic_iou(sem_pred, sem_gold, N_CLASS)
        existing_classes = (eval_scene_data['class_weights'] != 0).cpu().numpy()
        sem_miou = sem_iou[existing_classes].mean()
        
    return occ_iou, sem_miou

In [2]:
exp_name = 'test-rand'

now = datetime.now()
exp_name += f'-{int(now.timestamp())}'
exp_dir = os.path.join(LOGDIR, exp_name)

dataset = SynthSceneDataset(DATA_ROOT, DEVICE)
eval_scene_data = dataset.get_scene_data(dataset.train_dirs[3], '00', iou_nums=['00', '01'], seed=31)

sqn = SQN(d_feature=60, d_in=64, encoder_dims=[32, 64, 128], device=DEVICE, skip_connections=True, 
          second_head=5)
optimizer = optim.Adam(sqn.parameters(), lr=1e-3)
scheduler = StepLR(optimizer, step_size=250, gamma=0.9)

if LOG:
    writer = SummaryWriter(exp_dir)

In [3]:
for e in tqdm(range(10000)):
    sqn.train()
    optimizer.zero_grad()
    
    scene_data = dataset.get_scene_data(dataset.train_dirs[3], '00', iou_nums=['00', '01'])
    
    occ_logits, sem_logits  = sqn.forward(scene_data['features'], scene_data['input_points'], 
                                          scene_data['input_neighbors'], scene_data['input_pools'], 
                                          scene_data['query'])

    occ_loss = get_occupancy_loss(occ_logits, scene_data['query_occ'], scene_data['pos_w'])
    sem_loss = get_semantic_loss(sem_logits, scene_data['query_semantics'], scene_data['class_weights'])
        
    loss = occ_loss + 0.1*sem_loss
    loss.backward()
    optimizer.step()
    
    if LOG:
        writer.add_scalar('loss', occ_loss, e)
        writer.add_scalar('sem-loss', sem_loss, e)
    
    if (e % 100) == 0:
            train_occ_iou, train_sem_miou = evaluate(sqn, eval_scene_data)

            if LOG:
                writer.add_scalar('occ-iou', train_occ_iou, e)
                writer.add_scalar('sem-miou', train_sem_miou, e)
        
    if scheduler.get_last_lr()[-1] > 2e-5:
        scheduler.step()
    
    if LOG:
        writer.add_scalar('lr', scheduler.get_last_lr()[-1], e)
        torch.save(sqn.state_dict(), os.path.join(exp_dir, 'model'))

 16%|█████████▌                                                | 1648/10000 [2:34:01<13:00:36,  5.61s/it]


KeyboardInterrupt: 

In [None]:
# sqn.load_state_dict(torch.load('/mnt/data.nas/staff/eteke/sqn-single-experiments/occ-semantic-emb-skip-seed=31-lr=sch-1638545527/model'))
# sqn.load_state_dict(torch.load('/mnt/data.nas/staff/eteke/sqn-single-experiments/occ-semantic-emb-skip-seed=31-lr=sch-q-batch-1638895595/model'))
# eval_scene_data = dataset.get_scene_data(dataset.train_dirs[3], '00', ['03'])