In [1]:
import torch
import torch.nn.functional as F
import numpy as np
import pygame
import time

pygame 2.6.1 (SDL 2.32.54, Python 3.12.9)
Hello from the pygame community. https://www.pygame.org/contribute.html


  from pkg_resources import resource_stream, resource_exists


In [None]:
def normalize(x, mask=None, eps=1e-6):
    x_ = x if mask is None else x[mask]
    xMax = x_.max()
    xMin = x_.min()
    return (x - xMin).clamp(eps) / (xMax - xMin).clamp(eps)

class Fluid:
    def __init__(self, size, density=1000, gravity=-9.81, numIters=100, overRelaxation=1, boundaries=False, device=torch.device('cuda')):
        self.n = size+2 if boundaries else size
        self.numCells = size**2
        self.density = density
        self.gravity = gravity
        self.numIters = numIters
        self.overRelaxation = overRelaxation
        self.boundaries = boundaries
        self.device = device

        self.v = torch.zeros(self.n, self.n).to(device)
        self.u = torch.zeros(self.n, self.n).to(device)
        self.m = torch.ones(self.n, self.n).to(device)

        self.p = torch.zeros(self.n, self.n).to(device)
        w = 2
        self.s = F.pad(torch.ones(self.n-2*w, self.n-2*w).bool(), (w,w,w,w), value=False).to(device)

        centre = self.n / 2
        x = (torch.arange(self.n) - centre).pow(2)
        circle = normalize(torch.sqrt(x + x[:,None])).to(device)
        self.s &= circle > 0.25
        self.s[[0,1,-2,-1], round(centre)-10:round(centre)+10] = True
        self.s[round(centre)-10:round(centre)+10,[0,1,-2,-1]] = True


        self.x = torch.arange(self.n)[:,None].expand(-1,self.n).to(device)
        self.y = torch.arange(self.n)[None,:].expand(self.n,-1).to(device)
        self.iteration = 0


    def render(self):
        def hsv_to_rgb_torch(h, s, v):
            # h, s, v: shape (n,)
            i = torch.floor(h * 6)
            f = h * 6 - i
            i = i.long() % 6

            p = v * (1 - s)
            q = v * (1 - f * s)
            t = v * (1 - (1 - f) * s)

            r = torch.zeros_like(h)
            g = torch.zeros_like(h)
            b = torch.zeros_like(h)

            idx = i == 0
            r[idx], g[idx], b[idx] = v[idx], t[idx], p[idx]

            idx = i == 1
            r[idx], g[idx], b[idx] = q[idx], v[idx], p[idx]

            idx = i == 2
            r[idx], g[idx], b[idx] = p[idx], v[idx], t[idx]

            idx = i == 3
            r[idx], g[idx], b[idx] = p[idx], q[idx], v[idx]

            idx = i == 4
            r[idx], g[idx], b[idx] = t[idx], p[idx], v[idx]

            idx = i == 5
            r[idx], g[idx], b[idx] = v[idx], p[idx], q[idx]

            return torch.stack([r, g, b], dim=-1)
        
        angles = torch.atan2(self.u, self.v)
        hues = (angles + torch.pi) / (2 * torch.pi)

        k = 0.75
        s = (1-k) + k * normalize(self.m, self.s)
        v = (1-k) + k * normalize(self.p, self.s)

        rgb = 255 * hsv_to_rgb_torch(hues, s, v)
        rgb = torch.where(self.s[:,:,None], rgb, 0)
        rgb = rgb.view(self.n, self.n, 3)
        if self.boundaries:
            rgb = rgb[1:-1, 1:-1]
        rgb = torch.rot90(rgb, k=1, dims=(0, 1)).flip(dims=[0])
        return rgb

    def integrate(self, dt):
        # self.s = gen_box(self.n, self.iteration).to(self.device)
        mask = self.s & self.s.roll(1, 1)
        if self.boundaries:
            mask = F.pad(mask[1:, 1:-1], (1,1,1,0), value=False)
        self.v[mask] -= self.gravity * dt
        self.iteration += 1

    def solveIncompressibility(self, dt):

        ss = torch.stack([self.s.roll((-i,-j),(0,1)).float() for i,j in [[-1,0],[1,0],[0,-1],[0,1]]], 0)
        s = ss.sum(0)
        mask = self.s & (s > 0)
        if self.boundaries:
            mask = F.pad(mask[1:-1,1:-1], (1,1,1,1), value=False)
        maskX = mask.roll(1,0)
        maskY = mask.roll(1,1)


        for iter in range(self.numIters):
            div = self.u.roll(-1,0) - self.u + self.v.roll(-1,1) - self.v
            p = (self.overRelaxation * -div / s.clamp(1))

            sx0, sx1, sy0, sy1 = ss * p
            self.u[mask] -= sx0[mask]
            self.v[mask] -= sy0[mask]
            self.u[maskX] += sx1.roll(1,0)[maskX]
            self.v[maskY] += sy1.roll(1,1)[maskY]
        self.p = p * self.density / dt

        # Extrapolate
        validX = mask & maskX
        validY = mask & maskY
        for iter in range(self.numIters):
            neighboursX = torch.stack([validX.roll((-i,-j),(0,1))for i,j in [[-1,0],[1,0],[0,-1],[0,1]]], 0).any(0)
            extrapolateX = ~validX & neighboursX
            updateX = extrapolateX.any()
            if updateX:
                uu = torch.stack([self.u.roll((-i,-j),(0,1)).float() for i,j in [[-1,0],[1,0],[0,-1],[0,1]]], -1)[extrapolateX]
                ss = torch.stack([validX.roll((-i,-j),(0,1)).float() for i,j in [[-1,0],[1,0],[0,-1],[0,1]]], -1)[extrapolateX]
                self.u[extrapolateX] = (uu * ss).sum(-1) / ss.sum(-1)
                validX |= extrapolateX

            neighboursY = torch.stack([validY.roll((-i,-j),(0,1)) for i,j in [[-1,0],[1,0],[0,-1],[0,1]]], 0).any(0)
            extrapolateY = ~validY & neighboursY
            updateY = extrapolateY.any()
            if updateY:
                vv = torch.stack([self.v.roll((-i,-j),(0,1)).float() for i,j in [[-1,0],[1,0],[0,-1],[0,1]]], -1)[extrapolateY]
                ss = torch.stack([validY.roll((-i,-j),(0,1)).float() for i,j in [[-1,0],[1,0],[0,-1],[0,1]]], -1)[extrapolateY]
                self.v[extrapolateY] = (vv * ss).sum(-1) / ss.sum(-1)
                validY |= extrapolateY
            
            if not (updateX | updateY):
                break

    def sampleField(self, dt, field):
        n = self.n
        u, v = self.u, self.v
        dx, dy = 0, 0
        match field:
            case 'u':
                f = self.u
                v = torch.stack([self.v.roll((-i,-j), (0,1)) for i, j in [[-1,0],[0,0],[-1,1],[0,1]]],0).mean(0)
                dy = 0.5
            case 'v':
                f = self.v
                u = torch.stack([self.u.roll((-i,-j), (0,1)) for i, j in [[0,-1],[0,0],[1,-1],[1,0]]],0).mean(0)
                dx = 0.5
            case 'm':
                f = self.m
                u = torch.stack([self.u.roll(-i, 0) for i in range(2)],0).mean(0)
                v = torch.stack([self.v.roll(-j, 1) for j in range(2)],0).mean(0)
                dx, dy = 0.5, 0.5

        x = (self.x+dx) - dt * u
        y = (self.y+dy) - dt * v


        if self.boundaries:
            x = x.clamp(1,n)
            y = y.clamp(1,n)

        x0 = (x-dx).floor().int()
        tx = (x-dx)-x0
        x1 = x0+1

        y0 = (y-dy).floor().int()
        ty = (y-dy)-y0
        y1 = y0+1

        sx, sy = 1-tx, 1-ty
        if self.boundaries:
            x0, x1 = x0.clamp(0, self.n-1), x1.clamp(0, self.n-1)
            y0, y1 = y0.clamp(0, self.n-1), y1.clamp(0, self.n-1)
        else:
            x0, x1 = x0 % self.n, x1 % self.n
            y0, y1 = y0 % self.n, y1 % self.n


        val = sx*sy * f[x0, y0] + \
        tx*sy * f[x1, y0] + \
        tx*ty * f[x1, y1] + \
        sx*ty * f[x0, y1]

        return val

    def advectVel(self, dt):
        maskU = self.s & self.s.roll(1, 0)
        maskV = self.s & self.s.roll(1, 1)
        if self.boundaries:
            maskU = F.pad(maskU[1:, 1:-1], (1,1,1,0), value=False)
            maskV = F.pad(maskV[1:-1, 1:], (1,0,1,1), value=False)
        self.u[maskU], self.v[maskV] = self.sampleField(dt, 'u')[maskU], self.sampleField(dt, 'v')[maskV]

    def advectSmoke(self, dt):
        mask = self.s
        if self.boundaries:
            mask = F.pad(mask[1:-1,1:-1], (1,1,1,1), value=False)
        self.m[mask] = self.sampleField(dt, 'm')[mask]


In [None]:
import pygame
import numpy as np
import time
import torch

def run_interactive_fluid_simulation(fluid, window_size=512, target_fps=60):
    pygame.init()
    
    # Create display
    screen = pygame.display.set_mode((window_size, window_size))
    pygame.display.set_caption('Interactive Fluid Simulation - Click and drag!')
    clock = pygame.time.Clock()
    font = pygame.font.Font(None, 36)
    
    running = True
    last_time = time.time()
    mouse_pressed = False
    last_mouse_pos = None
    
    # For adding forces
    force_strength = 5000.0
    
    while running:
        # Handle events
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                running = False
            if event.key == pygame.K_r:
                # Reset simulation
                fluid.u.zero_()
                fluid.v.zero_()
                fluid.p.zero_()
                fluid.m.fill_(1.0)
            elif event.type == pygame.MOUSEBUTTONDOWN:
                if event.button == 1:  # Left mouse button
                    mouse_pressed = True
                    last_mouse_pos = pygame.mouse.get_pos()
            elif event.type == pygame.MOUSEBUTTONUP:
                if event.button == 1:
                    mouse_pressed = False
                    last_mouse_pos = None
            elif event.type == pygame.MOUSEMOTION:
                if mouse_pressed and last_mouse_pos:
                    current_pos = pygame.mouse.get_pos()
                    
                    # Convert screen coordinates to fluid grid coordinates
                    grid_x = int((current_pos[0] / window_size) * fluid.n)
                    grid_y = int((current_pos[1] / window_size) * fluid.n)
                    last_grid_x = int((last_mouse_pos[0] / window_size) * fluid.n)
                    last_grid_y = int((last_mouse_pos[1] / window_size) * fluid.n)
                    
                    # Calculate mouse velocity
                    dx = grid_x - last_grid_x
                    dy = grid_y - last_grid_y
                    
                    # Add force to fluid
                    if 0 <= grid_x < fluid.n and 0 <= grid_y < fluid.n:
                        radius = 3  # Radius of influence
                        for i in range(max(0, grid_x - radius), min(fluid.n, grid_x + radius + 1)):
                            for j in range(max(0, grid_y - radius), min(fluid.n, grid_y + radius + 1)):
                                distance = ((i - grid_x)**2 + (j - grid_y)**2)**0.5
                                if distance <= radius:
                                    strength = (1 - distance / radius) * force_strength
                                    fluid.u[i, j] += dx * strength
                                    fluid.v[i, j] += dy * strength
                    
                    last_mouse_pos = current_pos
        
        # Calculate dt
        current_time = time.time()
        dt = current_time - last_time
        last_time = current_time
        
        # Limit dt to prevent instability
        dt = min(dt, 1.0/target_fps)
        
        # Fluid simulation step
        fluid.integrate(dt)
        fluid.solveIncompressibility(dt)
        fluid.advectVel(dt)
        fluid.advectSmoke(dt)
        

        # Render
        rgb = fluid.render()
        rgb_np = rgb.detach().cpu().numpy().astype(np.uint8)
        
        # Convert to pygame surface
        rgb_transposed = np.transpose(rgb_np, (1, 0, 2))
        surface = pygame.surfarray.make_surface(rgb_transposed)
        
        # Scale to window size
        scaled_surface = pygame.transform.scale(surface, (window_size, window_size))
        
        # Blit to screen
        screen.blit(scaled_surface, (0, 0))
        
        # Add instructions
        fps_text = font.render(f'FPS: {clock.get_fps():.1f}', True, (255, 255, 255))
        instruction_text = font.render('Click and drag to add forces, R to reset', True, (255, 255, 255))
        
        screen.blit(fps_text, (10, 10))
        screen.blit(instruction_text, (10, window_size - 30))
        
        pygame.display.flip()
        
        # Maintain target framerate
        clock.tick(target_fps)
    
    pygame.quit()

# Usage
n = 128
fluid = Fluid(n, numIters=25, boundaries=False)
run_interactive_fluid_simulation(fluid, window_size=512, target_fps=60)

  rgb_np = rgb.detach().cpu().numpy().astype(np.uint8)
