# Contents

>[Contents](#scrollTo=ISfvEc6QTF3-)

>[Imports and Setup](#scrollTo=LzQFZNKL0uHb)

>[Create the Mesh](#scrollTo=ToNMP7mC797Y)

>[Create the renderer](#scrollTo=avCqYaNn8fi7)

>[Create a basic ML model](#scrollTo=NOyG2PbN-NXI)

>[Experiments](#scrollTo=0ni9zstuIsfs)

>>[Optimizing only weights (baseline)](#scrollTo=lAuhSR-TSelq)

>>[Optimizing colors and weights simultaneously](#scrollTo=aZLeUrEfSDlY)



# Imports and Setup

In [None]:
!pip install ninja

In [None]:
import os
import sys
import torch
need_pytorch3d=False
try:
    import pytorch3d
except ModuleNotFoundError:
    need_pytorch3d=True
if need_pytorch3d:
    if torch.__version__.startswith("2.2.") and sys.platform.startswith("linux"):
        # We try to install PyTorch3D via a released wheel.
        pyt_version_str=torch.__version__.split("+")[0].replace(".", "")
        version_str="".join([
            f"py3{sys.version_info.minor}_cu",
            torch.version.cuda.replace(".",""),
            f"_pyt{pyt_version_str}"
        ])
        !pip install fvcore iopath
        !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html
    else:
        # We try to install PyTorch3D from source.
        !pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'

In [None]:
import os
import torch
import numpy as np
from tqdm.notebook import tqdm
import imageio
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from skimage import img_as_ubyte

from pytorch3d.io import load_obj
from pytorch3d.structures import Meshes
from pytorch3d.utils import ico_sphere
from pytorch3d.transforms import Rotate, Translate, RotateAxisAngle
from pytorch3d.renderer import (
    FoVPerspectiveCameras, look_at_view_transform, look_at_rotation,
    RasterizationSettings, MeshRenderer, MeshRasterizer, BlendParams,
    SoftSilhouetteShader, HardPhongShader, PointLights, TexturesVertex,
)
from pytorch3d.loss import mesh_laplacian_smoothing

import torch
import torch.nn as nn
import torch.optim as optim

# Create the Sphere Mesh

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [None]:
sphere_mesh = ico_sphere(4, device)

verts = sphere_mesh.verts_packed()
verts_min = verts.min(dim=0, keepdim=True)[0]
verts_max = verts.max(dim=0, keepdim=True)[0]
normalized_verts = (verts - verts_min) / (verts_max - verts_min)

vertex_colors = nn.Parameter(normalized_verts.unsqueeze(0))
textures = TexturesVertex(verts_features=vertex_colors)
sphere_mesh.textures = textures

# Create the renderer
The renderer needs a rasterizer and a shader.

In [None]:
image_size = 128

In [None]:
cameras = FoVPerspectiveCameras(
  device=device,
  T=torch.tensor([[0.0, 0.0, 3.0]], device=device)
)

raster_settings = RasterizationSettings(
  image_size = image_size,
  blur_radius = 0.0,
  faces_per_pixel = 1,
)
lights = PointLights(device = device, location=[[0.0, 0.0, -3.0]])

renderer = MeshRenderer(
  rasterizer = MeshRasterizer(
    cameras=cameras,
    raster_settings=raster_settings
  ),
  shader = HardPhongShader(
    device = device,
    cameras=cameras,
    lights=lights
  )
)

In [None]:
image = renderer(sphere_mesh, cameras=cameras, lights=lights)
plt.imshow(image[0, ..., :].detach().cpu().numpy())

# Create a basic ML model

In [None]:
class RotationPredictor(nn.Module):
    def __init__(self):
        super(RotationPredictor, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(128 * 16 * 16, 512)
        self.fc2 = nn.Linear(512, 1)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = self.pool(self.relu(self.conv3(x)))
        x = x.contiguous().view(-1, 128 * 16 * 16)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

## Helper functions
We also need some helper functions for rotating and visualizing the spheres while training.

In [None]:
def create_rotated_sphere(sphere_mesh, angle_degrees, device):
  angle = torch.tensor([angle_degrees * np.pi / 180.0], device=device)

  rot_y = torch.tensor([[torch.cos(angle), 0, torch.sin(angle)],
                        [0, 1, 0],
                        [-torch.sin(angle), 0, torch.cos(angle)]], device=device)

  rotated_verts = sphere_mesh.verts_packed() @ rot_y.T

  rotated_mesh = Meshes(
    verts=[rotated_verts],
    faces=[sphere_mesh.faces_packed()],
    textures=sphere_mesh.textures
  )

  return rotated_mesh

In [None]:
def visualize_sphere(sphere_mesh, vertex_colors, renderer, cameras, lights):
  with torch.no_grad():
    num_verts = sphere_mesh.verts_packed().shape[0]

    vertex_colors_vis = vertex_colors.clone().detach()
    vertex_colors_vis = vertex_colors_vis.reshape(1, num_verts, 3)

    textures = TexturesVertex(verts_features=vertex_colors_vis)

    vis_mesh = Meshes(
        verts=[sphere_mesh.verts_packed()],
        faces=[sphere_mesh.faces_packed()],
        textures=textures
    )

    image = renderer(vis_mesh, cameras=cameras, lights=lights)
    plt.figure(figsize=(10, 10))
    plt.imshow(image[0, ..., :3].cpu().numpy())
    plt.axis('off')
    plt.show()


## Train and test loops

In [None]:
def train_model(model, optimizer, mesh, vertex_colors, renderer, cameras, lights, num_epochs=500, batch_size=32, lr=0.001, criterion=nn.MSELoss(), use_smoothing_loss=False):
  for epoch in range(num_epochs):
    total_loss = 0

    for i in range(batch_size):
      optimizer.zero_grad()

      textures = TexturesVertex(verts_features=vertex_colors)
      mesh.textures = textures
      angle = np.random.randint(0, 360)
      rotated_sphere = create_rotated_sphere(mesh, angle, device)

      image = renderer(rotated_sphere, cameras=cameras, lights=lights)
      image = image[..., :3].contiguous().permute(0, 3, 1, 2)

      predicted_angle = model(image)
      target_angle = torch.tensor([[angle * np.pi / 180.0]], device=device)

      loss = criterion(predicted_angle, target_angle)
      if use_smoothing_loss:
          loss += mesh_laplacian_smoothing(mesh, method="uniform")

      loss.backward()

      total_loss += loss.item()

    optimizer.step()

    if (epoch + 1) % 100 == 0:
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss/36:.4f}")
        visualize_sphere(mesh, vertex_colors, renderer, cameras, lights)

In [None]:
def test_model(model, sphere_mesh, renderer, cameras, lights, num_tests=500):
  total_squared_error = 0
  with torch.no_grad():
    for i in range(num_tests):
      angle = np.random.randint(0, 360)
      rotated_sphere = create_rotated_sphere(sphere_mesh, angle, device)

      image = renderer(rotated_sphere, cameras=cameras, lights=lights)
      image = image[..., :3].contiguous().permute(0, 3, 1, 2)

      predicted_angle = model(image)

      predicted_angle_degrees = predicted_angle.item() * 180 / np.pi
      squared_error = (angle - predicted_angle_degrees) ** 2
      total_squared_error += squared_error

  mse = np.sqrt(total_squared_error / num_tests)
  return mse


# Experiments

## Optimizing only the weights (baseline)

In [None]:
model = RotationPredictor().to(device)

optimizer = optim.Adam([
    {'params': model.parameters()}
], lr=0.001)

criterion = nn.MSELoss()

In [None]:
sphere_mesh = ico_sphere(4, device)

verts = sphere_mesh.verts_packed()
verts_min = verts.min(dim=0, keepdim=True)[0]
verts_max = verts.max(dim=0, keepdim=True)[0]
normalized_verts = (verts - verts_min) / (verts_max - verts_min)

vertex_colors = nn.Parameter(normalized_verts.unsqueeze(0))
textures = TexturesVertex(verts_features=vertex_colors)
sphere_mesh.textures = textures

In [None]:
train_model(model, optimizer, sphere_mesh, vertex_colors, renderer, cameras, lights, num_epochs=5000, batch_size=16)

In [None]:
num_tests=500
rmse = test_model(model, sphere_mesh, renderer, cameras, lights, num_tests)
print(f"Root mean squared error over {num_tests} tests: {rmse:.2f} (degrees).")

## Optimizing the colors and the weights simultaneously
The vertex colors were declared as a neural network parameter (nn.Parameter), so they will be learned during training.

In [None]:
model = RotationPredictor().to(device)

optimizer = optim.Adam([
    {'params': model.parameters()},
    {'params': [vertex_colors], 'lr': 0.001}
], lr=0.001)

criterion = nn.MSELoss()

In [None]:
sphere_mesh = ico_sphere(4, device)

verts = sphere_mesh.verts_packed()
verts_min = verts.min(dim=0, keepdim=True)[0]
verts_max = verts.max(dim=0, keepdim=True)[0]
normalized_verts = (verts - verts_min) / (verts_max - verts_min)

vertex_colors = nn.Parameter(normalized_verts.unsqueeze(0))
textures = TexturesVertex(verts_features=vertex_colors)
sphere_mesh.textures = textures

In [None]:
train_model(model, optimizer, sphere_mesh, vertex_colors, renderer, cameras, lights, num_epochs=5000, batch_size=16)

In [None]:
num_tests=500
rmse = test_model(model, sphere_mesh, renderer, cameras, lights, num_tests=500)
print(f"Root mean squared error over {num_tests} tests: {rmse:.2f} (degrees).")

## Optimizing the colors and the weights simultaneously with smoothing.

To prevent chaotic color changes, a laplacian smoothing loss can be included in the training loop.

In [None]:
model = RotationPredictor().to(device)

optimizer = optim.Adam([
    {'params': model.parameters()},
    {'params': [vertex_colors], 'lr': 0.001}
], lr=0.001)

criterion = nn.MSELoss()

In [None]:
sphere_mesh = ico_sphere(4, device)

verts = sphere_mesh.verts_packed()
verts_min = verts.min(dim=0, keepdim=True)[0]
verts_max = verts.max(dim=0, keepdim=True)[0]
normalized_verts = (verts - verts_min) / (verts_max - verts_min)

vertex_colors = nn.Parameter(normalized_verts.unsqueeze(0))
textures = TexturesVertex(verts_features=vertex_colors)
sphere_mesh.textures = textures

In [None]:
train_model(model, optimizer, sphere_mesh, vertex_colors, renderer, cameras, lights, num_epochs=5000, batch_size=16, use_smoothing_loss=True)

In [None]:
num_tests=500
rmse = test_model(model, sphere_mesh, renderer, cameras, lights, num_tests=500)
print(f"Root mean squared error over {num_tests} tests: {rmse:.2f} (degrees).")