In [2]:
print(torch.__version__)

1.12.1+cu113


In [1]:
import numpy as np
import rembg
import torch
import lpips
import os
from PIL import Image
import matplotlib.pyplot as plt

from tsr.system import TSR
from tsr.utils import remove_background, resize_foreground

In [2]:
output_dir = "test/"

images=[]
rembg_session = rembg.new_session()

image = remove_background(Image.open("./examples/chair.png"), rembg_session)
image = resize_foreground(image, 0.85)
image = np.array(image).astype(np.float32) / 255.0
image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
image = Image.fromarray((image * 255.0).astype(np.uint8))
image.save(os.path.join(output_dir, f"input.png"))
images.append(image)

In [5]:
# Load the model
model = TSR.from_pretrained(
    "./train",
    config_name="config.yaml",
    weight_name="model.ckpt",
)

# Set parameters
chunkSize = 8192 # Chunk size
nViews = 30 # Number of views
mcResolution = 256 # Marching cubes


# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")
if device=="cuda:0": torch.cuda.empty_cache()
model.to(device)
model.renderer.set_chunk_size(chunkSize)

In [11]:
render_out = []
for i, image in enumerate(images):
    with torch.no_grad():
        scene_codes = model([image], device=device)
        
    render_images = model.render(scene_codes, n_views=nViews, return_type="pil")
    render_out.append(render_images)
    for ri, render_image in enumerate(render_images[0]):
        render_image.save(os.path.join(output_dir, f"render_{ri:03d}.png"))

TEST

In [14]:
# Compute the loss
def compute_loss(render_img, gt_img):
    # NOTE: the rgb value range of OpenLRM is [0, 1]
    # render_images = render_img['render_images']
    # target_images = gt_img['target_images'].to(render_images)
    # render_images = rearrange(render_images, 'b n ... -> (b n) ...') * 2.0 - 1.0
    # target_images = rearrange(target_images, 'b n ... -> (b n) ...') * 2.0 - 1.0

    # loss_mse = F.mse_loss(render_images, target_images)
    # loss_lpips = 2.0 * loss_fn_vgg(img0, img1)

    loss_mse = F.mse_loss(render_img, gt_img)
    loss_lpips = 2.0 * loss_fn_vgg(render_img, gt_img)  

    # render_alphas = render_out['render_alphas']
    # target_alphas = render_gt['target_alphas']
    # loss_mask = F.mse_loss(render_alphas, target_alphas)

    # loss = loss_mse + loss_lpips + loss_mask

    loss = loss_mse + loss_lpips

    prefix = 'train'
    loss_dict = {}
    loss_dict.update({f'{prefix}/loss_mse': loss_mse})
    loss_dict.update({f'{prefix}/loss_lpips': loss_lpips})
    # loss_dict.update({f'{prefix}/loss_mask': loss_mask})
    loss_dict.update({f'{prefix}/loss': loss})

    return loss, loss_dict

In [13]:
# 옵티마이저 및 손실 함수 설정
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)    # weight_decay is L2 regularization
criterion = torch.nn.MSELoss()
loss_fn_vgg = lpips.LPIPS(net='vgg')

Setting up [LPIPS] perceptual loss: trunk [vgg], v[0.1], spatial [off]
Loading model from: /home/tony/anaconda3/envs/tsr/lib/python3.8/site-packages/lpips/weights/v0.1/vgg.pth


In [None]:
# Training loop
render_out = []
gt_img = []
for epoch in range(100):  # number of epochs
    for image in images:
        # Forward pass
        with torch.no_grad():
            scene_codes = model([image], device=device)

        render_images = model.render(scene_codes, n_views=nViews, return_type="pil")
        render_out.append(render_images)

        loss, loss_dict = compute_loss(render_out, gt_img)
        
        # Backward pass
        optimizer.zero_grad()
        loss_mask.backward()
        optimizer.step()

    print(f"Epoch {epoch+1}, Loss: {loss.item()}")