# Imports

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
import matplotlib.pyplot as plt
import matplotlib
from matplotlib.widgets import Slider  # Add this import
import numpy as np
from diffusion_models3d import Transformer, AbsorbingDiffusion, Block, CausalSelfAttention
from sampler_utils import retrieve_autoencoder_components_state_dicts, latent_ids_to_onehot3d, get_latent_loaders
from models3d import VQAutoEncoder, Generator
import matplotlib.pyplot as plt
from ipywidgets import interact, fixed
# from visualization_utils import MinecraftVisualizer
from data_utils import MinecraftVisualizer, get_minecraft_dataloaders, BlockConverter, MinecraftDataset
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from concurrent.futures import ThreadPoolExecutor
from PIL import Image
import torch.distributions as dists
from tqdm import tqdm
import gc



In [2]:
%matplotlib widget

# MC Visualizer

In [3]:
blocks_to_cols = {
            0: (0.5, 0.25, 0.0),    # light brown
            10: 'black', # bedrock
            29: "#006400", # cacutus
            38: "#B8860B",  # clay
            60: "brown",  # dirt
            92: "gold",  # gold ore
            93: "green",  # grass
            115: "brown",  # ladder...?
            119: (.02, .28, .16, 0.8),  # transparent forest green (RGBA) for leaves
            120: (.02, .28, .16, 0.8),  # leaves2
            194: "yellow",  # sand
            217: "gray",  # stone
            240: (0.0, 0.0, 1.0, 0.4),  # water
            227: (0.0, 1.0, 0.0, .3), # tall grass
            237: (0.33, 0.7, 0.33, 0.3), # vine
            40: "#2F4F4F",  # coal ore
            62: "#228B22",  # double plant
            108: "#BEBEBE",  # iron ore
            131: "saddlebrown",  # log1
            132: "saddlebrown",  #log2
            95: "lightgray",  # gravel
            243: "wheat",  # wheat. lmao
            197: "limegreen",  # sapling
            166: "orange",  #pumpkin
            167: "#FF8C00",  # pumpkin stem
            184: "#FFA07A",  # red flower
            195: "tan",  # sandstone
            250: "white",  #wool 
            251: "gold",   #yellow flower
        }


def draw_latent_cuboid(fig, latent_coords, size=4):
    """
    Draw a transparent cuboid around the specified latent coordinates.
    
    Args:
        fig: matplotlib figure to draw on
        latent_coords: list of tuples, each containing (d,h,w) coordinates
        size: size of each latent cell in final space (default 4 for 6->24 upscaling)
    """
    def cuboid_data(o, sizes):
        l, w, h = sizes
        x = [[o[0], o[0] + l, o[0] + l, o[0], o[0]],  
             [o[0], o[0] + l, o[0] + l, o[0], o[0]],  
             [o[0], o[0] + l, o[0] + l, o[0], o[0]],  
             [o[0], o[0] + l, o[0] + l, o[0], o[0]]]  
        y = [[o[1], o[1], o[1] + w, o[1] + w, o[1]],  
             [o[1], o[1], o[1] + w, o[1] + w, o[1]],  
             [o[1], o[1], o[1], o[1], o[1]],          
             [o[1] + w, o[1] + w, o[1] + w, o[1] + w, o[1] + w]]
        z = [[o[2], o[2], o[2], o[2], o[2]],          
             [o[2] + h, o[2] + h, o[2] + h, o[2] + h, o[2] + h],
             [o[2], o[2], o[2] + h, o[2] + h, o[2]],  
             [o[2], o[2], o[2] + h, o[2] + h, o[2]]]  
        return np.array(x), np.array(y), np.array(z)

    ax = fig.gca()
    
    # Convert coordinates to numpy array for easier manipulation
    coords = np.array(latent_coords)
    
    # Find min and max for each dimension
    d_min, h_min, w_min = coords.min(axis=0)
    d_max, h_max, w_max = coords.max(axis=0)
    
    # Calculate origin and sizes
    origin = np.array([abs(5 - d_max)*size, w_min*size, h_min*size])
    sizes = (
        abs(d_max - d_min + 1) * size,  # length
        (w_max - w_min + 1) * size,     # width
        (h_max - h_min + 1) * size      # height
    )
    
    # Create and draw single cuboid
    X, Y, Z = cuboid_data(origin, sizes)
    ax.plot_surface(X, Y, Z, color='red', alpha=0.1)
    
    # Plot edges
    for i in range(4):
        ax.plot(X[i], Y[i], Z[i], color='red', linewidth=1)
    for i in range(4):
        ax.plot([X[0][i], X[1][i]], [Y[0][i], Y[1][i]], [Z[0][i], Z[1][i]], 
               color='red', linewidth=2)
    
    return fig

def visualize_chunk(voxels, figsize=(10, 10), elev=20, azim=45, highlight_latents=None):
    """
    Optimized version of the 3D visualization of a Minecraft chunk.
    """
    import matplotlib.pyplot as plt
    import numpy as np

    # Convert one-hot to block IDs if needed
    if isinstance(voxels, torch.Tensor):
        if voxels.dim() == 4:  # One-hot encoded [C,H,W,D]
            voxels = voxels.detach().cpu()
            voxels = torch.argmax(voxels, dim=0).numpy()
        else:
            voxels = voxels.detach().cpu().numpy()

    # Apply the same transformations as original
    voxels = voxels.transpose(2, 0, 1) # Moves axes from [D,H,W] to [W,D,H]
    voxels = np.rot90(voxels, 1, (0, 1))  # Rotate 90 degrees around height axis
    # print([block_id for block_id in np.unique(voxels) if block_id not in blocks_to_cols])
    # Create figure and 3D axis
    fig = plt.figure(figsize=figsize)
    ax = fig.add_subplot(111, projection='3d')
    # depth, height, width = voxels.shape
    # # Set the aspect ratio to match the data dimensions
    # ax.set_box_aspect((depth, height, width))
    # Generate a single boolean mask for each block type
    block_masks = {block_id: (voxels == block_id) for block_id in np.unique(voxels) if block_id in blocks_to_cols}
    
    # Plot all block types with their respective colors
    for block_id, mask in block_masks.items():
        ax.voxels(mask, facecolors=blocks_to_cols[int(block_id)])
    
    # Plot remaining blocks in red with black edges
    other_vox = (voxels != 5) & (voxels != -1) & (~np.any(np.stack(list(block_masks.values())), axis=0))
    ax.voxels(other_vox, edgecolor="k", facecolor="red")
    
    # Set default view angle
    ax.view_init(elev=elev, azim=azim)

    if highlight_latents is not None:
        fig = draw_latent_cuboid(fig, highlight_latents)
    
    return fig



## Block Converter: Converts between block IDs and indices, needed for visualization

In [4]:
block_converter = BlockConverter.load_mappings('block_mappings.pt')

# Minecraft Chunks Dataset

In [5]:
from torch.utils.data import DataLoader, random_split, Dataset

class MinecraftDataset(Dataset):
    def __init__(self, data_path, converter):
        # Load data and convert to int16 to save memory
        self.chunks = torch.from_numpy(np.load(data_path)).to(torch.int16)
        
        # Load pre-saved mappings
        self.converter = converter
        self.num_block_types = len(self.converter.block_to_index)
        
        # Convert blocks to indices once at initialization
        for old_block, new_idx in self.converter.block_to_index.items():
            self.chunks[self.chunks == old_block] = new_idx
            
        # Store air block index
        self.air_idx = self.converter.block_to_index[5]
        self.target_size = 24

        # Pad if needed
        pad_size = self.target_size - self.chunks.size(-1)
        if pad_size > 0:
            self.chunks = F.pad(self.chunks, 
                              (0, pad_size, 0, pad_size, 0, pad_size), 
                              value=self.air_idx)

        # Convert to one-hot [N, C, H, W, D]
        self.processed_chunks = F.one_hot(
            self.chunks.long(), 
            num_classes=self.num_block_types
        ).permute(0, 4, 1, 2, 3).float()
        
        # Free up memory by deleting original chunks
        del self.chunks
        
        print(f"Loaded {len(self.processed_chunks)} chunks of size {self.processed_chunks.shape[1:]}")
        print(f"Number of unique block types: {self.num_block_types}")

    def __getitem__(self, idx):
        return self.processed_chunks[idx]

    def convert_to_original_blocks(self, data):
        """Convert from indices back to original block IDs"""
        return self.converter.convert_to_original_blocks(data)

    def __len__(self):
        return len(self.processed_chunks)



def get_minecraft_dataloader(data_path, converter, batch_size=32, num_workers=0):
    """
    Creates a single dataloader for exploring the entire Minecraft chunks dataset.
    """
    # Create dataset
    dataset = MinecraftDataset(data_path, converter)
    
    # Create dataloader with memory pinning
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=False,  # No need to shuffle for exploration
        num_workers=num_workers,
        pin_memory=True,
    )


    
    print(f"\nDataloader details:")
    print(f"Total samples: {len(dataset)}")
    print(f"Batch size: {batch_size}")
    print(f"Number of batches: {len(dataloader)}")
    
    return dataloader

# Load VQGAN Model

In [6]:
import os
from log_utils import log, load_stats, load_model
import copy
from hyperparams import HparamsVQGAN

# Loads hparams from hparams.json file in saved model directory
def load_hparams_from_json(log_dir):
    import json
    import os
    json_path = os.path.join(log_dir, 'hparams.json')
    
    if not os.path.exists(json_path):
        raise FileNotFoundError(f"No hparams.json file found in {log_dir}")
    
    with open(json_path, 'r') as f:
        hparams = json.load(f)

    return hparams

# turns loaded hparams json into propery hyperparams object
def dict_to_vcqgan_hparams(hparams_dict, dataset=None):
    # Determine which hyperparameter class to use based on the dataset
    if dataset == None:
        dataset = hparams_dict.get('dataset', 'MNIST')  # Default to MNIST if not specified
    
    vq_hyper = HparamsVQGAN(dataset)
    # Set attributes from the dictionary
    for key, value in hparams_dict.items():
        setattr(vq_hyper, key, value)
    
    return vq_hyper


def load_vqgan_from_checkpoint(H, vqgan):
    vqgan = load_model(vqgan, "vqgan", H.load_step, H.load_dir).cuda()
    vqgan.eval()
    return vqgan


def encode_and_quantize(vqgan, terrain_chunks):
    vqgan.eval()
    with torch.no_grad():
        encoded = vqgan.ae.encoder(terrain_chunks)
        quantized, _, quant_stats = vqgan.ae.quantize(encoded)
        print(f'zq shape: {quantized.size()}')
        latent_indices = quant_stats["min_encoding_indices"]
        print(f'latent_indices size: {latent_indices.size()}')
        latent_indices = latent_indices.view((encoded.size()[0], encoded.size()[2], encoded.size()[3]))
        print(f'latent_indices viewed size: {latent_indices.size()}')

    return quantized, latent_indices

In [7]:
model_path = 'saved_models/minecraft39ch_ce_3'

In [8]:
vqgan_hparams =  dict_to_vcqgan_hparams(load_hparams_from_json(f"{model_path}"), 'maps')
vqgan = VQAutoEncoder(vqgan_hparams)
vqgan = load_vqgan_from_checkpoint(vqgan_hparams, vqgan)
print(f'loaded from: {vqgan_hparams.log_dir}')

resolution: 24, num_resolutions: 3, num_res_blocks: 2, attn_resolutions: [6], in_channels: 256, out_channels: 39, block_in_ch: 256, curr_res: 6
Loading vqgan_95000.th
loaded from: minecraft39ch_ce_3


In [9]:
# This takes a while

train_loader= get_minecraft_dataloader(
        '../datasets/minecraft_chunks.npy',
        block_converter,
        batch_size=vqgan_hparams.batch_size,
        num_workers=0,
    )


Loaded 10205 chunks of size torch.Size([39, 24, 24, 24])
Number of unique block types: 39

Dataloader details:
Total samples: 10205
Batch size: 8
Number of batches: 1276


In [18]:
from sampler_utils import generate_latents_from_loader3d
latents = generate_latents_from_loader3d(vqgan_hparams, vqgan, train_loader)


100%|██████████| 1276/1276 [00:29<00:00, 42.74it/s]


In [19]:
latents.shape

torch.Size([10205, 216])