In [None]:
  from google.colab import drive
drive.mount('/gdrive',force_remount=True)

import os
os.chdir('/gdrive/MyDrive/Colab Notebooks/DDPS_Test')

Mounted at /gdrive


In [None]:
# Option.py --------------------------------------------------------------------

#import argparse
import os
import torch
import numpy as np
import datetime
from PIL import Image
import sys
current_working_directory = os.getcwd()
sys.path.append(current_working_directory + '/models')

class Options:
    def __init__(self):
        self.name = 'debug'
        self.num_threads = 0
        self.n_epoch = 1
        self.batch_size = 1
        self.batch_size_valid = 2
        self.valid_N = 2
        self.batch_size_test = 1
        self.freq_print = 1
        self.freq_valid = 10
        self.light_N = 4
        self.initial_pattern = 'tri_random'
        self.used_loss = 'cosine'
        self.lr_light = 3e-1
        self.lr_decay_light = 0.3
        self.step_size_light = 10
        self.fix_light_pattern = False
        self.fix_light_position = True
        self.load_epoch = 0
        self.load_step_start = 0
        self.load_latest = False
        self.dataset_root = './ApplyData/'
        self.save_dir = './results_apply/'
        self.light_geometric_calib_fn = './calibration/monitor_pixels_h12w6.npy' #'./calibration/monitor_pixels_h16w9.npy' / './calibration/monitor_pixels_h9w16.npy'
        self.load_monitor_light = False
        self.light_fn = ''
        self.pt_ref_x = 0.
        self.pt_ref_y = 0.
        self.pt_ref_z = 0.1 #0.1 / 0.5
        self.resizing_factor = 1
        self.cam_R = 608 #608 / 512
        self.cam_C = 456 #456 / 612
        self.usePatch = False
        self.roi_min_c = 511
        self.roi_min_r = 288
        self.roi_width = 200
        self.roi_height = 200
        self.cam_focal_length_pix = 8.110470904e+5 #8.110470904e+5 / 1.1485e+3
        self.cam_pitch = 1.0315e-5 #1.0315e-5 / 3.45e-6*2
        self.cam_gain = 0.0
        self.cam_gain_base = 1.0960
        self.static_interval = 1/24
        self.noise_sigma = 0.02
        self.rendering_scalar = 3.25 #5.681e2/28 / 1.3e2/40
        self.light_R = 12 # 16 / 9
        self.light_C = 6 # 9 / 16
        self.light_pitch = 0.315e-3 * 120 * 2
        self.monitor_gamma = 2.2
        self.vertical_gap = 226.8 * 1e-3 / 3
        self.horizontal_gap = 80.64 * 1e-3
        self.weight_normal = 1
        self.weight_recon = 5

        self.initialized = True

    def parse(self, save=True, isTrain=True, parse_args=None):
        if not self.initialized:
            self.initialize()

        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') ## gpu cuda 연산 가능시 사용함 (불가능하면 cpu로 대신 연산)
        self.dtype = torch.float32 ## float32 type

        self.cam_focal_length = self.cam_focal_length_pix * self.cam_pitch

        self.cam_shutter_time = self.static_interval / self.light_N

        self.original_R = self.cam_R
        self.original_C = self.cam_C

        self.cam_R //= self.resizing_factor
        self.cam_C //= self.resizing_factor
        self.cam_pitch *= self.resizing_factor

        if not self.usePatch:
            self.roi_height = self.cam_R
            self.roi_width = self.cam_C

        # Dataset paths
        self.tb_dir = os.path.join(self.save_dir, self.name, datetime.datetime.now().strftime("%Y%m%d_%H%M%S")) ## Output 폴더 생성
        if not os.path.exists(self.tb_dir):
            os.makedirs(self.tb_dir)

        print('------------ Options -------------')
        args = vars(self)
        print('-------------- End ----------------')

        if save:
            file_name = os.path.join(self.tb_dir, 'opt.txt')
            with open(file_name, 'wt') as opt_file:
                opt_file.write('------------ Options -------------\n')
                for k, v in sorted(args.items()):
                    opt_file.write('%s: %s\n' % (str(k), str(v)))
                opt_file.write('-------------- End ----------------\n')

        # Light position
        light_pos = np.load(self.light_geometric_calib_fn) ## 각 superpixel의 위치
        light_pos = np.reshape(light_pos, (self.light_R, self.light_C, 3)) ## npy로 복구 (superpixel 좌표값 데이터)
        self.light_pos = torch.tensor(light_pos, dtype=self.dtype, device=self.device) ## (cuda 연산 전용) tensor로 변환
        self.light_pos_np = light_pos ## 원본 np 저장

        self.illums = torch.zeros((self.light_N, self.light_R, self.light_C, 3), dtype=self.dtype, device=self.device)  # (# of Illum, Illum's Row, Illum's Column, 3)
        vmin =  0.1
        vmax =  0.9
        self.illums[:,:,:,:] = vmin

        torch.manual_seed(0)
        if self.initial_pattern == 'tri_random':
            self.illums[:, :, :, :] = torch.rand(self.light_N, self.light_R, self.light_C, 3)

        elif self.initial_pattern == 'gray_noise':
            self.illums[:, :, :, :] = torch.normal(mean=0.5, std=0.01, size=(self.light_N, self.light_R, self.light_C, 3))

        elif self.initial_pattern == 'mono_random':
            temp = torch.rand(self.light_N, self.light_R, self.light_C, 1)
            self.illums[:, :, :, :] = torch.tile(temp, (1,1,1,3))

        elif self.initial_pattern == 'mono_gradient':
            # Gradient
            row_indices = torch.arange(self.light_R).float().unsqueeze(1)
            col_indices = torch.arange(self.light_C).float().unsqueeze(0)
            ones = torch.ones(self.light_R, self.light_C)
            dist = row_indices / self.light_R
            self.illums[0] = torch.tile((0.1 + 0.8 * (ones - dist)).unsqueeze(-1), (1,1,3))
            self.illums[1] = torch.flip(self.illums[0], dims=[0])
            dist = col_indices / self.light_C
            self.illums[2] = torch.tile((0.1 + 0.8 * (ones - dist)).unsqueeze(-1), (1,1,3))
            self.illums[3] = torch.flip(self.illums[2], dims=[1])

        elif self.initial_pattern == 'tri_gradient':
            # Color Gradient
            row_indices = torch.arange(self.light_R).float().unsqueeze(1)
            col_indices = torch.arange(self.light_C).float().unsqueeze(0)
            ones = torch.ones(self.light_R, self.light_C)
            dist = row_indices / self.light_R
            self.illums[0,:,:,2] = ((ones - dist))
            self.illums[1,:,:,2] = dist
            dist = col_indices / self.light_C
            self.illums[0,:,:,0] = ((ones - dist))
            self.illums[1,:,:,0] = dist
            # Compute distance from center of image
            dist = torch.sqrt((row_indices - self.light_R/2)**2 + (col_indices - self.light_C/2)**2)
            dist_norm = (dist.max() - dist) / dist.max()
            intensity = dist_norm * ones
            self.illums[0,:,:,1] = intensity
            self.illums[1,:,:,1] = ones - intensity

        elif self.initial_pattern == 'mono_comnplementary':
            # Binary
            self.illums[1, :, :self.light_C // 2, :] = vmax
            self.illums[0, :, self.light_C // 2:, :] = vmax
            self.illums[2, :self.light_R // 2, :, :] = vmax
            self.illums[3, self.light_R // 2:, :, :] = vmax

        elif self.initial_pattern == 'tri_comnplementary':
            # Color binary
            self.illums[0, :, :self.light_C // 2, 0] = vmax
            self.illums[0, :, self.light_C // 2:, 1] = vmax
            self.illums[0, :self.light_R // 2, :, 2] = vmax
            self.illums[1, :, :self.light_C // 2, 1] = vmax
            self.illums[1, :, self.light_C // 2:, 0] = vmax
            self.illums[1, self.light_R // 2:, :, 2] = vmax


        elif self.initial_pattern == 'OLAT':
            # OLAT
            self.illums[0, :1, -1:, :] = vmax
            self.illums[1, :1, :1, :] = vmax
            self.illums[2, -1:, -1:, :] = vmax
            self.illums[3, -1:, :1, :] = vmax

        elif self.initial_pattern == 'grouped_OLAT':
            # Neighbor OLAT
            self.illums[0, :3, -3:, :] = vmax
            self.illums[1, :3, :3, :] = vmax
            self.illums[2, -3:, -3:, :] = vmax
            self.illums[3, -3:, :3, :] = vmax
        #self.illums[:, :, :, :] = 0.99 # for Test
        self.illums[:, :, :, :] = torch.logit(self.illums)

        # Reference point
        self.pt_ref = torch.tensor([self.pt_ref_x, self.pt_ref_y, self.pt_ref_z], dtype=self.dtype, device=self.device) ## 미사용 변수

        # Reference point
        reference_plane = self.compute_ptcloud_from_depth(self.pt_ref_z, 0, 0, self.cam_R, self.cam_C, self.cam_R, self.cam_C, self.cam_pitch, self.cam_focal_length)
        reference_plane = torch.tensor(reference_plane, device=self.device, dtype=self.dtype)
        self.reference_plane = torch.tile(reference_plane.unsqueeze(0), (self.batch_size, 1, 1, 1))

        return self

    def compute_ptcloud_from_depth(self, depth, roi_rmin, roi_cmin, roi_height, roi_width, full_height, full_width, pitch, focal_length):
        '''
        make a point cloud by unprojecting the pixels with depth
        '''
        if isinstance(depth, float):
            pass
        else:
            depth = depth[roi_rmin:roi_rmin + roi_height, roi_cmin:roi_cmin + roi_width] ## 사실상 미사용
        XYZ = np.zeros((roi_height, roi_width, 3), dtype=np.float32)
        r, c = np.meshgrid(np.linspace(roi_rmin, roi_rmin + roi_height - 1, roi_height),
                              np.linspace(roi_cmin, roi_cmin + roi_width - 1, roi_width),
                              indexing='ij')
        center_c, center_r = full_width / 2, full_height / 2
        XYZ[:, :, 0] = (depth / focal_length) * (c - center_c) * pitch
        XYZ[:, :, 1] = (depth / focal_length) * (r - center_r) * pitch
        XYZ[:, :, 2] = depth

        return XYZ

# trainer.py  ------------------------------------------------------------------

import torch
from torch.autograd import Variable
import matplotlib.pyplot as plt
import numpy as np
# from scipy import interpolate
#import utils
from PIL import Image
import gc
import os
import datetime

def model_results_to_np(model_results, batch_ind=None):
    if batch_ind is None:
        B = model_results['normal_est'].shape[0]
        batch_ind =range(0,B)

    results = dict()
    results['normal_est'] = model_results['normal_est'][batch_ind, ...].detach().cpu().numpy()
    results['albedo_est'] = model_results['albedo_est'][batch_ind, ...].detach().cpu().numpy()
    results['depth_est'] = model_results['depth_est'][batch_ind, ...].detach().cpu().numpy()

    return results

class Trainer:
    def __init__(self, opt, reconstructor):

        self.opt = opt
        self.reconstructor = reconstructor

        self.model_results = None
        self.optimize_light = not opt.fix_light_pattern
        self.optimize_position = not opt.fix_light_position

        # Optimization variable parameterization
        self.monitor_superpixel_positions = torch.nn.Parameter(self.opt.light_pos)
        self.camera_gain = torch.logit(torch.tensor([self.opt.cam_gain], device=self.opt.device, dtype=self.opt.dtype)/48)

        # Load monitor
        if opt.load_monitor_light:
            self.monitor_light_patterns = torch.load(opt.light_fn)

    def run_model(self, dataset, idx):

        data = dataset[idx]
        batch_size = 1

        depth = self.opt.reference_plane[:batch_size,:,:,:]
        depth = depth.to(device=self.opt.device)

        # Render image
        I_diffuse = data['imgs']
        self.monitor_light_patterns = data['patterns']
        mask = data['mask']
        incident, _ = self.compute_light_direction(depth, self.monitor_superpixel_positions)

        # Reconstruct normal and albedo
        recon_output = self.reconstructor.forward(I_diffuse, self.monitor_light_patterns, incident, self.camera_gain, mask)
        normal_est, albedo_est, depth_est = (recon_output['normal'], recon_output['albedo'], recon_output['depth'])

        model_results = {}
        model_results['normal_est'] = normal_est.detach()
        model_results['albedo_est'] = albedo_est.detach()
        model_results['depth_est'] = depth_est.detach()

        return model_results

    def compute_light_direction(self, ptcloud, light_pos):
        # Compute for point light sources
        incident = light_pos.unsqueeze(2).unsqueeze(2).unsqueeze(2) - ptcloud.unsqueeze(0).unsqueeze(0)
        incident = incident / (torch.linalg.norm(incident, axis=-1).unsqueeze(-1) + 1e-8)
        exitant = -(ptcloud / (torch.linalg.norm(ptcloud, axis=-1).unsqueeze(-1) + 1e-8))

        return incident, exitant[None, None, ...]

    def run_optimizers_one_step(self, data):

        # Backward pass
        if self.optimize_light:
            self.optimizer_light.zero_grad()

        # ----------------------------------------------
        self.model_results = self.run_model(data)
        self.loss_dict = self.model_results['losses']
        loss_sum = sum(self.loss_dict.values())
        loss_sum.requires_grad_(True)
        loss_sum.backward()

        # Step optimizer
        if self.optimize_light:
            self.optimizer_light.step()
        # ----------------------------------------------

    def run_schedulers_one_step(self):
        if self.optimize_light:
            self.scheduler_light.step()

    def get_losses(self):
        return self.loss_dict

    def get_model_results(self):
        return self.model_results

# Utils.py ---------------------------------------------------------------------

import torch
import os
import numpy as np
import matplotlib.pyplot as plt
#import trainer as trainer_utils
import datetime


def print_training_status(epoch, n_epoch, step, debug_step):
    message = f'(epoch: {epoch}/{n_epoch}, iters: {step}, debug_iters: {debug_step}) '
    print(message)


def display(opt, writer, model_results, step, mode, batch_index, dir, rendered_image, rendered_error): # @@@@
    results = model_results_to_np(model_results, batch_ind=batch_index)
    #results = trainer_utils.model_results_to_np(model_results, batch_ind=batch_index)
    writer.add_image(f'model_results_{mode+str(batch_index)}/[normal_gt]', ((-results['normal_gt'].transpose((2,0,1))+1)/2), step)
    writer.add_image(f'model_results_{mode+str(batch_index)}/[normal_est]', ((-results['normal_est'].transpose((2,0,1))+1)/2), step)
    writer.add_image(f'model_results_{mode+str(batch_index)}/[albedo_est]', results['albedo_est'].transpose((2,0,1)), step)
    normal_error = torch.sigmoid(rendered_error[batch_index]).detach().cpu().numpy()
    normal_error = torch.tensor([normal_error,normal_error,normal_error]).squeeze(1)
    writer.add_image(f'model_results_{mode+str(batch_index)}/[normal_err]', normal_error, step)
    for i in range(opt.light_N):
        rendered_img = rendered_image[batch_index][i].detach().cpu().numpy()
        writer.add_image(f'pattern[{i+1}]/{mode+str(batch_index)}[rendered_img]', rendered_img.transpose((2,0,1)), step)

def display_pattern(opt, writer, model_results, step, mode, dir, pattern):
    Pattern = torch.sigmoid(pattern)
    Pattern = Pattern.detach().cpu().numpy()
    for i in range(opt.light_N):
        writer.add_image(f'pattern[{i+1}]/pattern_img', Pattern[i].transpose((2,0,1)), step)


def save_model(monitor_light_patterns, superpixel_position, epoch, step, dir):

    if monitor_light_patterns is not None:
        torch.save(monitor_light_patterns, os.path.join(dir, f'monitor_light_patterns_epoch{str(epoch).zfill(5)}_step{str(step).zfill(5)}.pth'))
        torch.save(monitor_light_patterns, os.path.join(dir, 'monitor_light_patterns_latest.pth'))

        torch.save(superpixel_position, os.path.join(dir, f'superpixel_position_epoch{str(epoch).zfill(5)}_step{str(step).zfill(5)}.pth'))
        torch.save(superpixel_position, os.path.join(dir, 'superpixel_position_latest.pth'))

def save_model_pattern(monitor_light_patterns, epoch, step, dir):
    if monitor_light_patterns is not None:
        torch.save(monitor_light_patterns, os.path.join(dir, f'monitor_light_patterns_epoch{str(epoch).zfill(5)}_step{str(step).zfill(5)}.pth'))
        torch.save(monitor_light_patterns, os.path.join(dir, 'monitor_light_patterns_latest.pth'))

def save_model_position(superpixel_position, epoch, step, dir):
    if superpixel_position is not None:
        torch.save(superpixel_position, os.path.join(dir, f'superpixel_position_epoch{str(epoch).zfill(5)}_step{str(step).zfill(5)}.pth'))
        torch.save(superpixel_position, os.path.join(dir, 'superpixel_position_latest.pth'))

def load_monitor_light_patterns(outdir, epoch, latest=False):
    if latest:
        fn = os.path.join(outdir, 'monitor_light_patterns_latest.pth')
    else:
        fn = os.path.join(outdir, 'monitor_light_patterns_epoch%d.pth' % (epoch))
    if not os.path.isfile(fn):
        raise FileNotFoundError('%s not exists yet!' % fn)
    else:
        return torch.load(fn)


def load_camera_gain(outdir, epoch, latest=False):
    if latest:
        fn = os.path.join(outdir, 'camera_gain_latest.pth')
    else:
        fn = os.path.join(outdir, 'camera_gain_epoch%d.pth' % (epoch))
    if not os.path.isfile(fn):
        raise FileNotFoundError('%s not exists yet!' % fn)
    else:
        return torch.load(fn)


def cut_edge_batch(target_img):
    # torch tensor
    h = target_img.size(2)
    w = target_img.size(3)

    h1 = (h % 8)//2
    w1 = (w % 8)//2
    h2 = h - ((h % 8)-h1)
    w2 = w - ((w % 8)-w1)
    target_img = target_img[:, :, h1:h2, w1:w2]
    return target_img


def cut_edge(target_img):
    # numpy ndarray
    h = target_img.shape[0]
    w = target_img.shape[1]

    h1 = (h % 8)//2
    w1 = (w % 8)//2
    h2 = h - ((h % 8)-h1)
    w2 = w - ((w % 8)-w1)
    target_img = target_img[h1:h2, w1:w2]
    return target_img


def visualize3D(opt, rgb, ptcloud, light_pos, camera_pos, reference_plane, reference_point):
    # This code visualizes ptcloud, monitor, camera

    fig = plt.figure(figsize=(18, 15))
    ax = fig.add_subplot(2, 2, 1, projection='3d')
    # draw 3D scene
    ax.scatter(ptcloud[..., 0], ptcloud[..., 1], ptcloud[..., 2], c=rgb.reshape(-1, 3), marker='.', s=1, alpha=0.25)
    ax.set_xlabel('X [m]')
    ax.set_ylabel('Y [m]')
    ax.set_zlabel('Z [m]')
    ax.view_init(-45, -90)
    # draw monitor light source
    ax.scatter(light_pos[..., 0], light_pos[..., 1], light_pos[..., 2], c='red', marker='*', s=10)
    ax.scatter(light_pos[0,0, 0], light_pos[0,0, 1], light_pos[0,0, 2], c='blue', marker='*', s=20)
    # draw camera center
    ax.scatter(camera_pos[0], camera_pos[1], camera_pos[2], c='blue', marker='o', s=20)
    ax.scatter(reference_plane[..., 0], reference_plane[..., 1], reference_plane[..., 2], c=rgb.reshape(-1, 3), marker='.', s=1, alpha=0.25)
    ax.scatter(reference_point[0], reference_point[1], reference_point[2], c='green', marker='o', s=20)

    ax.set_xlim([-1., 1.])
    ax.set_ylim([-1., 1.])
    ax.set_zlim([-0., 2.])

    ax = fig.add_subplot(2, 2, 2, projection='3d')
    # draw 3D scene
    ax.scatter(ptcloud[..., 0], ptcloud[..., 1], ptcloud[..., 2], c=rgb.reshape(-1, 3), marker='.', s=1, alpha=0.25)
    ax.set_xlabel('X [m]')
    ax.set_ylabel('Y [m]')
    ax.set_zlabel('Z [m]')
    ax.view_init(0, 0)
    # draw monitor light source
    ax.scatter(light_pos[..., 0], light_pos[..., 1], light_pos[..., 2], c='red', marker='*', s=10)
    # draw camera center
    ax.scatter(camera_pos[0], camera_pos[1], camera_pos[2], c='blue', marker='o', s=20)
    ax.scatter(reference_plane[..., 0], reference_plane[..., 1], reference_plane[..., 2], c=rgb.reshape(-1, 3), marker='.', s=1, alpha=0.25)
    ax.scatter(reference_point[0], reference_point[1], reference_point[2], c='green', marker='o', s=20)

    ax.set_xlim([-1., 1.])
    ax.set_ylim([-1., 1.])
    ax.set_zlim([-0., 2.])

    ax = fig.add_subplot(2, 2, 3, projection='3d')
    # draw 3D scene
    ax.scatter(ptcloud[..., 0], ptcloud[..., 1], ptcloud[..., 2], c=rgb.reshape(-1, 3), marker='.', s=1, alpha=0.25)
    ax.set_xlabel('X [m]')
    ax.set_ylabel('Y [m]')
    ax.set_zlabel('Z [m]')
    ax.view_init(0, 90)
    # draw monitor light source
    ax.scatter(light_pos[..., 0], light_pos[..., 1], light_pos[..., 2], c='red', marker='*', s=10)
    # draw camera center
    ax.scatter(camera_pos[0], camera_pos[1], camera_pos[2], c='blue', marker='o', s=20)
    ax.scatter(reference_plane[..., 0], reference_plane[..., 1], reference_plane[..., 2], c=rgb.reshape(-1, 3), marker='.', s=1, alpha=0.25)
    ax.scatter(reference_point[0], reference_point[1], reference_point[2], c='green', marker='o', s=20)

    ax.set_xlim([-1., 1.])
    ax.set_ylim([-1., 1.])
    ax.set_zlim([-0., 2.])

    ax = fig.add_subplot(2, 2, 4, projection='3d')
    # draw 3D scene
    ax.scatter(ptcloud[..., 0], ptcloud[..., 1], ptcloud[..., 2], c=rgb.reshape(-1, 3), marker='.', s=1, alpha=0.25)
    ax.set_xlabel('X [m]')
    ax.set_ylabel('Y [m]')
    ax.set_zlabel('Z [m]')
    ax.view_init(90, 0)
    # draw monitor light source
    ax.scatter(light_pos[..., 0], light_pos[..., 1], light_pos[..., 2], c='red', marker='*', s=10)
    # draw camera center
    ax.scatter(camera_pos[0], camera_pos[1], camera_pos[2], c='blue', marker='o', s=20)
    ax.scatter(reference_plane[..., 0], reference_plane[..., 1], reference_plane[..., 2], c=rgb.reshape(-1, 3), marker='.', s=1, alpha=0.25)
    ax.scatter(reference_point[0], reference_point[1], reference_point[2], c='green', marker='o', s=20)

    ax.set_xlim([-1., 1.])
    ax.set_ylim([-1., 1.])
    ax.set_zlim([-0., 2.])

    plt.savefig(os.path.join(opt.tb_dir, datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + 'pointcloud'), facecolor='#eeeeee', bbox_inches='tight', dpi=300)
    plt.close()

def visualize_patterns(opt, monitor_light_patterns):
    monitor_light_radiance = torch.sigmoid(monitor_light_patterns)

    plt.figure(figsize=(10, 2*opt.light_N), constrained_layout=True)
    plt.suptitle('Pattern')
    for illum_idx in range(opt.light_N):
        pattern = (monitor_light_radiance[illum_idx])
        r = np.zeros_like(pattern)
        g = np.zeros_like(pattern)
        b = np.zeros_like(pattern)
        r[:, :, 0] = pattern[:, :, 0]
        g[:, :, 1] = pattern[:, :, 1]
        b[:, :, 2] = pattern[:, :, 2]

        r = np.clip(r, 0.0, 1.0)
        g = np.clip(g, 0.0, 1.0)
        b = np.clip(b, 0.0, 1.0)

        plt.subplot(opt.light_N, 4, illum_idx*4+1)
        plt.imshow(pattern)
        plt.title(f'Light pattern {illum_idx+1} RGB Channel')

        plt.subplot(opt.light_N, 4, illum_idx*4+2)
        plt.imshow(r)
        plt.title(f'Light pattern {illum_idx+1} R Channel')

        plt.subplot(opt.light_N, 4, illum_idx*4+3)
        plt.imshow(g)
        plt.title(f'Light pattern {illum_idx+1} G Channel')

        plt.subplot(opt.light_N, 4, illum_idx*4+4)
        plt.imshow(b)
        plt.title(f'Light pattern {illum_idx+1} B Channel')
    plt.savefig(os.path.join(opt.tb_dir, datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + '_0_Monitor_patterns'), facecolor='#eeeeee', bbox_inches='tight', dpi=300)
    plt.close()

def visualize_GT_data(opt, file_names, monitor_light_patterns, I_diffuse):
    batch_size = len(file_names)
    for batch_idx in range(batch_size):
        plt.figure(figsize=(8, opt.light_N*2), constrained_layout=True)
        plt.suptitle(f'Batch {batch_idx+1}: '+file_names[batch_idx])
        for illum_idx in range(opt.light_N):
            plt.subplot(opt.light_N, 2, illum_idx*2+1)
            plt.imshow((monitor_light_patterns[illum_idx]))
            plt.title(f'Light pattern {illum_idx+1}')
            plt.subplot(opt.light_N, 2, illum_idx*2+2)
            plt.imshow(I_diffuse[batch_idx, illum_idx])
            plt.title(f'B{batch_idx+1}_L{illum_idx+1}_Diffuse')
        plt.savefig(os.path.join(opt.tb_dir, datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + f'_1_{batch_idx+1}th_Data_Rendered_images.png'), facecolor='#eeeeee', bbox_inches='tight', dpi=300)
        plt.close()

def visualize_EST_normal(opt, file_names, normal_gt, normal_est, mask):
    batch_size = len(file_names)
    for batch_idx in range(batch_size):
        plt.figure(figsize=(8,8), constrained_layout=True)
        plt.suptitle(f'Batch {batch_idx + 1}: ' + file_names[batch_idx])

        plt.subplot(2, 2, 1)
        plt.imshow(((-normal_gt[batch_idx]+1)/2))
        plt.title(f'Batch{batch_idx+1}_Normal_GT')
        plt.colorbar()

        plt.subplot(2, 2, 3)
        plt.imshow(((-normal_est[batch_idx]+1)/2))
        plt.title(f'Batch{batch_idx+1}_Normal_EST')
        plt.colorbar()


        normal_cos_error = 1 - np.abs((normal_gt[batch_idx] * normal_est[batch_idx]).sum(-1))
        normal_angular_error = np.rad2deg(np.arccos((normal_gt[batch_idx] * normal_est[batch_idx]).sum(-1)))

        plt.subplot(2, 2, 2)
        plt.imshow(normal_cos_error * mask[batch_idx])
        plt.colorbar()
        plt.title(f'Batch{batch_idx+1}_Normal_Cosine_LOSS')


        plt.subplot(2, 2, 4)
        plt.imshow(normal_angular_error * mask[batch_idx])
        plt.colorbar()
        plt.title(f'Batch{batch_idx+1}_Normal_Angular_LOSS')

        plt.savefig(os.path.join(opt.tb_dir, datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + f'_1_{batch_idx+1}th_Data_EST_Normal.png'), facecolor='#eeeeee', bbox_inches='tight', dpi=300)
        plt.close()

def visualize_unsup_error(opt, file_names, monitor_light_patterns, I_diffuse, I_diffuse_est, I_diffuse_error):
    batch_size = len(file_names)
    for batch_idx in range(batch_size):
        plt.figure(figsize=(8, opt.light_N*2), constrained_layout=True)
        plt.suptitle(f'Batch {batch_idx+1}: '+file_names[batch_idx])
        for illum_idx in range(opt.light_N):
            plt.subplot(opt.light_N, 4, illum_idx*4+1)
            plt.imshow((monitor_light_patterns[illum_idx]))
            plt.title(f'Light pattern {illum_idx+1}')

            plt.subplot(opt.light_N, 4, illum_idx*4+2)
            plt.imshow(I_diffuse[batch_idx, illum_idx])
            plt.title(f'B{batch_idx+1}_L{illum_idx+1}_I_GT')

            plt.subplot(opt.light_N, 4, illum_idx*4+3)
            plt.imshow(I_diffuse_est[batch_idx, illum_idx])
            plt.title(f'B{batch_idx+1}_L{illum_idx+1}_I_Rerender')

            plt.subplot(opt.light_N, 4, illum_idx*4+4)
            plt.imshow(I_diffuse_error[batch_idx, illum_idx])
            plt.title(f'B{batch_idx+1}_L{illum_idx+1}_Error')

        plt.savefig(os.path.join(opt.tb_dir, datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + f'_2_{batch_idx+1}th_Data_Rerender.png'), facecolor='#eeeeee', bbox_inches='tight', dpi=300)
        plt.close()

# Dataset.py -------------------------------------------------------------------

import torch
import glob
import cv2
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from PIL import Image
import re
import os
#from pyntcloud import PyntCloud
from os.path import join
import random
#import models.utils as utils


class CreateBasisDataset(torch.utils.data.Dataset):
    def __init__(self, opt, list_name):
        self.opt = opt
        self.scene_names = file_load(os.path.join(opt.dataset_root, list_name))

    def __getitem__(self, i):
        scene_name = self.scene_names[i]
        pattern_dir = os.path.join(self.opt.dataset_root, scene_name, 'Pattern')
        pattern_file_names = [str(f).zfill(3)+'.png' for f in range(self.opt.light_N)]
        image_dir = os.path.join(self.opt.dataset_root, scene_name, 'Image')
        image_file_names = [str(f).zfill(3)+'.png' for f in range(self.opt.light_N)]

        vmax = 0.9
        vmin = 0.1
        patterns = torch.zeros((self.opt.light_N, self.opt.light_R, self.opt.light_C, 3), dtype=self.opt.dtype, device=self.opt.device)
        for idx, pat_file in enumerate(pattern_file_names):
            pat_path = os.path.join(pattern_dir, pat_file)
            pat = Image.open(pat_path)
            pat = np.array(pat).astype(np.float32)
            pat = pat/255
            pat = torch.tensor(pat[:,:,:3], dtype=self.opt.dtype, device=self.opt.device)
            pat[pat>vmax] = vmax
            pat[pat<vmin] = vmin
            pat = torch.logit(pat)
            patterns[idx] = pat

        imgs = torch.zeros((1, self.opt.light_N, self.opt.cam_R, self.opt.cam_C, 3), dtype=self.opt.dtype, device=self.opt.device)
        for idx, img_file in enumerate(image_file_names):
            img_path = os.path.join(image_dir, img_file)
            img = Image.open(img_path)
            img = np.array(img).astype(np.float32)
            img = img/255
            img = cv2.medianBlur(img, 3)
            img = torch.tensor(img[:,:,:3], dtype=self.opt.dtype, device=self.opt.device)
            imgs[0][idx] = img
        imgs = torch.clamp(imgs, 1e-8, 1)

        mask = np.load(os.path.join(self.opt.dataset_root, scene_name, 'geo_normal_img_mask.npy'))
        mask[mask!=0] = 1
        mask = mask[:,:,0]

        kernel = np.ones((7,7), np.uint8)
        mask = cv2.erode(mask, kernel, iterations=1)

        input_dict = {
            'id': i,
            'scene': scene_name,
            'patterns' : patterns,
            'imgs': imgs,
            'mask': mask
        }
        return input_dict

    def __len__(self):
        return len(self.scene_names)


def file_load(path):
    # Read data list
    data_path = []
    f = open("{0}.txt".format(path), 'r')
    while True:
        line = f.readline()
        if not line:
            break
        data_path.append(line[:-1])
    f.close()
    return data_path


def read_RGB(path):
    # RGB for compute diffuse
    img = Image.open(path)
    img.load()
    data = np.asarray(img, dtype="int32")
    return data

# renderer.py ------------------------------------------------------------------

import matplotlib.pyplot as plt
#from models.dataset import *
import torch.nn.functional as nn
import matplotlib.pyplot as plt
import torch
from scipy.io import loadmat
#import utils
import datetime
import torchvision.transforms as T

class LightDeskRenderer:
    def __init__(self, opt):
        self.opt = opt

    def render(self, basis_images, monitor_light_pattern_nonlinear, camera_gain):
        batch_size = basis_images.shape[0]

        monitor_light_radiance = torch.sigmoid(monitor_light_pattern_nonlinear) ** self.opt.monitor_gamma
        monitor_light_radiance = torch.flip(monitor_light_radiance, dims=[2])
        monitor_light_radiance = monitor_light_radiance.reshape(self.opt.light_N, -1, 3).unsqueeze(0).unsqueeze(-1)

        basis_images = basis_images.permute(0,1,4,2,3)
        # weights: [batch, lightN, R*C, 3, H*W]
        basis_images = basis_images.reshape(batch_size, self.opt.light_R*self.opt.light_C, 3, -1).unsqueeze(1)
        result = monitor_light_radiance * basis_images
        result = result.sum(axis=2)
        result = result.reshape(batch_size, self.opt.light_N, 3, self.opt.cam_R, self.opt.cam_C)
        result = result.permute(0,1,3,4,2)


        gain_scalar = self.opt.cam_gain_base ** (torch.sigmoid(camera_gain)*48)

        I_diffuse = gain_scalar * (self.opt.rendering_scalar * self.opt.cam_shutter_time * result)
        I_diffuse = torch.clamp(I_diffuse, 1e-8, 1)
        '''
        Sigma = self.opt.Gaussian_sigma
        ksize = int(8*Sigma+1)
        transform = T.GaussianBlur(kernel_size=(ksize,ksize), sigma=(Sigma,Sigma))

        I_diffuse_G = torch.zeros((batch_size, I_diffuse.shape[1], I_diffuse.shape[2], I_diffuse.shape[3], 3), dtype=self.opt.dtype, device=self.opt.device)
        for i in range(len(I_diffuse)):
            I_Temp = I_diffuse[i].permute(0, 3, 1, 2) # [#Pattern, R, C, 3] -> [#Pattern, 3, R, C]
            I_diffuse_G[i] = transform(I_Temp).permute(0, 2, 3, 1) # [#Pattern, 3, R, C] -> [#Pattern, R, C, 3]
        I_diffuse = torch.clamp(I_diffuse_G, 1e-8, 1)
        '''
        return I_diffuse

# reconstructor.py -------------------------------------------------------------

import matplotlib.pyplot as plt
import numpy as np
#import utils
from PIL import Image
import cv2
import torch
import torch.nn as nn
import datetime
import os

class Reconstructor:
    def __init__(self, opt):
        super().__init__()
        self.opt = opt


    def forward(self, im_diffuse, light_pattern_nonlinear, incident, camera_gain, mask):
        batch_size = im_diffuse.shape[0]

        normal, albedo, depth = self.photometric_stereo(batch_size, im_diffuse, light_pattern_nonlinear, incident, camera_gain, mask)

        recon_dict = {
            'normal': normal,
            'albedo': albedo,
            'depth' : depth
        }

        return recon_dict

    def photometric_stereo(self, batch_size, im_rendered, light_pattern_nonlinear, incident, camera_gain, mask):
        """
        Args:
            im_rendered: rendered image while training, real photo while testing
                    : [batch_size, #patterns, camR, camC, rgb] -> [batch_size, #patterns, camR, camC, rgb] -> [batch_size, #patterns, camR*camC]
            light_pattern_nonlinear: illumination
                    : sigmoid() -> [#patterns, lightR, lightC, rgb] -> [#patterns, lightR, lightC, 1] -> [#pattenrs, lightR*lightC]
            light_pos: coordinate of monitor
                    : [lightR, lightC, xyz] -> light_direction
            ptcloud: reference for photometric stereo, PLANE-NO(we assume that all data point is one same point)
            camera_gain: gain parameter for optimizing
        Output:
            surface normal: (batch, R, C, RGB) reconstructed normals
            diffuse albedo: (batch, R, C, RGB) reconstructed diffuse albedo
            depth map     : (batch, R, C, RGB) reconstructed depths
            valid mask
        """
        light_num = self.opt.light_N
        monitor_light_radiance = torch.sigmoid(light_pattern_nonlinear)**self.opt.monitor_gamma
        input_R, input_C = im_rendered.shape[2:4]

        incident = incident.reshape(self.opt.light_R*self.opt.light_C, batch_size*input_R*input_C, 3).permute(1,0,2)
        diffuse_albedo = torch.max(im_rendered, dim=1).values

        r = diffuse_albedo[:,:,:,0].reshape(batch_size*input_R*input_C, 1, 1)
        g = diffuse_albedo[:,:,:,1].reshape(batch_size*input_R*input_C, 1, 1)
        b = diffuse_albedo[:,:,:,2].reshape(batch_size*input_R*input_C, 1, 1)

        # im_rendered: [batch, #pattern, R, C, 3] -> [batch*R*C, 3*#pattern]
        im_rendered = im_rendered.permute(0,2,3,4,1) # batch, R, C, 3, #pattern
        im_rendered = im_rendered.reshape(batch_size*input_R*input_C, 3*light_num)
        im_rendered = im_rendered.unsqueeze(-1)

        im_rendered_red = im_rendered[:, :light_num]
        im_rendered_green = im_rendered[:, light_num:2*light_num]
        im_rendered_blue = im_rendered[:, 2*light_num:3*light_num]

        # light_pattern: [#pattern, r, c, 3] -> [3*#pattern, r*c]
        monitor_light_radiance = monitor_light_radiance.permute(3,0,1,2)
        monitor_light_radiance = monitor_light_radiance.reshape(3*light_num, self.opt.light_R*self.opt.light_C)

        # M: [3*#pattern, 3] -> [3*#pattern, xyz, R*C]
        # => M1:r, M2:g, M3:b   [#patterns, xyz, R*C]
        monitor_light_radiance = torch.tile(monitor_light_radiance.unsqueeze(0), (batch_size * input_R * input_C, 1, 1))
        M = monitor_light_radiance @ (incident)
        M1 = M[:, 0:light_num, :]
        M2 = M[:, light_num:2*light_num, :]
        M3 = M[:, 2*light_num:3*light_num, :]

        # iterative update
        for i in range(1):
            M_temp = torch.zeros_like(M, dtype=torch.float32, device=self.opt.device)
            # element-wise multiplication
            M_temp[:, 0:light_num, :] = r * M1
            M_temp[:, light_num:2 * light_num, :] = g * M2
            M_temp[:, 2 * light_num:3 * light_num, :] = b * M3
            # solve normal
            # invM: [3, 3*#patterns]
            # x:    [3, batch*R*C]
            print(M_temp.shape,im_rendered.shape)
            x = torch.linalg.lstsq(M_temp, im_rendered).solution

            x = (x.squeeze(-1))

            # black-pixel handling
            x = x/(torch.linalg.norm(x, axis=-1).unsqueeze(-1) + 1e-8)
            x = x.unsqueeze(-1)

            # solve each channel
            # invR/G/B: [batch_size*camR*camC, 1, 9]

            #------------------------------------------------------------------------------
            foreshortening = torch.clamp(incident@x, 1e-8, None)
            r_expose = monitor_light_radiance[:,0:light_num,:]@foreshortening
            g_expose = monitor_light_radiance[:,light_num:2*light_num,:]@foreshortening
            b_expose = monitor_light_radiance[:,2*light_num:3*light_num,:]@foreshortening

            r_new = (torch.linalg.lstsq(r_expose, im_rendered_red).solution)
            g_new = (torch.linalg.lstsq(g_expose, im_rendered_green).solution)
            b_new = (torch.linalg.lstsq(b_expose, im_rendered_blue).solution)


            r_ = r_new.reshape(batch_size, input_R, input_C)
            g_ = g_new.reshape(batch_size, input_R, input_C)
            b_ = b_new.reshape(batch_size, input_R, input_C)

            diffuse_albedo = torch.stack([r_, g_, b_], axis=-1)

            gain_scalar = self.opt.cam_gain_base ** (torch.sigmoid(camera_gain)*48)

            diffuse_albedo = diffuse_albedo / (gain_scalar * (self.opt.rendering_scalar * self.opt.cam_shutter_time) + 1e-8)
            diffuse_albedo = torch.clamp(diffuse_albedo, 1e-8, 1)

            surface_normal = x.reshape(batch_size, input_R, input_C, 3)
            normal_map = surface_normal.clone()
            surface_normal = (1-surface_normal)/2
            surface_normal = torch.clamp(surface_normal, 1e-8, 1)

            #normal_map = surface_normal.clone()
            #pnum = input_R*input_C
            # surface_normal: [batch, R, C, 3] -> [3, batch*R*C]
            #print(normal_map.shape)
            #normal_map = normal_map.permute(3,0,1,2) # 3, batch, R, C
            #normal_map = normal_map.reshape(3, batch_size*input_R*input_C)
            #normal_x = normal_map[0,:]
            #normal_y = normal_map[1,:]
            #normal_z = normal_map[2,:]

            #A_temp = torch.zeros(batch_size*2*pnum, dtype=torch.float32, device=self.opt.device)
            #A_temp[0:pnum] = normal_z[:]
            #A_temp[pnum:pnum*2] = normal_z[:]
            #A_temp = torch.diag(A_temp)
            #print(A_temp.shape)

            A_tmp = torch.zeros(batch_size,input_R,input_C,3,dtype=torch.float32, device=self.opt.device)
            #B_tmp = torch.zeros(batch_size,input_R,input_C,3,dtype=torch.float32, device=self.opt.device)
            z0 = torch.zeros(1,dtype=torch.float32, device=self.opt.device) #(m)
            z0[0] = 0
            k = 1
            # naive approach
            for i in range(input_R-1,0,-1):
              for j in range(input_C-1,0,-1):
                if mask[i][j] == 0:
                  continue
                if j == input_C-1 or mask[i][j+1] == 0:
                  ret = z0[0] + k*self.opt.cam_pitch * normal_map[0][i][j][0]/normal_map[0][i][j][2]
                else:
                  ret = A_tmp[0][i][j+1][0] + k*self.opt.cam_pitch * normal_map[0][i][j][0]/normal_map[0][i][j][2]

                if i == input_R-1 or mask[i+1][j] == 0:
                  ret += z0[0] + k*self.opt.cam_pitch * normal_map[0][i][j][1]/normal_map[0][i][j][2]
                else:
                  ret += A_tmp[0][i+1][j][0] + k*self.opt.cam_pitch * normal_map[0][i][j][1]/normal_map[0][i][j][2]

                for k in range(3):
                    A_tmp[0][i][j][k] = ret/2

            A_max = torch.max(A_tmp)
            A_min = torch.min(A_tmp)
            #print(A_max,A_min,A_tmp)
            A_tmp = torch.sub(A_tmp, A_min)
            A_tmp = torch.div(A_tmp,(A_max-A_min))
            A_tmp = torch.clamp(A_tmp, 1e-8, 1)
            depth_map = A_tmp
            '''
            for i in range(input_R):
              for j in range(input_C):
                if i == 0:
                  ret = z0[0] - self.opt.cam_pitch * normal_map[0][i][j][1]/normal_map[0][i][j][2]
                  for k in range(3):
                    B_tmp[0][i][j][k] = ret
                else:
                  ret = B_tmp[0][i-1][j][0] - self.opt.cam_pitch* normal_map[0][i][j][1]/normal_map[0][i][j][2]
                  for k in range(3):
                    B_tmp[0][i][j][k] = ret

                if j == 0:
                  #ret = z0[0] + self.opt.cam_pitch * surface_normal[0][i][j][0]/surface_normal[0][i][j][2]
                  ret = z0[0] - self.opt.cam_pitch * normal_map[0][i][j][0]/normal_map[0][i][j][2]
                  for k in range(3):
                    A_tmp[0][i][j][k] = ret
                else:
                  #ret = A_tmp[0][i][j+1][0] + self.opt.cam_pitch* surface_normal[0][i][j][0]/surface_normal[0][i][j][2]
                  ret = A_tmp[0][i][j-1][0] - self.opt.cam_pitch* normal_map[0][i][j][0]/normal_map[0][i][j][2]
                  for k in range(3):
                    A_tmp[0][i][j][k] = ret
            A_max = torch.max(A_tmp)
            A_min = torch.min(A_tmp)
            print(A_max,A_min,A_tmp)
            A_tmp = torch.sub(A_tmp, A_min)
            A_tmp = torch.div(A_tmp,(A_max-A_min))
            A_tmp = torch.clamp(A_tmp, 1e-8, 1)
            B_max = torch.max(B_tmp)
            B_min = torch.min(B_tmp)
            print(B_max,B_min,B_tmp)
            B_tmp = torch.sub(B_tmp, B_min)
            B_tmp = torch.div(B_tmp,(B_max-B_min))
            B_tmp = torch.clamp(B_tmp, 1e-8, 1)

            depth_map = A_tmp
            depth_map += B_tmp
            depth_map = torch.div(depth_map,2)
            '''

            #print(depth_map)

            #print(A_tmp.shape,A_tmp)



        return surface_normal, diffuse_albedo, depth_map


In [None]:
# Train.py (Main) --------------------------------------------------------------

import time
import torch
from torch.utils.data import DataLoader, random_split
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

def train(Opt, apply_dataset):
    # Option
    opt = Opt
    print('=================================================================================')
    print('Reconstruction output: %s' % opt.tb_dir)
    print('=================================================================================')

    # Image formation & Reconstruction & Trainer
    recon_model = Reconstructor(opt)
    trainer = Trainer(opt, recon_model)

    for k, v in enumerate(apply_dataset.scene_names):
      model_results = trainer.run_model(apply_dataset,k)
      results = model_results_to_np(model_results, batch_ind=0)
      normal_est = (results['normal_est']*255).astype(np.uint8)
      albedo_est = (results['albedo_est']*255).astype(np.uint8)
      depth_est = (results['depth_est']*255).astype(np.uint8)
      img_normal = Image.fromarray(normal_est)
      img_normal.save(opt.tb_dir+f'/normal_est_{v}.png','PNG')
      print(opt.tb_dir+f"/normal_est_{v}.png saved")
      img_albedo = Image.fromarray(albedo_est)
      img_albedo.save(opt.tb_dir+f'/albedo_est_{v}.png','PNG')
      print(opt.tb_dir+f"/albedo_est_{v}.png saved")
      img_depth = Image.fromarray(depth_est)
      img_depth.save(opt.tb_dir+f'/depth_est_{v}.png','PNG')
      print(opt.tb_dir+f"/depth_est_{v}.png saved")

# Main
opt = Options().parse(save=True, isTrain=True)
apply_dataset = CreateBasisDataset(opt, 'Apply_Data')
train(opt,apply_dataset)


------------ Options -------------
-------------- End ----------------
Reconstruction output: ./results_apply/debug/20241124_043045
torch.Size([277248, 12, 3]) torch.Size([277248, 12, 1])
tensor(0.0034, device='cuda:0', grad_fn=<MaxBackward1>) tensor(0., device='cuda:0', grad_fn=<MinBackward1>) tensor([[[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         ...,

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.