In [1]:
import torch

from IPython.display import display as d

In [2]:
from collections import namedtuple
from typing import Optional, List, Dict

In [3]:
from src.data.components.dataset import SpatialTemporalDomain
from abc import ABC

In [4]:
COORDS_LIMITS = {"x": [-3, 3]}
BATCH_SIZE = 128

In [5]:
dataset = SpatialTemporalDomain(coords_limits=COORDS_LIMITS, time_limits=[0, 100], n_samples=200_000)

In [6]:
_coords = namedtuple(typename="coords", field_names=["x", "y", "z"])
class Coords(_coords):
    x: torch.Tensor
    y: torch.Tensor
    z: torch.Tensor

    def __new__(
        cls, 
        x: Optional[torch.Tensor] = torch.tensor([]),
        y: Optional[torch.Tensor] = torch.tensor([]),
        z: Optional[torch.Tensor] = torch.tensor([])
    ):
        return super().__new__(cls, x=x, y=y, z=z)


_model_input = namedtuple(typename="STD", field_names=["coords", "time"])
class ModelBatch(_model_input):
    coords: Coords
    time: Optional[torch.Tensor]

    def __new__(
        cls,
        coords: Coords = Coords(),
        time: Optional[torch.Tensor] = torch.tensor([])
    ):
        return super().__new__(cls, coords=coords, time=time)
    

_model_output = namedtuple(typename="SpatialTemporalDomainSolution", field_names=["model_batch", "solution"])
class ModelOutput(_model_output):
    model_batch: ModelBatch
    solution: torch.Tensor

    def __new__(
        cls,
        model_batch: ModelBatch = ModelBatch(),
        solution: torch.Tensor = torch.tensor([])
    ):
        return super().__new__(cls, model_batch=model_batch, solution=solution)

class Collator(ABC):
    def __init__(self):
        pass

    def __call__(self, batch: List[Dict]) -> ModelBatch:
        pass


class BaseCollator(Collator):
    def __call__(self, batch: List[Dict]) -> ModelBatch:
        coords = dict()

        _item = batch[0]

        for key in _item["coords"].keys():
            coords[key] = torch.stack([item["coords"][key] for item in batch], dim=0)
            coords[key].requires_grad_(True)


        if _item["time"].__len__() > 0:
            time = torch.stack([item["time"] for item in batch], dim=0)
            time.requires_grad_(True)
        else:
            time = torch.tensor([])

        return ModelBatch(
            coords=Coords(**coords), 
            time=time
        )

In [7]:
train_loader = torch.utils.data.DataLoader(
    dataset=dataset, 
    batch_size=BATCH_SIZE, 
    collate_fn=BaseCollator()
)

In [8]:
sample = next(iter(train_loader))

In [9]:
from src.models.components.encoders.main_sequence_encoder import MainEncoderLayer
from src.models.components.linear_blocks.linear_down_up_block import LinearDownUpBlock
from src.models.components.pde_nn import SimplePINN

In [10]:
from src.utils.partial_checkpoint import PartialCheckpoint

In [11]:
CKPT_PATH = "logs/train/runs/2024-08-07_00-22-42/checkpoints/epoch_009.ckpt"
EMB_DIM = 32

In [12]:
partial_enc = PartialCheckpoint(
    target_model=MainEncoderLayer(
        embedding_dim=EMB_DIM, 
        dropout_inputs=0.0, 
        num_coords=1
    ),
    ckpt_path=CKPT_PATH,
    ckpt_prefix="net._orig_mod.layers.0"
)
d(partial_enc)

PartialCheckpoint(
  (partial_model): MainEncoderLayer(
    (dropout): Dropout(p=0.0, inplace=False)
    (branched_linear_block_xyz): ModuleList(
      (0): Sequential(
        (0): Linear(in_features=1, out_features=32, bias=True)
        (1): Tanh()
      )
    )
    (linear_block_t): Sequential(
      (0): Linear(in_features=4, out_features=32, bias=True)
      (1): Tanh()
    )
    (out_linear_block): Sequential(
      (0): Linear(in_features=64, out_features=32, bias=True)
      (1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (res_bn): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
)

In [13]:
partial_linear = PartialCheckpoint(
    target_model=LinearDownUpBlock(
        in_features=EMB_DIM, 
        out_features=1, 
        activation_type="tanh", 
        reduce=False, 
        down=True, 
        num_layers=2, 
        use_batch_norm=True
    ),
    ckpt_path=CKPT_PATH,
    ckpt_prefix="net._orig_mod.layers.1"
)
d(partial_linear)

PartialCheckpoint(
  (partial_model): LinearDownUpBlock(
    (dropout): Dropout(p=0.0, inplace=False)
    (linear_block): Sequential(
      (0): Sequential(
        (0): Linear(in_features=32, out_features=32, bias=True)
        (1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): Tanh()
      )
      (1): Sequential(
        (0): Linear(in_features=32, out_features=32, bias=True)
        (1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): Tanh()
      )
    )
    (out_block): Linear(in_features=32, out_features=1, bias=True)
    (cls_layers): Sequential(
      (0): Dropout(p=0.0, inplace=False)
      (1): Sequential(
        (0): Sequential(
          (0): Linear(in_features=32, out_features=32, bias=True)
          (1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): Tanh()
        )
        (1): Sequential(
          (0): Linear(in_features=32, ou

In [14]:
class TracedSimplePINN(SimplePINN):
    def __init__(self, **kwargs):
        super(TracedSimplePINN, self).__init__(**kwargs)


    def forward(self, coords: Coords, time: torch.Tensor) -> torch.Tensor:

        inputs = ModelBatch(coords, time)

        return super().forward(inputs)

# Torch

In [15]:
partial_model = TracedSimplePINN(layers=[partial_enc, partial_linear], embedding_dim=32)
d(partial_model)

TracedSimplePINN(
  (layers): Sequential(
    (0): PartialCheckpoint(
      (partial_model): MainEncoderLayer(
        (dropout): Dropout(p=0.0, inplace=False)
        (branched_linear_block_xyz): ModuleList(
          (0): Sequential(
            (0): Linear(in_features=1, out_features=32, bias=True)
            (1): Tanh()
          )
        )
        (linear_block_t): Sequential(
          (0): Linear(in_features=4, out_features=32, bias=True)
          (1): Tanh()
        )
        (out_linear_block): Sequential(
          (0): Linear(in_features=64, out_features=32, bias=True)
          (1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (res_bn): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): PartialCheckpoint(
      (partial_model): LinearDownUpBlock(
        (dropout): Dropout(p=0.0, inplace=False)
        (linear_block): Sequential(
          (0): Sequential(
           

In [16]:
traced_partial_model = torch.jit.trace(partial_model, sample)
d(traced_partial_model)

TracedSimplePINN(
  original_name=TracedSimplePINN
  (layers): Sequential(
    original_name=Sequential
    (0): PartialCheckpoint(
      original_name=PartialCheckpoint
      (partial_model): MainEncoderLayer(
        original_name=MainEncoderLayer
        (dropout): Dropout(original_name=Dropout)
        (branched_linear_block_xyz): ModuleList(
          original_name=ModuleList
          (0): Sequential(
            original_name=Sequential
            (0): Linear(original_name=Linear)
            (1): Tanh(original_name=Tanh)
          )
        )
        (linear_block_t): Sequential(
          original_name=Sequential
          (0): Linear(original_name=Linear)
          (1): Tanh(original_name=Tanh)
        )
        (out_linear_block): Sequential(
          original_name=Sequential
          (0): Linear(original_name=Linear)
          (1): BatchNorm1d(original_name=BatchNorm1d)
        )
        (res_bn): BatchNorm1d(original_name=BatchNorm1d)
      )
    )
    (1): PartialCheck

In [17]:
compiled_traced_partial_model = torch.compile(traced_partial_model)
d(compiled_traced_partial_model)

OptimizedModule(
  (_orig_mod): TracedSimplePINN(
    original_name=TracedSimplePINN
    (layers): Sequential(
      original_name=Sequential
      (0): PartialCheckpoint(
        original_name=PartialCheckpoint
        (partial_model): MainEncoderLayer(
          original_name=MainEncoderLayer
          (dropout): Dropout(original_name=Dropout)
          (branched_linear_block_xyz): ModuleList(
            original_name=ModuleList
            (0): Sequential(
              original_name=Sequential
              (0): Linear(original_name=Linear)
              (1): Tanh(original_name=Tanh)
            )
          )
          (linear_block_t): Sequential(
            original_name=Sequential
            (0): Linear(original_name=Linear)
            (1): Tanh(original_name=Tanh)
          )
          (out_linear_block): Sequential(
            original_name=Sequential
            (0): Linear(original_name=Linear)
            (1): BatchNorm1d(original_name=BatchNorm1d)
          )
        

In [18]:
torch.jit.save(compiled_traced_partial_model, "./ckpts/burger_1d_unord_v2/burgers_1d_unord_v2_emb32_extreme_time.pt")

In [21]:
from time import time

In [24]:
_time = 0.0
n_iters = 20_000

for _ in range(n_iters):
    start_time = time()
    partial_model(*sample);
    # traced_partial_model(*sample);
    _time += time() - start_time

print(_time / n_iters * 1000)

1.4746709585189821


In [25]:
%timeit partial_model(*sample)
%timeit traced_partial_model(*sample)
%timeit compiled_traced_partial_model(*sample)

1.38 ms ± 21.2 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
824 µs ± 12.4 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
928 µs ± 18.9 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


# Lightning & ONNX

In [171]:
torch.onnx.export(
    model=partial_model,
    args=tuple(sample),
    f="./ckpts/burgers_1d_v1.onnx",
    input_names=["coords", "time"],
    output_names=["solution"],
    # dynamic_axes={
    #     "coords": {0: "batch_size", 1: "N"},
    #     "time": {0: "batch_size", 1: "N"}
    # }
)

RuntimeError: Dynamic shape axis should be no more than the shape dimension for N

In [152]:
import onnx
import onnxruntime as ort
import numpy as np

In [153]:
model = onnx.load("./ckpts/burgers_1d_v1.onnx")
d(model)

ir_version: 8
opset_import {
  version: 17
}
producer_name: "pytorch"
producer_version: "2.2.2"
graph {
  node {
    input: "model_batch"
    input: "onnx::MatMul_53"
    output: "/layers/layers.0/partial_model/branched_linear_block_xyz.0/branched_linear_block_xyz.0.0/MatMul_output_0"
    name: "/layers/layers.0/partial_model/branched_linear_block_xyz.0/branched_linear_block_xyz.0.0/MatMul"
    op_type: "MatMul"
  }
  node {
    input: "layers.0.partial_model.branched_linear_block_xyz.0.0.bias"
    input: "/layers/layers.0/partial_model/branched_linear_block_xyz.0/branched_linear_block_xyz.0.0/MatMul_output_0"
    output: "/layers/layers.0/partial_model/branched_linear_block_xyz.0/branched_linear_block_xyz.0.0/Add_output_0"
    name: "/layers/layers.0/partial_model/branched_linear_block_xyz.0/branched_linear_block_xyz.0.0/Add"
    op_type: "Add"
  }
  node {
    input: "/layers/layers.0/partial_model/branched_linear_block_xyz.0/branched_linear_block_xyz.0.0/Add_output_0"
    output: "/

In [154]:
print(onnx.helper.printable_graph(model.graph))

graph main_graph (
  %model_batch[FLOAT, 1x10000x1]
  %onnx::MatMul_3[FLOAT, 1x10000x1]
) initializers (
  %layers.0.partial_model.branched_linear_block_xyz.0.0.bias[FLOAT, 16]
  %layers.0.partial_model.linear_block_t.0.bias[FLOAT, 16]
  %layers.0.partial_model.out_linear_block.bias[FLOAT, 16]
  %layers.1.partial_model.linear_block.0.0.bias[FLOAT, 16]
  %layers.1.partial_model.linear_block.1.0.bias[FLOAT, 16]
  %layers.1.partial_model.linear_block.2.0.bias[FLOAT, 16]
  %layers.1.partial_model.linear_block.3.0.bias[FLOAT, 16]
  %layers.1.partial_model.out_block.bias[FLOAT, 1]
  %onnx::MatMul_53[FLOAT, 1x16]
  %onnx::MatMul_54[FLOAT, 1x16]
  %onnx::MatMul_55[FLOAT, 32x16]
  %onnx::MatMul_56[FLOAT, 16x16]
  %onnx::MatMul_57[FLOAT, 16x16]
  %onnx::MatMul_58[FLOAT, 16x16]
  %onnx::MatMul_59[FLOAT, 16x16]
  %onnx::MatMul_60[FLOAT, 16x1]
) {
  %/layers/layers.0/partial_model/branched_linear_block_xyz.0/branched_linear_block_xyz.0.0/MatMul_output_0 = MatMul(%model_batch, %onnx::MatMul_53)
  %/

In [155]:
ort_session = ort.InferenceSession("./ckpts/burgers_1d_v1.onnx")

In [156]:
d(ort_session.get_inputs()[0].name)

'model_batch'

In [157]:
def to_numpy(tensor: torch.Tensor) -> np.ndarray:
    return tensor.detach().cpu().numpy()

In [159]:
outputs = ort_session.run(
    output_names=None, 
    input_feed={
        ort_session.get_inputs()[0].name: Coords(*[to_numpy(coord) for coord in sample.coords]), 
        ort_session.get_inputs()[1].name: to_numpy(sample.time)
    }
)

RuntimeError: Iterable of <class 'numpy.ndarray'> should be given as array for input 'model_batch'.

In [74]:
%timeit model(*sample)

TypeError: 'ModelProto' object is not callable