In [1]:
# To run on CPU:
#   CUDA_VISIBLE_DEVICES='' OMP_NUM_THREADS=1 python3 smoke_pytorch.py

import torch
import time
import math
import numpy as np
import os
import matplotlib
import matplotlib.pyplot as plt
import pdb
from imageio import imread, imwrite
from torch import nn

device = torch.device("cuda") if torch.cuda.is_available() else torch.device(
    "cpu")

num_iterations = 50
n_grid = 110
dx = 1.0 / n_grid
steps = 100
learning_rate = 100


def roll_col(t, n):
    return torch.cat((t[:, -n:], t[:, :-n]), axis=1)


def roll_row(t, n):
    return torch.cat((t[-n:, :], t[:-n, :]), axis=0)


def project(vx, vy):
    """Project the velocity field to be approximately mass-conserving,
     using a few iterations of Gauss-Seidel."""
    p = torch.zeros(vx.shape).to(device)
    h = 1.0 / vx.shape[0]
    div = -0.5 * h * (roll_row(vx, -1) - roll_row(vx, 1) + roll_col(vy, -1) -
                      roll_col(vy, 1))

    for k in range(6):
        p = (div + roll_row(p, 1) + roll_row(p, -1) + roll_col(p, 1) +
             roll_col(p, -1)) / 4.0

    vx -= 0.5 * (roll_row(p, -1) - roll_row(p, 1)) / h
    vy -= 0.5 * (roll_col(p, -1) - roll_col(p, 1)) / h
    return vx, vy


def advect(f, vx, vy):
    """Move field f according to x and y velocities (u and v)
     using an implicit Euler integrator."""
    rows, cols = f.shape
    cell_ys, cell_xs = torch.meshgrid(torch.arange(rows), torch.arange(cols))
    cell_ys = torch.transpose(cell_ys, 0, 1).float().to(device)
    cell_xs = torch.transpose(cell_xs, 0, 1).float().to(device)
    center_xs = (cell_xs - vx).flatten()
    center_ys = (cell_ys - vy).flatten()

    # Compute indices of source cells.
    left_ix = torch.floor(center_xs).long()
    top_ix = torch.floor(center_ys).long()
    rw = center_xs - left_ix.float()  # Relative weight of right-hand cells.
    bw = center_ys - top_ix.float()  # Relative weight of bottom cells.
    left_ix = torch.remainder(left_ix,
                              rows)  # Wrap around edges of simulation.
    right_ix = torch.remainder(left_ix + 1, rows)
    top_ix = torch.remainder(top_ix, cols)
    bot_ix = torch.remainder(top_ix + 1, cols)

    # A linearly-weighted sum of the 4 surrounding cells.
    flat_f = (1 - rw) * ((1 - bw)*f[left_ix,  top_ix] + bw*f[left_ix,  bot_ix]) \
             + rw * ((1 - bw)*f[right_ix, top_ix] + bw*f[right_ix, bot_ix])
    return torch.reshape(flat_f, (rows, cols))


def forward(iteration, smoke, vx, vy, output):
    for t in range(1, steps):
        vx_updated = advect(vx, vx, vy)
        vy_updated = advect(vy, vx, vy)
        vx, vy = project(vx_updated, vy_updated)
        smoke = advect(smoke, vx, vy)

        if output:
            matplotlib.image.imsave("output_pytorch/step{0:03d}.png".format(t),
                                    255 * smoke.cpu().detach().numpy())

    return smoke


def main():
    os.system("mkdir -p output_pytorch")
    print("Loading initial and target states...")
    initial_smoke_img = imread("init_smoke.png")[:, :, 0] / 255.0
    target_img = imread("peace.png")[::2, ::2, 3] / 255.0

    vx = torch.zeros(n_grid,
                     n_grid,
                     requires_grad=True,
                     device=device,
                     dtype=torch.float32)
    vy = torch.zeros(n_grid,
                     n_grid,
                     requires_grad=True,
                     device=device,
                     dtype=torch.float32)
    initial_smoke = torch.tensor(initial_smoke_img,
                                 device=device,
                                 dtype=torch.float32)
    target = torch.tensor(target_img, device=device, dtype=torch.float32)

    for opt in range(num_iterations):
        t = time.time()
        smoke = forward(opt, initial_smoke, vx, vy,
                        opt == (num_iterations - 1))
        loss = ((smoke - target)**2).mean()
        print('forward time', (time.time() - t) * 1000, 'ms')

        t = time.time()
        loss.backward()
        print('backward time', (time.time() - t) * 1000, 'ms')

        with torch.no_grad():
            vx -= learning_rate * vx.grad.data
            vy -= learning_rate * vy.grad.data
            vx.grad.data.zero_()
            vy.grad.data.zero_()

        print('Iter', opt, ' Loss =', loss.item())


main()

Loading initial and target states...


  initial_smoke_img = imread("init_smoke.png")[:, :, 0] / 255.0
  target_img = imread("peace.png")[::2, ::2, 3] / 255.0
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


forward time 560.6887340545654 ms
backward time 974.6057987213135 ms
Iter 0  Loss = 0.39777231216430664
forward time 492.98691749572754 ms
backward time 896.0657119750977 ms
Iter 1  Loss = 0.29252928495407104
forward time 443.5296058654785 ms
backward time 938.2364749908447 ms
Iter 2  Loss = 0.20421172678470612
forward time 420.3324317932129 ms
backward time 858.8697910308838 ms
Iter 3  Loss = 0.15880261361598969
forward time 439.55087661743164 ms
backward time 849.9243259429932 ms
Iter 4  Loss = 0.1365165412425995
forward time 431.171178817749 ms
backward time 805.9675693511963 ms
Iter 5  Loss = 0.12370306998491287
forward time 436.82026863098145 ms
backward time 870.7237243652344 ms
Iter 6  Loss = 0.11406853049993515
forward time 408.07032585144043 ms
backward time 835.9174728393555 ms
Iter 7  Loss = 0.11112909764051437
forward time 428.76625061035156 ms
backward time 836.9486331939697 ms
Iter 8  Loss = 0.09803799539804459
forward time 438.6017322540283 ms
backward time 867.666721343