In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import numpy as np
from torchvision import transforms
from PIL import Image
import math
import matplotlib.pyplot as plt
import tqdm

# Part 1

In [None]:
class SinusoidalPositionalEncoding(nn.Module):
    def __init__(self, L=10):
        super(SinusoidalPositionalEncoding, self).__init__()
        self.L = L

    def forward(self, x):
        pe = [x]
        for i in range(self.L):
            for fn in [torch.sin, torch.cos]:
                pe.append(fn(2.0 ** i * x))
        return torch.cat(pe, dim=-1)

class MLP(nn.Module):
    def __init__(self, in_dim, out_dim, hidden_dim=256, num_layers=3):
        super(MLP, self).__init__()
        layers = [nn.Linear(in_dim, hidden_dim), nn.ReLU()]
        for _ in range(num_layers - 2):
            layers.extend([nn.Linear(hidden_dim, hidden_dim), nn.ReLU()])
        layers.append(nn.Linear(hidden_dim, out_dim))
        layers.append(nn.Sigmoid())
        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        return self.layers(x)

class NeuralField(nn.Module):
    def __init__(self, in_dim, out_dim, hidden_dim=256, num_layers=3, L=10):
        super(NeuralField, self).__init__()
        self.pe = SinusoidalPositionalEncoding(L)
        self.mlp = MLP(in_dim * (2 * L + 1), out_dim, hidden_dim, num_layers)

    def forward(self, x):
        x = self.pe(x)
        return self.mlp(x)

class ImageDataset(Dataset):
    def __init__(self, image_path, N):
        self.image = Image.open(image_path).convert('RGB')
        self.transform = transforms.ToTensor()
        self.N = N
        self.w, self.h = self.image.size

        # Generate random coordinates
        x_coords = np.random.randint(0, self.w, N)
        y_coords = np.random.randint(0, self.h, N)
        self.coordinates = torch.tensor(np.stack((x_coords / self.w, y_coords / self.h), axis=1), dtype=torch.float32)

        # Transform the image to a tensor and extract colors
        image_tensor = self.transform(self.image)
        self.colors = image_tensor[:, y_coords, x_coords].permute(1, 0).to(torch.float32)

    def __len__(self):
        return self.N

    def __getitem__(self, idx):
        return self.coordinates, self.colors[idx]

    def get_all_data(self):
        return self.coordinates, self.colors

def psnr(mse):
    return 10 * math.log10(1 / mse)

def generate_image(model, image_path, device):
    image = Image.open(image_path).convert('RGB')
    transform = transforms.ToTensor()
    tensor = transform(image).to(device).float()
    h, w = tensor.shape[1], tensor.shape[2]
    coords = np.array([[x / w, y / h] for y in range(h) for x in range(w)])
    coords = torch.from_numpy(coords).to(device).float()
    colors = model(coords)
    colors = colors.view(h, w, 3).cpu().detach().numpy()
    return colors


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = NeuralField(2, 3).to(device)
images = []
optimizer = optim.Adam(model.parameters(), lr=2e-3)
criterion = nn.MSELoss()
losses = []

image_path = './panda.jpg'

image = generate_image(model, image_path, device)
images.append(image)

for iteration in range(1, 2001):
    dataset = ImageDataset(image_path, N=10000)
    model.train()
    running_loss = 0.0
    coords, colors = dataset.get_all_data()
    coords, colors = coords.to(device), colors.to(device)

    optimizer.zero_grad()
    outputs = model(coords)
    loss = criterion(outputs, colors)
    loss.backward()
    optimizer.step()
    running_loss += loss.item()

    loss = running_loss / len(coords)
    losses.append(loss)

    if (iteration)% 250 == 0:
        print(f'Iteration {iteration}/2000, Loss: {loss}')
        image = generate_image(model, image_path, device)
        images.append(image)

In [None]:
num_rows = 3
num_cols = 3
fig, axes = plt.subplots(num_rows, num_cols, figsize=(10, 10))  # Adjust the size as needed

for i, img in enumerate(images):
    row = i // num_cols
    col = i % num_cols
    axes[row, col].imshow(img)
    axes[row, col].axis('off')  # Turn off axis numbers and labels
    axes[row, col].set_title(f"Iteration {0+250*i}")

plt.tight_layout()
plt.show()


In [None]:
import matplotlib.pyplot as plt

# Assuming 'losses' is a list of loss values, one for each iteration
iterations = range(1, 2001)  # 3000 iterations, starting from 1
psnr_values = [psnr(mse) if mse != 0 else float('inf') for mse in losses]

plt.figure(figsize=(10, 6))  # Adjust the size as needed
plt.plot(iterations, psnr_values, label='Loss per Iteration')
plt.xlabel('Iteration')
plt.ylabel('PSNR')
plt.title('PSNR During Training')
plt.legend()
plt.show()


# Part 2.2: Sampling


In [None]:
import numpy as np
data = np.load(f"lego_200x200.npz")

# Training images: [100, 200, 200, 3]
images_train = data["images_train"] / 255.0

# Cameras for the training images
# (camera-to-world transformation matrix): [100, 4, 4]
c2ws_train = data["c2ws_train"]

# Validation images:
images_val = data["images_val"] / 255.0

# Cameras for the validation images: [10, 4, 4]
# (camera-to-world transformation matrix): [10, 200, 200, 3]
c2ws_val = data["c2ws_val"]

# Test cameras for novel-view video rendering:
# (camera-to-world transformation matrix): [60, 4, 4]
c2ws_test = data["c2ws_test"]

# Camera focal length
focal = data["focal"]  # float

In [None]:
def transform(c2w, X_c):
    '''
    Transforms points from camera coordinates (X_c) to world coordinates using PyTorch.
    X_c is a batch of 3-element vectors.
    '''
    # Add a column of ones to X_c to convert to homogeneous coordinates
    ones = torch.ones(X_c.shape[0], 1, device=X_c.device)
    X_c_homogeneous = torch.cat([X_c, ones], dim=1)

    # Transform using the c2w matrix
    X_w_homogeneous = torch.matmul(X_c_homogeneous, c2w.T)

    # Extract the first three elements (x, y, z coordinates) from the result
    X_w = X_w_homogeneous[:, :3]

    return X_w

def pixel_to_camera(K, uv, s):

    # Invert the intrinsic matrix
    K_inv = torch.inverse(K)

    # Add one to uv and mutiply by s
    ones = torch.ones(uv.shape[0], 1)
    uv_homog = torch.cat([uv, ones], dim=1) * s

    # Transform to camera coordinates
    x_c = torch.mm(uv_homog, K_inv.T)

    return x_c

def pixel_to_ray(K, c2w, uv, s=1):

    # Calculate ray direction
    x_c = pixel_to_camera(K, uv, s)
    # Transform camera coordinates to world coordinates
    x_w = transform(c2w, x_c)


    w2c = torch.inverse(c2w)
    R = w2c[:3, :3]
    t = w2c[:3, 3]

    ray_o = -torch.matmul(torch.inverse(R), t.unsqueeze(1)).squeeze(-1)
    ray_o = ray_o.expand_as(x_w)
    ray_d = torch.nn.functional.normalize(x_w - ray_o, p=2, dim=1)

    return ray_o, ray_d


In [None]:
class RaysData:
    def __init__(self, images, K, c2ws):
        """
        Initialize the RaysData object.
        """
        self.images = images
        self.K = K
        self.c2ws = c2ws
        self.H, self.W = images[0].shape[:2]

        self.uvs = self.generate_uvs()
        self.pixels = self.generate_pixels()


    def sample_rays(self, N, M):
        """
        Samples rays from the images.

        :param N: Total number of rays to sample.
        :param M: Number of images to sample from.
        :return: Ray origins, ray directions, and pixel colors.
        """
        num_images = self.images.shape[0]


        # Sample M images
        selected_indices = np.random.choice(num_images, M, replace=False)
        selected_images = self.images[selected_indices]
        selected_c2ws = self.c2ws[selected_indices]

        # Calculate number of rays per image
        rays_per_image = N // M

        # Initialize lists to store ray origins, directions, and colors
        ray_origins = []
        ray_directions = []
        pixel_colors = []

        for i in range(M):
            image = selected_images[i]
            c2w = selected_c2ws[i]
            c2w = torch.tensor(c2w, dtype=torch.float32)

            height, width, _ = image.shape

            uv = np.column_stack((
                np.random.randint(0, width, rays_per_image) + 0.5,
                np.random.randint(0, height, rays_per_image) + 0.5
            ))

            # Convert to tensor
            uv = torch.tensor(uv, dtype=torch.float32)

            # Get ray origins and directions
            r_o, r_d = pixel_to_ray(self.K, c2w, uv)

            # Store the results
            ray_origins.append(r_o)
            ray_directions.append(r_d)

            # Retrieve and store pixel colors
            uv_pixel = uv.long()

            colors = image[uv_pixel[:, 1], uv_pixel[:, 0]]
            colors = torch.tensor(colors, dtype=torch.float32)
            pixel_colors.append(colors)

        # Concatenate the results from all images
        ray_origins = torch.cat(ray_origins, dim=0)
        ray_directions = torch.cat(ray_directions, dim=0)
        pixel_colors = torch.cat(pixel_colors, dim=0)

        return ray_origins, ray_directions, pixel_colors

    def generate_uvs(self):
        # Generate UV coordinates for each image
        x = torch.arange(0, 200).tile(200)
        y = torch.arange(0, 200).repeat_interleave(200)
        combined_tensor = torch.stack((x, y), dim=1).repeat(100,1)
        return combined_tensor

    def generate_pixels(self):
        # Generate pixel values for each image
        images = torch.tensor(self.images, dtype=torch.float32)
        color_grid = images.reshape(-1, 3)  # Reshape to [batch*height*width, channels]
        return color_grid

In [None]:
# Return all r_o, r_d, colors, return_uv of a single image
# This is for generating the input of Val images
class TestData:
    def __init__(self, images, K, c2ws):
        self.images = images
        self.K = K
        self.c2ws = c2ws
        self.current_image_index = 0

    def __iter__(self):
        self.current_image_index = 0  
        return self

    def __next__(self):
        if self.current_image_index < len(self.images):
            # Get the image and camera-to-world matrix for the current index
            image = self.images[self.current_image_index]
            c2w = self.c2ws[self.current_image_index]

            # Move c2w to the appropriate device
            c2w = torch.tensor(c2w, dtype=torch.float32)

            # Compute width and height
            height, width, _ = image.shape

            # Generate a grid of (u, v) pixel coordinates for the entire image
            uv = np.mgrid[0:height, 0:width].reshape(2, -1).T
            return_uv = uv.copy()
            uv = uv + 0.5
            uv = torch.from_numpy(uv).float()

            # Get ray origins and directions
            r_o, r_d = pixel_to_ray(self.K, c2w, uv)

            
            colors = image[return_uv[:, 1], return_uv[:, 0]]  
            colors = torch.tensor(colors, dtype=torch.float32)

            # Increment the index for the next image
            self.current_image_index += 1

            return r_o, r_d, colors, return_uv
        else:
            raise StopIteration

In [None]:
def sample_along_rays(ray_o, ray_d, near=2.0, far=6.0, n_samples=33, t_width=0.05, perturb=True):
    # Uniformly sample t values
    t = torch.linspace(near, far, n_samples, device=ray_o.device)
    t = t.expand(ray_o.shape[0], n_samples).clone()

    # Introduce perturbations during training
    if perturb:
        t += torch.rand(t.shape, device=ray_o.device) * t_width
    # Calculate 3D coordinates for each sample along each ray
    points_3d = ray_o.unsqueeze(1) + ray_d.unsqueeze(1) * t.unsqueeze(-1)

    return points_3d[:,1:], (t[:, 1:] -t[:, :-1]).unsqueeze(-1)

## Define K

In [None]:
f_x = focal  # from the loaded data
f_y = focal  # assuming square pixels
image_width = 200  # example width
image_height = 200  # example height
o_x = image_width / 2
o_y = image_height / 2

# Intrinsic matrix K
K = np.array([[f_x, 0, o_x],
              [0, f_y, o_y],
              [0, 0, 1]], dtype=np.float32)


K = torch.from_numpy(K)

# --- You Need to Implement These ------
dataset = RaysData(images_train, K, c2ws_train)
rays_o, rays_d, pixels = dataset.sample_rays(200, 100)
points, t = sample_along_rays(rays_o, rays_d, perturb=True)

print(rays_o.shape, rays_d.shape, pixels.shape, points.shape, dataset.uvs.shape, dataset.pixels.shape)

# Part 2.3: Putting the Dataloading All Together

In [None]:
import viser, time  # pip install viser
import numpy as np

# --- You Need to Implement These ------
dataset = RaysData(images_train, K, c2ws_train)
rays_o, rays_d, pixels = dataset.sample_rays(200,2)
points, t = sample_along_rays(rays_o, rays_d, perturb=True)
rays_o = rays_o.numpy()
rays_d = rays_d.numpy()
points = points.numpy()
H, W = images_train.shape[1:3]
# ---------------------------------------

server = viser.ViserServer(share=True)
for i, (image, c2w) in enumerate(zip(images_train, c2ws_train)):
    server.add_camera_frustum(
        f"/cameras/{i}",
        fov=2 * np.arctan2(H / 2, K[0, 0]),
        aspect=W / H,
        scale=0.15,
        wxyz=viser.transforms.SO3.from_matrix(c2w[:3, :3]).wxyz,
        position=c2w[:3, 3],
        image=image
    )
for i, (o, d) in enumerate(zip(rays_o, rays_d)):
    server.add_spline_catmull_rom(
        f"/rays/{i}", positions=np.stack((o, o + d * 6.0)),
    )
server.add_point_cloud(
    f"/samples",
    colors=np.zeros_like(points).reshape(-1, 3),
    points=points.reshape(-1, 3),
    point_size=0.02,
)
time.sleep(1000)

# 2.4 Neural Radiance Field


In [None]:
import torch.nn.functional as F


def SinusoidalPositionalEncoding(x, L=4):
    pe = [x]
    for i in range(L):
        for fn in [torch.sin, torch.cos]:
            pe.append(fn(2.0 ** i * x))
    return torch.cat(pe, dim=-1)


# Define the neural network model as per the architecture provided in the image
class NeRFMLP(nn.Module):
    def __init__(self, L=4):
        super(NeRFMLP, self).__init__()

        pe_output_dim_x = 3 + L*6
        pe_output_dim_rd = 3 + L*6

        # Define the First part of MLP for the 'x' stream
        self.mlp_first = nn.Sequential(
            nn.Linear(pe_output_dim_x, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU()
        )
        # Define the Second part of MLP for the 'x' stream
        self.mlp_second = nn.Sequential(
            nn.Linear(256 + pe_output_dim_x, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 256)
        )


        # Define the fully connect layer for density estimation
        self.fc_density_1 = nn.Linear(256,1)

        # Define the fully connect layer for rgb
        self.fc_rgb_1 = nn.Linear(256,256)
        self.fc_rgb_2 = nn.Linear(256 + pe_output_dim_rd,128)
        self.fc_rgb_3 = nn.Linear(128,3)



    def forward(self, x, rd):
        # Apply positional encoding
        x_start = SinusoidalPositionalEncoding(x)
        rd = SinusoidalPositionalEncoding(rd).unsqueeze(1).expand(-1, 32, -1)

        # Pass through the first part of MLP for the 'x' stream
        x = self.mlp_first(x_start)
        # Inject the input (after PE) to the middle of MLP through concatenation
        x = torch.cat((x, x_start), dim=-1)
        # Pass through the second part of MLP for the 'x' stream
        share =  self.mlp_second(x)


        #### density layer ##############
        density = self.fc_density_1(share)
        density = F.relu(density)

        #### rgb layer ##############

        rgb = self.fc_rgb_1(share)
        rgb = torch.cat((rgb,rd), dim = -1)
        rgb = self.fc_rgb_2(rgb)
        rgb = F.relu(rgb)
        rgb = self.fc_rgb_3(rgb)
        rgb = torch.sigmoid(rgb)

        return rgb, density

# 2.5 Color Rendering

In [None]:
import torch

def volrend(sigmas, rgbs, step_size):
    """
    Compute the volume rendering equation for a batch of samples along a ray.
    """

    alpha = 1 - torch.exp(-sigmas* step_size)
    T_0 = torch.ones_like(alpha[:, 0, :]).unsqueeze(-1)
    T_rest = torch.cumprod((1 - alpha), dim = 1)[:, :-1, :]
    T = torch.concat((T_0,T_rest), dim=1)
    C_r = torch.cumsum(T*alpha*rgbs, dim = 1)[:,-1,:]

    return C_r


In [None]:
# Preparation for visualize image 0

f_x = focal  # from the loaded data
f_y = focal  # assuming square pixels
image_width = 200  # example width
image_height = 200  # example height
o_x = image_width / 2
o_y = image_height / 2

# Intrinsic matrix K
K = np.array([[f_x, 0, o_x],
              [0, f_y, o_y],
              [0, 0, 1]], dtype=np.float32)

K = torch.from_numpy(K)


start_pt = 0
end_pt = 40000
# --- You Need to Implement These ------
valset = RaysData(images_val, K, c2ws_val)
uvs = valset.uvs[:end_pt]
pixels_val = valset.pixels[:end_pt]
uvs = uvs + 0.5
c2w = torch.from_numpy(valset.c2ws[0]).float()
rays_o_val, rays_d_val = pixel_to_ray(valset.K, c2w, uvs)
points_val, t_val = sample_along_rays(rays_o_val, rays_d_val, perturb=True)

In [None]:
plt.imshow(valset.images[0])

# Training

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = NeRFMLP()
if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model)
model = model.to(device)

images = []
optimizer = optim.Adam(model.parameters(), lr=5e-4)
criterion = nn.MSELoss()
losses = []
val_losses = []
dataset = RaysData(images_train, K, c2ws_train)
H, W = images_train.shape[1:3]

# image = generate_image(model, image_path, device)
# images.append(image)

for iteration in range(1, 3001):
    # dataset = ImageDataset(image_path, N=10000)
    model.train()
    running_loss = 0.0
    rays_o, rays_d, pixels = dataset.sample_rays(10000, 100)
    points, t= sample_along_rays(rays_o, rays_d, perturb=True)
    rays_o, rays_d, pixels, points, t = rays_o.to(device), rays_d.to(device), pixels.to(device), points.to(device), t.to(device)

    optimizer.zero_grad()
    # predict and render color
    rgbs, density = model(points, rays_d)
    color_pred = volrend(density, rgbs, t)

    loss = criterion(color_pred, pixels)
    loss.backward()
    optimizer.step()

    # loss = running_loss / pixels.shape[0]
    losses.append(loss)

    if (iteration)% 100 == 0:
        with torch.no_grad():  # No gradient computation during evaluation
            model.eval()
            rays_d_val = rays_d_val.to(device)
            points_val = points_val.to(device)
            pixels_val = pixels_val.to(device)
            t_val = t_val.to(device)
            rgb, density = model(points_val, rays_d_val)
            color_pred = volrend(density, rgb, t_val)
            val_loss = criterion(color_pred, pixels_val).item()
            val_losses.append(val_loss)
            images.append(color_pred)
            print(f'Iteration {iteration}/1000, Training Loss: {loss}, Val Loss: {val_loss}')
            model.train()

In [None]:
import math
import matplotlib.pyplot as plt

def psnr(mse):
    if mse == 0:
        return float('inf')
    return 10 * math.log10(1 / mse)

iterations = range(1, 3001)  # 3000 iterations, starting from 1
validation_iterations = range(100, 3001, 100)  # Validation points every 100 iterations

psnr_values = [psnr(mse) for mse in losses]
psnr_values_val = [psnr(mse) for mse in val_losses]

plt.figure(figsize=(10, 6))
plt.plot(iterations, psnr_values, label='Training PSNR')
plt.plot(validation_iterations, psnr_values_val, label='Validation PSNR', marker='o')  # Adding a marker for clarity

plt.xlabel('Iteration')
plt.ylabel('PSNR')
plt.title('PSNR During Training')
plt.legend()
plt.grid(True)  # Optional: Add grid for better readability
plt.show()


In [None]:
import matplotlib.pyplot as plt

def plot_every_other_image(images, interval):
    # Filter to get every fourth image from the list
    selected_images = images[::interval]  # This selects every fourth image

    # Calculate the number of images to plot
    num_images = len(selected_images)

    # Create subplots
    fig, axes = plt.subplots(2, 3, figsize=(15, 7))  # Adjust the size as needed

    # Flatten the axes array for easy iteration
    axes = axes.flatten()

    # Iterate over selected images and flattened axes to plot each image
    for i, (ax, image) in enumerate(zip(axes, selected_images)):
        # Reshape and transpose the image (adjust dimensions as needed)
        ax.imshow(image.cpu().reshape(200, 200, 3))
        ax.axis('off')  # Turn off axis

        # Add a title for each image
        iteration_number = (i+1) * 5 * 100  # Adjust the multiplier to match the interval
        ax.set_title(f"Iteration {iteration_number}")

    plt.tight_layout()
    plt.show()

plot_every_other_image(images, 5)


In [None]:
# Create an instance of the TestData class
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

rays_data_val = TestData(images_val, K, c2ws_val)

model.eval()
model.to(device)
rendered_images = []
for (rays_o, rays_d, pixels, uv) in rays_data_val:
    
    with torch.no_grad():
        
        rays_o = rays_o.to(device)
        rays_d = rays_d.to(device)
        pixels = pixels.to(device)

        # Sample points along rays and move to GPU
        points_3d, t_values = sample_along_rays(rays_o, rays_d, perturb=False)
        points_3d = points_3d.to(device)
        t_values = t_values.to(device)

        # Forward pass through the NeRF model to get the radiance field
        rgbs, density = model(points_3d, rays_d)
        
        # Perform volume rendering
        rendered_colors = volrend(density, rgbs, t_values)
        rendered_images.append(rendered_colors)

In [None]:
# first 5 val predicted images
fig, axes = plt.subplots(1, 5, figsize=(20, 4))  

# Plot first five images
for i in range(5):
    ax = axes[i]
    ax.imshow(rendered_images[i+5].cpu().numpy().reshape(200,200,3).transpose(1,0,2))  
    ax.set_title(f'Predicted Val image {(i+6)}')
    ax.axis('off')  

plt.show()

# Generate GIF using test c2w

In [None]:
class GenerateData:
    def __init__(self, K, c2ws):
        self.K = K
        self.c2ws = c2ws
        self.current_pose_index = 0

    def __iter__(self):
        self.current_pose_index = 0
        return self

    def __next__(self):
        if self.current_pose_index < len(self.c2ws):
            c2w = self.c2ws[self.current_pose_index]
            c2w = torch.tensor(c2w, dtype=torch.float32)
            
            height, width = 200, 200  

            # Generate a grid of (u, v) pixel coordinates for the entire image
            uv = np.mgrid[0:height, 0:width].reshape(2, -1).T
            uv = uv + 0.5
            uv = torch.from_numpy(uv).float()

            r_o, r_d = pixel_to_ray(self.K, c2w, uv)

            self.current_pose_index += 1

            return r_o, r_d, uv.long()
        else:
            raise StopIteration

In [None]:
test_rendered_images = []
rays_data_test = GenerateData(K, c2ws_test)

model.eval()
rendered_images = []
for (rays_o, rays_d, uv) in rays_data_test:
    
    with torch.no_grad():
        
        rays_o = rays_o.to(device)
        rays_d = rays_d.to(device)

        points_3d, t_values = sample_along_rays(rays_o, rays_d, perturb=False)
        points_3d = points_3d.to(device)
        t_values = t_values.to(device)

        # Forward pass through the NeRF model to get the radiance field
        rgbs, density = model(points_3d, rays_d)
        
        # Perform volume rendering
        rendered_colors = volrend(density, rgbs, t_values)
        test_rendered_images.append(rendered_colors)

In [None]:
fig, axes = plt.subplots(1, 5, figsize=(20, 4))  

# Plot first four images
for i in range(5):
    ax = axes[i]
    ax.imshow(test_rendered_images[i].cpu().numpy().reshape(200,200,3).transpose(1,0,2))  
    ax.set_title(f'Predicted Test image {(i+1)}')
    ax.axis('off')  

plt.show()

In [None]:
# Save as a gif
import imageio
from PIL import Image

output_gif_path = 'Rendered_Images_Animation.gif'

# Create a list to hold the converted images
images_for_gif = []

for img in test_rendered_images:
    # Ensure the image is in uint8
    img = img.cpu().numpy().reshape(200,200,3).transpose(1,0,2)
    if img.dtype != np.uint8:
        img = (img * 255).astype(np.uint8)

    # resize to (400,400)
    img_pil = Image.fromarray(img)
    img_resized = img_pil.resize((400, 400))
    img_resized = np.array(img_resized)

    
    images_for_gif.append(img_resized )

# Save the frames as an animated GIF
imageio.mimsave(output_gif_path, images_for_gif, fps=10)
