In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import dnnlib
from dnnlib.util import EasyDict
from utils import *
import wandb
import os
import click

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler

import matplotlib.pyplot as plt
from matplotlib.patches import Circle
import copy
import sys

import math
import time
import gc
import traceback
import itertools
import numpy as np

import torch.multiprocessing
torch.multiprocessing.set_sharing_strategy('file_system')

from networks.custom_modules import MultiShapeStyleLinearApproach
from networks.stylegan import StyleGANDiscriminator

from einops import rearrange

In [3]:
from torch.profiler import profile, record_function, ProfilerActivity

In [4]:
torch.backends.cudnn.benchmark = True               # Improves training speed.
torch.backends.cuda.matmul.allow_tf32 = False       # Improves numerical accuracy.
torch.backends.cudnn.allow_tf32 = False             # Improves numerical accuracy.

In [5]:
c = dnnlib.EasyDict()
c.steps = 6

c.G_args = dnnlib.EasyDict()
c.G_args.hidden_dim = 256
c.G_args.hidden_action_dim = 512
c.G_args.planner_type = 'mlp'
c.G_args.num_actions = c.steps
c.G_args.shape_style_dim = 256
c.G_args.shape_encoding_dim = 64
c.G_args.shape_progressive = True
c.G_args.shape_num_sinusoidals = 6
c.G_args.use_textures = False
c.G_args.texture_progressive = True
c.G_args.texture_num_sinusoidals = 4
c.G_args.to_rgb_type = 'styled'
c.G_args.output_size = 128
c.feature_volume_size = 32
c.G_args.const_in = False
c.G_args.size_in = 8
c.G_args.c_model = 32
c.G_args.planner_lr_mul = 0.01
c.G_args.shape_library_lr_mul = 1
c.G_args.texture_library_lr_mul = 1


c.device = torch.device('cuda')
generator = MultiShapeStyleLinearApproach(**c.G_args).to(c.device)
discriminator_dual = StyleGANDiscriminator(c.G_args.output_size, in_channels=6, channel_multiplier=1).to(c.device) 

In [6]:
batch_size = 16
noise = torch.randn(batch_size, c.G_args.hidden_dim).to(c.device)
noise_generator = torch.zeros
input_noise = [
    noise_generator(batch_size, *_noise.shape[1:], device=c.device) for _noise in generator.to_rgb.get_noise_like()
]

In [7]:
size = c.G_args.output_size

In [16]:
def trace_handler(p):
    output = p.key_averages().table(sort_by="self_cuda_time_total", row_limit=10)
    print(output)
    p.export_chrome_trace("trace_" + str(p.step_num) + ".json")


forward_t = []
loss_t = []
backward_t = []

with profile(
    activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
#     with_stack=True,
#     with_flops=True,
#     with_modules=True,
#     schedule=torch.profiler.schedule(
#         skip_first=4,
#         wait=3,
#         warmup=1,
#         active=3,
#         repeat=2),
#     on_trace_ready=trace_handler
) as prof:
    for _ in range(16):
        time_1 = time.time()
        with record_function("model_forward"):
            input_noise = [
                noise_generator(batch_size, *_noise.shape[1:], device=c.device) for _noise in generator.to_rgb.get_noise_like()
            ]
            generator_output = generator(noise, c.feature_volume_size, steps=c.steps, noise=input_noise, mode='sample')
        torch.cuda.synchronize()
        time_2 = time.time()
        with record_function("loss_calculation"):
            samples = generator_output[f'renders/{c.steps}/render']
            low_res_samples = generator_output[f'renders/{c.steps}/low_res_render']
            masks = generator_output[f'renders/{c.steps}/segmentation']

            # Dual discriminator loss
            fake_pred_dual = discriminator_dual(torch.cat([samples, F.interpolate(low_res_samples, (size, size), mode='bilinear')], dim=1))
            generator_dual_loss = F.softplus(-fake_pred_dual).mean()

            # Mask consistency loss
            mask_consistency_loss = torch.zeros([], device=c.device)
            original_masks = rearrange(masks[0], 'b H W d -> b d H W')
            for i in range(1, len(masks) - 1):
                low_res_masks = rearrange(masks[i], 'b H W d -> b d H W')
                high_res_masks = rearrange(masks[i + 1], 'b H W d -> b d H W')

                if high_res_masks.shape[-1] > original_masks.shape[-1]:
                    high_res_into_low_res_masks = F.interpolate(high_res_masks, (low_res_masks.shape[-2], low_res_masks.shape[-1]), mode='bilinear')
                    mask_consistency_loss += F.mse_loss(low_res_masks, high_res_into_low_res_masks)

            generator_loss = generator_dual_loss + mask_consistency_loss * 1.0
        torch.cuda.synchronize()
        time_3 = time.time()

        with record_function("model_backward"):
            generator_loss.mean().backward()
        torch.cuda.synchronize()
        time_4 = time.time()
        
        forward_t.append(time_2 - time_1)
        loss_t.append(time_3 - time_2)
        backward_t.append(time_4 - time_3)
        
#         prof.step()
    
# prof.export_chrome_trace("trace.json")
# prof.export_stacks("profiler_stacks.txt", "self_cuda_time_total")
# print(prof.key_averages(group_by_stack_n=5).table(sort_by="self_cuda_time_total", row_limit=5))

print(f"model_forward : \t {np.mean(forward_t)} +- {np.std(forward_t)}")
print(f"loss_calc     : \t {np.mean(loss_t)} +- {np.std(loss_t)}")
print(f"model_backward: \t {np.mean(backward_t)} +- {np.std(backward_t)}")

  "See the documentation of nn.Upsample for details.".format(mode)


model_forward : 	 0.430645227432251 +- 0.015960674464912235
loss_calc     : 	 0.04855097830295563 +- 0.0008227695754198076
model_backward: 	 0.8713723570108414 +- 0.004989595833739248


In [18]:
forward_t = []
loss_t = []
backward_t = []

for _ in range(16):
    time_1 = time.time()
    with record_function("model_forward"):
        input_noise = [
            noise_generator(batch_size, *_noise.shape[1:], device=c.device) for _noise in generator.to_rgb.get_noise_like()
        ]
        generator_output = generator(noise, c.feature_volume_size, steps=c.steps, noise=input_noise, mode='sample')
    torch.cuda.synchronize()
    time_2 = time.time()
    with record_function("loss_calculation"):
        samples = generator_output[f'renders/{c.steps}/render']
        low_res_samples = generator_output[f'renders/{c.steps}/low_res_render']
        masks = generator_output[f'renders/{c.steps}/segmentation']

        # Dual discriminator loss
        fake_pred_dual = discriminator_dual(torch.cat([samples, F.interpolate(low_res_samples, (size, size), mode='bilinear')], dim=1))
        generator_dual_loss = F.softplus(-fake_pred_dual).mean()

        # Mask consistency loss
        mask_consistency_loss = torch.zeros([], device=c.device)
        original_masks = rearrange(masks[0], 'b H W d -> b d H W')
        for i in range(1, len(masks) - 1):
            low_res_masks = rearrange(masks[i], 'b H W d -> b d H W')
            high_res_masks = rearrange(masks[i + 1], 'b H W d -> b d H W')

            if high_res_masks.shape[-1] > original_masks.shape[-1]:
                high_res_into_low_res_masks = F.interpolate(high_res_masks, (low_res_masks.shape[-2], low_res_masks.shape[-1]), mode='bilinear')
                mask_consistency_loss += F.mse_loss(low_res_masks, high_res_into_low_res_masks)

        generator_loss = generator_dual_loss + mask_consistency_loss * 1.0
    torch.cuda.synchronize()
    time_3 = time.time()

    with record_function("model_backward"):
        generator_loss.mean().backward()
    torch.cuda.synchronize()
    time_4 = time.time()

    forward_t.append(time_2 - time_1)
    loss_t.append(time_3 - time_2)
    backward_t.append(time_4 - time_3)
    
print(f"model_forward : \t {np.mean(forward_t)} +- {np.std(forward_t)}")
print(f"loss_calc     : \t {np.mean(loss_t)} +- {np.std(loss_t)}")
print(f"model_backward: \t {np.mean(backward_t)} +- {np.std(backward_t)}")

model_forward : 	 0.42754605412483215 +- 0.01359851692440814
loss_calc     : 	 0.049103811383247375 +- 0.0009719820773462114
model_backward: 	 0.8699076175689697 +- 0.0037340833535047836
