In [1]:
import yaml
import os

In [2]:
import torch
import matplotlib.pyplot as plt

In [3]:
from data.data_generation import generate_dataset
from helper_functions_sampling import get_all_points, PointPooling3D

In [4]:
from train_script import PoseuilleFlowAnalytic
from torch.utils.data import DataLoader
import torch
from networks_models import ConvUNetBis, SmallLinear
import numpy as np
from tqdm import tqdm
from loss_functions import PDELoss, ReconLoss
import pdb

In [None]:
def ravel_multi_index(coords, shape):
    r"""Converts a tensor of coordinate vectors into a tensor of flat indices.
    This is a `torch` implementation of `numpy.ravel_multi_index`.
    Args:
        coords: A tensor of coordinate vectors, (*, D).
        shape: The source shape.
    Returns:
        The raveled indices, (*,).
    """
    shape = torch.tensor(shape + (1,), dtype=coords.dtype)
    coefs = shape[1:].flipud().cumprod(dim=0).flipud()
    sol = (coords * coefs).sum(dim=-1)
    return sol

In [5]:
# loading the model config
with open("training_config.yml") as file:
    config = yaml.safe_load(file)

In [6]:
# setting device and data types
dtype = torch.float32
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cpu


In [7]:
# boundary loss weights parameters
boundary_recon = config["boundary_recon"]
boundary_pde = 1 - boundary_recon
# flow loss weights parameters
flow_pde = config["flow_pde"]
flow_recon = 1 - flow_pde
# background loss weights parameters
background_recon = config["background_recon"]
background_pde = 1 - background_recon

In [8]:
# set up losses
recon_loss_function = ReconLoss()
pde_loss_function = PDELoss(rho=config["rho"], mu=config["mu"], gx=config["gx"], gy=config["gy"], gz=config["gz"])

In [9]:
# volume size
nvox = config["nvox"]
nsamples = 8 #config["nsamples"]
samples, segmentation_maps = generate_dataset(nsamples=nsamples, nvox=nvox)
val_dataset = PoseuilleFlowAnalytic(samples, segmentation_maps)
batch_size = 4 #config["batch_size"]
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

100%|██████████| 8/8 [00:00<00:00, 2273.34it/s]

Normalizaing samples.
Creating segmentation maps.





In [10]:
# setup models
model = ConvUNetBis(input_size=64, channels_in=samples.shape[1], channels_init=4)
model = model.to(device=device)
model.eval()
print("There are ", sum(p.numel() for p in model.parameters()), " parameters to train.")

There are  15887  parameters to train.


In [11]:
smallLinear = SmallLinear(num_features=model.channels_out+4, num_outputs=4) # + x, y, segmentation label
smallLinear.eval()
print("There are ", sum(p.numel() for p in smallLinear.parameters()), " parameters to train.")

There are  579  parameters to train.


In [12]:
model_name = config["model_name"]
checkpoint = torch.load(f'trainings/saved_models/{model_name}')
model.load_state_dict(checkpoint["model_state_dict"])
smallLinear.load_state_dict(checkpoint["linear_model_state_dict"])

<All keys matched successfully>

In [13]:
# point sampler
pp3d = PointPooling3D(interpolation="trilinear")

In [14]:
epoch_total_loss = 0
epoch_recon_loss = 0
epoch_pde_loss = 0

In [49]:
predictions = torch.from_numpy(np.zeros((len(samples), 4, samples.shape[2], samples.shape[3], samples.shape[4])))
predictions = predictions.float()

In [90]:
# iterate through the dataset
for i, (flow_sample, map_sample) in enumerate(val_dataloader):
    flow_sample = flow_sample.to(device=device, dtype=dtype) # move to device, e.g. GPU
    map_sample = map_sample.to(device=device, dtype=dtype)

    # =====================forward======================
    # compute latent vectors
    latent_vectors = model(flow_sample)
    
    # select same number of points per image to sample, unsqueeze at dim 1 to get the shape
    # batch_size x 1 x num_points x coordinates_size
    pts_ints = torch.Tensor(np.array([get_all_points(map_one.squeeze()) for map_one in map_sample])).unsqueeze(1).unsqueeze(1)
    # normalize the points to be in -1, 1 (required by grid_sample)
    # to have it zoomed, I divide these points here by whatever number
    pts = (2*(pts_ints/(nvox-1)) - 1)/1
    
    # features, segmentation maps, flow values
    pts_vectors, pts_maps, pts_flows = pp3d(latent_vectors, map_sample, flow_sample, pts)
    pts_locations = pts.squeeze()
    pts_locations.requires_grad = True # needed for the PDE loss
    
    # create feature vectors for each sampled point
    feature_vector = torch.cat([pts_locations, pts_maps, pts_vectors], dim=-1) 
    feature_vector = feature_vector.reshape((-1, model.channels_out + 4)) # x, y, z, seg inter, features
    
    # split input features to allow taking separate derivatives
    inputs = [feature_vector[..., i:i+1] for i in range(feature_vector.shape[-1])]
    x_ = torch.cat(inputs, axis=-1)
    
    # forward through linear model
    outputs_linear = smallLinear(x_)
    
    # points as integers for indexing
    pts_ints_l = pts_ints.squeeze().long()

    for k in tqdm(range(len(pts_ints_l))):
        pts_one_long = pts_ints_l[k]
        outputs_linear_one = outputs_linear.reshape(batch_size, 64**3, 4)[k]

        flat_one_points = ravel_multi_index(pts_one_long, (64, 64, 64))
        a = torch.zeros((4, 64*64*64), dtype=outputs_linear_one.dtype, requires_grad=False)
        a[:, flat_one_points] = outputs_linear_one.T
        predictions[i*batch_size + k] = a.reshape((4, 64, 64, 64))

100%|██████████| 4/4 [00:00<00:00, 96.78it/s]
100%|██████████| 4/4 [00:00<00:00, 235.88it/s]


In [91]:
fig = px.imshow(
    np.concatenate((samples[:5, 0, ..., 1].squeeze(), predictions_np[:5, 0, ..., 1].squeeze()), axis=0), 
    facet_col=0, 
    facet_col_wrap=5
)
fig.for_each_annotation(lambda a: a.update(text="Sample" if int(a.text.split("=")[1])<=4 else "Prediction"))
fig

In [70]:
# =====================losses======================
# get the losses weights for each point
num_loss_terms = 2
seg_interpolations_rand = pts_maps.reshape((-1, 1))
weights = torch.ones((len(seg_interpolations_rand), num_loss_terms), dtype=torch.float)
# boundary weights
weights[(seg_interpolations_rand>=0.75).squeeze(), :] = torch.Tensor([boundary_pde, boundary_recon])
# flow weights
weights[(seg_interpolations_rand>=1.25).squeeze(), :] = torch.Tensor([flow_pde, flow_recon])
# background weights
weights[(seg_interpolations_rand<0.5).squeeze(), :] = torch.Tensor([background_pde, background_recon])

In [71]:
# compute the loss
pde_loss = weights[:, 0]*pde_loss_function.compute_loss(inputs, outputs_linear)
recon_loss = weights[:, 1]*recon_loss_function.compute_loss(pts_flows, outputs_linear)
loss = torch.mean(pde_loss + recon_loss)

In [72]:
# update write iteration loss
epoch_total_loss += loss.item()
epoch_recon_loss += recon_loss.mean().item()
epoch_pde_loss += pde_loss.mean().item()

In [73]:
print("Total loss: ", epoch_total_loss, "Reconstruction loss: ", epoch_recon_loss, "PDE loss: ", epoch_pde_loss)
# np.save("predictions.npy", predictions.detach().cpu().numpy())

Total loss:  0.0017090195324271917 Reconstruction loss:  0.0017056771321222186 PDE loss:  3.3423875720473006e-06


In [28]:
# # vmap the prediction
# from functorch import vmap
# def get_volume_predictions(pts_one_long, outputs_linear_one):
#     flat_one_points = ravel_multi_index(pts_one_long, (64, 64, 64))
#     a = torch.zeros((4, 64*64*64), dtype=outputs_linear_one.dtype, requires_grad=False)
#     a.index_put(flat_one_points.chunk(len(flat_one_points)), outputs_linear_one.T)
#     return a.reshape((4, 64, 64, 64))
# batched_get_volume_predictions = vmap(get_volume_predictions)
# d = batched_get_volume_predictions(pts_ints_l, outputs_linear.reshape(batch_size, 64**3, 4))

In [None]:
# # check reshaping
# outputs_linear.shape
# outputs_linear.chunk(4, dim=0)[0]
# outputs_linear.reshape(batch_size, 64**3, 4)[0]