In [1]:
import gin
%matplotlib inline

from src.aae.Renderer import Renderer
from src.aae.utils import TimerManager 
from src.aae.Visualizations import plot_img, plot_batch



import gin
import torch
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm

from src.aae.models import AugmentedAutoEncoder
from src.aae.dataset import OnlineRenderer
from src.aae.TransformUtils import aae_paper_views


ITERS = 10

@gin.configurable
def train_aae(num_workers: int=gin.REQUIRED,
              num_train_iters: int=gin.REQUIRED,
              cache_save_interval: int=gin.REQUIRED,
              device: str=gin.REQUIRED):
    
    
    overall = "Overall"
    ds_init = "Dataset Initialization"
    epoch_iter = "Epoch Iteration"
    data_cast = "Cast to device"
    
    labels = (overall, ds_init, epoch_iter, data_cast)
    timers = tuple(TimerManager() for _ in labels)
    
    def t_update(l, labels=labels, timers=timers):
        t_idx = labels.index(l)
        timers[t_idx](l)
        
    t_update(overall)
    
    BS = 64
    # batch_iters = aae_paper_views // BS
    batch_iters = ITERS
    
    ######################################
    # Optimization Step 
    ######################################
    for epoch in tqdm(range(num_train_iters), desc="AAE Training"):
        # Create new dataset for each epoch
        
        t_update(ds_init)
        dataset = OnlineRenderer()
        dl = DataLoader(dataset,
                   batch_size=dataset.batch_size,
                   shuffle=True,
                   num_workers=0)
        t_update(ds_init)
        
        t_update(epoch_iter)
        steps = 0 
        for data in tqdm(dl, desc=f"Epoch: {epoch + 1}", leave=False):
                t_update(data_cast)
                aug, gt, _pose = data
                aug.to(device)
                gt.to(device)
                t_update(data_cast)
                steps += 1
                
                if steps >= 10:
                    break
        t_update(epoch_iter)
        
    t_update(overall)
    
    
    for l, t in zip(labels, timers):
        print(t)

# gin.add_config_file_search_path('..')
gin.enter_interactive_mode()
gin.parse_config_file('../config/train/linemod/[test]_obj_0001.gin')

r = Renderer("../data/t_less/models_cad/obj_01.ply")

## Timing for Rendering a Sinlge Images

In [None]:
timer = TimerManager()
cur_name = "One Image"


for _ in tqdm(range(ITERS), desc=f"Timing for Rendering Single Image without Augmentations"):
    timer(cur_name)
    aug_imgs, img, pose = r.produce_batch_images(batch_size=1)
    timer(cur_name)

print(timer)
plot_img(img)

## Timing for Rendering a Batch of Images

In [None]:
cur_name = "Batch Image"
BS = 64
timer = TimerManager()

for _ in tqdm(range(ITERS), desc=f"Timing for Rendering Batch of {BS} Images with Augmentations"):
    timer(cur_name)
    aug_imgs, imgs, _ = r.produce_batch_images(batch_size=BS)
    timer(cur_name)
    
print(timer)
plot_batch(aug_imgs)

## Timing for Data Iteration

In [2]:
train_aae(num_train_iters=1)

AAE Training:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch: 1:   0%|          | 0/1424 [00:00<?, ?it/s]

Overall
Avg: 248.5 | std: 0.0 | # of ticks 1
Dataset Initialization
Avg: 0.5988 | std: 0.0 | # of ticks 1
Epoch Iteration
Avg: 247.9 | std: 0.0 | # of ticks 1
Cast to device
Avg: 0.2377 | std: 0.6626 | # of ticks 10
