link artigo : https://openaccess.thecvf.com/content_cvpr_2018/papers/Sun_Pix3D_Dataset_and_CVPR_2018_paper.pdf

link git : https://github.com/xingyuansun/pix3d

In [None]:
import torch
from torch import nn, optim
import matplotlib.pyplot as plt
import numpy as np
from copy import deepcopy

In [None]:
# 'sofa', 'chair', 'desk', 'bed', 'bookcase', 'tool', 'misc', 'wardrobe', 'table'
category = "bookcase"

with open("data/category.txt", "w") as file:
    file.write(category)

from data.dataloader_filtered import train_loader, test_loader, val_loader

In [None]:
torch.cuda.empty_cache()

In [None]:
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
device

In [None]:
from model.model import Rec3D, version
model = Rec3D().to(device)

In [None]:
try:
    model.load_state_dict(torch.load("model/weights/weights_{}_{}.pdf".format(version, category)))
    print("Weights loaded")
except:
    pass

model

In [None]:
EPOCHS = 50
lr = 0.0001
optimizer = optim.Adam(model.parameters(), lr=lr)
loss_fn_1 = nn.MSELoss()
loss_fn_2 = nn.MSELoss()

train_seg = True
train_rec = True

In [None]:
train_losses = [list(), list()]
valid_losses = [list(), list()]

In [None]:
def train_step(model, loader, loss_fn_1, loss_fn_2, optimizer, device) -> float:
  train_loss = [0, 0]
  for X, Y_mask, Y_cloud, image in loader:
    if X.shape == (1,1):
      continue
    model.zero_grad()
    y_pred = model(X.to(device), torch.rand((1,128,128,128)).to(device))
    if train_seg:
      loss_1 = loss_fn_1(y_pred[0], Y_mask.to(device))
      loss_1.backward(retain_graph=True)
      train_loss[0] +=  loss_1.item()
    if train_rec:
      loss_2 = loss_fn_2(y_pred[1], Y_cloud.to(device))
      loss_2.backward()
      train_loss[1] +=  loss_2.item()
    optimizer.step()
  train_loss[0] = train_loss[0] / len(loader)
  train_loss[1] = train_loss[1] / len(loader)
  return train_loss


def validation_step(model, loader, loss_fn_1, loss_fn_2, device) -> float:
  loss = [0, 0]
  with torch.no_grad():
    for X, Y_mask, Y_cloud, image in loader:
      if X.shape == (1,1):
        continue
      y_pred = model(X.to(device), torch.rand((1,128,128,128)).to(device))
      loss_1 = loss_fn_1(y_pred[0], Y_mask.to(device)).item() 
      loss_2 = loss_fn_2(y_pred[1], Y_cloud.to(device)).item()
      loss[0] += loss_1
      loss[1] += loss_2
  loss[0] = loss[0] / len(loader)
  loss[1] = loss[1] / len(loader)
  return loss

In [None]:
from tqdm import tqdm

def train_model():
  for epoch in tqdm(range(EPOCHS)):
    train_loss = train_step(model, train_loader, loss_fn_1, loss_fn_2, optimizer, device)
    train_losses[0].append(train_loss[0])
    train_losses[1].append(train_loss[1])
    valid_loss = validation_step(model, val_loader, loss_fn_1, loss_fn_2, device)
    valid_losses[0].append(valid_loss[0])
    valid_losses[1].append(valid_loss[1])
    if valid_losses[1][-1] == min(valid_losses[1]):
      best_model = deepcopy(model)
      torch.save(best_model.state_dict(), "model/weights/weights_{}_{}.pdf".format(version, category))
  if best_model:
    return best_model

In [None]:
train_seg, train_rec = True, True
model.freeze_encoder_block()
best_model = train_model()

In [None]:
train_seg, train_rec = True, True
model.unfreeze_encoder_block()
best_model = train_model()

In [None]:
def plot_loss(loss_train, loss_valid):
  plt.plot(loss_valid[0], label='valid')
  plt.plot(loss_train[0], label='train')
  plt.title('Loss per epoch [segmentation]')
  plt.ylabel('loss')
  plt.xlabel('epoch')
  plt.legend()
  plt.show()

plot_loss(train_losses, valid_losses)

In [None]:
def plot_loss(loss_train, loss_valid):
  plt.plot(loss_valid[1], label='valid')
  plt.plot(loss_train[1], label='train')
  plt.title('Loss per epoch [reconstruction]')
  plt.ylabel('loss')
  plt.xlabel('epoch')
  plt.legend()
  plt.show()

plot_loss(train_losses, valid_losses)

In [None]:
threshold_seg = 0.5
threshold_rec = 0.5

fig = plt.figure(figsize=(10, 10))

with torch.no_grad():
    j = 0
    count = 0
    skip = 0
    for X, y_mask, y_cloud, image in test_loader:
        if X.shape == (1,1):
            continue
        if count < skip:
            count += 1
            continue

        y = model(X.to(device), torch.zeros((1,128,128,128)).to(device))

        ax = fig.add_subplot(4, 4, 1+4*j)
        ax.imshow(image.squeeze().permute(1, 2, 0).cpu().numpy().astype(np.uint8))
        ax.set_xticks([])
        ax.set_yticks([])
        if j == 0:
            ax.title.set_text("Imagem")

        ax = fig.add_subplot(4, 4, 2+4*j)
        ax.imshow((y[0].squeeze() > threshold_seg).cpu().numpy().astype(np.uint8) * 255, cmap="gray", interpolation="None")
        ax.set_xticks([])
        ax.set_yticks([])
        if j == 0:
            ax.title.set_text("Segmentação")

        t = (y[1].flatten() >= threshold_rec).nonzero(as_tuple=True)[0]
        t = torch.stack((t // (128*128), (t % (128*128))//128, t % 128), dim=1)
        cloud = t.cpu().numpy()

        ax = fig.add_subplot(4, 4, 3+4*j, projection='3d')
        ax.scatter(cloud[:, 0], cloud[:, 1], cloud[:, 2], c=((cloud[:, 0] - 64)**2 + (cloud[:, 1] - 64)**2 + (cloud[:, 2] - 64)**2)**0.5, cmap="viridis")
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_zticks([])
        ax.set_xlim(-1, 129)
        ax.set_ylim(-1, 129)
        ax.set_zlim(-1, 129)
        if j == 0:
            ax.title.set_text("Reconstrução")

        t = y_cloud.flatten().nonzero(as_tuple=True)[0]
        t = torch.stack((t // (128*128), (t % (128*128))//128, t % 128), dim=1)
        cloud = t.cpu().numpy()

        ax = fig.add_subplot(4, 4, 4+4*j, projection='3d')
        ax.scatter(cloud[:, 0], cloud[:, 1], cloud[:, 2], c=((cloud[:, 0] - 64)**2 + (cloud[:, 1] - 64)**2 + (cloud[:, 2] - 64)**2)**0.5, cmap="viridis")
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_zticks([])
        ax.set_xlim(-1, 129)
        ax.set_ylim(-1, 129)
        ax.set_zlim(-1, 129)
        if j == 0:
            ax.title.set_text("Esperado")
        
        j += 1
        if j == 4:
            break

In [None]:
torch.save(model.state_dict(), "model/weights/weights_{}_{}.pdf".format(version, category))