In [None]:
import torch
from torch.utils.data import DataLoader
import pytorch3d
import sys

import pandas as pd 

from shapeaxi import saxi_nets
from shapeaxi import saxi_dataset
from shapeaxi import saxi_dataset
from shapeaxi import saxi_transforms

import plotly.express as px
from plotly.subplots import make_subplots
import plotly.graph_objects as go

from pytorch3d.structures import Meshes 
from pytorch3d.renderer import (
        FoVPerspectiveCameras, look_at_view_transform, look_at_rotation, 
        RasterizationSettings, MeshRenderer, MeshRasterizer, MeshRendererWithFragments, BlendParams,
        SoftSilhouetteShader, HardPhongShader, SoftPhongShader, AmbientLights, PointLights, TexturesUV, TexturesVertex, TexturesAtlas
)

: 

In [None]:
# model = saxi_nets.SaxiRingClassification(subdivision_level=3, out_classes=3, radius=1.1, hidden_dim=512, out_size=128, dropout_lvl=0.2, image_size=224, base_encoder="ViT", base_encoder_params="in_channels=4, img_size=(224,224), patch_size=(16,16),spatial_dims=2")
model = saxi_nets.SaxiRingClassification(subdivision_level=2, out_classes=3, radius=1.05, hidden_dim=512, out_size=128, dropout_lvl=0.2, image_size=224, base_encoder="resnet18", base_encoder_params="pretrained=False,spatial_dims=2,n_input_channels=4,num_classes=512")
model = model.to('cuda').eval()

In [None]:
ds = saxi_dataset.SaxiDataset(pd.read_csv('/CMF/data/lumargot/DCBIA/Airway_Obst_Classif_Sample/airway_4classes_test.csv'), transform=saxi_transforms.EvalTransform(scale_factor=0.02764634543775486))

In [None]:
def create_figure(image_data1, image_data2):
    fig = make_subplots(rows=1, cols=2, subplot_titles=('Image 1', 'Image 2'))

    # Add initial frames for both images with shared coloraxis
    fig.add_trace(go.Heatmap(z=image_data1[0], coloraxis="coloraxis"), row=1, col=1)
    fig.add_trace(go.Heatmap(z=image_data2[0], coloraxis="coloraxis"), row=1, col=2)

    # Create frames for the animation
    frames = []
    for k in range(image_data1.shape[0]):
        frame = go.Frame(data=[
            go.Heatmap(z=image_data1[k], coloraxis="coloraxis"),
            go.Heatmap(z=image_data2[k], coloraxis="coloraxis")
        ], name=str(k))
        frames.append(frame)

    # Add frames to the figure
    fig.frames = frames

    # Calculate the aspect ratio
    height, width = image_data1[0].shape[:2]
    aspect_ratio = height / width

    # Determine global min and max values for consistent color scale
    vmin = min(image_data1.min(), image_data2.min())
    vmax = max(image_data1.max(), image_data2.max())

    # Update layout with animation settings and fixed aspect ratio
    fig.update_layout(
        autosize=False,
        width=1200,  # Adjust width as needed
        height=600,  # Adjust height according to aspect ratio
        coloraxis={"colorscale": "jet",
                   "cmin": vmin,  # Set global min value for color scale
                    "cmax": vmax},   # Set global max value for color scale},  # Set colorscale for the shared coloraxis
        updatemenus=[{
            "buttons": [
                {
                    "args": [None, {"frame": {"duration": 500, "redraw": True},
                                    "fromcurrent": True, "mode": "immediate"}],
                    "label": "Play",
                    "method": "animate"
                },
                {
                    "args": [[None], {"frame": {"duration": 0, "redraw": False},
                                    "mode": "immediate"}],
                    "label": "Pause",
                    "method": "animate"
                }
            ],
            "direction": "left",
            "pad": {"r": 10, "t": 87},
            "showactive": False,
            "type": "buttons",
            "x": 0.1,
            "xanchor": "right",
            "y": 0,
            "yanchor": "top"
        }],
        sliders=[{
            "steps": [
                {
                    "args": [[str(k)], {"frame": {"duration": 300, "redraw": True},
                                        "mode": "immediate"}],
                    "label": str(k),
                    "method": "animate"
                } for k in range(image_data1.shape[0])
            ],
            "active": 0,
            "yanchor": "top",
            "xanchor": "left",
            "currentvalue": {
                "font": {"size": 20},
                "prefix": "Frame:",
                "visible": True,
                "xanchor": "right"
            },
            "transition": {"duration": 300, "easing": "cubic-in-out"}
        }]
    )
    return fig

In [None]:

def render(self, V, F, CN):
    # Render the input surface mesh to an image
    textures = TexturesVertex(verts_features=CN)
    meshes = Meshes(verts=V, faces=F, textures=textures)
    X = []
    PF = []

    for camera_position in self.ico_verts:
        camera_position = camera_position.unsqueeze(0)
        camera_position = camera_position.to(self.device)
        R = look_at_rotation(camera_position, device=self.device)  # (1, 3, 3)
        T = -torch.bmm(R.transpose(1, 2), camera_position[:,:,None])[:, :, 0]   # (1, 3)
        images = self.renderer(meshes_world=meshes.clone(), R=R, T=T)
        fragments = self.renderer.rasterizer(meshes.clone())
        
        pix_to_face = fragments.pix_to_face
        zbuf = fragments.zbuf

        v = V[:,F[:,pix_to_face][:,:,:,:,:,0]].squeeze(dim=5).squeeze(dim=1).squeeze(dim=1)
        
        z_buf_n = torch.square(v - camera_position).sum(dim=-1).unsqueeze(-1)*(pix_to_face >= 0)
        zbuf = zbuf*(pix_to_face >= 0)

        images = torch.cat([images[:,:,:,0:3], z_buf_n, torch.square(zbuf)], dim=-1)
        
        images = images.permute(0,3,1,2)
        pix_to_face = pix_to_face.permute(0,3,1,2)
        
        X.append(images.unsqueeze(1))
        PF.append(pix_to_face.unsqueeze(1))
    
    X = torch.cat(X, dim=1)
    PF = torch.cat(PF, dim=1)

    return X, PF


V, F, CN = ds[11]

X, PF = render(model, V.unsqueeze(0).cuda(), F.unsqueeze(0).cuda(), CN.unsqueeze(0).cuda())


image_data_zbuf_c = X[0,:,3:4].permute(0,2,3,1).squeeze().cpu().numpy()
image_data_zbuf = X[0,:,4:].permute(0,2,3,1).squeeze().cpu().numpy()

fig = create_figure(image_data_zbuf_c, image_data_zbuf)

fig.show()

In [None]:
torch.cuda.empty_cache()