In [1]:
import os
os.makedirs("/kaggle/working/unsup3d_thi", exist_ok=True)
%cd /kaggle/working/unsup3d_thi
!git clone http://github.com/21120558/unsup3d.git

/kaggle/working/unsup3d_thi
Cloning into 'unsup3d'...
remote: Enumerating objects: 186008, done.[K
remote: Counting objects: 100% (823/823), done.[K
remote: Compressing objects: 100% (355/355), done.[K
remote: Total 186008 (delta 475), reused 773 (delta 432), pack-reused 185185 (from 1)[K
Receiving objects: 100% (186008/186008), 1.29 GiB | 24.60 MiB/s, done.
Resolving deltas: 100% (477/477), done.
Updating files: 100% (185431/185431), done.


In [2]:
%cd /kaggle/working
!pip install "git+https://github.com/facebookresearch/pytorch3d.git"

/kaggle/working
Collecting git+https://github.com/facebookresearch/pytorch3d.git
  Cloning https://github.com/facebookresearch/pytorch3d.git to /tmp/pip-req-build-p6z3mu3_
  Running command git clone --filter=blob:none --quiet https://github.com/facebookresearch/pytorch3d.git /tmp/pip-req-build-p6z3mu3_
  Resolved https://github.com/facebookresearch/pytorch3d.git to commit 58566963d620cbe067ec53eae62ca262aecfbe27
  Preparing metadata (setup.py) ... [?25ldone
[?25hCollecting iopath (from pytorch3d==0.7.8)
  Downloading iopath-0.1.10.tar.gz (42 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 kB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25ldone
Collecting portalocker (from iopath->pytorch3d==0.7.8)
  Downloading portalocker-2.10.1-py3-none-any.whl.metadata (8.5 kB)
Downloading portalocker-2.10.1-py3-none-any.whl (18 kB)
Building wheels for collected packages: pytorch3d, iopath
  Building wheel for pytorch3d (set

In [3]:
%cd /kaggle/working/unsup3d_thi/unsup3d/unsup3d

/kaggle/working/unsup3d_thi/unsup3d/unsup3d


In [4]:
%%writefile trainer.py
import os
import glob
from datetime import datetime
import numpy as np
import torch
from . import meters
from . import utils
from .dataloaders import get_data_loaders


class Trainer():
    def __init__(self, cfgs, model):
        self.device = cfgs.get('device', 'cpu')
        self.num_epochs = cfgs.get('num_epochs', 30)
        self.batch_size = cfgs.get('batch_size', 64)
        self.checkpoint_dir = cfgs.get('checkpoint_dir', 'results')
        self.save_checkpoint_freq = cfgs.get('save_checkpoint_freq', 1)
        self.keep_num_checkpoint = cfgs.get('keep_num_checkpoint', 2)  # -1 for keeping all checkpoints
        self.resume = cfgs.get('resume', True)
        self.use_logger = cfgs.get('use_logger', True)
        self.log_freq = cfgs.get('log_freq', 1000)
        self.archive_code = cfgs.get('archive_code', True)
        self.checkpoint_name = cfgs.get('checkpoint_name', None)
        self.test_result_dir = cfgs.get('test_result_dir', None)
        self.cfgs = cfgs

        self.metrics_trace = meters.MetricsTrace()
        self.make_metrics = lambda m=None: meters.StandardMetrics(m)
        self.model = model(cfgs)
        self.model.trainer = self
        self.train_loader, self.val_loader, self.test_loader = get_data_loaders(cfgs)

    def load_checkpoint(self, optim=True):
        """Search the specified/latest checkpoint in checkpoint_dir and load the model and optimizer."""
        if self.checkpoint_name is not None:
            checkpoint_path = os.path.join(self.checkpoint_dir, self.checkpoint_name)
        else:
            checkpoints = sorted(glob.glob(os.path.join(self.checkpoint_dir, '*.pth')))
            if len(checkpoints) == 0:
                return 0
            checkpoint_path = checkpoints[-1]
            self.checkpoint_name = os.path.basename(checkpoint_path)
        print(f"Loading checkpoint from {checkpoint_path}")
        cp = torch.load(checkpoint_path, map_location=self.device)
        self.model.load_model_state(cp)
        if optim:
            self.model.load_optimizer_state(cp)
        self.metrics_trace = cp['metrics_trace']
        epoch = cp['epoch']
        return epoch

    def save_checkpoint(self, epoch, optim=True):
        """Save model, optimizer, and metrics state to a checkpoint in checkpoint_dir for the specified epoch."""
        utils.xmkdir(self.checkpoint_dir)
        checkpoint_path = os.path.join(self.checkpoint_dir, f'checkpoint{epoch:03}.pth')
        state_dict = self.model.get_model_state()
        if optim:
            optimizer_state = self.model.get_optimizer_state()
            state_dict = {**state_dict, **optimizer_state}
        state_dict['metrics_trace'] = self.metrics_trace
        state_dict['epoch'] = epoch
        print(f"Saving checkpoint to {checkpoint_path}")
        torch.save(state_dict, checkpoint_path)
        if self.keep_num_checkpoint > 0:
            utils.clean_checkpoint(self.checkpoint_dir, keep_num=self.keep_num_checkpoint)

    def save_clean_checkpoint(self, path):
        """Save model state only to specified path."""
        torch.save(self.model.get_model_state(), path)

    def test(self):
        """Perform testing."""
        self.model.to_device(self.device)
        self.current_epoch = self.load_checkpoint(optim=False)
        if self.test_result_dir is None:
            self.test_result_dir = os.path.join(self.checkpoint_dir, f'test_results_{self.checkpoint_name}'.replace('.pth',''))
        print(f"Saving testing results to {self.test_result_dir}")

        with torch.no_grad():
            m = self.run_epoch(self.test_loader, epoch=self.current_epoch, is_test=True)

        score_path = os.path.join(self.test_result_dir, 'eval_scores.txt')
        self.model.save_scores(score_path)

    def train(self):
        """Perform training."""
        ## archive code and configs
        if self.archive_code:
            utils.archive_code(os.path.join(self.checkpoint_dir, 'archived_code.zip'), filetypes=['.py', '.yml'])
        utils.dump_yaml(os.path.join(self.checkpoint_dir, 'configs.yml'), self.cfgs)

        ## initialize
        start_epoch = 0
        self.metrics_trace.reset()
        self.train_iter_per_epoch = len(self.train_loader)
        self.model.to_device(self.device)
        self.model.init_optimizers()

        ## resume from checkpoint
        if self.resume:
            start_epoch = self.load_checkpoint(optim=True)

        ## initialize tensorboardX logger
        if self.use_logger:
            from tensorboardX import SummaryWriter
            self.logger = SummaryWriter(os.path.join(self.checkpoint_dir, 'logs', datetime.now().strftime("%Y%m%d-%H%M%S")))

            ## cache one batch for visualization
            self.viz_input = self.val_loader.__iter__().__next__()

        ## run epochs
        print(f"{self.model.model_name}: optimizing to {self.num_epochs} epochs")
        for epoch in range(start_epoch, self.num_epochs):
            self.current_epoch = epoch
            metrics = self.run_epoch(self.train_loader, epoch)
            self.metrics_trace.append("train", metrics)

            with torch.no_grad():
                metrics = self.run_epoch(self.val_loader, epoch, is_validation=True)
                self.metrics_trace.append("val", metrics)

            if (epoch+1) % self.save_checkpoint_freq == 0:
                self.save_checkpoint(epoch+1, optim=True)
            self.metrics_trace.plot(pdf_path=os.path.join(self.checkpoint_dir, 'metrics.pdf'))
            self.metrics_trace.save(os.path.join(self.checkpoint_dir, 'metrics.json'))

        print(f"Training completed after {epoch+1} epochs.")

    def run_epoch(self, loader, epoch=0, is_validation=False, is_test=False):
        """Run one epoch."""
        is_train = not is_validation and not is_test
        metrics = self.make_metrics()

        if is_train:
            print(f"Starting training epoch {epoch}")
            self.model.set_train()
        else:
            print(f"Starting validation epoch {epoch}")
            self.model.set_eval()

        for iter, input in enumerate(loader):
            #if iter == 200:
            #    break
            m = self.model.forward(input)
            if is_train:
                self.model.backward()
            elif is_test:
                self.model.save_results(self.test_result_dir)

            metrics.update(m, self.batch_size)
            print(f"{'T' if is_train else 'V'}{epoch:02}/{iter:05}/{metrics}")

            if self.use_logger and is_train:
                total_iter = iter + epoch*self.train_iter_per_epoch
                if total_iter % self.log_freq == 0:
                    self.model.forward(self.viz_input)
                    self.model.visualize(self.logger, total_iter=total_iter, max_bs=25)
        return metrics


Overwriting trainer.py


In [5]:
%cd /kaggle/working/unsup3d_thi/unsup3d/unsup3d/renderer

/kaggle/working/unsup3d_thi/unsup3d/unsup3d/renderer


In [7]:
%%writefile renderer.py
import torch
import math
from .utils import *

from pytorch3d.utils import ico_sphere
from pytorch3d.io import load_obj
from pytorch3d.structures import Meshes
from pytorch3d.ops import sample_points_from_meshes
from pytorch3d.loss import chamfer_distance
import matplotlib.pyplot as plt
# from mpl_toolkits.mplot3d import Axes3D
from pytorch3d.renderer import (
    look_at_view_transform,
    FoVPerspectiveCameras,
    PointLights,
    DirectionalLights,
    Materials,
    RasterizationSettings,
    MeshRenderer,
    MeshRasterizer,
    SoftPhongShader,
    TexturesUV,
    TexturesVertex,
    DirectionalLights,
)

EPS = 1e-7


class Renderer():
    def __init__(self, cfgs):
        self.device = cfgs.get('device', 'cpu')
        self.image_size = cfgs.get('image_size', 64)
        self.min_depth = cfgs.get('min_depth', 0.9)
        self.max_depth = cfgs.get('max_depth', 1.1)
        self.rot_center_depth = cfgs.get('rot_center_depth', (self.min_depth + self.max_depth) / 2)
        self.fov = cfgs.get('fov', 10)
        self.tex_cube_size = cfgs.get('tex_cube_size', 2)
        self.renderer_min_depth = cfgs.get('renderer_min_depth', 0.1)
        self.renderer_max_depth = cfgs.get('renderer_max_depth', 10.)

        #### camera intrinsics
        #             (u)   (x)
        #    d * K^-1 (v) = (y)
        #             (1)   (z)
        R, t = look_at_view_transform(2, 0, 60)
        ## renderer for visualization
        #R = [[[1., 0., 0.],
        #      [0., 1., 0.],
        #     [0., 0., 1.]]]
        #R = torch.FloatTensor(R).to(self.device)
        #t = torch.zeros(1, 3, dtype=torch.float32).to(self.device)
        fx = (self.image_size - 1) / 2 / (math.tan(self.fov / 2 * math.pi / 180))
        fy = (self.image_size - 1) / 2 / (math.tan(self.fov / 2 * math.pi / 180))
        cx = (self.image_size - 1) / 2
        cy = (self.image_size - 1) / 2
        K = [[fx, 0., cx],
             [0., fy, cy],
             [0., 0., 1.]]
        K = torch.FloatTensor(K).to(self.device)
        self.inv_K = torch.inverse(K).unsqueeze(0)
        self.K = K.unsqueeze(0)

        ambient_color = (1.0, 1.0, 1.0)  
        diffuse_color = (0.0, 0.0, 0.0)  
        specular_color = (0.0, 0.0, 0.0) 

        directional_lights = DirectionalLights(
            ambient_color=(ambient_color,),
            diffuse_color=(diffuse_color,),
            specular_color=(specular_color,),
            direction=((0, 1, 0),),  
            device=self.device
        )

        cameras = FoVPerspectiveCameras(
            device=self.device,
            R=R, 
            T=t,  
            fov = self.fov,
            zfar=self.renderer_max_depth, 
            znear=self.renderer_min_depth
        )

        self.rasterizer = MeshRasterizer(
            cameras=cameras,
            raster_settings=RasterizationSettings(
                image_size=self.image_size,
                blur_radius=0.0,
                faces_per_pixel=1,
            )
        )

        self.renderer = MeshRenderer(
            rasterizer=self.rasterizer,
            shader=SoftPhongShader(
                device=self.device,
                cameras=cameras,
                lights=directional_lights
            )
        )

    def set_transform_matrices(self, view):
        self.rot_mat, self.trans_xyz = get_transform_matrices(view)

    def rotate_pts(self, pts, rot_mat):
        centroid = torch.FloatTensor([0.,0.,self.rot_center_depth]).to(pts.device).view(1,1,3)
        pts = pts - centroid  # move to centroid
        pts = pts.matmul(rot_mat.transpose(2,1))  # rotate
        pts = pts + centroid  # move back
        return pts

    def translate_pts(self, pts, trans_xyz):
        return pts + trans_xyz

    def depth_to_3d_grid(self, depth):
        b, h, w = depth.shape
        grid_2d = get_grid(b, h, w, normalize=False).to(depth.device)  # N x h x w x 2
        depth = depth.unsqueeze(-1)
        grid_3d = torch.cat((grid_2d, torch.ones_like(depth)), dim=3)
        grid_3d = grid_3d.matmul(self.inv_K.to(depth.device).transpose(2,1)) * depth
        return grid_3d

    def grid_3d_to_2d(self, grid_3d):
        b, h, w, _ = grid_3d.shape
        grid_2d = grid_3d / grid_3d[...,2:]
        grid_2d = grid_2d.matmul(self.K.to(grid_3d.device).transpose(2,1))[:,:,:,:2]
        WH = torch.FloatTensor([w-1, h-1]).to(grid_3d.device).view(1,1,1,2)
        grid_2d = grid_2d / WH *2.-1.  # normalize to -1~1
        return grid_2d

    def get_warped_3d_grid(self, depth):
        b, h, w = depth.shape
        grid_3d = self.depth_to_3d_grid(depth).reshape(b, -1, 3)
        grid_3d = self.rotate_pts(grid_3d, self.rot_mat)
        grid_3d = self.translate_pts(grid_3d, self.trans_xyz)
        return grid_3d.reshape(b, h, w, 3) # return 3d vertices

    def get_inv_warped_3d_grid(self, depth):
        b, h, w = depth.shape
        grid_3d = self.depth_to_3d_grid(depth).reshape(b,-1,3)
        grid_3d = self.translate_pts(grid_3d, -self.trans_xyz)
        grid_3d = self.rotate_pts(grid_3d, self.rot_mat.transpose(2,1))
        return grid_3d.reshape(b,h,w,3) # return 3d vertices

    def get_warped_2d_grid(self, depth):
        b, h, w = depth.shape
        grid_3d = self.get_warped_3d_grid(depth)
        grid_2d = self.grid_3d_to_2d(grid_3d)
        return grid_2d

    def get_inv_warped_2d_grid(self, depth):
        b, h, w = depth.shape
        grid_3d = self.get_inv_warped_3d_grid(depth)
        grid_2d = self.grid_3d_to_2d(grid_3d)
        return grid_2d

    def warp_canon_depth(self, canon_depth):
        b, h, w = canon_depth.shape
        grid_3d = self.get_warped_3d_grid(canon_depth).reshape(b, -1, 3) # b x (hxw) x 3
        faces = get_face_idx(b, h, w).to(canon_depth.device)

        meshes = Meshes(verts=grid_3d, faces=faces)
        warped_depth = self.rasterizer(meshes).zbuf.squeeze(3)

        ############################################################
        # warped_depth = self.renderer.render_depth(grid_3d, faces)
        ############################################################

        # allow some margin out of valid range
        margin = (self.max_depth - self.min_depth) / 2
        warped_depth = warped_depth.clamp(min=self.min_depth - margin, max=self.max_depth + margin)
        warped_depth[warped_depth == self.min_depth - margin] = self.max_depth


        return warped_depth.flip(1).flip(2)

    def get_normal_from_depth(self, depth):
        b, h, w = depth.shape
        grid_3d = self.depth_to_3d_grid(depth)

        tu = grid_3d[:,1:-1,2:] - grid_3d[:,1:-1,:-2]
        tv = grid_3d[:,2:,1:-1] - grid_3d[:,:-2,1:-1]
        normal = tu.cross(tv, dim=3)

        zero = torch.FloatTensor([0,0,1]).to(depth.device)
        normal = torch.cat([zero.repeat(b,h-2,1,1), normal, zero.repeat(b,h-2,1,1)], 2)
        normal = torch.cat([zero.repeat(b,1,w,1), normal, zero.repeat(b,1,w,1)], 1)
        normal = normal / (((normal**2).sum(3, keepdim=True))**0.5 + EPS)
        return normal

    def render_yaw(self, im, depth, v_before=None, v_after=None, rotations=None, maxr=90, nsample=9, crop_mesh=None):
        b, c, h, w = im.shape
        grid_3d = self.depth_to_3d_grid(depth)

        if crop_mesh is not None:
            top, bottom, left, right = crop_mesh  # pixels from border to be cropped
            if top > 0:
                grid_3d[:,:top,:,1] = grid_3d[:,top:top+1,:,1].repeat(1,top,1)
                grid_3d[:,:top,:,2] = grid_3d[:,top:top+1,:,2].repeat(1,top,1)
            if bottom > 0:
                grid_3d[:,-bottom:,:,1] = grid_3d[:,-bottom-1:-bottom,:,1].repeat(1,bottom,1)
                grid_3d[:,-bottom:,:,2] = grid_3d[:,-bottom-1:-bottom,:,2].repeat(1,bottom,1)
            if left > 0:
                grid_3d[:,:,:left,0] = grid_3d[:,:,left:left+1,0].repeat(1,1,left)
                grid_3d[:,:,:left,2] = grid_3d[:,:,left:left+1,2].repeat(1,1,left)
            if right > 0:
                grid_3d[:,:,-right:,0] = grid_3d[:,:,-right-1:-right,0].repeat(1,1,right)
                grid_3d[:,:,-right:,2] = grid_3d[:,:,-right-1:-right,2].repeat(1,1,right)

        grid_3d = grid_3d.reshape(b,-1,3)
        im_trans = []

        # inverse warp
        if v_before is not None:
            rot_mat, trans_xyz = get_transform_matrices(v_before)
            grid_3d = self.translate_pts(grid_3d, -trans_xyz)
            grid_3d = self.rotate_pts(grid_3d, rot_mat.transpose(2,1))

        if rotations is None:
            rotations = torch.linspace(-math.pi/180*maxr, math.pi/180*maxr, nsample)
        for i, ri in enumerate(rotations):
            ri = torch.FloatTensor([0, ri, 0]).to(im.device).view(1,3)
            rot_mat_i, _ = get_transform_matrices(ri)
            grid_3d_i = self.rotate_pts(grid_3d, rot_mat_i.repeat(b,1,1))

            if v_after is not None:
                if len(v_after.shape) == 3:
                    v_after_i = v_after[i]
                else:
                    v_after_i = v_after
                rot_mat, trans_xyz = get_transform_matrices(v_after_i)
                grid_3d_i = self.rotate_pts(grid_3d_i, rot_mat)
                grid_3d_i = self.translate_pts(grid_3d_i, trans_xyz)

            faces = get_face_idx(b, h, w).to(im.device)
            textures = im.permute(0, 2, 3, 1).reshape(b, -1, 3)

            meshes = Meshes(verts=grid_3d_i, faces=faces)
            meshes.textures = TexturesVertex(verts_features=textures)

            warped_images = self.renderer(meshes).clamp(min=-1., max=1.)
            warped_images = warped_images[:, :, :, :3].permute(0, 3, 1, 2)

            ###############################################################
            # warped_images = self.renderer.render_rgb(grid_3d_i, faces, textures).clamp(min=-1., max=1.)
            ###############################################################
            im_trans += [warped_images]
        return torch.stack(im_trans, 1)  # b x t x c x h x w


Overwriting renderer.py


In [8]:
%cd /kaggle/working/unsup3d_thi/unsup3d
!python run.py --config experiments/train_celeba.yml --gpu 0 --num_workers 4

/kaggle/working/unsup3d_thi/unsup3d
Loading configs from experiments/train_celeba.yml
Environment: GPU 0 seed 0 number of workers 4
Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth
100%|█████████████████████████████████████████| 528M/528M [00:03<00:00, 174MB/s]
Loading training data from data/celeba_cropped/train
Loading validation data from data/celeba_cropped/val
Archiving code to results/celeba/archived_code.zip
Saving configs to results/celeba/configs.yml
unsup3d_celeba: optimizing to 30 epochs
Starting training epoch 0
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
T00/00000/    4.4Hz	loss: 49.12128
add_video needs package moviepy
add_video needs package moviepy
T00/00001/    5.3Hz	loss: 36.68722
T00/00002/   14.0Hz	loss: 29.77112
T00/00003/   21.9Hz	loss: 25.52618
T00/00004/   28.9Hz	loss: 22.53014
T00/00005/   35.3Hz	loss: 20.33954
T00/00006/   41.1Hz	loss: 18.65780
T00/00007/  

In [10]:
%cd /kaggle/working/unsup3d_thi/unsup3d
!python -m demo.demo --input /kaggle/working/unsup3d_thi/unsup3d/demo/images/human_face --result /kaggle/working/unsup3d_thi/unsup3d/demo/results/human_face --checkpoint /kaggle/working/unsup3d_thi/unsup3d/results/celeba/checkpoint002.pth

/kaggle/working/unsup3d_thi/unsup3d
Loading checkpoint from /kaggle/working/unsup3d_thi/unsup3d/results/celeba/checkpoint002.pth
Processing /kaggle/working/unsup3d_thi/unsup3d/demo/images/human_face/001_face.png
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
Saving results to /kaggle/working/unsup3d_thi/unsup3d/demo/results/human_face/001_face
Processing /kaggle/working/unsup3d_thi/unsup3d/demo/images/human_face/002_face.png
Saving results to /kaggle/working/unsup3d_thi/unsup3d/demo/results/human_face/002_face
Processing /kaggle/working/unsup3d_thi/unsup3d/demo/images/human_face/003_face.png
Saving results to /kaggle/working/unsup3d_thi/unsup3d/demo/results/human_face/003_face
Processing /kaggle/working/unsup3d_thi/unsup3d/demo/images/human_face/004_face.png
Saving results to /kaggle/working/unsup3d_thi/unsup3d/demo/results/human_face/004_face
Processing /kaggle/working/unsup3d_thi/unsup3d/demo/images/human_face/005_face.png
Saving results to /kaggle/working/uns

In [1]:
%cd /kaggle/working/
!zip -r file.zip /kaggle/working/unsup3d_thi/unsup3d/demo/results/human_face/002_face

/kaggle/working

zip error: Nothing to do! (try: zip -r file.zip . -i /kaggle/working/unsup3d_thi/unsup3d/demo/results/human_face/002_face)
