In [1]:
import os
import sys

from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
import SimpleITK as sitk
import nrrd
import vtk

import torch
from torch import nn
from torch.nn import functional as F
from torchvision import transforms

import pytorch_lightning as pl
import pickle
import monai 
import glob 
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

sys.path.append('/mnt/famli_netapp_shared/C1_ML_Analysis/src/famli-ultra-sim/')
sys.path.append('/mnt/famli_netapp_shared/C1_ML_Analysis/src/famli-ultra-sim/dl')
import dl.transforms.ultrasound_transforms as ultrasound_transforms
import dl.loaders.mr_us_dataset as mr_us_dataset
import dl.nets.us_simulation_jit as us_simulation_jit
import dl.nets.us_simu as us_simu

import importlib

from dl.nets.layers import TimeDistributed


In [2]:
mount_point = '/mnt/famli_netapp_shared/C1_ML_Analysis'
               
from shapeaxi import utils
from shapeaxi.saxi_dataset import SaxiDataset
from torch.utils.data import DataLoader



In [3]:

df = pd.read_csv(os.path.join(mount_point, 'source/diffusion-models/blender/studies/placenta/FAM-025-0754-2.csv'))

ds = SaxiDataset(df, mount_point=mount_point, surf_column='surf', transform=None, CN=True)
dl = DataLoader(ds, batch_size=150, num_workers=2, collate_fn=utils.pad_verts_faces)
            


In [4]:
probe_params_df = pd.read_csv(os.path.join(mount_point, 'source/blender/simulated_data_export/probe_params.csv'))

probe_directions = []
probe_origins = []

for _, row in probe_params_df.iterrows():
    probe_params = pickle.load(open(os.path.join(mount_point, row['probe_param_fn']), 'rb'))

    probe_direction = torch.tensor(probe_params['probe_direction'], dtype=torch.float32)
    probe_origin = torch.tensor(probe_params['probe_origin'], dtype=torch.float32)

    probe_directions.append(probe_direction.T)
    probe_origins.append(probe_origin)

In [5]:
from pytorch3d.structures import (
    Meshes,
    Pointclouds,)

from pytorch3d.renderer import (
        FoVPerspectiveCameras, PerspectiveCameras, look_at_rotation, 
        RasterizationSettings, MeshRenderer, MeshRasterizer, MeshRendererWithFragments,
        HardPhongShader, AmbientLights, TexturesVertex
)
from pytorch3d.ops import (sample_points_from_meshes,
                           knn_points, 
                           knn_gather)

from pytorch3d.loss import (
    chamfer_distance,
    point_mesh_edge_distance, 
    point_mesh_face_distance
)


cameras = FoVPerspectiveCameras()

raster_settings = RasterizationSettings(image_size=128, blur_radius=0, faces_per_pixel=1,max_faces_per_bin=200000)        
rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings)
lights = AmbientLights()
renderer = MeshRenderer(rasterizer=rasterizer,shader=HardPhongShader(cameras=cameras, lights=lights))

device = torch.device("cuda:0")
renderer = renderer.to(device)

In [6]:
def render(V, F, CN, camera_position, R):
    # Render the input surface mesh to an image
    textures = TexturesVertex(verts_features=CN.to(torch.float32))
    meshes = Meshes(verts=V, faces=F, textures=textures)
    
    X = []
    PF = []

    # for camera_position in self.ico_verts:
    #     camera_position = camera_position.unsqueeze(0)
    #     R = look_at_rotation(camera_position, device=self.device)  # (1, 3, 3)
    #     T = -torch.bmm(R.transpose(1, 2), camera_position[:,:,None])[:, :, 0]   # (1, 3)
    #     images = self.renderer(meshes_world=meshes.clone(), R=R, T=T)        
    #     fragments = self.renderer.rasterizer(meshes.clone())
    #     pix_to_face = fragments.pix_to_face
    #     zbuf = fragments.zbuf
    #     images = torch.cat([images[:,:,:,0:3], zbuf], dim=-1)
    #     images = images.permute(0,3,1,2)
    #     pix_to_face = pix_to_face.permute(0,3,1,2)
    #     X.append(images.unsqueeze(1))
    #     PF.append(pix_to_face.unsqueeze(1))
    
    # camera_position = camera_position.unsqueeze(0)
    # R = look_at_rotation(camera_position, device=V.device)  # (1, 3, 3)
    T = -torch.bmm(R.transpose(1, 2), camera_position[:,:,None])[:, :, 0]   # (1, 3)
    
    images = renderer(meshes_world=meshes.clone(), R=R, T=T)        
    fragments = renderer.rasterizer(meshes.clone())
    pix_to_face = fragments.pix_to_face
    zbuf = fragments.zbuf
    images = torch.cat([images[:,:,:,0:3], zbuf], dim=-1)
    images = images.permute(0,3,1,2)
    pix_to_face = pix_to_face.permute(0,3,1,2)
    X.append(images.unsqueeze(1))
    PF.append(pix_to_face.unsqueeze(1))
    
    X = torch.cat(X, dim=1)
    PF = torch.cat(PF, dim=1)        

    return X, PF

In [9]:
for V, F, CN in dl:
    X, PF = render(V.cuda(), F.cuda(), CN.cuda(), probe_origins[0].cuda().unsqueeze(0), probe_directions[0].cuda().unsqueeze(0))
    print(X.shape, PF.shape)

torch.Size([150, 1, 4, 128, 128]) torch.Size([150, 1, 1, 128, 128])
torch.Size([150, 1, 4, 128, 128]) torch.Size([150, 1, 1, 128, 128])
torch.Size([44, 1, 4, 128, 128]) torch.Size([44, 1, 1, 128, 128])


In [None]:
probe_origins[0].cuda().unsqueeze(0).shape