In [1]:
import os
import numpy as np
import torch
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from maze_dataset.plotting import MazePlot
import random
from torch.nn import functional as F

In [2]:
os.chdir("/home/atul/diffusion-based-environment-generator/generator")
print(f"Current working directory: {os.getcwd()}")

Current working directory: /home/atul/diffusion-based-environment-generator/generator


In [3]:
from maze.grid_world_generator import generate_multiple_grid_worlds
from maze.solvers.a_star_l1 import main

generate_multiple_grid_worlds(100_000, 10)
main() # main function of a_star

./data/grid/


Generating Grid Worlds:   7%|▋         | 6524/100000 [00:01<00:20, 4670.84it/s]


KeyboardInterrupt: 

In [None]:
import os
import numpy as np
import torch
from torch.utils.data import Dataset

class MazeDataset(Dataset):
    def __init__(self, root):
        self.grid_dir = os.path.join(root, 'grid')
        self.path_length_dir = os.path.join(root, 'path_length')
        self.astar_dir = os.path.join(root, 'a_star_l1_results')

        self.indices = []

        for fname in os.listdir(self.grid_dir):
            if fname.endswith('.npy'):
                idx = fname.split('.')[0].split('_')[-1]
                path_file = f'path_length_{idx}.npy'
                astar_file = f'a_star_{idx}.npy'

                if os.path.exists(os.path.join(self.path_length_dir, path_file)) and \
                   os.path.exists(os.path.join(self.astar_dir, astar_file)):
                    self.indices.append(idx)

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, i):
        idx = self.indices[i]

        # Load grid (maze)
        grid_path = os.path.join(self.grid_dir, f'maze_{idx}.npy')
        maze = np.load(grid_path)
        maze = torch.tensor(maze, dtype=torch.float32)  # assuming [H, W, 3]

        # Load scalar values
        path_length = np.load(os.path.join(self.path_length_dir, f'path_length_{idx}.npy')).item()
        a_star = np.load(os.path.join(self.astar_dir, f'a_star_{idx}.npy')).item()

        # Compute difficulty 
        difficulty = a_star / (max(path_length, 1))
        difficulty = torch.tensor(difficulty, dtype=torch.float32)

        return maze, difficulty


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class DiffusionMLP(nn.Module):
    def __init__(self, input_dim=300, cond_dim=1, hidden_dim=256):
        super().__init__()
        # Condition embedding
        self.cond_net = nn.Sequential(
            nn.Linear(cond_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU()
        )

        # Diffusion MLP
        self.mlp = nn.Sequential(
            nn.Linear(input_dim + 32, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim)  # Output is noise of same shape as input
        )

    def forward(self, x, cond):
        B = x.shape[0]
        x = x.reshape(B, -1)  # Flatten to (B, 300)
        cond_embed = self.cond_net(cond.view(B, -1))  # Shape (B, 32)
        x = torch.cat([x, cond_embed], dim=-1)  # Shape (B, 332)
        return self.mlp(x).view(B, 3, 10, 10)  # Output shape same as input


In [None]:
def linear_beta_schedule(timesteps):
    beta_start = 1e-4
    beta_end = 0.02
    return torch.linspace(beta_start, beta_end, timesteps)


In [None]:
def diffusion_loss(model, x_0, difficulty, timesteps, betas):
    device = x_0.device
    t = torch.randint(0, timesteps, (x_0.shape[0],), device=device).long()
    
    noise = torch.randn_like(x_0)
    alpha = 1. - betas
    alpha_cumprod = torch.cumprod(alpha, dim=0)
    alpha_bar = alpha_cumprod[t].view(-1, 1, 1, 1)
    
    x_t = torch.sqrt(alpha_bar) * x_0 + torch.sqrt(1 - alpha_bar) * noise

    pred_noise = model(x_t, difficulty.unsqueeze(-1))  # (B, 1)
    return F.mse_loss(pred_noise, noise)


In [None]:
from torch.utils.data import DataLoader

def train(model, dataloader, epochs=100, lr=1e-4, timesteps=1000):
    model = model.to('cuda')
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    betas = linear_beta_schedule(timesteps).to('cuda')

    for epoch in range(epochs):
        for maze, difficulty in dataloader:
            maze = maze.permute(0, 3, 1, 2).to('cuda')  # [B, C, H, W]
            difficulty = difficulty.to('cuda')

            loss = diffusion_loss(model, maze, difficulty, timesteps, betas)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(f"Epoch {epoch+1}/{epochs} - Loss: {loss.item():.4f}")


In [None]:
maze_dataset = MazeDataset("/home/atul/diffusion-based-environment-generator/generator/data/")
diffusion_dataloader = torch.utils.data.DataLoader(maze_dataset, batch_size=64, shuffle=True)
model = DiffusionMLP(cond_dim=1)
train(model, diffusion_dataloader)


In [None]:
def convert_generated_maze(recon_batch):
    """
    Convert model output into final [B, 10, 10, 3] grid with:
    - Binary wall map (thresholded)
    - One-hot source and destination locations
    """
    B = recon_batch.shape[0]
    print(B)

    wall = (recon_batch[..., 0] >= 0.5).int()

    source_flat = recon_batch[..., 1].reshape(B, -1)
    print(source_flat.shape)
    source_indices = source_flat.argmax(dim=1)
    source_onehot = torch.zeros_like(source_flat, dtype=torch.int)
    source_onehot[torch.arange(B), source_indices] = 1
    source_onehot = source_onehot.view(B, 10, 10)

    dest_flat = recon_batch[..., 2].reshape(B, -1)
    dest_indices = dest_flat.argmax(dim=1)
    dest_onehot = torch.zeros_like(dest_flat, dtype=torch.int)
    dest_onehot[torch.arange(B), dest_indices] = 1
    dest_onehot = dest_onehot.view(B, 10, 10)

    result = torch.stack([wall, source_onehot, dest_onehot], dim=-1).int()  # [B, 10, 10, 3]
    return result


In [None]:
@torch.no_grad()
def sample(model, difficulty, shape=(3, 64, 64), timesteps=1000):
    model.eval()
    device = next(model.parameters()).device
    betas = linear_beta_schedule(timesteps).to(device)
    alpha = 1. - betas
    alpha_cumprod = torch.cumprod(alpha, dim=0)
    alpha_cumprod_prev = torch.cat([torch.tensor([1.], device=device), alpha_cumprod[:-1]])

    sqrt_recip_alpha = torch.sqrt(1.0 / alpha)
    sqrt_one_minus_alpha_bar = torch.sqrt(1 - alpha_cumprod)
    posterior_variance = betas * (1. - alpha_cumprod_prev) / (1. - alpha_cumprod)

    x = torch.randn((1, *shape), device=device)
    difficulty = torch.tensor([difficulty], device=device).unsqueeze(0)

    for t in reversed(range(timesteps)):
        t_tensor = torch.full((1,), t, device=device, dtype=torch.long)
        noise_pred = model(x, difficulty)

        if t > 0:
            noise = torch.randn_like(x)
        else:
            noise = torch.zeros_like(x)

        x = (
            sqrt_recip_alpha[t] * (x - betas[t] / sqrt_one_minus_alpha_bar[t] * noise_pred)
            + torch.sqrt(posterior_variance[t]) * noise
        )

    x = x.clamp(0, 1).permute(0, 2, 3, 1)
    print(x.shape)
    return convert_generated_maze(x).squeeze(0)


In [None]:
def visualize_grid_world(grid):
    """
    Converts a 3-channel grid world into an RGB image for visualization.
    - First channel: Wall (0 or 1)
    - Second channel: Source (1 if source)
    - Third channel: Destination (1 if destination)
    """
    # Extract channels

    # print(grid[:, :, 0])
    wall = grid[:, :, 0] < 0.5
    source = grid[:, :, 1] >= 0.5
    destination = grid[:, :, 2] >= 0.5
    
    # Create an RGB image with a white background (1, 1, 1)
    img = np.ones((*wall.shape, 3), dtype=np.float32)  # White background
    
    # Set walls to black (0, 0, 0)
    img[wall] = np.array([0, 0, 0])
    
    # Set destination to green (0, 1, 0)
    img[destination] = np.array([0, 1, 0])

    # Set source to blue (0, 0, 1)
    img[source] = np.array([0, 0, 1])
    
    return img

In [None]:
def visualize_sample(sample, difficulty):
    # Visualize using your function
    print(sample.shape)
    img = visualize_grid_world(sample)

    # Display
    plt.imshow(img)
    plt.title(f"Sampled Maze — Difficulty {difficulty:.2f}")
    plt.axis('off')
    plt.show()

In [None]:
difficulty = 1.2
generated = sample(model, shape=(3, 10, 10), difficulty=difficulty)  # (3, 10, 10)
visualize_sample(generated.cpu().numpy(), difficulty=difficulty)
