### Data Download & Extract


In [None]:
import os

if not os.path.exists("kaggle/"):
  !pip install gdown
  !gdown --id 1LLdQUA-wRLrRCE2Tha9dcHyviy1y_oaP
  !mkdir -p /root/.kaggle
  !cp kaggle.json /root/.kaggle/kaggle.json
  !chmod 600 /root/.kaggle/kaggle.json
  !kaggle datasets download -d ronak555/shapenetcorerendering-part1
  !kaggle datasets download -d ronak555/shapenetvox32
  !unzip shapenetcorerendering-part1.zip
  !unzip shapenetvox32.zip
  !pip install tensorboardX
  !gdown --id 1zpk_-nH3o5uaNPSzYcr9YWZgasiVp8JP

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  inflating: kaggle/tmp/ShapeNetVox32/04379243/c3b49ab37876c7f219fb4103277a6b93/model.binvox  
  inflating: kaggle/tmp/ShapeNetVox32/04379243/c3c467718eb9b2a313f96345312df593/model.binvox  
  inflating: kaggle/tmp/ShapeNetVox32/04379243/c3c635d741fab1615f0b5ee0fb357b4c/model.binvox  
  inflating: kaggle/tmp/ShapeNetVox32/04379243/c3cd2a7f997a6a40f3017d945b17b4d6/model.binvox  
  inflating: kaggle/tmp/ShapeNetVox32/04379243/c3e43144fd61c56f19fb4103277a6b93/model.binvox  
  inflating: kaggle/tmp/ShapeNetVox32/04379243/c3e5380fb498f749f79675bb6cb63c97/model.binvox  
  inflating: kaggle/tmp/ShapeNetVox32/04379243/c4071718e45630ee5510d59f3ab1ed64/model.binvox  
  inflating: kaggle/tmp/ShapeNetVox32/04379243/c40a88d13709eba91f30b807ae39b61d/model.binvox  
  inflating: kaggle/tmp/ShapeNetVox32/04379243/c416034b70b1d339838b39398d1628f8/model.binvox  
  inflating: kaggle/tmp/ShapeNetVox32/04379243/c418195771c7625945821c000807c3b1/

### Import Libraries

In [None]:
import cv2
import json
import numpy as np
import os
import math
import random
import scipy.io
import scipy.ndimage
import sys
import torch
import torch.backends.cudnn
import torch.utils.data
import torch.utils.data.dataset
import torchvision.models
from datetime import datetime as dt
from tensorboardX import SummaryWriter
from time import time
from enum import Enum, unique
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

### Utility Code Section

#### Binvox 

##### Voxel Model Class Definition



###### Voxel

In [None]:
class Voxels(object):
    """ Holds a binvox model.
    data is either a three-dimensional numpy boolean array (dense representation)
    or a two-dimensional numpy float array (coordinate representation).
    dims, translate and scale are the model metadata.
    dims are the voxel dimensions, e.g. [32, 32, 32] for a 32x32x32 model.
    scale and translate relate the voxels to the original model coordinates.
    To translate voxel coordinates i, j, k to original coordinates x, y, z:
    x_n = (i+.5)/dims[0]
    y_n = (j+.5)/dims[1]
    z_n = (k+.5)/dims[2]
    x = scale*x_n + translate[0]
    y = scale*y_n + translate[1]
    z = scale*z_n + translate[2]
    """
    def __init__(self, data, dims, translate, scale, axis_order):
        self.data = data
        self.dims = dims
        self.translate = translate
        self.scale = scale
        assert (axis_order in ('xzy', 'xyz'))
        self.axis_order = axis_order

    def clone(self):
        data = self.data.copy()
        dims = self.dims[:]
        translate = self.translate[:]
        return Voxels(data, dims, translate, self.scale, self.axis_order)

    def write(self, fp):
        write(self, fp)

def write(voxel_model, fp):
    """ Write binary binvox format.
    Note that when saving a model in sparse (coordinate) format, it is first
    converted to dense format.
    Doesn't check if the model is 'sane'.
    """
    if voxel_model.data.ndim == 2:
        # TODO avoid conversion to dense
        dense_voxel_data = sparse_to_dense(voxel_model.data, voxel_model.dims).astype(int)
    else:
        dense_voxel_data = voxel_model.data.astype(int)

    file_header = [
        '#binvox 1\n',
        'dim %s\n' % ' '.join(map(str, voxel_model.dims)),
        'translate %s\n' % ' '.join(map(str, voxel_model.translate)),
        'scale %s\n' % str(voxel_model.scale), 'data\n'
    ]

    for fh in file_header:
        fp.write(fh.encode('latin-1'))

    if voxel_model.axis_order not in ('xzy', 'xyz'):
        raise ValueError('[ERROR] Unsupported voxel model axis order')

    if voxel_model.axis_order == 'xzy':
        voxels_flat = dense_voxel_data.flatten()
    elif voxel_model.axis_order == 'xyz':
        voxels_flat = np.transpose(dense_voxel_data, (0, 2, 1)).flatten()

    # keep a sort of state machine for writing run length encoding
    state = voxels_flat[0]
    ctr = 0
    for c in voxels_flat:
        if c == state:
            ctr += 1
            # if ctr hits max, dump
            if ctr == 255:
                fp.write(chr(state).encode('latin-1'))
                fp.write(chr(ctr).encode('latin-1'))
                ctr = 0
        else:
            # if switch state, dump
            fp.write(chr(state).encode('latin-1'))
            fp.write(chr(ctr).encode('latin-1'))
            state = c
            ctr = 1
    # flush out remainders
    if ctr > 0:
        fp.write(chr(state).encode('latin-1'))
        fp.write(chr(ctr).encode('latin-1'))

###### Write

In [None]:
def write(voxel_model, fp):
    """ Write binary binvox format.
    Note that when saving a model in sparse (coordinate) format, it is first
    converted to dense format.
    Doesn't check if the model is 'sane'.
    """
    if voxel_model.data.ndim == 2:
        # TODO avoid conversion to dense
        dense_voxel_data = sparse_to_dense(voxel_model.data, voxel_model.dims).astype(int)
    else:
        dense_voxel_data = voxel_model.data.astype(int)

    file_header = [
        '#binvox 1\n',
        'dim %s\n' % ' '.join(map(str, voxel_model.dims)),
        'translate %s\n' % ' '.join(map(str, voxel_model.translate)),
        'scale %s\n' % str(voxel_model.scale), 'data\n'
    ]

    for fh in file_header:
        fp.write(fh.encode('latin-1'))

    if voxel_model.axis_order not in ('xzy', 'xyz'):
        raise ValueError('[ERROR] Unsupported voxel model axis order')

    if voxel_model.axis_order == 'xzy':
        voxels_flat = dense_voxel_data.flatten()
    elif voxel_model.axis_order == 'xyz':
        voxels_flat = np.transpose(dense_voxel_data, (0, 2, 1)).flatten()

    # keep a sort of state machine for writing run length encoding
    state = voxels_flat[0]
    ctr = 0
    for c in voxels_flat:
        if c == state:
            ctr += 1
            # if ctr hits max, dump
            if ctr == 255:
                fp.write(chr(state).encode('latin-1'))
                fp.write(chr(ctr).encode('latin-1'))
                ctr = 0
        else:
            # if switch state, dump
            fp.write(chr(state).encode('latin-1'))
            fp.write(chr(ctr).encode('latin-1'))
            state = c
            ctr = 1
    # flush out remainders
    if ctr > 0:
        fp.write(chr(state).encode('latin-1'))
        fp.write(chr(ctr).encode('latin-1'))

##### 3D Model reading utility & Visualization

###### Reading

In [None]:
def read_as_3d_array(fp, fix_coords=True):
    """ Read binary binvox format as array.
    Returns the model with accompanying metadata.
    Voxels are stored in a three-dimensional numpy array, which is simple and
    direct, but may use a lot of memory for large models. (Storage requirements
    are 8*(d^3) bytes, where d is the dimensions of the binvox model. Numpy
    boolean arrays use a byte per element).
    Doesn't do any checks on input except for the '#binvox' line.
    """
    dims, translate, scale = read_header(fp)
    raw_data = np.frombuffer(fp.read(), dtype=np.uint8)
    # if just using reshape() on the raw data:
    # indexing the array as array[i,j,k], the indices map into the
    # coords as:
    # i -> x
    # j -> z
    # k -> y
    # if fix_coords is true, then data is rearranged so that
    # mapping is
    # i -> x
    # j -> y
    # k -> z
    values, counts = raw_data[::2], raw_data[1::2]
    data = np.repeat(values, counts).astype(np.int32)
    data = data.reshape(dims)
    if fix_coords:
        # xzy to xyz TODO the right thing
        data = np.transpose(data, (0, 2, 1))
        axis_order = 'xyz'
    else:
        axis_order = 'xzy'
    return Voxels(data, dims, translate, scale, axis_order)

def read_header(fp):
    """ Read binvox header. Mostly meant for internal use.
    """
    line = fp.readline().strip()
    if not line.startswith(b'#binvox'):
        raise IOError('[ERROR] Not a binvox file')
    dims = list(map(int, fp.readline().strip().split(b' ')[1:]))
    translate = list(map(float, fp.readline().strip().split(b' ')[1:]))
    scale = list(map(float, fp.readline().strip().split(b' ')[1:]))[0]
    fp.readline()
    return dims, translate, scale

###### Visualization

In [None]:
def get_volume_views(volume, save_dir, n_itr):
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    volume = volume.squeeze().__ge__(0.5)
    fig = plt.figure()
    ax = fig.gca(projection=Axes3D.name)
    # ax.set_aspect('equal')
    ax.voxels(volume, edgecolor="k")

    save_path = os.path.join(save_dir, 'voxels-%06d.png' % n_itr)
    plt.savefig(save_path, bbox_inches='tight')
    plt.close()
    return cv2.imread(save_path)

#### Data Loader & Transform

##### Camera methods

###### Compute extrinsic matrix

In [None]:
MAX_CAMERA_DISTANCE = 1.75

def compute_extrinsic_matrix(
    azimuth: float, elevation: float, distance: float
):  # pragma: no cover
    """
    Copied from meshrcnn codebase:
    https://github.com/facebookresearch/meshrcnn/blob/main/shapenet/utils/coords.py#L96
    Compute 4x4 extrinsic matrix that converts from homogeneous world coordinates
    to homogeneous camera coordinates. We assume that the camera is looking at the
    origin.
    Used in R2N2 Dataset when computing calibration matrices.
    Args:
        azimuth: Rotation about the z-axis, in degrees.
        elevation: Rotation above the xy-plane, in degrees.
        distance: Distance from the origin.
    Returns:
        FloatTensor of shape (4, 4).
    """
    azimuth, elevation, distance = float(azimuth), float(elevation), float(distance)

    az_rad = -math.pi * azimuth / 180.0
    el_rad = -math.pi * elevation / 180.0
    sa = math.sin(az_rad)
    ca = math.cos(az_rad)
    se = math.sin(el_rad)
    ce = math.cos(el_rad)
    R_world2obj = torch.tensor(
        [[ca * ce, sa * ce, -se], [-sa, ca, 0], [ca * se, sa * se, ce]]
    )
    R_obj2cam = torch.tensor([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]])
    R_world2cam = R_obj2cam.mm(R_world2obj)
    cam_location = torch.tensor([[distance, 0, 0]]).t()
    T_world2cam = -(R_obj2cam.mm(cam_location))
    RT = torch.cat([R_world2cam, T_world2cam], dim=1)
    RT = torch.cat([RT, torch.tensor([[0.0, 0, 0, 1]])])

    # Georgia: For some reason I cannot fathom, when Blender loads a .obj file it
    # rotates the model 90 degrees about the x axis. To compensate for this quirk we
    # roll that rotation into the extrinsic matrix here
    rot = torch.tensor([[1, 0, 0, 0], [0, 0, -1, 0], [0, 1, 0, 0], [0, 0, 0, 1]])
    RT = RT.mm(rot.to(RT))

    return RT

###### Compute camera caliberation

In [None]:
def compute_camera_calibration( RT):
    """
    Helper function for calculating rotation and translation matrices from ShapeNet
    to camera transformation and ShapeNet to PyTorch3D transformation.
    Args:
        RT: Extrinsic matrix that performs ShapeNet world view to camera view
            transformation.
    Returns:
        R: Rotation matrix of shape (3, 3).
        T: Translation matrix of shape (3).
    """
    # Transform the mesh vertices from shapenet world to pytorch3d world.
    shapenet_to_pytorch3d = torch.tensor(
        [
            [-1.0, 0.0, 0.0, 0.0],
            [0.0, 1.0, 0.0, 0.0],
            [0.0, 0.0, -1.0, 0.0],
            [0.0, 0.0, 0.0, 1.0],
        ],
        dtype=torch.float32,
    )
    RT = torch.transpose(RT, 0, 1).mm(shapenet_to_pytorch3d)  # (4, 4)
    # Extract rotation and translation matrices from RT.
    R = RT[:3, :3]
    T = RT[3, :3]
    return R, T

##### ShapeNet Dataset Class Definition

In [None]:
@unique
class DatasetType(Enum):
    TRAIN = 0
    TEST = 1
    VAL = 2
    
class ShapeNetDataset(torch.utils.data.dataset.Dataset):
    """ShapeNetDataset class used for PyTorch DataLoader"""
    def __init__(self, dataset_type, file_list, n_views_rendering, transforms=None):
        self.dataset_type = dataset_type
        self.file_list = file_list
        self.transforms = transforms
        self.n_views_rendering = n_views_rendering

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        taxonomy_name, sample_name, rendering_images, volume = self.get_datum(idx)

        if self.transforms:
            rendering_images = self.transforms(rendering_images)

        return taxonomy_name, sample_name, rendering_images, volume

    def set_n_views_rendering(self, n_views_rendering):
        self.n_views_rendering = n_views_rendering

    def get_datum(self, idx):
        taxonomy_name = self.file_list[idx]['taxonomy_name']
        sample_name = self.file_list[idx]['sample_name']
        rendering_image_paths = self.file_list[idx]['rendering_images']
        volume_path = self.file_list[idx]['volume']

        # Get data of rendering images
        if self.dataset_type == DatasetType.TRAIN:
            selected_rendering_image_paths = [
                rendering_image_paths[i]
                for i in random.sample(range(len(rendering_image_paths)), self.n_views_rendering)
            ]
        else:
            selected_rendering_image_paths = [rendering_image_paths[i] for i in range(self.n_views_rendering)]

        rendering_images = []
        for image_path in selected_rendering_image_paths:
            rendering_image = cv2.imread(image_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
            if len(rendering_image.shape) < 3:
                print('[FATAL] %s It seems that there is something wrong with the image file %s' %
                      (dt.now(), image_path))
                sys.exit(2)

            rendering_images.append(rendering_image)

        # Get data of volume
        _, suffix = os.path.splitext(volume_path)

        if suffix == '.mat':
            volume = scipy.io.loadmat(volume_path)
            volume = volume['Volume'].astype(np.float32)
        elif suffix == '.binvox':
            with open(volume_path, 'rb') as f:
                volume = read_as_3d_array(f)
                volume = volume.data.astype(np.float32)

        return taxonomy_name, sample_name, np.asarray(rendering_images), volume


# //////////////////////////////// = End of ShapeNetDataset Class Definition = ///////////////////////////////// #

##### ShapeNet Dataloader Class Definition

In [None]:
class ShapeNetDataLoader:
    def __init__(self, cfg):
        self.dataset_taxonomy = None
        self.rendering_image_path_template = cfg.DATASETS.SHAPENET.RENDERING_PATH
        self.volume_path_template = cfg.DATASETS.SHAPENET.VOXEL_PATH

        # Load all taxonomies of the dataset
        with open(cfg.DATASETS.SHAPENET.TAXONOMY_FILE_PATH, encoding='utf-8') as file:
            self.dataset_taxonomy = json.loads(file.read())

    def get_dataset(self, dataset_type, n_views_rendering, transforms=None):
        files = []

        # Load data for each category
        for taxonomy in self.dataset_taxonomy:
            taxonomy_folder_name = taxonomy['taxonomy_id']
            print('[INFO] %s Collecting files of Taxonomy[ID=%s, Name=%s]' %
                  (dt.now(), taxonomy['taxonomy_id'], taxonomy['taxonomy_name']))
            samples = []
            if dataset_type == DatasetType.TRAIN:
                samples = taxonomy['train']
            elif dataset_type == DatasetType.TEST:
                samples = taxonomy['test']
            elif dataset_type == DatasetType.VAL:
                samples = taxonomy['val']

            files.extend(self.get_files_of_taxonomy(taxonomy_folder_name, samples))

        print('[INFO] %s Complete collecting files of the dataset. Total files: %d.' % (dt.now(), len(files)))
        return ShapeNetDataset(dataset_type, files, n_views_rendering, transforms)

    def get_files_of_taxonomy(self, taxonomy_folder_name, samples):
        files_of_taxonomy = []

        for sample_idx, sample_name in enumerate(samples):
            # Get file path of volumes
            volume_file_path = self.volume_path_template % (taxonomy_folder_name, sample_name)
            if not os.path.exists(volume_file_path):
                print('[WARN] %s Ignore sample %s/%s since volume file not exists.' %
                      (dt.now(), taxonomy_folder_name, sample_name))
                continue
            # Get file list of rendering images
            img_file_path = self.rendering_image_path_template % (taxonomy_folder_name, sample_name, 0)
            img_folder = os.path.dirname(img_file_path)

            metadata_file_path = os.path.join(img_folder,"rendering_metadata.txt")
            total_views = len(os.listdir(img_folder))
            rendering_image_indexes = range(total_views)
            rendering_images_file_path = []
            for image_idx in rendering_image_indexes:
                img_file_path = self.rendering_image_path_template % (taxonomy_folder_name, sample_name, image_idx)
                if not os.path.exists(img_file_path):
                    continue

                rendering_images_file_path.append(img_file_path)

            if len(rendering_images_file_path) == 0:
                print('[WARN] %s Ignore sample %s/%s since image files not exists.' %
                      (dt.now(), taxonomy_folder_name, sample_name))
                continue

            # Append to the list of rendering images
            files_of_taxonomy.append({
                'taxonomy_name': taxonomy_folder_name,
                'sample_name': sample_name,
                'rendering_images': rendering_images_file_path,
                'volume': volume_file_path,
                'metadata':metadata_file_path,
            })

            # Report the progress of reading dataset
            # if sample_idx % 500 == 499 or sample_idx == n_samples - 1:
            #     print('[INFO] %s Collecting %d of %d' % (dt.now(), sample_idx + 1, n_samples))

        return files_of_taxonomy


# /////////////////////////////// = End of ShapeNetDataLoader Class Definition = /////////////////////////////// #

##### Data Transform

###### Compose

In [None]:
class Compose(object):
    """ Composes several transforms together.
    For example:
    >>> transforms.Compose([
    >>>     transforms.RandomBackground(),
    >>>     transforms.CenterCrop(127, 127, 3),
    >>>  ])
    """
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, rendering_images, bounding_box=None):
        for t in self.transforms:
            if t.__class__.__name__ == 'RandomCrop' or t.__class__.__name__ == 'CenterCrop':
                rendering_images = t(rendering_images, bounding_box)
            else:
                rendering_images = t(rendering_images)

        return rendering_images


###### To Tensor

In [None]:
class ToTensor(object):
    """
    Convert a PIL Image or numpy.ndarray to tensor.
    Converts a PIL Image or numpy.ndarray (H x W x C) in the range [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
    """
    def __call__(self, rendering_images):
        assert (isinstance(rendering_images, np.ndarray))
        array = np.transpose(rendering_images, (0, 3, 1, 2))
        # handle numpy array
        tensor = torch.from_numpy(array)

        # put it from HWC to CHW format
        return tensor.float()

###### Normalize

In [None]:
class Normalize(object):
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, rendering_images):
        assert (isinstance(rendering_images, np.ndarray))
        rendering_images -= self.mean
        rendering_images /= self.std

        return rendering_images

###### RandomPermuteRGB

In [None]:
class RandomPermuteRGB(object):
    def __call__(self, rendering_images):
        assert (isinstance(rendering_images, np.ndarray))

        random_permutation = np.random.permutation(3)
        for img_idx, img in enumerate(rendering_images):
            rendering_images[img_idx] = img[..., random_permutation]

        return rendering_images

###### Center Crop

In [None]:
class CenterCrop(object):
    def __init__(self, img_size, crop_size):
        """Set the height and weight before and after cropping"""
        self.img_size_h = img_size[0]
        self.img_size_w = img_size[1]
        self.crop_size_h = crop_size[0]
        self.crop_size_w = crop_size[1]

    def __call__(self, rendering_images, bounding_box=None):
        if len(rendering_images) == 0:
            return rendering_images

        crop_size_c = rendering_images[0].shape[2]
        processed_images = np.empty(shape=(0, self.img_size_h, self.img_size_w, crop_size_c))
        for img_idx, img in enumerate(rendering_images):
            img_height, img_width, _ = img.shape

            if bounding_box is not None:
                bounding_box = [
                    bounding_box[0] * img_width,
                    bounding_box[1] * img_height,
                    bounding_box[2] * img_width,
                    bounding_box[3] * img_height
                ]  # yapf: disable

                # Calculate the size of bounding boxes
                bbox_width = bounding_box[2] - bounding_box[0]
                bbox_height = bounding_box[3] - bounding_box[1]
                bbox_x_mid = (bounding_box[2] + bounding_box[0]) * .5
                bbox_y_mid = (bounding_box[3] + bounding_box[1]) * .5

                # Make the crop area as a square
                square_object_size = max(bbox_width, bbox_height)
                x_left = int(bbox_x_mid - square_object_size * .5)
                x_right = int(bbox_x_mid + square_object_size * .5)
                y_top = int(bbox_y_mid - square_object_size * .5)
                y_bottom = int(bbox_y_mid + square_object_size * .5)

                # If the crop position is out of the image, fix it with padding
                pad_x_left = 0
                if x_left < 0:
                    pad_x_left = -x_left
                    x_left = 0
                pad_x_right = 0
                if x_right >= img_width:
                    pad_x_right = x_right - img_width + 1
                    x_right = img_width - 1
                pad_y_top = 0
                if y_top < 0:
                    pad_y_top = -y_top
                    y_top = 0
                pad_y_bottom = 0
                if y_bottom >= img_height:
                    pad_y_bottom = y_bottom - img_height + 1
                    y_bottom = img_height - 1

                # Padding the image and resize the image
                processed_image = np.pad(img[y_top:y_bottom + 1, x_left:x_right + 1],
                                         ((pad_y_top, pad_y_bottom), (pad_x_left, pad_x_right), (0, 0)),
                                         mode='edge')
                processed_image = cv2.resize(processed_image, (self.img_size_w, self.img_size_h))
            else:
                if img_height > self.crop_size_h and img_width > self.crop_size_w:
                    x_left = int(img_width - self.crop_size_w) // 2
                    x_right = int(x_left + self.crop_size_w)
                    y_top = int(img_height - self.crop_size_h) // 2
                    y_bottom = int(y_top + self.crop_size_h)
                else:
                    x_left = 0
                    x_right = img_width
                    y_top = 0
                    y_bottom = img_height

                processed_image = cv2.resize(img[y_top:y_bottom, x_left:x_right], (self.img_size_w, self.img_size_h))

            processed_images = np.append(processed_images, [processed_image], axis=0)
            # Debug
            # fig = plt.figure()
            # ax1 = fig.add_subplot(1, 2, 1)
            # ax1.imshow(img)
            # if not bounding_box is None:
            #     rect = patches.Rectangle((bounding_box[0], bounding_box[1]),
            #                              bbox_width,
            #                              bbox_height,
            #                              linewidth=1,
            #                              edgecolor='r',
            #                              facecolor='none')
            #     ax1.add_patch(rect)
            # ax2 = fig.add_subplot(1, 2, 2)
            # ax2.imshow(processed_image)
            # plt.show()
        return processed_images


###### RandomCrop

In [None]:
class RandomCrop(object):
    def __init__(self, img_size, crop_size):
        """Set the height and weight before and after cropping"""
        self.img_size_h = img_size[0]
        self.img_size_w = img_size[1]
        self.crop_size_h = crop_size[0]
        self.crop_size_w = crop_size[1]

    def __call__(self, rendering_images, bounding_box=None):
        if len(rendering_images) == 0:
            return rendering_images

        crop_size_c = rendering_images[0].shape[2]
        processed_images = np.empty(shape=(0, self.img_size_h, self.img_size_w, crop_size_c))
        for img_idx, img in enumerate(rendering_images):
            img_height, img_width, _ = img.shape

            if bounding_box is not None:
                bounding_box = [
                    bounding_box[0] * img_width,
                    bounding_box[1] * img_height,
                    bounding_box[2] * img_width,
                    bounding_box[3] * img_height
                ]  # yapf: disable

                # Calculate the size of bounding boxes
                bbox_width = bounding_box[2] - bounding_box[0]
                bbox_height = bounding_box[3] - bounding_box[1]
                bbox_x_mid = (bounding_box[2] + bounding_box[0]) * .5
                bbox_y_mid = (bounding_box[3] + bounding_box[1]) * .5

                # Make the crop area as a square
                square_object_size = max(bbox_width, bbox_height)
                square_object_size = square_object_size * random.uniform(0.8, 1.2)

                x_left = int(bbox_x_mid - square_object_size * random.uniform(.4, .6))
                x_right = int(bbox_x_mid + square_object_size * random.uniform(.4, .6))
                y_top = int(bbox_y_mid - square_object_size * random.uniform(.4, .6))
                y_bottom = int(bbox_y_mid + square_object_size * random.uniform(.4, .6))

                # If the crop position is out of the image, fix it with padding
                pad_x_left = 0
                if x_left < 0:
                    pad_x_left = -x_left
                    x_left = 0
                pad_x_right = 0
                if x_right >= img_width:
                    pad_x_right = x_right - img_width + 1
                    x_right = img_width - 1
                pad_y_top = 0
                if y_top < 0:
                    pad_y_top = -y_top
                    y_top = 0
                pad_y_bottom = 0
                if y_bottom >= img_height:
                    pad_y_bottom = y_bottom - img_height + 1
                    y_bottom = img_height - 1

                # Padding the image and resize the image
                processed_image = np.pad(img[y_top:y_bottom + 1, x_left:x_right + 1],
                                         ((pad_y_top, pad_y_bottom), (pad_x_left, pad_x_right), (0, 0)),
                                         mode='edge')
                processed_image = cv2.resize(processed_image, (self.img_size_w, self.img_size_h))
            else:
                if img_height > self.crop_size_h and img_width > self.crop_size_w:
                    x_left = int(img_width - self.crop_size_w) // 2
                    x_right = int(x_left + self.crop_size_w)
                    y_top = int(img_height - self.crop_size_h) // 2
                    y_bottom = int(y_top + self.crop_size_h)
                else:
                    x_left = 0
                    x_right = img_width
                    y_top = 0
                    y_bottom = img_height

                processed_image = cv2.resize(img[y_top:y_bottom, x_left:x_right], (self.img_size_w, self.img_size_h))

            processed_images = np.append(processed_images, [processed_image], axis=0)

        return processed_images


###### Random Flip

In [None]:
class RandomFlip(object):
    def __call__(self, rendering_images):
        assert (isinstance(rendering_images, np.ndarray))

        for img_idx, img in enumerate(rendering_images):
            if random.randint(0, 1):
                rendering_images[img_idx] = np.fliplr(img)

        return rendering_images

###### Color Jitter

In [None]:
class ColorJitter(object):
    def __init__(self, brightness, contrast, saturation):
        self.brightness = brightness
        self.contrast = contrast
        self.saturation = saturation

    def __call__(self, rendering_images):
        if len(rendering_images) == 0:
            return rendering_images

        # Allocate new space for storing processed images
        img_height, img_width, img_channels = rendering_images[0].shape
        processed_images = np.empty(shape=(0, img_height, img_width, img_channels))

        # Randomize the value of changing brightness, contrast, and saturation
        brightness = 1 + np.random.uniform(low=-self.brightness, high=self.brightness)
        contrast = 1 + np.random.uniform(low=-self.contrast, high=self.contrast)
        saturation = 1 + np.random.uniform(low=-self.saturation, high=self.saturation)

        # Randomize the order of changing brightness, contrast, and saturation
        attr_names = ['brightness', 'contrast', 'saturation']
        attr_values = [brightness, contrast, saturation]    # The value of changing attrs
        attr_indexes = np.array(range(len(attr_names)))    # The order of changing attrs
        np.random.shuffle(attr_indexes)

        for img_idx, img in enumerate(rendering_images):
            processed_image = img
            for idx in attr_indexes:
                processed_image = self._adjust_image_attr(processed_image, attr_names[idx], attr_values[idx])

            processed_images = np.append(processed_images, [processed_image], axis=0)
            # print('ColorJitter', np.mean(ori_img), np.mean(processed_image))
            # fig = plt.figure(figsize=(8, 4))
            # ax1 = fig.add_subplot(1, 2, 1)
            # ax1.imshow(ori_img)
            # ax2 = fig.add_subplot(1, 2, 2)
            # ax2.imshow(processed_image)
            # plt.show()
        return processed_images

    def _adjust_image_attr(self, img, attr_name, attr_value):
        """
        Adjust or randomize the specified attribute of the image

        Args:
            img: Image in BGR format
                Numpy array of shape (h, w, 3)
            attr_name: Image attribute to adjust or randomize
                       'brightness', 'saturation', or 'contrast'
            attr_value: the alpha for blending is randomly drawn from [1 - d, 1 + d]

        Returns:
            Output image in BGR format
            Numpy array of the same shape as input
        """
        gs = self._bgr_to_gray(img)

        if attr_name == 'contrast':
            img = self._alpha_blend(img, np.mean(gs[:, :, 0]), attr_value)
        elif attr_name == 'saturation':
            img = self._alpha_blend(img, gs, attr_value)
        elif attr_name == 'brightness':
            img = self._alpha_blend(img, 0, attr_value)
        else:
            raise NotImplementedError(attr_name)
        return img

    def _bgr_to_gray(self, bgr):
        """
        Convert a RGB image to a grayscale image
            Differences from cv2.cvtColor():
                1. Input image can be float
                2. Output image has three repeated channels, other than a single channel

        Args:
            bgr: Image in BGR format
                 Numpy array of shape (h, w, 3)

        Returns:
            gs: Grayscale image
                Numpy array of the same shape as input; the three channels are the same
        """
        ch = 0.114 * bgr[:, :, 0] + 0.587 * bgr[:, :, 1] + 0.299 * bgr[:, :, 2]
        gs = np.dstack((ch, ch, ch))
        return gs

    def _alpha_blend(self, im1, im2, alpha):
        """
        Alpha blending of two images or one image and a scalar

        Args:
            im1, im2: Image or scalar
                Numpy array and a scalar or two numpy arrays of the same shape
            alpha: Weight of im1
                Float ranging usually from 0 to 1

        Returns:
            im_blend: Blended image -- alpha * im1 + (1 - alpha) * im2
                Numpy array of the same shape as input image
        """
        im_blend = alpha * im1 + (1 - alpha) * im2
        return im_blend


###### Random Noise

In [None]:
class RandomNoise(object):
    def __init__(self,
                 noise_std,
                 eigvals=(0.2175, 0.0188, 0.0045),
                 eigvecs=((-0.5675, 0.7192, 0.4009), (-0.5808, -0.0045, -0.8140), (-0.5836, -0.6948, 0.4203))):
        self.noise_std = noise_std
        self.eigvals = np.array(eigvals)
        self.eigvecs = np.array(eigvecs)

    def __call__(self, rendering_images):
        alpha = np.random.normal(loc=0, scale=self.noise_std, size=3)
        noise_rgb = \
            np.sum(
                np.multiply(
                    np.multiply(
                        self.eigvecs,
                        np.tile(alpha, (3, 1))
                    ),
                    np.tile(self.eigvals, (3, 1))
                ),
                axis=1
            )

        # Allocate new space for storing processed images
        img_height, img_width, img_channels = rendering_images[0].shape
        assert (img_channels == 3), "Please use RandomBackground to normalize image channels"
        processed_images = np.empty(shape=(0, img_height, img_width, img_channels))

        for img_idx, img in enumerate(rendering_images):
            processed_image = img[:, :, ::-1]    # BGR -> RGB
            for i in range(img_channels):
                processed_image[:, :, i] += noise_rgb[i]

            processed_image = processed_image[:, :, ::-1]    # RGB -> BGR
            processed_images = np.append(processed_images, [processed_image], axis=0)
            # from copy import deepcopy
            # ori_img = deepcopy(img)
            # print(noise_rgb, np.mean(processed_image), np.mean(ori_img))
            # print('RandomNoise', np.mean(ori_img), np.mean(processed_image))
            # fig = plt.figure(figsize=(8, 4))
            # ax1 = fig.add_subplot(1, 2, 1)
            # ax1.imshow(ori_img)
            # ax2 = fig.add_subplot(1, 2, 2)
            # ax2.imshow(processed_image)
            # plt.show()
        return processed_images


###### Random Background

In [None]:
class RandomBackground(object):
    def __init__(self, random_bg_color_range, random_bg_folder_path=None):
        self.random_bg_color_range = random_bg_color_range
        self.random_bg_files = []
        if random_bg_folder_path is not None:
            self.random_bg_files = os.listdir(random_bg_folder_path)
            self.random_bg_files = [os.path.join(random_bg_folder_path, rbf) for rbf in self.random_bg_files]

    def __call__(self, rendering_images):
        if len(rendering_images) == 0:
            return rendering_images

        img_height, img_width, img_channels = rendering_images[0].shape
        # If the image has the alpha channel, add the background
        if not img_channels == 4:
            return rendering_images

        # Generate random background
        r, g, b = np.array([
            np.random.randint(self.random_bg_color_range[i][0], self.random_bg_color_range[i][1] + 1) for i in range(3)
        ]) / 255.

        random_bg = None
        if len(self.random_bg_files) > 0:
            random_bg_file_path = random.choice(self.random_bg_files)
            random_bg = cv2.imread(random_bg_file_path).astype(np.float32) / 255.

        # Apply random background
        processed_images = np.empty(shape=(0, img_height, img_width, img_channels - 1))
        for img_idx, img in enumerate(rendering_images):
            alpha = (np.expand_dims(img[:, :, 3], axis=2) == 0).astype(np.float32)
            img = img[:, :, :3]
            bg_color = random_bg if random.randint(0, 1) and random_bg is not None else np.array([[[r, g, b]]])
            img = alpha * bg_color + (1 - alpha) * img

            processed_images = np.append(processed_images, [img], axis=0)

        return processed_images

#### Training Utility

###### Save Checkpoints

In [None]:
def save_checkpoints(cfg, file_path, epoch_idx, encoder, encoder_solver, decoder, decoder_solver, refiner,
                     refiner_solver, merger, merger_solver, best_iou, best_epoch):
    print('[INFO] %s Saving checkpoint to %s ...' % (dt.now(), file_path))
    checkpoint = {
        'epoch_idx': epoch_idx,
        'best_iou': best_iou,
        'best_epoch': best_epoch,
        'encoder_state_dict': encoder.state_dict(),
        'encoder_solver_state_dict': encoder_solver.state_dict(),
        'decoder_state_dict': decoder.state_dict(),
        'decoder_solver_state_dict': decoder_solver.state_dict()
    }

    if cfg.NETWORK.USE_REFINER:
        checkpoint['refiner_state_dict'] = refiner.state_dict()
        checkpoint['refiner_solver_state_dict'] = refiner_solver.state_dict()
    if cfg.NETWORK.USE_MERGER:
        checkpoint['merger_state_dict'] = merger.state_dict()
        checkpoint['merger_solver_state_dict'] = merger_solver.state_dict()

    torch.save(checkpoint, file_path)


##### Average Meter

In [None]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

##### Init weights, count, cuda

In [None]:
def var_or_cuda(x):
    if torch.cuda.is_available():
        x = x.cuda(non_blocking=True)

    return x

In [None]:
def init_weights(m):
    if type(m) == torch.nn.Conv2d or type(m) == torch.nn.Conv3d or type(m) == torch.nn.ConvTranspose3d:
        torch.nn.init.kaiming_normal_(m.weight)
        if m.bias is not None:
            torch.nn.init.constant_(m.bias, 0)
    elif type(m) == torch.nn.BatchNorm2d or type(m) == torch.nn.BatchNorm3d:
        torch.nn.init.constant_(m.weight, 1)
        torch.nn.init.constant_(m.bias, 0)
    elif type(m) == torch.nn.Linear:
        torch.nn.init.normal_(m.weight, 0, 0.01)
        torch.nn.init.constant_(m.bias, 0)

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters())

### Model Declaration

#### Decoder

In [None]:
class Decoder(torch.nn.Module):
    def __init__(self, cfg):
        super(Decoder, self).__init__()
        self.cfg = cfg

        # Layer Definition
        self.layer1 = torch.nn.Sequential(
            torch.nn.ConvTranspose3d(2048, 512, kernel_size=4, stride=2, bias=cfg.NETWORK.TCONV_USE_BIAS, padding=1),
            torch.nn.BatchNorm3d(512),
            torch.nn.ReLU()
        )
        self.layer2 = torch.nn.Sequential(
            torch.nn.ConvTranspose3d(512, 128, kernel_size=4, stride=2, bias=cfg.NETWORK.TCONV_USE_BIAS, padding=1),
            torch.nn.BatchNorm3d(128),
            torch.nn.ReLU()
        )
        self.layer3 = torch.nn.Sequential(
            torch.nn.ConvTranspose3d(128, 32, kernel_size=4, stride=2, bias=cfg.NETWORK.TCONV_USE_BIAS, padding=1),
            torch.nn.BatchNorm3d(32),
            torch.nn.ReLU()
        )
        self.layer4 = torch.nn.Sequential(
            torch.nn.ConvTranspose3d(32, 8, kernel_size=4, stride=2, bias=cfg.NETWORK.TCONV_USE_BIAS, padding=1),
            torch.nn.BatchNorm3d(8),
            torch.nn.ReLU()
        )
        self.layer5 = torch.nn.Sequential(
            torch.nn.ConvTranspose3d(8, 1, kernel_size=1, bias=cfg.NETWORK.TCONV_USE_BIAS),
            torch.nn.Sigmoid()
        )

    def forward(self, image_features):
        image_features = image_features.permute(1, 0, 2, 3, 4).contiguous()
        image_features = torch.split(image_features, 1, dim=0)
        gen_volumes = []
        raw_features = []

        for features in image_features:
            gen_volume = features.view(-1, 2048, 2, 2, 2)
            # print(gen_volume.size())   # torch.Size([batch_size, 2048, 2, 2, 2])
            gen_volume = self.layer1(gen_volume)
            # print(gen_volume.size())   # torch.Size([batch_size, 512, 4, 4, 4])
            gen_volume = self.layer2(gen_volume)
            # print(gen_volume.size())   # torch.Size([batch_size, 128, 8, 8, 8])
            gen_volume = self.layer3(gen_volume)
            # print(gen_volume.size())   # torch.Size([batch_size, 32, 16, 16, 16])
            gen_volume = self.layer4(gen_volume)
            raw_feature = gen_volume
            # print(gen_volume.size())   # torch.Size([batch_size, 8, 32, 32, 32])
            gen_volume = self.layer5(gen_volume)
            # print(gen_volume.size())   # torch.Size([batch_size, 1, 32, 32, 32])
            raw_feature = torch.cat((raw_feature, gen_volume), dim=1)
            # print(raw_feature.size())  # torch.Size([batch_size, 9, 32, 32, 32])

            gen_volumes.append(torch.squeeze(gen_volume, dim=1))
            raw_features.append(raw_feature)

        gen_volumes = torch.stack(gen_volumes).permute(1, 0, 2, 3, 4).contiguous()
        raw_features = torch.stack(raw_features).permute(1, 0, 2, 3, 4, 5).contiguous()
        # print(gen_volumes.size())      # torch.Size([batch_size, n_views, 32, 32, 32])
        # print(raw_features.size())     # torch.Size([batch_size, n_views, 9, 32, 32, 32])
        return raw_features, gen_volumes

#### Encoder

In [None]:
class Encoder(torch.nn.Module):
    def __init__(self, cfg):
        super(Encoder, self).__init__()
        self.cfg = cfg

        # Layer Definition
        vgg16_bn = torchvision.models.vgg16_bn(pretrained=True)
        self.vgg = torch.nn.Sequential(*list(vgg16_bn.features.children()))[:27]
        self.layer1 = torch.nn.Sequential(
            torch.nn.Conv2d(512, 512, kernel_size=3),
            torch.nn.BatchNorm2d(512),
            torch.nn.ELU(),
        )
        self.layer2 = torch.nn.Sequential(
            torch.nn.Conv2d(512, 512, kernel_size=3),
            torch.nn.BatchNorm2d(512),
            torch.nn.ELU(),
            torch.nn.MaxPool2d(kernel_size=3)
        )
        self.layer3 = torch.nn.Sequential(
            torch.nn.Conv2d(512, 256, kernel_size=1),
            torch.nn.BatchNorm2d(256),
            torch.nn.ELU()
        )

        # Don't update params in VGG16
        for param in vgg16_bn.parameters():
            param.requires_grad = False

    def forward(self, rendering_images):
        # print(rendering_images.size())  # torch.Size([batch_size, n_views, img_c, img_h, img_w])
        rendering_images = rendering_images.permute(1, 0, 2, 3, 4).contiguous()
        rendering_images = torch.split(rendering_images, 1, dim=0)
        image_features = []

        for img in rendering_images:
            features = self.vgg(img.squeeze(dim=0))
            # print(features.size())    # torch.Size([batch_size, 512, 28, 28])
            features = self.layer1(features)
            # print(features.size())    # torch.Size([batch_size, 512, 26, 26])
            features = self.layer2(features)
            # print(features.size())    # torch.Size([batch_size, 512, 24, 24])
            features = self.layer3(features)
            # print(features.size())    # torch.Size([batch_size, 256, 8, 8])
            image_features.append(features)

        image_features = torch.stack(image_features).permute(1, 0, 2, 3, 4).contiguous()
        # print(image_features.size())  # torch.Size([batch_size, n_views, 256, 8, 8])
        return image_features

#### Merger

In [None]:
class Merger(torch.nn.Module):
    def __init__(self, cfg):
        super(Merger, self).__init__()
        self.cfg = cfg

        # Layer Definition
        self.layer1 = torch.nn.Sequential(
            torch.nn.Conv3d(9, 16, kernel_size=3, padding=1),
            torch.nn.BatchNorm3d(16),
            torch.nn.LeakyReLU(cfg.NETWORK.LEAKY_VALUE)
        )
        self.layer2 = torch.nn.Sequential(
            torch.nn.Conv3d(16, 8, kernel_size=3, padding=1),
            torch.nn.BatchNorm3d(8),
            torch.nn.LeakyReLU(cfg.NETWORK.LEAKY_VALUE)
        )
        self.layer3 = torch.nn.Sequential(
            torch.nn.Conv3d(8, 4, kernel_size=3, padding=1),
            torch.nn.BatchNorm3d(4),
            torch.nn.LeakyReLU(cfg.NETWORK.LEAKY_VALUE)
        )
        self.layer4 = torch.nn.Sequential(
            torch.nn.Conv3d(4, 2, kernel_size=3, padding=1),
            torch.nn.BatchNorm3d(2),
            torch.nn.LeakyReLU(cfg.NETWORK.LEAKY_VALUE)
        )
        self.layer5 = torch.nn.Sequential(
            torch.nn.Conv3d(2, 1, kernel_size=3, padding=1),
            torch.nn.BatchNorm3d(1),
            torch.nn.LeakyReLU(cfg.NETWORK.LEAKY_VALUE)
        )

    def forward(self, raw_features, coarse_volumes):
        n_views_rendering = coarse_volumes.size(1)
        raw_features = torch.split(raw_features, 1, dim=1)
        volume_weights = []

        for i in range(n_views_rendering):
            raw_feature = torch.squeeze(raw_features[i], dim=1)
            # print(raw_feature.size())       # torch.Size([batch_size, 9, 32, 32, 32])

            volume_weight = self.layer1(raw_feature)
            # print(volume_weight.size())     # torch.Size([batch_size, 16, 32, 32, 32])
            volume_weight = self.layer2(volume_weight)
            # print(volume_weight.size())     # torch.Size([batch_size, 8, 32, 32, 32])
            volume_weight = self.layer3(volume_weight)
            # print(volume_weight.size())     # torch.Size([batch_size, 4, 32, 32, 32])
            volume_weight = self.layer4(volume_weight)
            # print(volume_weight.size())     # torch.Size([batch_size, 2, 32, 32, 32])
            volume_weight = self.layer5(volume_weight)
            # print(volume_weight.size())     # torch.Size([batch_size, 1, 32, 32, 32])

            volume_weight = torch.squeeze(volume_weight, dim=1)
            # print(volume_weight.size())     # torch.Size([batch_size, 32, 32, 32])
            volume_weights.append(volume_weight)

        volume_weights = torch.stack(volume_weights).permute(1, 0, 2, 3, 4).contiguous()
        volume_weights = torch.softmax(volume_weights, dim=1)
        # print(volume_weights.size())        # torch.Size([batch_size, n_views, 32, 32, 32])
        # print(coarse_volumes.size())        # torch.Size([batch_size, n_views, 32, 32, 32])
        coarse_volumes = coarse_volumes * volume_weights
        coarse_volumes = torch.sum(coarse_volumes, dim=1)

        return torch.clamp(coarse_volumes, min=0, max=1)

#### Refiner

In [None]:
class Refiner(torch.nn.Module):
    def __init__(self, cfg):
        super(Refiner, self).__init__()
        self.cfg = cfg

        # Layer Definition
        self.layer1 = torch.nn.Sequential(
            torch.nn.Conv3d(1, 32, kernel_size=4, padding=2),
            torch.nn.BatchNorm3d(32),
            torch.nn.LeakyReLU(cfg.NETWORK.LEAKY_VALUE),
            torch.nn.MaxPool3d(kernel_size=2)
        )
        self.layer2 = torch.nn.Sequential(
            torch.nn.Conv3d(32, 64, kernel_size=4, padding=2),
            torch.nn.BatchNorm3d(64),
            torch.nn.LeakyReLU(cfg.NETWORK.LEAKY_VALUE),
            torch.nn.MaxPool3d(kernel_size=2)
        )
        self.layer3 = torch.nn.Sequential(
            torch.nn.Conv3d(64, 128, kernel_size=4, padding=2),
            torch.nn.BatchNorm3d(128),
            torch.nn.LeakyReLU(cfg.NETWORK.LEAKY_VALUE),
            torch.nn.MaxPool3d(kernel_size=2)
        )
        self.layer4 = torch.nn.Sequential(
            torch.nn.Linear(8192, 2048),
            torch.nn.ReLU()
        )
        self.layer5 = torch.nn.Sequential(
            torch.nn.Linear(2048, 8192),
            torch.nn.ReLU()
        )
        self.layer6 = torch.nn.Sequential(
            torch.nn.ConvTranspose3d(128, 64, kernel_size=4, stride=2, bias=cfg.NETWORK.TCONV_USE_BIAS, padding=1),
            torch.nn.BatchNorm3d(64),
            torch.nn.ReLU()
        )
        self.layer7 = torch.nn.Sequential(
            torch.nn.ConvTranspose3d(64, 32, kernel_size=4, stride=2, bias=cfg.NETWORK.TCONV_USE_BIAS, padding=1),
            torch.nn.BatchNorm3d(32),
            torch.nn.ReLU()
        )
        self.layer8 = torch.nn.Sequential(
            torch.nn.ConvTranspose3d(32, 1, kernel_size=4, stride=2, bias=cfg.NETWORK.TCONV_USE_BIAS, padding=1),
            torch.nn.Sigmoid()
        )

    def forward(self, coarse_volumes):
        volumes_32_l = coarse_volumes.view((-1, 1, self.cfg.CONST.N_VOX, self.cfg.CONST.N_VOX, self.cfg.CONST.N_VOX))
        # print(volumes_32_l.size())       # torch.Size([batch_size, 1, 32, 32, 32])
        volumes_16_l = self.layer1(volumes_32_l)
        # print(volumes_16_l.size())       # torch.Size([batch_size, 32, 16, 16, 16])
        volumes_8_l = self.layer2(volumes_16_l)
        # print(volumes_8_l.size())        # torch.Size([batch_size, 64, 8, 8, 8])
        volumes_4_l = self.layer3(volumes_8_l)
        # print(volumes_4_l.size())        # torch.Size([batch_size, 128, 4, 4, 4])
        flatten_features = self.layer4(volumes_4_l.view(-1, 8192))
        # print(flatten_features.size())   # torch.Size([batch_size, 2048])
        flatten_features = self.layer5(flatten_features)
        # print(flatten_features.size())   # torch.Size([batch_size, 8192])
        volumes_4_r = volumes_4_l + flatten_features.view(-1, 128, 4, 4, 4)
        # print(volumes_4_r.size())        # torch.Size([batch_size, 128, 4, 4, 4])
        volumes_8_r = volumes_8_l + self.layer6(volumes_4_r)
        # print(volumes_8_r.size())        # torch.Size([batch_size, 64, 8, 8, 8])
        volumes_16_r = volumes_16_l + self.layer7(volumes_8_r)
        # print(volumes_16_r.size())       # torch.Size([batch_size, 32, 16, 16, 16])
        volumes_32_r = (volumes_32_l + self.layer8(volumes_16_r)) * 0.5
        # print(volumes_32_r.size())       # torch.Size([batch_size, 1, 32, 32, 32])

        return volumes_32_r.view((-1, self.cfg.CONST.N_VOX, self.cfg.CONST.N_VOX, self.cfg.CONST.N_VOX))

### Global Configurations

In [None]:
def globalConfig():
  from easydict import EasyDict as edict
  config = edict()
  config.DATASETS                                = edict()
  config.DATASETS.SHAPENET                       = edict()
  config.DATASETS.SHAPENET.TAXONOMY_FILE_PATH    = './datasets/ShapeNet.json'
  config.DATASETS.SHAPENET.RENDERING_PATH        = '/home/hzxie/Datasets/ShapeNet/ShapeNetRendering/%s/%s/rendering/%02d.png'
  config.DATASETS.SHAPENET.VOXEL_PATH            = '/home/hzxie/Datasets/ShapeNet/ShapeNetVox32/%s/%s/model.binvox'

  config.DATASET                                 = edict()
  config.DATASET.MEAN                            = [0.5, 0.5, 0.5]
  config.DATASET.STD                             = [0.5, 0.5, 0.5]
  config.DATASET.TRAIN_DATASET                   = 'ShapeNet'
  config.DATASET.TEST_DATASET                    = 'ShapeNet'

  # Common
  #
  config.CONST                                   = edict()
  config.CONST.DEVICE                            = '0'
  config.CONST.RNG_SEED                          = 0
  config.CONST.IMG_W                             = 224       # Image width for input
  config.CONST.IMG_H                             = 224       # Image height for input
  config.CONST.N_VOX                             = 32
  config.CONST.BATCH_SIZE                        = 64
  config.CONST.N_VIEWS_RENDERING                 = 1         # Dummy property for Pascal 3D
  config.CONST.CROP_IMG_W                        = 128       # Dummy property for Pascal 3D
  config.CONST.CROP_IMG_H                        = 128       # Dummy property for Pascal 3D

  #
  # Directories
  config.DIR                                     = edict()
  config.DIR.OUT_PATH                            = './output'
  # config.DIR.RANDOM_BG_PATH                      = '/home/hzxie/Datasets/SUN2012/JPEGImages'

  #
  # Network
  #
  config.NETWORK                                 = edict()
  config.NETWORK.LEAKY_VALUE                     = .2
  config.NETWORK.TCONV_USE_BIAS                  = False
  config.NETWORK.USE_REFINER                     = True
  config.NETWORK.USE_MERGER                      = True

  #
  # Training
  #
  config.TRAIN                                   = edict()
  config.TRAIN.RESUME_TRAIN                      = False
  config.TRAIN.NUM_WORKER                        = 4             # number of data workers
  config.TRAIN.NUM_EPOCHES                       = 250
  config.TRAIN.BRIGHTNESS                        = .4
  config.TRAIN.CONTRAST                          = .4
  config.TRAIN.SATURATION                        = .4
  config.TRAIN.NOISE_STD                         = .1
  config.TRAIN.RANDOM_BG_COLOR_RANGE             = [[225, 255], [225, 255], [225, 255]]
  config.TRAIN.POLICY                            = 'adam'        # available options: sgd, adam
  config.TRAIN.EPOCH_START_USE_REFINER           = 0
  config.TRAIN.EPOCH_START_USE_MERGER            = 0
  config.TRAIN.ENCODER_LEARNING_RATE             = 1e-3
  config.TRAIN.DECODER_LEARNING_RATE             = 1e-3
  config.TRAIN.REFINER_LEARNING_RATE             = 1e-3
  config.TRAIN.MERGER_LEARNING_RATE              = 1e-4
  config.TRAIN.ENCODER_LR_MILESTONES             = [150]
  config.TRAIN.DECODER_LR_MILESTONES             = [150]
  config.TRAIN.REFINER_LR_MILESTONES             = [150]
  config.TRAIN.MERGER_LR_MILESTONES              = [150]
  config.TRAIN.BETAS                             = (.9, .999)
  config.TRAIN.MOMENTUM                          = .9
  config.TRAIN.GAMMA                             = .5
  config.TRAIN.SAVE_FREQ                         = 10            # weights will be overwritten every save_freq epoch
  config.TRAIN.UPDATE_N_VIEWS_RENDERING          = False

  #
  # Testing options
  #
  config.TEST                                    = edict()
  config.TEST.RANDOM_BG_COLOR_RANGE              = [[240, 240], [240, 240], [240, 240]]
  config.TEST.VOXEL_THRESH                       = [.2, .3, .4, .5]
  return config

### Testing

In [None]:
def test_net(cfg,
             epoch_idx=-1,
             output_dir=None,
             test_data_loader=None,
             test_writer=None,
             encoder=None,
             decoder=None,
             refiner=None,
             merger=None):
    # Enable the inbuilt cudnn auto-tuner to find the best algorithm to use
    torch.backends.cudnn.benchmark = True

    # Load taxonomies of dataset
    taxonomies = []
    with open(cfg.DATASETS[cfg.DATASET.TEST_DATASET.upper()].TAXONOMY_FILE_PATH, encoding='utf-8') as file:
        taxonomies = json.loads(file.read())
    taxonomies = {t['taxonomy_id']: t for t in taxonomies}

    # Set up data loader
    if test_data_loader is None:
        # Set up data augmentation
        IMG_SIZE = cfg.CONST.IMG_H, cfg.CONST.IMG_W
        CROP_SIZE = cfg.CONST.CROP_IMG_H, cfg.CONST.CROP_IMG_W
        test_transforms = Compose([
            CenterCrop(IMG_SIZE, CROP_SIZE),
            RandomBackground(cfg.TEST.RANDOM_BG_COLOR_RANGE),
            Normalize(mean=cfg.DATASET.MEAN, std=cfg.DATASET.STD),
            ToTensor(),
        ])

        dataset_loader = ShapeNetDataLoader(cfg)
        test_data_loader = torch.utils.data.DataLoader(dataset=dataset_loader.get_dataset(
                            DatasetType.TEST, cfg.CONST.N_VIEWS_RENDERING, test_transforms),
                                                       batch_size=1,
                                                       num_workers=1,
                                                       pin_memory=True,
                                                       shuffle=False)

    # Set up networks
    if decoder is None or encoder is None:
        encoder = Encoder(cfg)
        decoder = Decoder(cfg)
        refiner = Refiner(cfg)
        merger = Merger(cfg)

        if torch.cuda.is_available():
            encoder = torch.nn.DataParallel(encoder).cuda()
            decoder = torch.nn.DataParallel(decoder).cuda()
            refiner = torch.nn.DataParallel(refiner).cuda()
            merger = torch.nn.DataParallel(merger).cuda()

        print('[INFO] %s Loading weights from %s ...' % (dt.now(), cfg.CONST.WEIGHTS))
        checkpoint = torch.load(cfg.CONST.WEIGHTS)
        epoch_idx = checkpoint['epoch_idx']
        encoder.load_state_dict(checkpoint['encoder_state_dict'])
        decoder.load_state_dict(checkpoint['decoder_state_dict'])

        if cfg.NETWORK.USE_REFINER:
            refiner.load_state_dict(checkpoint['refiner_state_dict'])
        if cfg.NETWORK.USE_MERGER:
            merger.load_state_dict(checkpoint['merger_state_dict'])

    # Set up loss functions
    bce_loss = torch.nn.BCELoss()
    mse_loss = torch.nn.MSELoss()

    # Testing loop
    n_samples = len(test_data_loader)
    test_iou = dict()
    encoder_losses = AverageMeter()
    refiner_losses = AverageMeter()

    # Switch models to evaluation mode
    encoder.eval()
    decoder.eval()
    refiner.eval()
    merger.eval()

    for sample_idx, (taxonomy_id, sample_name, rendering_images, ground_truth_volume) in enumerate(test_data_loader):
        taxonomy_id = taxonomy_id[0] if isinstance(taxonomy_id[0], str) else taxonomy_id[0].item()
        sample_name = sample_name[0]

        with torch.no_grad():
            # Get data from data loader
            rendering_images = var_or_cuda(rendering_images)
            ground_truth_volume = var_or_cuda(ground_truth_volume)

            # Test the encoder, decoder, refiner and merger
            image_features = encoder(rendering_images)
            raw_features, generated_volume = decoder(image_features)

            if cfg.NETWORK.USE_MERGER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_MERGER:
                generated_volume = merger(raw_features, generated_volume)
            else:
                generated_volume = torch.mean(generated_volume, dim=1)
            encoder_loss = mse_loss(generated_volume, ground_truth_volume) * 10

            if cfg.NETWORK.USE_REFINER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_REFINER:
                generated_volume = refiner(generated_volume)
                refiner_loss = mse_loss(generated_volume, ground_truth_volume) * 10
            else:
                refiner_loss = encoder_loss

            # Append loss and accuracy to average metrics
            encoder_losses.update(encoder_loss.item())
            refiner_losses.update(refiner_loss.item())

            # IoU per sample
            sample_iou = []
            for th in cfg.TEST.VOXEL_THRESH:
                _volume = torch.ge(generated_volume, th).float()
                intersection = torch.sum(_volume.mul(ground_truth_volume)).float()
                union = torch.sum(torch.ge(_volume.add(ground_truth_volume), 1)).float()
                sample_iou.append((intersection / union).item())

            # IoU per taxonomy
            if taxonomy_id not in test_iou:
                test_iou[taxonomy_id] = {'n_samples': 0, 'iou': []}
            test_iou[taxonomy_id]['n_samples'] += 1
            test_iou[taxonomy_id]['iou'].append(sample_iou)

            # Append generated volumes to TensorBoard
            if output_dir and sample_idx < 3:
                img_dir = output_dir % 'images'
                # Volume Visualization
                gv = generated_volume.cpu().numpy()
                rendering_views = get_volume_views(gv, os.path.join(img_dir, 'test'),
                                                                              epoch_idx)
                test_writer.add_image('Test Sample#%02d/Volume Reconstructed' % sample_idx, np.transpose(rendering_views,(2,0,1)), epoch_idx)
                gtv = ground_truth_volume.cpu().numpy()
                rendering_views = get_volume_views(gtv, os.path.join(img_dir, 'test'),
                                                                              epoch_idx)
                test_writer.add_image('Test Sample#%02d/Volume GroundTruth' % sample_idx, np.transpose(rendering_views,(2,0,1)), epoch_idx)

            # Print sample loss and IoU
            print('[INFO] %s Test[%d/%d] Taxonomy = %s Sample = %s EDLoss = %.4f RLoss = %.4f IoU = %s' %
                  (dt.now(), sample_idx + 1, n_samples, taxonomy_id, sample_name, encoder_loss.item(),
                   refiner_loss.item(), ['%.4f' % si for si in sample_iou]))

    # Output testing results
    mean_iou = []
    for taxonomy_id in test_iou:
        test_iou[taxonomy_id]['iou'] = np.mean(test_iou[taxonomy_id]['iou'], axis=0)
        mean_iou.append(test_iou[taxonomy_id]['iou'] * test_iou[taxonomy_id]['n_samples'])
    mean_iou = np.sum(mean_iou, axis=0) / n_samples

    # Print header
    print('============================ TEST RESULTS ============================')
    print('Taxonomy', end='\t')
    print('#Sample', end='\t')
    print('Baseline', end='\t')
    for th in cfg.TEST.VOXEL_THRESH:
        print('t=%.2f' % th, end='\t')
    print()
    # Print body
    for taxonomy_id in test_iou:
        print('%s' % taxonomies[taxonomy_id]['taxonomy_name'].ljust(8), end='\t')
        print('%d' % test_iou[taxonomy_id]['n_samples'], end='\t')
        if 'baseline' in taxonomies[taxonomy_id]:
            print('%.4f' % taxonomies[taxonomy_id]['baseline']['%d-view' % cfg.CONST.N_VIEWS_RENDERING], end='\t\t')
        else:
            print('N/a', end='\t\t')

        for ti in test_iou[taxonomy_id]['iou']:
            print('%.4f' % ti, end='\t')
        print()
    # Print mean IoU for each threshold
    print('Overall ', end='\t\t\t\t')
    for mi in mean_iou:
        print('%.4f' % mi, end='\t')
    print('\n')

    # Add testing results to TensorBoard
    max_iou = np.max(mean_iou)
    if test_writer is not None:
        test_writer.add_scalar('EncoderDecoder/EpochLoss', encoder_losses.avg, epoch_idx)
        test_writer.add_scalar('Refiner/EpochLoss', refiner_losses.avg, epoch_idx)
        test_writer.add_scalar('Refiner/IoU', max_iou, epoch_idx)

    return max_iou

### Training

##### Connect to drive

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


##### Intializing New Configurations

In [None]:
cfg = globalConfig()
cfg['DATASETS']["SHAPENET"]["TAXONOMY_FILE_PATH"] = "./ShapeNet.json"
cfg['DATASETS']["SHAPENET"]["RENDERING_PATH"] = "./kaggle/tmp/ShapeNetRendering/%s/%s/rendering/%02d.png"
cfg['DATASETS']["SHAPENET"]["VOXEL_PATH"] = "./kaggle/tmp/ShapeNetVox32/%s/%s/model.binvox"
cfg['CONST']["IMG_W"] = 224
cfg['CONST']["IMG_H"] = 224
cfg['CONST']["BATCH_SIZE"] = 64
cfg['CONST']["N_VIEWS_RENDERING"] = 1
cfg["TRAIN"]["RESUME_TRAIN"] = True
cfg["TRAIN"]['NUM_EPOCHES']=  31
cfg["TRAIN"]['POLICY'] = 'adam'
cfg["TRAIN"]['ENCODER_LEARNING_RATE'] = 0.001
cfg["TRAIN"]['DECODER_LEARNING_RATE'] = 0.001
cfg["TRAIN"]['REFINER_LEARNING_RATE'] = 0.001
cfg["TRAIN"]['MERGER_LEARNING_RATE'] = 0.0001
cfg["TRAIN"]['SAVE_FREQ'] = 10
cfg["TEST"] ['VOXEL_THRESH'] = [0.2, 0.3, 0.4, 0.5]
cfg["EXPT_NAME"] = "exp0"
cfg["DIR"]["OUT_PATH"] = "./gdrive/MyDrive/Vision/Code/" + cfg["EXPT_NAME"] + "/"

##### Loading Old Configuration

In [None]:
exptname = "exp0"
path = "/content/gdrive/MyDrive/Vision/Code/" + exptname + "/sample.json"
with open(path,"r") as openfile:
  cfg = json.load(openfile)

cfg["TRAIN"]["RESUME_TRAIN"] = True
cfg["TRAIN"]['NUM_EPOCHES']=  31
cfg["EXPT_NAME"] = "exp0"
# cfg["DIR"]["OUT_PATH"] =  + cfg["EXPT_NAME"]

TypeError: ignored

##### Defining transforms

In [None]:
train_transforms = Compose([
        RandomCrop((cfg.CONST.IMG_H,cfg.CONST.IMG_W), (cfg.CONST.CROP_IMG_H, cfg.CONST.CROP_IMG_W)),
        RandomBackground(cfg.TRAIN.RANDOM_BG_COLOR_RANGE),
        ColorJitter(cfg.TRAIN.BRIGHTNESS, cfg.TRAIN.CONTRAST, cfg.TRAIN.SATURATION),
        RandomNoise(cfg.TRAIN.NOISE_STD),
        Normalize(mean=cfg.DATASET.MEAN, std=cfg.DATASET.STD),
        RandomFlip(),
        RandomPermuteRGB(),
        ToTensor(),
    ])

val_transforms = Compose([
        CenterCrop((cfg.CONST.IMG_H,cfg.CONST.IMG_W), (cfg.CONST.CROP_IMG_H, cfg.CONST.CROP_IMG_W)),
        RandomBackground(cfg.TEST.RANDOM_BG_COLOR_RANGE),
        Normalize(mean=cfg.DATASET.MEAN, std=cfg.DATASET.STD),
        ToTensor(),
    ])

##### Dataloader

In [None]:
train_dataset_loader = ShapeNetDataLoader(cfg)
val_dataset_loader = ShapeNetDataLoader(cfg)
train_data_loader = torch.utils.data.DataLoader(dataset=train_dataset_loader.get_dataset(
                    DatasetType.TRAIN, cfg.CONST.N_VIEWS_RENDERING, train_transforms),
                                                    batch_size=cfg.CONST.BATCH_SIZE,
                                                    num_workers=cfg.TRAIN.NUM_WORKER,
                                                    pin_memory=True,
                                                    shuffle=True,
                                                    drop_last=True)
val_data_loader = torch.utils.data.DataLoader(dataset=val_dataset_loader.get_dataset(
                  DatasetType.VAL, cfg.CONST.N_VIEWS_RENDERING, val_transforms),
                                                  batch_size=1,
                                                  num_workers=1,
                                                  pin_memory=True,
                                                  shuffle=False)

[INFO] 2022-12-12 02:27:35.697315 Collecting files of Taxonomy[ID=02691156, Name=aeroplane]
[INFO] 2022-12-12 02:27:37.071673 Collecting files of Taxonomy[ID=02828884, Name=bench]
[INFO] 2022-12-12 02:27:37.669600 Collecting files of Taxonomy[ID=02933112, Name=cabinet]
[INFO] 2022-12-12 02:27:38.068842 Collecting files of Taxonomy[ID=02958343, Name=car]
[WARN] 2022-12-12 02:27:38.351183 Ignore sample 02958343/f9c1d7748c15499c6f2bd1c4e9adb41 since volume file not exists.
[INFO] 2022-12-12 02:27:39.757437 Collecting files of Taxonomy[ID=03001627, Name=chair]
[INFO] 2022-12-12 02:27:41.458416 Collecting files of Taxonomy[ID=03211117, Name=display]
[INFO] 2022-12-12 02:27:41.708627 Collecting files of Taxonomy[ID=03636649, Name=lamp]
[INFO] 2022-12-12 02:27:42.242991 Collecting files of Taxonomy[ID=03691459, Name=speaker]
[INFO] 2022-12-12 02:27:42.634842 Collecting files of Taxonomy[ID=04090263, Name=rifle]
[INFO] 2022-12-12 02:27:43.561167 Collecting files of Taxonomy[ID=04256520, Name=s



[INFO] 2022-12-12 02:27:47.319659 Collecting files of Taxonomy[ID=02933112, Name=cabinet]
[INFO] 2022-12-12 02:27:47.373112 Collecting files of Taxonomy[ID=02958343, Name=car]
[INFO] 2022-12-12 02:27:47.673904 Collecting files of Taxonomy[ID=03001627, Name=chair]
[INFO] 2022-12-12 02:27:47.915157 Collecting files of Taxonomy[ID=03211117, Name=display]
[INFO] 2022-12-12 02:27:47.955390 Collecting files of Taxonomy[ID=03636649, Name=lamp]
[INFO] 2022-12-12 02:27:48.035441 Collecting files of Taxonomy[ID=03691459, Name=speaker]
[INFO] 2022-12-12 02:27:48.089063 Collecting files of Taxonomy[ID=04090263, Name=rifle]
[INFO] 2022-12-12 02:27:48.179137 Collecting files of Taxonomy[ID=04256520, Name=sofa]
[INFO] 2022-12-12 02:27:48.291649 Collecting files of Taxonomy[ID=04379243, Name=table]
[INFO] 2022-12-12 02:27:48.603713 Collecting files of Taxonomy[ID=04401088, Name=telephone]
[INFO] 2022-12-12 02:27:48.640972 Collecting files of Taxonomy[ID=04530566, Name=watercraft]
[INFO] 2022-12-12 02:

##### Initializing Models

In [None]:
encoder = Encoder(cfg)
decoder = Decoder(cfg)
refiner = Refiner(cfg)
merger = Merger(cfg)
print('[DEBUG] %s Parameters in Encoder: %d.' % (dt.now(), count_parameters(encoder)))
print('[DEBUG] %s Parameters in Decoder: %d.' % (dt.now(), count_parameters(decoder)))
print('[DEBUG] %s Parameters in Refiner: %d.' % (dt.now(), count_parameters(refiner)))
print('[DEBUG] %s Parameters in Merger: %d.' % (dt.now(), count_parameters(merger)))

# Initialize weights of networks
encoder.apply(init_weights)
decoder.apply(init_weights)
refiner.apply(init_weights)
merger.apply(init_weights)


if torch.cuda.is_available():
  encoder = torch.nn.DataParallel(encoder).cuda()
  decoder = torch.nn.DataParallel(decoder).cuda()
  refiner = torch.nn.DataParallel(refiner).cuda()
  merger = torch.nn.DataParallel(merger).cuda()




[DEBUG] 2022-12-12 02:27:51.292713 Parameters in Encoder: 7772480.
[DEBUG] 2022-12-12 02:27:51.293096 Parameters in Decoder: 71583064.
[DEBUG] 2022-12-12 02:27:51.293275 Parameters in Refiner: 34880352.
[DEBUG] 2022-12-12 02:27:51.293483 Parameters in Merger: 8571.


##### Optimizers

In [None]:
# Set up solver
if cfg.TRAIN.POLICY == 'adam':
    encoder_solver = torch.optim.Adam(filter(lambda p: p.requires_grad, encoder.parameters()),
                                      lr=cfg.TRAIN.ENCODER_LEARNING_RATE,
                                      betas=cfg.TRAIN.BETAS)
    decoder_solver = torch.optim.Adam(decoder.parameters(),
                                      lr=cfg.TRAIN.DECODER_LEARNING_RATE,
                                      betas=cfg.TRAIN.BETAS)
    refiner_solver = torch.optim.Adam(refiner.parameters(),
                                      lr=cfg.TRAIN.REFINER_LEARNING_RATE,
                                      betas=cfg.TRAIN.BETAS)
    merger_solver = torch.optim.Adam(merger.parameters(), lr=cfg.TRAIN.MERGER_LEARNING_RATE, betas=cfg.TRAIN.BETAS)
elif cfg.TRAIN.POLICY == 'sgd':
        encoder_solver = torch.optim.SGD(filter(lambda p: p.requires_grad, encoder.parameters()),
                                         lr=cfg.TRAIN.ENCODER_LEARNING_RATE,
                                         momentum=cfg.TRAIN.MOMENTUM)
        decoder_solver = torch.optim.SGD(decoder.parameters(),
                                         lr=cfg.TRAIN.DECODER_LEARNING_RATE,
                                         momentum=cfg.TRAIN.MOMENTUM)
        refiner_solver = torch.optim.SGD(refiner.parameters(),
                                         lr=cfg.TRAIN.REFINER_LEARNING_RATE,
                                         momentum=cfg.TRAIN.MOMENTUM)
        merger_solver = torch.optim.SGD(merger.parameters(),
                                        lr=cfg.TRAIN.MERGER_LEARNING_RATE,
                                        momentum=cfg.TRAIN.MOMENTUM)

##### Scheduler


In [None]:
# Set up learning rate scheduler to decay learning rates dynamically
encoder_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(encoder_solver,
                                                            milestones=cfg.TRAIN.ENCODER_LR_MILESTONES,
                                                            gamma=cfg.TRAIN.GAMMA)
decoder_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(decoder_solver,
                                                            milestones=cfg.TRAIN.DECODER_LR_MILESTONES,
                                                            gamma=cfg.TRAIN.GAMMA)
refiner_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(refiner_solver,
                                                            milestones=cfg.TRAIN.REFINER_LR_MILESTONES,
                                                            gamma=cfg.TRAIN.GAMMA)
merger_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(merger_solver,
                                                            milestones=cfg.TRAIN.MERGER_LR_MILESTONES,
                                                            gamma=cfg.TRAIN.GAMMA)

##### Loss function


In [None]:
# Set up loss functions
bce_loss = torch.nn.BCELoss()
mse_loss = torch.nn.MSELoss()

##### Load pretrained if exists

In [None]:
# Load pretrained model if exists
init_epoch = 0
best_iou = -1
best_epoch = -1
if 'WEIGHTS' in cfg.CONST and cfg.TRAIN.RESUME_TRAIN:
    print('[INFO] %s Recovering from %s ...' % (dt.now(), cfg.CONST.WEIGHTS))
    checkpoint = torch.load(cfg.CONST.WEIGHTS)
    init_epoch = checkpoint['epoch_idx']
    best_iou = checkpoint['best_iou']
    best_epoch = checkpoint['best_epoch']

    encoder.load_state_dict(checkpoint['encoder_state_dict'])
    decoder.load_state_dict(checkpoint['decoder_state_dict'])
    if cfg.NETWORK.USE_REFINER:
        refiner.load_state_dict(checkpoint['refiner_state_dict'])
    if cfg.NETWORK.USE_MERGER:
        merger.load_state_dict(checkpoint['merger_state_dict'])

    print('[INFO] %s Recover complete. Current epoch #%d, Best IoU = %.4f at epoch #%d.' %
          (dt.now(), init_epoch, best_iou, best_epoch))

##### Summary writer

In [None]:
output_dir = os.path.join(cfg.DIR.OUT_PATH, '%s', dt.now().isoformat())
log_dir = output_dir % 'logs'
ckpt_dir = output_dir % 'checkpoints'
train_writer = SummaryWriter(os.path.join(log_dir, 'train'))
val_writer = SummaryWriter(os.path.join(log_dir, 'test'))
with open(cfg.DIR.OUT_PATH + "sample.json","w") as outfile:
  json.dump(cfg,outfile)
  
!mkdir -p $cfg.DIR.OUT_PATH
!gdown --id 1P_rcNSYNBYwZp-lbS2AiMDMv1ZCBFExK -O $cfg.DIR.OUT_PATH

Downloading...
From: https://drive.google.com/uc?id=1P_rcNSYNBYwZp-lbS2AiMDMv1ZCBFExK
To: /content/gdrive/MyDrive/Vision/Code/exp0/PTN
100% 596k/596k [00:00<00:00, 112MB/s]


##### Training Loop

In [None]:
for epoch_idx in range(init_epoch, cfg.TRAIN.NUM_EPOCHES):
    # Tick / tock
    epoch_start_time = time()

    # Batch average meterics
    batch_time = AverageMeter()
    data_time = AverageMeter()
    encoder_losses = AverageMeter()
    refiner_losses = AverageMeter()

    # switch models to training mode
    encoder.train()
    decoder.train()
    merger.train()
    refiner.train()

    batch_end_time = time()
    n_batches = len(train_data_loader)
    for batch_idx, (taxonomy_names, sample_names, rendering_images,
                    ground_truth_volumes) in enumerate(train_data_loader):
        # Measure data time
        data_time.update(time() - batch_end_time)

        # Get data from data loader
        rendering_images = var_or_cuda(rendering_images)
        ground_truth_volumes = var_or_cuda(ground_truth_volumes)

        # Train the encoder, decoder, refiner, and merger
        image_features = encoder(rendering_images)
        raw_features, generated_volumes = decoder(image_features)

        if cfg.NETWORK.USE_MERGER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_MERGER:
            generated_volumes = merger(raw_features, generated_volumes)
        else:
            generated_volumes = torch.mean(generated_volumes, dim=1)
        encoder_loss = mse_loss(generated_volumes, ground_truth_volumes) * 10

        if cfg.NETWORK.USE_REFINER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_REFINER:
            generated_volumes = refiner(generated_volumes)
            refiner_loss = mse_loss(generated_volumes, ground_truth_volumes) * 10
        else:
            refiner_loss = encoder_loss

        # Gradient decent
        encoder.zero_grad()
        decoder.zero_grad()
        refiner.zero_grad()
        merger.zero_grad()

        if cfg.NETWORK.USE_REFINER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_REFINER:
            encoder_loss.backward(retain_graph=True)
            refiner_loss.backward()
        else:
            encoder_loss.backward()

        encoder_solver.step()
        decoder_solver.step()
        refiner_solver.step()
        merger_solver.step()

        # Append loss to average metrics
        encoder_losses.update(encoder_loss.item())
        refiner_losses.update(refiner_loss.item())
        # Append loss to TensorBoard
        n_itr = epoch_idx * n_batches + batch_idx
        train_writer.add_scalar('EncoderDecoder/BatchLoss', encoder_loss.item(), n_itr)
        train_writer.add_scalar('Refiner/BatchLoss', refiner_loss.item(), n_itr)

        # Tick / tock
        batch_time.update(time() - batch_end_time)
        batch_end_time = time()
        print(
            '[INFO] %s [Epoch %d/%d][Batch %d/%d] BatchTime = %.3f (s) DataTime = %.3f (s) EDLoss = %.4f RLoss = %.4f'
            % (dt.now(), epoch_idx + 1, cfg.TRAIN.NUM_EPOCHES, batch_idx + 1, n_batches, batch_time.val,
                data_time.val, encoder_loss.item(), refiner_loss.item()))

    # Append epoch loss to TensorBoard
    train_writer.add_scalar('EncoderDecoder/EpochLoss', encoder_losses.avg, epoch_idx + 1)
    train_writer.add_scalar('Refiner/EpochLoss', refiner_losses.avg, epoch_idx + 1)

    # Adjust learning rate
    encoder_lr_scheduler.step()
    decoder_lr_scheduler.step()
    refiner_lr_scheduler.step()
    merger_lr_scheduler.step()

    # Tick / tock
    epoch_end_time = time()
    print('[INFO] %s Epoch [%d/%d] EpochTime = %.3f (s) EDLoss = %.4f RLoss = %.4f' %
          (dt.now(), epoch_idx + 1, cfg.TRAIN.NUM_EPOCHES, epoch_end_time - epoch_start_time, encoder_losses.avg,
            refiner_losses.avg))

    # Update Rendering Views
    if cfg.TRAIN.UPDATE_N_VIEWS_RENDERING:
        n_views_rendering = random.randint(1, cfg.CONST.N_VIEWS_RENDERING)
        train_data_loader.dataset.set_n_views_rendering(n_views_rendering)
        print('[INFO] %s Epoch [%d/%d] Update #RenderingViews to %d' %
              (dt.now(), epoch_idx + 2, cfg.TRAIN.NUM_EPOCHES, n_views_rendering))

    # Validate the training models
    iou = test_net(cfg, epoch_idx + 1, output_dir, val_data_loader, val_writer, encoder, decoder, refiner, merger)

    # Save weights to file
    if (epoch_idx + 1) % cfg.TRAIN.SAVE_FREQ == 0:
        if not os.path.exists(ckpt_dir):
            os.makedirs(ckpt_dir)

        save_checkpoints(cfg, os.path.join(ckpt_dir, 'ckpt-epoch-%04d.pth' % (epoch_idx + 1)),
                                              epoch_idx + 1, encoder, encoder_solver, decoder, decoder_solver,
                                              refiner, refiner_solver, merger, merger_solver, best_iou, best_epoch)
    if iou > best_iou:
        if not os.path.exists(ckpt_dir):
            os.makedirs(ckpt_dir)

        best_iou = iou
        best_epoch = epoch_idx + 1
        save_checkpoints(cfg, os.path.join(ckpt_dir, 'best-ckpt.pth'), epoch_idx + 1, encoder,
                                              encoder_solver, decoder, decoder_solver, refiner, refiner_solver,
                                              merger, merger_solver, best_iou, best_epoch)

# Close SummaryWriter for TensorBoard
train_writer.close()
val_writer.close()

[INFO] 2022-12-12 02:28:17.149140 [Epoch 1/31][Batch 1/478] BatchTime = 16.903 (s) DataTime = 4.508 (s) EDLoss = 4.7985 RLoss = 3.4295
[INFO] 2022-12-12 02:28:18.945793 [Epoch 1/31][Batch 2/478] BatchTime = 1.797 (s) DataTime = 0.003 (s) EDLoss = 2.9865 RLoss = 2.4391
[INFO] 2022-12-12 02:28:20.744182 [Epoch 1/31][Batch 3/478] BatchTime = 1.798 (s) DataTime = 0.002 (s) EDLoss = 2.6992 RLoss = 2.1082
[INFO] 2022-12-12 02:28:22.544312 [Epoch 1/31][Batch 4/478] BatchTime = 1.800 (s) DataTime = 0.002 (s) EDLoss = 2.4298 RLoss = 1.8748
[INFO] 2022-12-12 02:28:24.332349 [Epoch 1/31][Batch 5/478] BatchTime = 1.788 (s) DataTime = 0.002 (s) EDLoss = 2.4179 RLoss = 1.7704
[INFO] 2022-12-12 02:28:26.124426 [Epoch 1/31][Batch 6/478] BatchTime = 1.792 (s) DataTime = 0.002 (s) EDLoss = 2.3024 RLoss = 1.6189
[INFO] 2022-12-12 02:28:27.927495 [Epoch 1/31][Batch 7/478] BatchTime = 1.803 (s) DataTime = 0.002 (s) EDLoss = 2.1352 RLoss = 1.5101
[INFO] 2022-12-12 02:28:29.751749 [Epoch 1/31][Batch 8/478] B

KeyboardInterrupt: ignored