In [11]:
import torch
from torch import nn
import numpy as np
import tqdm
from PIL import Image

device = torch.device("mps")
res = (256, 256)

In [12]:
images = ["turing.png", "feng.png", ]
save_dir = "./outputs"
export_image = False
export_gaussians = True

In [13]:
class Rasterizer(nn.Module):
    def __init__(self, N: int = 128, res=(512, 512)) -> None:
        super().__init__()
        self.N, self.res = N, res
        self.gPosition = nn.Parameter(torch.normal(0.0, 0.15, (N, 2)))
        self.gRotation = nn.Parameter(torch.zeros(N))
        self.gScale = nn.Parameter(torch.normal(-4, 0.5, (N, )).repeat(2).reshape(2, -1).T)
        self.gColor = nn.Parameter(torch.zeros((N, 3)))
        self.gOpacity = nn.Parameter(torch.zeros(N))
        self.fBuffer = None 

    def forward(self):
        device = self.gPosition.device
        res_tensor = torch.tensor(self.res, device=device)
        
        fBuffer = torch.zeros((self.res[0], self.res[1], 3), device=device)
        fTransmissive = torch.ones(self.res, device=device)

        x = torch.arange(self.res[0], device=device)
        y = torch.arange(self.res[1], device=device)
        grid_x, grid_y = torch.meshgrid(x, y, indexing='ij')
        pixel_coords = torch.stack([grid_x, grid_y], dim=-1).float()
        normalized_coords = (pixel_coords / res_tensor - 0.5) * 2.0

        for i in range(self.N):
            position = torch.sigmoid(self.gPosition[i]) * 2.0 - 1.0
            rotation = self.gRotation[i]
            scale = torch.exp(self.gScale[i])
            color = torch.sigmoid(self.gColor[i])
            opacity = torch.sigmoid(self.gOpacity[i])

            c = torch.cos(rotation)
            s = torch.sin(rotation)
            rot_mat = torch.stack([torch.stack([c, -s]), torch.stack([s, c])])
            transform = torch.diag(1.0 / scale) @ rot_mat
            X = (normalized_coords - position) @ transform.T
            gaussian_val = torch.exp(-0.5 * (X**2).sum(dim=-1))
            
            alpha = gaussian_val * opacity
            contribution = alpha.unsqueeze(-1) * fTransmissive.unsqueeze(-1) * color
            fBuffer = fBuffer + contribution
            fTransmissive = fTransmissive * (1.0 - alpha)
            
        self.fBuffer = fBuffer
        return fBuffer

In [14]:
model = Rasterizer(N=256, res=res).to(device)
model.train()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-1)
epochs = 60
loss = nn.MSELoss()

In [15]:
iter = 0
for image_path in images:
    image = torch.tensor(
        np.array(Image.open(f"img/{image_path}").convert("RGB").resize(res)) / 255.0,
        dtype=torch.float32,
    )

    target = image.to(device)
    for i in tqdm.tqdm(range(epochs)):
        optimizer.zero_grad()
        output = model()
        L = loss(output, target)
        L.backward()
        optimizer.step()

        if export_image:
            output_image = model.fBuffer.detach().cpu().numpy()
            output_image = np.clip(output_image, 0, 1)
            output_image = (output_image * 255).astype(np.uint8)
            Image.fromarray(output_image).save(f"{save_dir}/{iter}.png")
        iter += 1
    
    if export_gaussians:
        with torch.no_grad():
            positions = model.gPosition.detach().cpu().numpy()
            rotations = model.gRotation.detach().cpu().numpy()
            scales = model.gScale.detach().cpu().numpy()
            colors = model.gColor.detach().cpu().numpy()
            opacities = model.gOpacity.detach().cpu().numpy()
        np.savez(
            f"./gs_npz/{image_path}.npz",
            positions=positions,
            rotations=rotations,
            scales=scales,
            colors=colors,
            opacities=opacities,
            res=res,
        )

100%|██████████| 60/60 [00:13<00:00,  4.45it/s]
100%|██████████| 60/60 [00:13<00:00,  4.35it/s]
