In [None]:
# setup environment
try:
    # setup colab
    import google.colab
    !pip install kappamodules
    !pip install torch_geometric
    import torch
    print(torch.__version__)
    device = torch.device("cuda")
    torch.cuda.get_device_name(device)
    # might need to torch version to the one installed in colab
    !pip install torch_scatter torch_cluster -f https://data.pyg.org/whl/torch-2.3.0+cu121.html
    # checkout repo
    !git clone https://github.com/BenediktAlkin/upt-tutorial.git
    %cd upt-tutorial
except ImportError:
    # setup with own server
    import torch
    print(torch.__version__)
    device = torch.device("cuda:7")
    print(torch.cuda.get_device_name(device))

# 6 Transient Flow CFD

In [ ]:
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm

from upt.collators.simulation_collator import SimulationCollator

In [ ]:
from upt.datasets.simulation_dataset import SimulationDataset

# initialize dataset
train_dataset = SimulationDataset(
    root="./data/simulation",
    # how many inputs to use for training
    num_inputs=8192,
    # how many outputs to use for training
    num_outputs=4096,
    # mode
    mode="train",
)
rollout_dataset = SimulationDataset(
    root="./data/simulation",
    # use all inputs for rollout
    num_inputs=float("inf"),
    # use all outputs for rollout
    num_outputs=float("inf"),
    # mode
    mode="train",
)

In [ ]:
# hyperparameters
dim = 192  # ~6M parameter model
num_heads = 3
epochs = 10
batch_size = 256

In [ ]:
from upt.models.upt import UPT
from upt.models.approximator import Approximator
from upt.models.decoder_perceiver import DecoderPerceiver
from upt.models.encoder_supernodes import EncoderSupernodes
from upt.models.conditioner_timestep import ConditionerTimestep

# initialize model
model = UPT(
    conditioner=ConditionerTimestep(
        dim=dim,
        num_timesteps=train_dataset.num_timesteps,
    ),
    encoder=EncoderSupernodes(
        # simulation has 3 inputs: 2D velocity + pressure
        input_dim=3,
        # 2D dataset
        ndim=2,
        # positions are rescaled to [0, 200]
        radius=5,
        # in regions where there are a lot of mesh cells, it would result in supernodes having a lot of
        # connections to nodes. but since we sample the supernodes uniform, we also have a lot of supernodes
        # in dense regions, so we can simply limit the maximum amount of connections to each supernodes
        # to avoid an extreme amount of edges
        max_degree=32,
        # dimension for the supernode pooling -> use same as ViT-T latent dim
        gnn_dim=dim,
        # ViT-T latent dimension
        enc_dim=dim,
        enc_num_heads=num_heads,
        # ViT-T has 12 blocks -> parameters are split evenly among encoder/approximator/decoder
        enc_depth=4,
        # downsample to 128 latent tokens for fast training
        perc_dim=dim,
        perc_num_heads=num_heads,
        num_latent_tokens=128,
    ),
    approximator=Approximator(
        # tell the approximator the dimension of the input (perc_dim or enc_dim of encoder)
        input_dim=dim,
        # as in ViT-T
        dim=dim,
        num_heads=num_heads,
        # ViT-T has 12 blocks -> parameters are split evenly among encoder/approximator/decoder
        depth=4,
    ),
    decoder=DecoderPerceiver(
        # tell the decoder the dimension of the input (dim of approximator)
        input_dim=dim,
        # 2D velocity + pressure
        output_dim=3,
        # simulation is 2D
        ndim=2,
        # as in ViT-T
        dim=dim,
        num_heads=num_heads,
        # ViT-T has 12 blocks -> parameters are split evenly among encoder/approximator/decoder
        depth=4,
        # we assume num_outputs to be constant so we can simply reshape the dense result into a sparse tensor
        unbatch_mode="dense_to_sparse_unpadded",
    ),
)
model = model.to(device)
print(f"parameters: {sum(p.numel() for p in model.parameters()) / 1e6:.1f}M")

In [ ]:
# setup dataloader
train_dataloader = DataLoader(
    dataset=train_dataset,
    batch_size=batch_size,
    shuffle=True,
    drop_last=True,
    collate_fn=SimulationCollator(num_supernodes=512, deterministic=False),
)

In [ ]:

# initialize optimizer and learning rate schedule (linear warmup for first 10% -> linear decay)
optim = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.05)
total_updates = len(train_dataloader) * epochs
warmup_updates = int(total_updates * 0.1)
lrs = torch.concat(
    [
        # linear warmup
        torch.linspace(0, optim.defaults["lr"], warmup_updates),
        # linear decay
        torch.linspace(optim.defaults["lr"], 0, total_updates - warmup_updates),
    ],
)

In [ ]:
# train model
update = 0
pbar = tqdm(total=total_updates)
pbar.update(0)
pbar.set_description("train_loss: ?????")
train_losses = []
for _ in range(epochs):
    # train for an epoch
    for batch in train_dataloader:
        # schedule learning rate
        for param_group in optim.param_groups:
            param_group["lr"] = lrs[update]

        # forward pass
        y_hat = model(
            input_feat=batch["input_feat"].to(device),
            input_pos=batch["input_pos"].to(device),
            supernode_idxs=batch["supernode_idxs"].to(device),
            batch_idx=batch["batch_idx"].to(device),
            output_pos=batch["output_pos"].to(device),
        )
        y = batch["output_feat"].to(device)
        loss = F.mse_loss(y_hat, y)

        # backward pass
        loss.backward()

        # update step
        optim.step()
        optim.zero_grad()

        # status update
        update += 1
        pbar.update()
        pbar.set_description(f"train_loss: {loss.item():.4f}")
        train_losses.append(loss.item())
pbar.close()

Now that we have a trained model, we can try to generate a simulation. In order to simulate new simulations, one would need to train on a lot of different simulations, but as we only train on 1 simulation for this tutorial, we'll generate the rollout of this simulation.

In [None]:
# setup dataloader
rollout_dataloader = DataLoader(
    dataset=rollout_dataset,
    batch_size=1,
    collate_fn=SimulationCollator(num_supernodes=512, deterministic=True),
)

In [None]:
# get rollout batch
batch = next(iter(rollout_dataloader))

# rollout
y_hat = model.rollout(
    input_feat=batch["input_feat"].to(device),
    input_pos=batch["input_pos"].to(device),
    supernode_idxs=batch["supernode_idxs"].to(device),
    batch_idx=batch["batch_idx"].to(device),
    output_pos=batch["output_pos"].to(device),
)
# get ground truth
y = batch["output_feat"]

# extract plot data
output_pos = batch["output_pos"]
y_hat = [y_hat[i].cpu() for i in range(len(y_hat))]

# generate pngs
# for i in range(len(y_hat)):
#     plt.clf()
#     fig, (ax0, ax1, ax2) = plt.subplots(3, 1)
#     # particles
#     if output_pos[i].shape[0] == 1:
#         px = output_pos[i][0, :, 0]
#         pz = output_pos[i][0, :, 2]
#     else:
#         px = output_pos[i][:, 0]
#         pz = output_pos[i][:, 2]
#     c0 = output_gt[i].abs().mean(dim=1)
#     c1 = output_pred[i].abs().mean(dim=1)
#     c2 = (output_gt[i] - output_pred[i]).abs().mean(dim=1)
#     cmax = 5
#     # cmax = max(c0.max(), c1.max())
#     ax0.scatter(px, pz, c=c0, vmin=0, vmax=cmax, s=3)
#     scatter1 = ax1.scatter(px, pz, c=c1, vmin=0, vmax=cmax, s=3)
#     scatter2 = ax2.scatter(px, pz, c=c2, vmin=0, vmax=2, s=3)
#     # save
#     ax0.set_title("target")
#     ax1.set_title("pred")
#     ax2.set_title("delta")
#     ax0.set_xticks(ticks=[])
#     ax0.set_yticks(ticks=[])
#     ax1.set_xticks(ticks=[])
#     ax1.set_yticks(ticks=[])
#     ax2.set_xticks(ticks=[])
#     ax2.set_yticks(ticks=[])
#     plt.colorbar(scatter1, ax=[ax0, ax1], orientation="vertical")
#     plt.colorbar(scatter2, ax=ax2, orientation="vertical")
#     plt.savefig(self.out_png / f"{self.update_counter.cur_checkpoint}_{index.item():04d}_{i:04d}.png")
#     plt.close()
#     if i > 200:
#         break
#
# # generate gifs
# def load_pil(uri):
#     temp = Image.open(uri)
#     img = temp.copy()
#     temp.close()
#     return img
#
# imgs = [
#     load_pil(self.out_png / f"{self.update_counter.cur_checkpoint}_{index.item():04d}_{i:04d}.png")
#     for i in range(min(200, len(output_pred)))
# ]
# imgs[0].save(
#     fp=self.out_gif / f"{self.update_counter.cur_checkpoint}_{index.item():04d}.gif",
#     format="GIF",
#     append_images=imgs[1:],
#     save_all=True,
#     duration=100,
#     loop=0,
# )