# Utils

> A set of utility functions for common tasks, suck as fast visualization.

In [None]:
#| default_exp utils

In [None]:
#| export 
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import open3d as o3d
from typing import Mapping

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [None]:
#| export 
def_device = 'mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
#| export 
def to_device(x, device=def_device):
    if isinstance(x, torch.Tensor): return x.to(device)
    if isinstance(x, Mapping): return {k:v.to(device) for k,v in x.items()}
    return type(x)(to_device(o, device) for o in x)

In [None]:
#| export 
class DataLoaders:
    def __init__(self, *dls): self.train, self.valid = dls[:2]

In [None]:
#| export
def pc_to_o3d(pc): # point cloud as np.array or torch.tensor
    "turn a point cloud, represented as a np.array or torch.tensor to an [Open3D.geometry.PointCloud](http://www.open3d.org/docs/0.16.0/python_api/open3d.geometry.PointCloud.html)"
    pc = o3d.geometry.PointCloud(
            o3d.utility.Vector3dVector(pc)
    )
    return pc

In [None]:
#| export 
def quick_vis(pc): # point cloud as np.array or torch.tensor
    if len(pc.shape) == 3 and pc.shape[0] == 1: pc.squeeze(0) # removing batch dimension
    if isinstance(pc, torch.Tensor): pc = pc.cpu().detach().numpy()
    
    if pc.shape[-1] != 3: pc = pc.T 
    
    pc = pc_to_o3d(pc)
    o3d.visualization.draw_geometries([pc])

In [None]:
#| hide
#| eval: false
pc = torch.randn(100, 3)
quick_vis(pc)

# Loss Functions

In [None]:
#| export 
def cal_loss(pred, gold, smoothing=True):
    ''' Calculate cross entropy loss, apply label smoothing if needed. '''

    gold = gold.contiguous().view(-1)

    if smoothing:
        eps = 0.2
        n_class = pred.size(1)

        one_hot = torch.zeros_like(pred).scatter(1, gold.view(-1, 1), 1)
        one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1)
        log_prb = F.log_softmax(pred, dim=1)

        loss = -(one_hot * log_prb).sum(dim=1).mean()
    else:
        loss = F.cross_entropy(pred, gold, reduction='mean')

    return loss