In [1]:
import os
import torch
import numpy as np
from tqdm.notebook import tqdm
import imageio
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence as pack_sequence, pad_packed_sequence as unpack_sequence
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

import torchvision.models as models
# io utils
from pytorch3d.io import load_obj

# datastructures
from pytorch3d.structures import Meshes

# 3D transformations functions
from pytorch3d.transforms import Rotate, Translate

# rendering components
from pytorch3d.renderer import (
    FoVPerspectiveCameras, look_at_view_transform, look_at_rotation, 
    RasterizationSettings, MeshRenderer, MeshRasterizer, BlendParams,
    SoftSilhouetteShader, HardPhongShader, SoftPhongShader, AmbientLights, PointLights, TexturesUV, TexturesVertex,
)

import vtk
import sys
sys.path.insert(0,'..')
import fly_by_features as fbf
from vtk.util.numpy_support import vtk_to_numpy
from vtk.util.numpy_support import numpy_to_vtk
import pandas as pd
from sklearn.model_selection import train_test_split
from skimage import img_as_ubyte

import monai
from monai.data import ITKReader, PILReader
from monai.transforms import (
    ToTensor, LoadImage, Lambda, AddChannel, RepeatChannel, ScaleIntensityRange, RandSpatialCrop,
    Resized, Compose
)
from monai.config import print_config
from monai.metrics import DiceMetric
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Imports for monai model
import logging
import tempfile
from glob import glob

from PIL import Image
from torch.utils.tensorboard import SummaryWriter

from monai.data import ArrayDataset, create_test_image_2d, decollate_batch
from monai.inferers import sliding_window_inference
from monai.transforms import (
    Activations,
    AddChannel,
    AsDiscrete,
    Compose,
    LoadImage,
    RandRotate90,
    RandSpatialCrop,
    ScaleIntensity,
    EnsureType,
)
from monai.visualize import plot_2d_or_3d_image
print("imports done")

imports done


In [2]:
# Set the cuda device 
if torch.cuda.is_available():
    device = torch.device("cuda:0")
    torch.cuda.set_device(device)
else:
    device = torch.device("cpu")
    
 

In [3]:
# Initialize a perspective camera.
cameras = FoVPerspectiveCameras(device=device)

# We will also create a Phong renderer. This is simpler and only needs to render one face per pixel.
raster_settings = RasterizationSettings(
    image_size=512, 
    blur_radius=0, 
    faces_per_pixel=1, 
)
# We can add a point light in front of the object. 

lights = PointLights(device=device)
rasterizer = MeshRasterizer(
        cameras=cameras, 
        raster_settings=raster_settings
    )
phong_renderer = MeshRenderer(
    rasterizer=rasterizer,
    shader=HardPhongShader(device=device, cameras=cameras)
)


class FlyByDataset(Dataset):
    def __init__(self, df):
        self.df = df

    def set_env_params(self, params):
        self.params = params

    def __len__(self):
        return len(self.df.index)

    def __getitem__(self, idx):
        
        surf = fbf.ReadSurf(df.iloc[idx]["surf"])
        surf = fbf.GetUnitSurf(surf)
        surf, _a, _v = fbf.RandomRotation(surf)

        surf = fbf.ComputeNormals(surf)

        color_normals = ToTensor(dtype=torch.float32, device=device)(vtk_to_numpy(fbf.GetColorArray(surf, "Normals"))/255.0)
        verts = ToTensor(dtype=torch.float32, device=device)(vtk_to_numpy(surf.GetPoints().GetData()))
        faces = ToTensor(dtype=torch.int64, device=device)(vtk_to_numpy(surf.GetPolys().GetData()).reshape(-1, 4)[:,1:])
        region_id = ToTensor(dtype=torch.int64, device=device)(vtk_to_numpy(surf.GetPointData().GetScalars("UniversalID")))
        region_id = torch.clamp(region_id, min=0)
        faces_pid0 = faces[:,0:1]
        region_id_faces = torch.take(region_id, faces_pid0)
        #print("shape region_id_faces: ", region_id_faces.shape)
        
        return verts, faces, region_id, region_id_faces, faces_pid0, color_normals,df.iloc[idx]["surf"]
        
def pad_verts_faces(batch):
    names = [n for v,f,rid,ridf,fpid0,cn,n in batch]
    verts = [v for v, f, rid, ridf, fpid0, cn,n in batch]
    faces = [f for v, f, rid, ridf, fpid0, cn,n in batch]
    region_ids = [rid for v, f, rid, ridf, fpid0, cn,n in batch]
    region_ids_faces = [ridf for v, f, rid, ridf, fpid0, cn,n in batch]
    faces_pid0s = [fpid0 for v, f, rid, ridf, fpid0, cn,n in batch]
    color_normals = [cn for v, f, rid, ridf, fpid0, cn,n in batch]
    
    pad_seq_verts = pad_sequence(verts, batch_first=True, padding_value=0.0)
    pad_seq_faces = pad_sequence(faces, batch_first=True, padding_value=-1)
    pad_seq_rid = pad_sequence(region_ids, batch_first=True, padding_value=0)
    pad_seq_faces_pid0s = pad_sequence(faces_pid0s, batch_first=True, padding_value=-1)
    pad_seq_cn = pad_sequence(color_normals, batch_first=True, padding_value=0.)
    l = [f.shape[0] for f in faces]
    
    return pad_seq_verts, pad_seq_faces, pad_seq_rid, torch.cat(region_ids_faces), pad_seq_faces_pid0s, pad_seq_cn, l, names
        
df = pd.read_csv("/NIRAL/work/leclercq/data/training_UID.csv")

# Split data between training and validation 
df_train, df_val = train_test_split(df, test_size=0.1)  

# Datasets 
train_data = FlyByDataset(df_train)
val_data = FlyByDataset(df_val)

# Dataloaders
batch_size = 10
num_classes = 34
train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True, collate_fn=pad_verts_faces)
val_dataloader = DataLoader(val_data, batch_size=batch_size, shuffle=True, collate_fn=pad_verts_faces)
dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
post_trans = AsDiscrete(argmax=True, to_onehot=True, num_classes=num_classes)
post_label = AsDiscrete(to_onehot=True, num_classes=num_classes)
post_pred = AsDiscrete(argmax=True, to_onehot=True, num_classes=num_classes)


# create UNet, DiceLoss and Adam optimizer
model = monai.networks.nets.UNet(
    spatial_dims=2,
    in_channels=4,   # images: torch.cuda.FloatTensor[4,224,224,4]
    out_channels=num_classes, # background + gum + 16 teeth
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
).to(device)
loss_function = monai.losses.DiceCELoss(to_onehot_y=True,softmax=True)
optimizer = torch.optim.AdamW(model.parameters(), 1e-5)
best_metric = -1
best_metric_epoch = -1
epoch_loss_values = list()
metric_values = list()
writer = SummaryWriter()



In [None]:
nb_epoch = 1_000_000
dist_cam = 1.35


camera_position = ToTensor(dtype=torch.float32, device=device)([[0, 0, dist_cam]])
R = look_at_rotation(camera_position, device=device)  # (1, 3, 3)
T = -torch.bmm(R.transpose(1, 2), camera_position[:,:,None])[:, :, 0]   # (1, 3)

# Start training
val_interval = 2
best_metric = -1
best_metric_epoch  = -1
epoch_loss_values = list()
metric_values = list()
writer = SummaryWriter()


for epoch in range (nb_epoch):
    print("-" * 20)
    print(f"epoch {epoch + 1}/{nb_epoch}")
    model.train() # Switch to training mode
    epoch_loss = 0
    step = 0
    for batch, (V, F, Y, YF, F0, CN, FL,N) in enumerate(train_dataloader):
        step += 1
        textures = TexturesVertex(verts_features=CN)
        meshes = Meshes(verts=V, faces=F, textures=textures)
        images = phong_renderer(meshes_world=meshes.clone(), R=R, T=T)
        pix_to_face, zbuf, bary_coords, dists = phong_renderer.rasterizer(meshes.clone())
        y_p = torch.take(YF, pix_to_face)*(pix_to_face >= 0)
        images = images.permute(0,3,1,2)
        y_p = y_p.permute(0,3,1,2)
        #print("shape images : ",images.shape)
        #print("shape y_p : ",y_p.shape)
        inputs, labels = images.to(device), y_p.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_function(outputs,labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        epoch_len = int(np.ceil(len(train_data) / train_dataloader.batch_size))
        print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
        writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step)
        
    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
    
    writer.add_scalar("train_mean_dice", epoch_loss, epoch + 1)


    # Validation
    if (epoch) % val_interval == 0: # every two epochs : validation
        model.eval()
        with torch.no_grad():
            val_images = None
            val_yp = None
            val_outputs = None
            for batch, (V, F, Y, YF, F0, CN, FL,N) in enumerate(val_dataloader):               

                textures = TexturesVertex(verts_features=CN)
                meshes = Meshes(verts=V, faces=F, textures=textures)
                val_images = phong_renderer(meshes_world=meshes.clone(), R=R, T=T)    
                pix_to_face, zbuf, bary_coords, dists = phong_renderer.rasterizer(meshes.clone()) 
                val_y_p = torch.take(YF, pix_to_face)*(pix_to_face >= 0)
                val_images, val_y_p = val_images.permute(0,3,1,2), val_y_p.permute(0,3,1,2)            
                val_images, val_labels = val_images.to(device), val_y_p.to(device)
                
                roi_size = (96, 96)
                sw_batch_size = batch_size
                val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)               
                
                
                val_labels_list = decollate_batch(val_labels)                
                val_labels_convert = [
                    post_label(val_label_tensor) for val_label_tensor in val_labels_list
                ]
                
                val_outputs_list = decollate_batch(val_outputs)
                val_outputs_convert = [
                    post_pred(val_pred_tensor) for val_pred_tensor in val_outputs_list
                ]
                
                dice_metric(y_pred=val_outputs_convert, y=val_labels_convert)
                print(val_outputs.shape)
                print(val_outputs_convert[0].shape)
                print(len(val_outputs_convert))
    
                
            # aggregate the final mean dice result
            metric = dice_metric.aggregate().item()
            # reset the status for next validation round
            dice_metric.reset()
            metric_values.append(metric)
            if metric > best_metric:
                best_metric = metric
                best_metric_epoch = epoch + 1
                torch.save(model.state_dict(), "best_metric_model_segmentation2d_array.pth")
                print("saved new best metric model")
            print(
                "current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}".format(
                    epoch + 1, metric, best_metric, best_metric_epoch
                )
            )
            writer.add_scalar("val_mean_dice", metric, epoch + 1)
            imgs_output = torch.argmax(val_outputs, dim=1).detach().cpu()
            imgs_output = imgs_output.unsqueeze(1)
            print(imgs_output.shape)
            val_rgb = torch.cat((val_labels,val_labels,val_labels),dim=1)
            out_rgb = torch.cat((imgs_output,imgs_output,imgs_output),dim=1)
            print(val_rgb.shape)
            print(out_rgb.shape)
            writer.add_images("labels",val_rgb,epoch)
            writer.add_images("output", out_rgb,epoch)
            
            
print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}")
writer.close()
            
        

--------------------
epoch 1/1000000
1/6, train_loss: 4.8790
2/6, train_loss: 4.8727
3/6, train_loss: 4.8660
4/6, train_loss: 4.8532
5/6, train_loss: 4.8589
6/6, train_loss: 4.8462
epoch 1 average loss: 4.8627
torch.Size([7, 34, 512, 512])
torch.Size([34, 512, 512])
7
saved new best metric model
current epoch: 1 current mean dice: 0.0182 best mean dice: 0.0182 at epoch 1
torch.Size([7, 1, 512, 512])
torch.Size([7, 3, 512, 512])
torch.Size([7, 3, 512, 512])
--------------------
epoch 2/1000000
1/6, train_loss: 4.8381
2/6, train_loss: 4.8353
3/6, train_loss: 4.8224
4/6, train_loss: 4.8143
5/6, train_loss: 4.8187
6/6, train_loss: 4.8151
epoch 2 average loss: 4.8240
--------------------
epoch 3/1000000
1/6, train_loss: 4.8077
2/6, train_loss: 4.8074
3/6, train_loss: 4.7909
4/6, train_loss: 4.7799
5/6, train_loss: 4.7784
6/6, train_loss: 4.7897
epoch 3 average loss: 4.7923
torch.Size([7, 34, 512, 512])
torch.Size([34, 512, 512])
7
saved new best metric model
current epoch: 3 current mean di

In [None]:
"""
print("shape V : ",V.shape)
print("shape F : ",F.shape)
print("shape Y : ",Y.shape)
print("shape YF : ",YF.shape)

print("shape F0 : ",F0.shape)
print("shape F0[0] : ",F0[0].shape)
"""
print(images.type())
print("images shape: ",images.shape)
print("y_p shape: ", y_p.shape)
print(len(train_dataloader))
print(len(train_dataloader))
print(images.type())
print(y_p.type())

In [None]:


images = images.permute(0,2,3,1)
y_p = y_p.permute(0,2,3,1)

fig = go.Figure(make_subplots(rows=1, cols=2, column_widths=[0.5, 0.5], specs=[[{}, {}]]))
fig.add_trace(go.Image(z=(images[2][...,0:3]*255).cpu().numpy()), row=1, col=1)

labelmap = np.flip((y_p[2][...,0]).cpu().numpy(), axis=0)
fig.add_trace(go.Heatmap(z=labelmap), row=1, col=2)
fig.update_layout(
    width = 1400, height = 700,
    autosize = False )
fig

In [None]:
a  = 1.016516198
print(f"best metric: {a:.4f}")
print(metric_values)
a =2
print(val_outputs.shape)
print(val_outputs_convert[0].shape)

In [None]:
imgs_output = imgs_output.unsqueeze(1)