# Imports

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
import matplotlib.pyplot as plt
import matplotlib
import numpy as np
from sampler_utils import retrieve_autoencoder_components_state_dicts, latent_ids_to_onehot3d, get_latent_loaders
from models3d import VQAutoEncoder, Generator
import matplotlib.pyplot as plt
from ipywidgets import interact, fixed
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from concurrent.futures import ThreadPoolExecutor
from PIL import Image
import torch.distributions as dists
from tqdm import tqdm
import gc
from models3d import BiomeClassifier

In [3]:
%matplotlib widget
#matplotlib.use('TkAgg')

# Block Converter

In [4]:
class BlockBiomeConverter:
    def __init__(self, block_mappings=None, biome_mappings=None):
        """
        Initialize with pre-computed mappings for both blocks and biomes
        
        Args:
            block_mappings: dict containing 'index_to_block' and 'block_to_index'
            biome_mappings: dict containing 'index_to_biome' and 'biome_to_index'
        """
        self.index_to_block = block_mappings['index_to_block'] if block_mappings else None
        self.block_to_index = block_mappings['block_to_index'] if block_mappings else None
        self.index_to_biome = biome_mappings['index_to_biome'] if biome_mappings else None
        self.biome_to_index = biome_mappings['biome_to_index'] if biome_mappings else None
    
    @classmethod
    def from_dataset(cls, data_path):
        """Create mappings from a dataset file"""
        data = np.load(data_path, allow_pickle=True)
        voxels = data['voxels']
        biomes = data['biomes']
        
        # Create block mappings (blocks are integers)
        unique_blocks = np.unique(voxels)
        block_to_index = {int(block): idx for idx, block in enumerate(unique_blocks)}
        index_to_block = {idx: int(block) for idx, block in enumerate(unique_blocks)}
        
        # Create biome mappings (biomes are strings)
        unique_biomes = np.unique(biomes)
        biome_to_index = {str(biome): idx for idx, biome in enumerate(unique_biomes)}
        index_to_biome = {idx: str(biome) for idx, biome in enumerate(unique_biomes)}
        
        block_mappings = {'index_to_block': index_to_block, 'block_to_index': block_to_index}
        biome_mappings = {'index_to_biome': index_to_biome, 'biome_to_index': biome_to_index}
        
        return cls(block_mappings, biome_mappings)
    
    @classmethod
    def from_arrays(cls, voxels, biomes):
        """Create mappings directly from numpy arrays"""
        # Create block mappings (blocks are integers)
        unique_blocks = np.unique(voxels)
        block_to_index = {int(block): idx for idx, block in enumerate(unique_blocks)}
        index_to_block = {idx: int(block) for idx, block in enumerate(unique_blocks)}
        
        # Create biome mappings (biomes are strings)
        unique_biomes = np.unique(biomes)
        biome_to_index = {str(biome): idx for idx, biome in enumerate(unique_biomes)}
        index_to_biome = {idx: str(biome) for idx, biome in enumerate(unique_biomes)}
        
        block_mappings = {'index_to_block': index_to_block, 'block_to_index': block_to_index}
        biome_mappings = {'index_to_biome': index_to_biome, 'biome_to_index': biome_to_index}
        
        return cls(block_mappings, biome_mappings)
    
    @classmethod
    def load_mappings(cls, path):
        """Load pre-saved mappings"""
        mappings = torch.load(path)
        return cls(mappings['block_mappings'], mappings['biome_mappings'])
    
    def save_mappings(self, path):
        """Save mappings for later use"""
        torch.save({
            'block_mappings': {
                'index_to_block': self.index_to_block,
                'block_to_index': self.block_to_index
            },
            'biome_mappings': {
                'index_to_biome': self.index_to_biome,
                'biome_to_index': self.biome_to_index
            }
        }, path)
    
    def convert_to_original_blocks(self, data):
        """
        Convert from indices back to original block IDs.
        Handles both one-hot encoded and already-indexed data.
        
        Args:
            data: torch.Tensor of either:
                - one-hot encoded blocks [B, C, H, W, D] or [C, H, W, D]
                - indexed blocks [B, H, W, D] or [H, W, D]
        Returns:
            torch.Tensor of original block IDs with shape [B, H, W, D] or [H, W, D]
        """
        # If one-hot encoded (dim == 5 or first dim == num_blocks), convert to indices first
        if len(data.shape) == 5 or (len(data.shape) == 4 and data.shape[0] == len(self.block_to_index)):
            data = torch.argmax(data, dim=1 if len(data.shape) == 5 else 0)
        
        # Now convert indices to original blocks
        if len(data.shape) == 4:  # Batch dimension present
            return torch.tensor([[[[self.index_to_block[int(b)] 
                                for b in row]
                                for row in layer]
                                for layer in slice_]
                                for slice_ in data])
        else:  # No batch dimension
            return torch.tensor([[[self.index_to_block[int(b)] 
                                for b in row]
                                for row in layer]
                                for layer in data])

    def convert_to_original_biomes(self, data):
        """
        Convert from indices back to original biome strings.
        Handles both one-hot encoded and already-indexed data.
        
        Args:
            data: torch.Tensor of either:
                - one-hot encoded biomes [B, C, H, W, D] or [C, H, W, D]
                - indexed biomes [B, H, W, D] or [H, W, D]
        Returns:
            numpy array of original biome strings with shape [B, H, W, D] or [H, W, D]
        """
        # If one-hot encoded (dim == 5 or first dim == num_biomes), convert to indices first
        if len(data.shape) == 5 or (len(data.shape) == 4 and data.shape[0] == len(self.biome_to_index)):
            data = torch.argmax(data, dim=1 if len(data.shape) == 5 else 0)
        
        # Now convert indices to original biomes
        if len(data.shape) == 4:  # Batch dimension present
            return np.array([[[[self.index_to_biome[int(b)] 
                            for b in row]
                            for row in layer]
                            for layer in slice_]
                            for slice_ in data])
        else:  # No batch dimension
            return np.array([[[self.index_to_biome[int(b)] 
                            for b in row]
                            for row in layer]
                            for layer in data])

# Minecraft Chunks Dataset

In [16]:
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np

class MinecraftDataset(Dataset):
    def __init__(self, data_path):
        data_path = Path(data_path)

        # Try to load processed data first
        # assert processed_data_path.exists() and mappings_path.exists()

        print("Loading pre-processed data...")
        processed_data = torch.load(data_path)
        # Only keep the chunks, discard biome data
        self.processed_chunks = processed_data['chunks']
        # self.processed_biomes = processed_data['biomes']
        # Delete the biomes to free memory
        # del processed_data['biomes']
        # del processed_data['chunks']
        del processed_data
        
        
        print(f"Loaded {len(self.processed_chunks)} chunks of size {self.processed_chunks.shape[1:]}")
        print(f"Number of unique block types: {self.processed_chunks.shape[1]}")
        print(f'Unique blocks: {torch.unique(torch.argmax(self.processed_chunks, dim=1)).tolist()}')
        #print(f"Loaded {len(self.processed_biomes)} chunks of size {self.processed_biomes.shape[1:]}")
        #print(f"Number of unique biome types: {self.processed_biomes.shape[1]}")
        #print(f'Unique biomes: {torch.unique(torch.argmax(self.processed_biomes, dim=1)).tolist()}')
        

    def __getitem__(self, idx):
        # return self.processed_chunks[idx]
        return self.processed_chunks[idx] #, self.processed_biomes[idx]

    def __len__(self):
        # return len(self.processed_chunks), len(self.processed_biomes)
        return len(self.processed_chunks)

def get_minecraft_dataloaders(data_path, batch_size=32, val_split=0.1, num_workers=4):
    """
    Creates training and validation dataloaders for Minecraft chunks.
    """
    # Create dataset
    dataset = MinecraftDataset(data_path)
    
    # Split into train and validation sets
    val_size = int(val_split * len(dataset))
    train_size = len(dataset) - val_size
    
    # Use a fixed seed for reproducibility
    generator = torch.Generator().manual_seed(42)
    
    train_dataset, val_dataset = torch.utils.data.random_split(
        dataset, 
        [train_size, val_size],
        generator=generator
    )
    
    # Create dataloaders with memory pinning
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True,
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True,
    )

    
    print(f"\nDataloader details:")
    print(f"Training samples: {len(train_dataset)}")
    print(f"Validation samples: {len(val_dataset)}")
    print(f"Batch size: {batch_size}")
    print(f"Training batches: {len(train_loader)}")
    print(f"Validation batches: {len(val_loader)}")
    
    return train_loader, val_loader


# Helper Functions
Loads hyperparameters and model from checkpoint, gives easy function to encode a structure into latent codes, as well as decode those latents back into the reconstruction

## VQGAN Helper Functions

In [15]:
import os
from log_utils import log, load_stats, load_model
import copy
from hyperparams import HparamsVQGAN

# Loads hparams from hparams.json file in saved model directory
def vq_load_hparams_from_json(log_dir):
    import json
    import os
    json_path = os.path.join(log_dir, 'hparams.json')
    
    if not os.path.exists(json_path):
        raise FileNotFoundError(f"No hparams.json file found in {log_dir}")
    
    with open(json_path, 'r') as f:
        hparams = json.load(f)

    return hparams

# turns loaded hparams json into propery hyperparams object
def vq_dict_to_vcqgan_hparams(hparams_dict, dataset=None):
    # Determine which hyperparameter class to use based on the dataset
    if dataset == None:
        dataset = hparams_dict.get('dataset', 'MNIST')  # Default to MNIST if not specified
    
    vq_hyper = HparamsVQGAN(dataset)
    # Set attributes from the dictionary
    for key, value in hparams_dict.items():
        setattr(vq_hyper, key, value)
    
    return vq_hyper

def load_model_vq(model, model_load_name, step, log_dir, strict=False):
    checkpoint_path = os.path.join("../model_logs", log_dir, "saved_models", f"{model_load_name}_{step}.th")
    print(f"\nAttempting to load checkpoint from: {checkpoint_path}")
    
    if not os.path.exists(checkpoint_path):
        raise FileNotFoundError(f"Checkpoint not found at: {checkpoint_path}")
        
    # Load checkpoint
    checkpoint = torch.load(checkpoint_path)
    
    # Create new state dict with modified keys
    new_state_dict = {}
    for key, value in checkpoint.items():
        if key.startswith('ae.'):
            # Remove the 'ae.' prefix
            new_key = key[3:]  # Skip 'ae.'
            new_state_dict[new_key] = value
    
    # Try loading state dict
    try:
        model.load_state_dict(new_state_dict, strict=strict)
        
        print("\nLoad successful!")
    except Exception as e:
        print(f"\nError loading state dict: {e}")
        
    return model
def vq_load_vqgan_from_checkpoint(H, vqgan):
    vqgan = load_model_vq(vqgan, "vqgan", H.load_step, H.load_dir).cuda()
    vqgan.eval()
    return vqgan


def vq_encode_and_quantize(vqgan, terrain_chunks, device='cuda'):
    """Memory-efficient encoding function for single codebook VQGAN"""
    vqgan.eval()
    with torch.no_grad():
        # Move input to device and ensure it's one-hot encoded
        terrain_chunks = terrain_chunks.to(device)
        
        # Get encodings through encoder
        latents = vqgan.encoder(terrain_chunks)
        
        # Get quantized indices from VQ layer
        _, _, quant_stats = vqgan.quantize(latents)
        indices = quant_stats["min_encoding_indices"]
        
        # Move to CPU and clear GPU memory
        indices = indices.cpu()
        torch.cuda.empty_cache()
        
        return indices

def vq_decode_from_indices(indices, vqgan, device='cuda'):
    """Reconstructs from latent indices"""
    with torch.no_grad():
        # Move indices to device
        indices = indices.to(device)
        
        # Get original shape for reshaping
        batch_size = indices.size(0)
        
        # Convert indices to quantized embeddings using the VQ layer's get_codebook_entry
        shape = (batch_size, 6, 6, 6, vqgan.embed_dim)  # Shape for 3D data
        quant = vqgan.quantize.get_codebook_entry(indices.reshape(-1), shape)
        
        # Generate output through decoder
        decoded = vqgan.generator(quant)
        
        # Convert to block IDs if one-hot encoded
        if decoded.shape[1] > 1:
            decoded = torch.argmax(decoded, dim=1)
        
        # Move result to CPU and clear GPU memory
        decoded = decoded.squeeze(0).cpu()
        torch.cuda.empty_cache()
        
        return decoded

## FQGAN helper functions

In [6]:
import os
from log_utils import log, load_stats, load_model
import copy
from fq_models import FQModel, HparamsFQGAN


# Loads hparams from hparams.json file in saved model directory
def load_hparams_from_json(log_dir):
    import json
    import os
    json_path = os.path.join(log_dir, 'hparams.json')
    
    if not os.path.exists(json_path):
        raise FileNotFoundError(f"No hparams.json file found in {log_dir}")
    
    with open(json_path, 'r') as f:
        hparams = json.load(f)

    return hparams

# turns loaded hparams json into propery hyperparams object
def dict_to_vcqgan_hparams(hparams_dict, dataset=None):
    # Determine which hyperparameter class to use based on the dataset
    if dataset == None:
        dataset = hparams_dict.get('dataset', 'MNIST')  # Default to MNIST if not specified
    
    vq_hyper = HparamsFQGAN(dataset)
    # Set attributes from the dictionary
    for key, value in hparams_dict.items():
        setattr(vq_hyper, key, value)
    
    return vq_hyper

# Loads fqgan model weights from a given checkpoint file
def load_fqgan_from_checkpoint(H, fqgan):
    fqgan = load_model(fqgan, "fqgan", H.load_step, H.load_dir).cuda()
    fqgan.eval()
    return fqgan

# Takes a chunk or batch of chunks from the dataset, returns the encoded style and structure indices matrices
def encode_and_quantize(fqgan, terrain_chunks, device='cuda'):
    """Memory-efficient encoding function"""
    fqgan.eval()
    with torch.no_grad():
        # Move input to device
        terrain_chunks = terrain_chunks.to(device)
        
        # Get encodings
        h_style, h_struct = fqgan.encoder(terrain_chunks)
        
        # Process style path
        h_style = fqgan.quant_conv_style(h_style)
        quant_style, _, style_stats = fqgan.quantize_style(h_style)
        style_indices = style_stats[2]  # Get indices from tuple
        style_indices = style_indices.view(
            (h_style.size()[0], h_style.size()[2], h_style.size()[3], h_style.size()[4])
        )
        
        # Clear intermediate tensors
        del h_style, quant_style, style_stats
        
        # Process structure path
        h_struct = fqgan.quant_conv_struct(h_struct)
        quant_struct, _, struct_stats = fqgan.quantize_struct(h_struct)
        struct_indices = struct_stats[2]  # Get indices from tuple
        struct_indices = struct_indices.view(
            (h_struct.size()[0], h_struct.size()[2], h_struct.size()[3], h_struct.size()[4])
        )
        
        # Clear intermediate tensors
        del h_struct, quant_struct, struct_stats
        
        # Move indices to CPU to save GPU memory
        style_indices = style_indices.cpu()
        struct_indices = struct_indices.cpu()
        
        torch.cuda.empty_cache()
        
        return style_indices, struct_indices

# Takes style and structure indices, returns the reconstructed map
def decode_from_indices(style_indices, struct_indices, fqgan, device='cuda', two_stage=False):
    """Memory-efficient decoding function"""
    with torch.no_grad():
        # Move indices to device only when needed
        style_indices = style_indices.to(device)
        struct_indices = struct_indices.to(device)
        
        # Get quantized vectors
        quant_style = fqgan.quantize_style.get_codebook_entry(
            style_indices.view(-1),
            shape=[1, fqgan.embed_dim, *style_indices.shape[1:]]
        )
        quant_struct = fqgan.quantize_struct.get_codebook_entry(
            struct_indices.view(-1),
            shape=[1, fqgan.embed_dim, *struct_indices.shape[1:]]
        )
        
        # Clear indices from GPU
        del style_indices, struct_indices
        
        # Combine and decode
        quant = torch.cat([quant_struct, quant_style], dim=1)
        # quant = quant_style + quant_struct
        del quant_style, quant_struct
        
        if two_stage:
            decoded, binary_decoded = fqgan.decoder(quant)
        else:
            decoded = fqgan.decoder(quant)
        
        del quant
        
        # Convert to block IDs if one-hot encoded
        if decoded.shape[1] > 1:
            decoded = torch.argmax(decoded, dim=1)
        
        # Move result to CPU and clear GPU memory
        result = decoded.squeeze(0).cpu()
        if two_stage:
            binary_result = binary_decoded.squeeze(0).cpu()
            del decoded
            del binary_decoded
            torch.cuda.empty_cache()
            return result, binary_result
        
        del decoded
        torch.cuda.empty_cache()
        
        return result

# Load FQ model

In [7]:
# Important if you reload the model, can run into memory issues. There's a memory leak somewhere I haven't been able to fix
import gc
gc.collect()
torch.cuda.empty_cache()

In [8]:
# replace with directory for whatever model you want
# model_path = './FQGAN_2stagedecoder_nobiomemodel_16bothcbook_EMA3'
model_path = "./FQGAN_2stagedecoder_nobiomemodel_16bothcbook_EMA3"

In [10]:
# I'm manually setting the load step here, if it errors out take a look at what the actual .th file number is
fqgan_hparams =  dict_to_vcqgan_hparams(load_hparams_from_json(f"{model_path}"), 'minecraft')
fqgan_hparams.load_step = 10000
fqgan = FQModel(fqgan_hparams)
fqgan = load_fqgan_from_checkpoint(fqgan_hparams, fqgan)
print(f'loaded from: {fqgan_hparams.log_dir}')

NO biome supervision
Disentangle Ratio:  0.5
Loading fqgan_10000.th
loaded from: FQGAN_2stagedecoder_nobiomemodel_16bothcbook_EMA3


# Load VQ Model

In [10]:
#from models3d import VQAutoEncoder

# # vqmodel_path = 'saved_models/minecraft39ch_ce_3'
#vqgan_hparams =  vq_dict_to_vcqgan_hparams(vq_load_hparams_from_json(f"../model_logs/minecraft39ch_ce_3_fqdataset"), 'minecraft')
#vqgan_hparams.load_step = 10000

#vqgan = VQAutoEncoder(vqgan_hparams)
#vqgan = vq_load_vqgan_from_checkpoint(vqgan_hparams, vqgan)
#print(f'loaded from: {vqgan_hparams.log_dir}')

In [84]:
from models3d import VQAutoEncoder

# vqmodel_path = 'saved_models/minecraft39ch_ce_3'
#vqgan_hparams =  vq_dict_to_vcqgan_hparams(vq_load_hparams_from_json(f"../model_logs/minecraft39ch_ce_3_fqdataset"), 'minecraft')
vqgan_hparams =  vq_dict_to_vcqgan_hparams(vq_load_hparams_from_json(f"../model_logs/minecraft39ch_ce_3"), 'minecraft')
#vqgan_hparams.load_step = 10000
vqgan_hparams.load_step = 95000

vqgan = VQAutoEncoder(vqgan_hparams)
vqgan = vq_load_vqgan_from_checkpoint(vqgan_hparams, vqgan)
print(f'loaded from: {vqgan_hparams.log_dir}')

resolution: 24, num_resolutions: 3, num_res_blocks: 2, attn_resolutions: [6], in_channels: 256, out_channels: 39, block_in_ch: 256, curr_res: 6

Attempting to load checkpoint from: ../model_logs/minecraft39ch_ce_3/saved_models/vqgan_95000.th

Load successful!
loaded from: minecraft39ch_ce_3


# Load dataset

In [17]:
# loads a preprocesed dataset, which already has a mappings file created and everything one-hot encoded nicely. For more memory efficient, could try just loading the validation set file
# data_path = '../../text2env/data/minecraft_biome_newworld_10k_processed_cleaned.pt'
data_path = './minecraft_biome_newworld_10k_processed_cleaned.pt'
train_loader, val_loader = get_minecraft_dataloaders(
    data_path,
    batch_size=4,
    num_workers=0, # Must be 0 if you're on windows, otherwise it errors
    val_split=0.1
)


Loading pre-processed data...


  processed_data = torch.load(data_path)


Loaded 11082 chunks of size torch.Size([43, 24, 24, 24])
Number of unique block types: 43
Unique blocks: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42]

Dataloader details:
Training samples: 9974
Validation samples: 1108
Batch size: 4
Training batches: 2494
Validation batches: 277


In [18]:
# mappings_path = '../../text2env/data/minecraft_biome_newworld_10k_mappings.pt'
mappings_path = './minecraft_biome_newworld_10k_mappings.pt'

block_converter = BlockBiomeConverter.load_mappings(mappings_path)

  mappings = torch.load(path)


# Test the model

In [12]:
# Create an instnance of our visualizer
from visualization_utils import MinecraftVisualizerPyVista
visualizer = MinecraftVisualizerPyVista()


# Load a random sample from the training data

In [24]:
# Get a sample
# You can rerun this to get a different sample, since the train loader has shuffle=True
batch = next(iter(train_loader))
sample = batch[0].unsqueeze(0).cuda()  # Add batch dim and move to GPU

In [27]:
# If you find an interesting sample, like one with a cool feature or with a clear biome split, you can save it and later load it with the example code below
# torch.save(sample, 'bowlaroundlake.pt')
# sample = torch.load('treesandcorner.pt')

In [14]:
# Get style and structure code matrices
# Make sure to rerun this if you get a new sample from the loader
with torch.no_grad():
    style_indices, struct_indices = encode_and_quantize(fqgan, sample)
reconstructed, binary_reconstructed = decode_from_indices(style_indices, struct_indices, fqgan, two_stage=True)
print(reconstructed.shape)

torch.Size([24, 24, 24])


## Structure Exploration

In [19]:
# store codes and its corresponding chunks.
structure_dict = {}
for batch in train_loader:
    for sample_idx in range(len(batch)):
        sample= batch[sample_idx].unsqueeze(0).cuda()
        with torch.no_grad():
            style_indices, struct_indices = encode_and_quantize(fqgan, sample)
        reconstructed, binary_reconstructed = decode_from_indices(style_indices, struct_indices, fqgan, two_stage=True)
        for i in range(struct_indices.shape[1]):  
            for j in range(struct_indices.shape[2]):  
                for k in range(struct_indices.shape[3]): 
                    struct_code = struct_indices[0, i, j, k].item()  
                    x_start, y_start, z_start = i * 4, j * 4, k * 4
                    x_end, y_end, z_end = x_start + 4, y_start + 4, z_start + 4
                    block = reconstructed[x_start:x_end, y_start:y_end, z_start:z_end]
                    if struct_code not in structure_dict:
                        structure_dict[struct_code] = []  # Initialize list if not present
                    structure_dict[struct_code].append(block)

# convert chunks into binary chunks
binary_structure_dict = {
    struct_code: [(b != 0).to(dtype=torch.int) for b in block_list]  # Convert each block tensor
    for struct_code, block_list in structure_dict.items()
}


In [14]:
# Use to find out which index is entirely not used, don't need to run 
import numpy as np
structure_dict = {}
structure_array = []
for batch in train_loader:
    for sample_idx in range(len(batch)):
        sample= batch[sample_idx].unsqueeze(0).cuda()
        with torch.no_grad():
            style_indices, struct_indices = encode_and_quantize(fqgan, sample)
        reconstructed, binary_reconstructed = decode_from_indices(style_indices, struct_indices, fqgan, two_stage=True)
        for i in range(struct_indices.shape[1]):  
            for j in range(struct_indices.shape[2]):  
                for k in range(struct_indices.shape[3]): 
                    struct_code = struct_indices[0, i, j, k].item()  
                    structure_array.append(struct_code)
#print(structure_array)
print(np.unique(structure_array))

RuntimeError: Expected 4D (unbatched) or 5D (batched) input to conv3d, but got input of size: [1, 4, 43, 24, 24, 24]

In [20]:
import numpy as np

# Dictionary to store probability of existence at each position, sorted by key
probability_dict = {}
sorted_keys = sorted(binary_structure_dict.keys())  # Ensure sorted order

for struct_code in sorted_keys:
    block_list = binary_structure_dict[struct_code]
    block_array = np.array(block_list)  # Convert list of binary blocks to NumPy array

    # Compute probability: mean along the first axis (number of samples)
    probability_array = np.mean(block_array, axis=0)
    probability_array = np.round(probability_array, 3)
    probability_dict[struct_code] = probability_array

print(probability_dict)

{0: array([[[0.6  , 0.833, 0.833, 0.8  ],
        [0.6  , 0.8  , 0.8  , 0.8  ],
        [0.567, 0.833, 0.8  , 0.8  ],
        [0.467, 0.733, 0.733, 0.733]],

       [[0.8  , 0.933, 0.933, 0.9  ],
        [0.867, 0.933, 0.9  , 0.867],
        [0.833, 0.933, 0.933, 0.9  ],
        [0.767, 0.867, 0.9  , 0.867]],

       [[0.867, 0.933, 0.933, 0.933],
        [0.867, 0.933, 0.9  , 0.9  ],
        [0.867, 0.933, 0.9  , 0.9  ],
        [0.767, 0.867, 0.867, 0.867]],

       [[0.833, 0.833, 0.867, 0.9  ],
        [0.8  , 0.9  , 0.867, 0.9  ],
        [0.9  , 0.9  , 0.9  , 0.9  ],
        [0.8  , 0.833, 0.8  , 0.833]]]), 1: array([[[0.789, 0.862, 0.821, 0.741],
        [0.794, 0.853, 0.817, 0.746],
        [0.822, 0.853, 0.831, 0.776],
        [0.71 , 0.739, 0.726, 0.692]],

       [[0.907, 0.953, 0.905, 0.851],
        [0.899, 0.938, 0.897, 0.845],
        [0.893, 0.908, 0.887, 0.845],
        [0.769, 0.789, 0.776, 0.751]],

       [[0.878, 0.932, 0.889, 0.837],
        [0.875, 0.917, 0.878, 

In [None]:
# Check for results
key = 0  # The structure code to check

if key in binary_structure_dict:
    values = binary_structure_dict[key]  # Get the list of blocks
    print(f"First 50 values for key {key}:")
    
    for i, block in enumerate(values[:50]):  # Iterate over the first 50 blocks
        print(f"Block {i+1}:")
        print(block)  # Each block is a 4x4x4 tensor
        print("-" * 30)  # Separator
else:
    print(f"Key {key} not found in dictionary.")

In [33]:
# Plot structural heatmap
import os
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.colors as mcolors

# Define colors and levels for heatmap
colors = ["white", "yellow", "orange", "red", "black"]
levels = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]
cmap = mcolors.ListedColormap(colors)
norm = mcolors.BoundaryNorm(levels, cmap.N)

# Ensure the directory exists
save_dir = os.path.join(os.getcwd(), "Structure_Heatmap_V2")
os.makedirs(save_dir, exist_ok=True)

# Directory to save images
image_paths = []

# Plot heatmap and corresponding values
for struct_code, prob_matrix in probability_dict.items():
    fig = plt.figure(figsize=(12, 6))

    # Left subplot: Probability values
    ax1 = fig.add_subplot(121)
    ax1.set_title(f"Probability Values for Structure Code {struct_code}")
    ax1.set_xticks(range(4))
    ax1.set_yticks(range(4))

    # Display values in table format
    table_data = []
    for i in range(4):
        for j in range(4):
            row = [f"{prob_matrix[i, j, k]:.2f}" for k in range(4)]
            table_data.append(row)
    
    ax1.axis("tight")
    ax1.axis("off")
    table = ax1.table(cellText=table_data, colLabels=["Z=0", "Z=1", "Z=2", "Z=3"],
                      rowLabels=[f"X={i}, Y={j}" for i in range(4) for j in range(4)],
                      cellLoc="center", loc="center")

    # Right subplot: 3D Heatmap
    ax2 = fig.add_subplot(122, projection='3d')

    # Get the coordinates for the 4x4x4 grid
    x, y, z = np.indices((4, 4, 4)).reshape(3, -1)
    probs = prob_matrix.flatten()

    # Normalize values to discrete levels
    color_indices = np.digitize(probs, levels, right=True) - 1
    color_indices = np.clip(color_indices, 0, len(colors) - 1)
    colors_mapped = [colors[i] for i in color_indices]

    # Scatter plot with smaller spheres
    ax2.scatter(x, y, z, c=colors_mapped, s=300, alpha=0.8, edgecolors="k")

    # Create color bar legend
    sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
    sm.set_array([])
    cbar = plt.colorbar(sm, ax=ax2, fraction=0.02, pad=0.1, ticks=levels)
    cbar.set_label("Probability Levels")

    ax2.set_title(f"3D Frequency Heatmap for Structure Code {struct_code}")
    ax2.set_xlabel("X")
    ax2.set_ylabel("Y")
    ax2.set_zlabel("Z")
    ax2.set_xticks(range(4))
    ax2.set_yticks(range(4))
    ax2.set_zticks(range(4))

    # Save the figure instead of displaying
    image_path = os.path.join(save_dir, f"heatmap_{struct_code}.png")
    plt.savefig(image_path, bbox_inches='tight')
    plt.close(fig)

    image_paths.append(image_path)

# Return the list of saved image file paths
image_paths


['/root/autodl-tmp/minecraft_diffusion-master/Dual_codebook_2.0/Structure_Heatmap_V2/heatmap_0.png',
 '/root/autodl-tmp/minecraft_diffusion-master/Dual_codebook_2.0/Structure_Heatmap_V2/heatmap_1.png',
 '/root/autodl-tmp/minecraft_diffusion-master/Dual_codebook_2.0/Structure_Heatmap_V2/heatmap_2.png',
 '/root/autodl-tmp/minecraft_diffusion-master/Dual_codebook_2.0/Structure_Heatmap_V2/heatmap_3.png',
 '/root/autodl-tmp/minecraft_diffusion-master/Dual_codebook_2.0/Structure_Heatmap_V2/heatmap_4.png',
 '/root/autodl-tmp/minecraft_diffusion-master/Dual_codebook_2.0/Structure_Heatmap_V2/heatmap_5.png',
 '/root/autodl-tmp/minecraft_diffusion-master/Dual_codebook_2.0/Structure_Heatmap_V2/heatmap_6.png',
 '/root/autodl-tmp/minecraft_diffusion-master/Dual_codebook_2.0/Structure_Heatmap_V2/heatmap_7.png',
 '/root/autodl-tmp/minecraft_diffusion-master/Dual_codebook_2.0/Structure_Heatmap_V2/heatmap_8.png',
 '/root/autodl-tmp/minecraft_diffusion-master/Dual_codebook_2.0/Structure_Heatmap_V2/heatma

## Style Exploration (Block Type Only)

In [13]:
# Store the style codes and corresponding chunks
style_dict = {}
for batch in train_loader:
    for sample_idx in range(len(batch)):
        sample= batch[sample_idx].unsqueeze(0).cuda()
        with torch.no_grad():
            style_indices, struct_indices = encode_and_quantize(fqgan, sample)
        reconstructed, binary_reconstructed = decode_from_indices(style_indices, struct_indices, fqgan, two_stage=True)
        for i in range(style_indices.shape[1]):  
            for j in range(style_indices.shape[2]):  
                for k in range(style_indices.shape[3]): 
                    style_code = style_indices[0, i, j, k].item()  
                    x_start, y_start, z_start = i * 4, j * 4, k * 4
                    x_end, y_end, z_end = x_start + 4, y_start + 4, z_start + 4
                    block = reconstructed[x_start:x_end, y_start:y_end, z_start:z_end]
                    if style_code not in style_dict:
                        style_dict[style_code] = []  # Initialize list if not present
                    style_dict[style_code].append(block)


In [1]:
print(style_dict[19][20:30])

NameError: name 'style_dict' is not defined

In [15]:
# Count the frequency of different block types
from collections import defaultdict
block_type_frequencies = defaultdict(lambda: defaultdict(int))
block_total_counts = defaultdict(int)
for style_code, block_list in style_dict.items():
    for block_matrix in block_list:
        flattened_blocks = block_matrix.flatten()
        block_total_counts[style_code] += len(flattened_blocks)

        # Count the frequency of each block type
        for block_type in flattened_blocks:
            block_type_frequencies[style_code][block_type.item()] += 1
sorted_block_frequency_dict = {}
for style_code in sorted(block_type_frequencies.keys()):  # Sort by style code
    block_counts = block_type_frequencies[style_code]
    total_blocks = block_total_counts[style_code] if block_total_counts[style_code] > 0 else 1  # Avoid division by zero

    sorted_block_frequency_dict[style_code] = {
        block_type: block_counts[block_type] / total_blocks
        for block_type in sorted(block_counts.keys())  # Sort by block type
    }
print(sorted_block_frequency_dict)

{0: {0: 0.19814343205684462, 2: 1.965038043136515e-07, 3: 0.0002077700224276342, 4: 4.336183948521243e-05, 6: 2.9475570647047725e-06, 8: 6.563227064075961e-05, 10: 0.00050711081766543, 11: 0.08270622419250037, 12: 6.55012681045505e-08, 14: 0.010865874363327675, 15: 9.471483367918003e-05, 16: 0.0021671749565071578, 17: 0.0015244765138653085, 18: 0.0018967857217715735, 19: 0.0021692709970865037, 22: 2.489048187972919e-06, 23: 1.1528223186400888e-05, 24: 3.0785596009138736e-06, 25: 0.0001168542622985181, 27: 7.532645832023308e-06, 28: 1.3100253620910101e-06, 32: 0.010261166656186463, 33: 0.00016853476283300845, 34: 1.6506319562346726e-05, 35: 6.55012681045505e-08, 36: 0.000286633549225513, 37: 0.6668276687836676, 38: 6.091617933723197e-06, 39: 3.93007608627303e-07, 40: 0.020675147770860845, 41: 0.0012035858014211155, 42: 1.6375317026137626e-05}, 1: {0: 0.2187001873866858, 2: 9.506223914921577e-08, 3: 0.0001265278403076062, 4: 1.5019833785576092e-05, 6: 9.506223914921577e-08, 8: 1.34988379

In [26]:
# check for frequency 
for i in block_type_frequencies.values():
    print(i[37], i[40])

10180378 315645
2737764 138506
3687378 41887
17926060 781558
3945713 43218
3474925 266080
10710129 326415
3739435 190858
15449400 386645
6774780 201948
1761697 15946
4314247 199158
9544919 831776
893160 83745
521919 12177
910905 41099
250311 5798
1028908 43345
167061 15274
809361 7186


In [16]:
# check blocks that are most commonly used
threshold = 0.001
filtered_block_frequency_dict = {
    style_code: {
        block_type: freq for block_type, freq in block_counts.items() if freq >= threshold
    }
    for style_code, block_counts in sorted_block_frequency_dict.items()
}
common_block_index = {
    style_code: [block_type for block_type, freq in block_counts.items() if freq >= threshold]
    for style_code, block_counts in sorted_block_frequency_dict.items()
}

common_block_list = []
for i in common_block_index.values():
    for j in i:
        if j not in common_block_list:
            common_block_list.append(j)
common_block_list.sort()
print(common_block_list)

[0, 11, 14, 16, 17, 18, 19, 32, 33, 37, 40, 41]


In [19]:
block_index_to_name = {
    0: "AIR",
    1: "BONE_BLOCK",
    2: "BROWN_MUSHROOM",
    3: "BROWN_MUSHROOM_BLOCK",
    4: "CACTUS",
    5: "CHEST",
    6: "CLAY",
    7: "COAL_ORE",
    8: "COBBLESTONE",
    9: "DEADBUSH",
    10: "DIAMOND_ORE",
    11: "DIRT",
    12: "DOUBLE_PLANT",
    13: "EMERALD_ORE",
    14: "FLOWING_LAVA",
    15: "FLOWING_WATER",
    16: "GOLD_ORE",
    17: "GRASS",
    18: "GRAVEL",
    19: "IRON_ORE",
    20: "LAPIS_ORE",
    21: "LAVA",
    22: "LEAVES",
    23: "LEAVES2",
    24: "LOG",
    25: "LOG2",
    26: "MOB_SPAWNER",
    27: "MONSTER_EGG",
    28: "MOSSY_COBBLESTONE",
    29: "PUMPKIN",
    30: "REDSTONE_ORE",
    31: "RED_FLOWER",
    32: "RED_MUSHROOM_BLOCK",
    33: "REEDS",
    34: "SAND",
    35: "SANDSTONE",
    36: "SNOW_LAYER",
    37: "STONE",
    38: "TALLGRASS",
    39: "VINE",
    40: "WATER",
    41: "WATERLILY",
    42: "YELLOW_FLOWER"
}

In [24]:
# Histogram of block types
import matplotlib.pyplot as plt
import json
import os

histogram_dir = os.path.join(os.getcwd(), "Block_type_histograms_V2")
os.makedirs(histogram_dir, exist_ok=True)
for style_code, block_data in filtered_block_frequency_dict.items():
    block_types = list(block_data.keys())  # Block type labels
    frequencies = list(block_data.values())  # Corresponding frequencies

    # Convert block types to integers for correct sorting
    block_types = [int(bt) for bt in block_types]

    # Sort block types and frequencies accordingly
    sorted_indices = sorted(range(len(block_types)), key=lambda k: block_types[k])
    block_types = [block_types[i] for i in sorted_indices]
    frequencies = [frequencies[i] for i in sorted_indices]

    #x_labels = [f"{bt}\n{block_index_to_name.get(bt, 'UNKNOWN')}" for bt in block_types]
    x_labels = [block_index_to_name.get(bt, 'UNKNOWN') for bt in block_types]

    # Plot histogram
    plt.figure(figsize=(10, 5))
    plt.bar(block_types, frequencies, color="royalblue", alpha=0.7)
    plt.xlabel("Block Type")
    plt.ylabel("Frequency Ratio")
    plt.title(f"Block Type Frequency for Style Code {style_code}")
    plt.xticks(block_types, x_labels, rotation=90, fontsize=6)  # Ensure block types are properly labeled on x-axis
    plt.grid(axis="y", linestyle="--", alpha=0.7)

    # Save the histogram image
    save_path = os.path.join(histogram_dir, f"histogram_style_code_{style_code}.png")
    plt.savefig(save_path, bbox_inches="tight")
    plt.close()

print(f"Histograms saved in directory: {histogram_dir}")

Histograms saved in directory: /root/autodl-tmp/minecraft_diffusion-master/Dual_codebook_2.0/Block_type_histograms_V2


## Biome Exploration

In [18]:
# check for biome data
for batch in train_loader:
    voxels, biomes = batch
    biome_value = block_converter.convert_to_original_biomes(biomes)
    print(biome_value[2])
    break

[[['forest' 'forest' 'forest' ... 'plains' 'plains' 'plains']
  ['forest' 'forest' 'forest' ... 'plains' 'plains' 'plains']
  ['forest' 'forest' 'forest' ... 'plains' 'plains' 'plains']
  ...
  ['forest' 'forest' 'forest' ... 'plains' 'plains' 'plains']
  ['forest' 'forest' 'forest' ... 'plains' 'plains' 'plains']
  ['forest' 'forest' 'forest' ... 'plains' 'plains' 'plains']]

 [['forest' 'forest' 'forest' ... 'plains' 'plains' 'plains']
  ['forest' 'forest' 'forest' ... 'plains' 'plains' 'plains']
  ['forest' 'forest' 'forest' ... 'plains' 'plains' 'plains']
  ...
  ['forest' 'forest' 'forest' ... 'plains' 'plains' 'plains']
  ['forest' 'forest' 'forest' ... 'plains' 'plains' 'plains']
  ['forest' 'forest' 'forest' ... 'plains' 'plains' 'plains']]

 [['forest' 'forest' 'forest' ... 'plains' 'plains' 'plains']
  ['forest' 'forest' 'forest' ... 'plains' 'plains' 'plains']
  ['forest' 'forest' 'forest' ... 'plains' 'plains' 'plains']
  ...
  ['forest' 'forest' 'forest' ... 'plains' 'plai

In [19]:
# store style code and its corresponding biome chunks and block chunks 
biome_dict = {}

for batch in train_loader:
    voxels, biomes = batch  # Now unpack both
    for sample_idx in range(len(voxels)):
        sample = voxels[sample_idx].unsqueeze(0).cuda()
        
        with torch.no_grad():
            style_indices, struct_indices = encode_and_quantize(fqgan, sample)
            reconstructed, binary_reconstructed = decode_from_indices(
                style_indices, struct_indices, fqgan, two_stage=True
            )

        biome_tensor = biomes[sample_idx]
        biome_names = block_converter.convert_to_original_biomes(biome_tensor)

        for i in range(style_indices.shape[1]):  
            for j in range(style_indices.shape[2]):  
                for k in range(style_indices.shape[3]):  
                    style_code = style_indices[0, i, j, k].item()

                    x_start, y_start, z_start = i * 4, j * 4, k * 4
                    x_end, y_end, z_end = x_start + 4, y_start + 4, z_start + 4

                    block = reconstructed[x_start:x_end, y_start:y_end, z_start:z_end]
                    biome_chunk = biome_names[x_start:x_end, y_start:y_end, z_start:z_end]

                    # Store both block and corresponding biome values
                    if style_code not in biome_dict:
                        biome_dict[style_code] = []

                    biome_dict[style_code].append({
                        "block": block,
                        "biome": biome_chunk
                    })

In [28]:
# example for the first sample in index 0.
print(biome_dict[0][0]["biome"])

[[['swampland' 'swampland' 'swampland' 'swampland']
  ['swampland' 'swampland' 'swampland' 'swampland']
  ['swampland' 'swampland' 'swampland' 'swampland']
  ['swampland' 'swampland' 'swampland' 'swampland']]

 [['swampland' 'swampland' 'swampland' 'swampland']
  ['swampland' 'swampland' 'swampland' 'swampland']
  ['swampland' 'swampland' 'swampland' 'swampland']
  ['swampland' 'swampland' 'swampland' 'swampland']]

 [['swampland' 'swampland' 'swampland' 'swampland']
  ['swampland' 'swampland' 'swampland' 'swampland']
  ['swampland' 'swampland' 'swampland' 'swampland']
  ['swampland' 'swampland' 'swampland' 'swampland']]

 [['swampland' 'swampland' 'swampland' 'swampland']
  ['swampland' 'swampland' 'swampland' 'swampland']
  ['swampland' 'swampland' 'swampland' 'swampland']
  ['swampland' 'swampland' 'swampland' 'swampland']]]


In [30]:
# find the frequency of biome for each code index
from collections import defaultdict

biome_frequencies = defaultdict(lambda: defaultdict(int))
biome_total_counts = defaultdict(int)


for style_code, entries in biome_dict.items():
    for entry in entries:
        biome_matrix = entry["biome"]  
        flattened_biomes = biome_matrix.flatten()

        # Update total biome count for the style code
        biome_total_counts[style_code] += len(flattened_biomes)

        # Count occurrences of each biome type
        for biome in flattened_biomes:
            biome_frequencies[style_code][biome] += 1

sorted_biome_frequency_dict = {}
for style_code in sorted(biome_frequencies.keys()):
    biome_counts = biome_frequencies[style_code]
    total_biomes = biome_total_counts[style_code] if biome_total_counts[style_code] > 0 else 1

    sorted_biome_frequency_dict[style_code] = {
        biome: biome_counts[biome] / total_biomes
        for biome in sorted(biome_counts.keys())  # Sort by biome name
    }

print(sorted_biome_frequency_dict[0])


{np.str_('beaches'): 0.012664664858581944, np.str_('birch_forest'): 0.04071086977991507, np.str_('cave'): 0.3224620340193419, np.str_('desert'): 0.042630535360211504, np.str_('extreme_hills'): 0.15213599000235506, np.str_('forest'): 0.14629722481786206, np.str_('mutated_extreme_hills'): 0.00047291291565056863, np.str_('ocean'): 0.02534025039694297, np.str_('plains'): 0.12334860329253747, np.str_('river'): 0.02441863998602153, np.str_('savanna'): 0.029020994294657033, np.str_('swampland'): 0.04432703922328327, np.str_('taiga'): 0.03617024105263958}


In [34]:
print(sorted_biome_frequency_dict)

{0: {np.str_('beaches'): 0.012664664858581944, np.str_('birch_forest'): 0.04071086977991507, np.str_('cave'): 0.3224620340193419, np.str_('desert'): 0.042630535360211504, np.str_('extreme_hills'): 0.15213599000235506, np.str_('forest'): 0.14629722481786206, np.str_('mutated_extreme_hills'): 0.00047291291565056863, np.str_('ocean'): 0.02534025039694297, np.str_('plains'): 0.12334860329253747, np.str_('river'): 0.02441863998602153, np.str_('savanna'): 0.029020994294657033, np.str_('swampland'): 0.04432703922328327, np.str_('taiga'): 0.03617024105263958}, 1: {np.str_('beaches'): 0.029120701688594434, np.str_('birch_forest'): 0.032540730439492904, np.str_('cave'): 0.1486817296373845, np.str_('desert'): 0.014762180068127129, np.str_('extreme_hills'): 0.10217897434294822, np.str_('forest'): 0.12550913049366125, np.str_('mutated_extreme_hills'): 0.0002714928591518485, np.str_('ocean'): 0.2171942873214788, np.str_('plains'): 0.1211893340416888, np.str_('river'): 0.07409541444003459, np.str_('s

In [18]:
# Biome Mapping
print(block_converter.index_to_biome)
Biome_mapping = {0: 'beaches', 1: 'birch_forest', 2: 'cave', 3: 'desert', 4: 'extreme_hills', 5: 'forest', 6: 'mutated_extreme_hills', 7: 'ocean', 8: 'plains', 9: 'river', 10: 'savanna', 11: 'swampland', 12: 'taiga'}

{0: 'beaches', 1: 'birch_forest', 2: 'cave', 3: 'desert', 4: 'extreme_hills', 5: 'forest', 6: 'mutated_extreme_hills', 7: 'ocean', 8: 'plains', 9: 'river', 10: 'savanna', 11: 'swampland', 12: 'taiga'}


In [37]:
# plot histogram of biome frequency. 
import matplotlib.pyplot as plt
import os

histogram_dir = os.path.join(os.getcwd(), "Biome_histograms_V1")
os.makedirs(histogram_dir, exist_ok=True)

# Assume your biome frequency dictionary is called `sorted_biome_frequency_dict`
for style_code, biome_data in sorted_biome_frequency_dict.items():
    biome_names = list(biome_data.keys())         # Biome name labels (strings)
    frequencies = list(biome_data.values())       # Corresponding frequency ratios

    # Sort biome names and frequencies alphabetically
    sorted_indices = sorted(range(len(biome_names)), key=lambda k: biome_names[k])
    biome_names = [biome_names[i] for i in sorted_indices]
    frequencies = [frequencies[i] for i in sorted_indices]

    # Plot histogram
    plt.figure(figsize=(12, 6))
    plt.bar(biome_names, frequencies, color="seagreen", alpha=0.75)
    plt.xlabel("Biome Name")
    plt.ylabel("Frequency Ratio")
    plt.title(f"Biome Frequency for Style Code {style_code}")
    plt.xticks(rotation=90, fontsize=7)
    plt.grid(axis="y", linestyle="--", alpha=0.6)

    # Save the histogram image
    save_path = os.path.join(histogram_dir, f"biome_histogram_style_code_{style_code}.png")
    plt.savefig(save_path, bbox_inches="tight")
    plt.close()

print(f"Biome histograms saved in directory: {histogram_dir}")

Biome histograms saved in directory: /root/autodl-tmp/minecraft_diffusion-master/Dual_codebook_2.0/Biome_histograms_V1


# Render the original dataset chunk

In [None]:
# coords are the coordinates you want to highlight with a red box in the visualization. Helpful if you modify latent codes, so you can see where you expect the changes to show up.
# X axis goes from 0 all the way on the left to 5 in the center
# Y axis goes from 0 at the bottom to 5 at the top
# Z axis goes from 0 at the center to 5 all the way on the right
# Example:

coords=[(0, 2, 5)]
# coords = coords + [(x, y, 0) for x in [3, 4] for y in range(6)]
# Don't forget: You need to convert from the one-hot encoded representaiton back into the original minecraft block IDs before rendering!
converted_orig = block_converter.convert_to_original_blocks(sample.squeeze())
plotter = visualizer.visualize_chunk(converted_orig, interactive=True, show_axis=False, highlight_latents=coords)
plotter.show()

# plotter.reset_camera()
# plotter.show()
# img = plotter.screenshot(window_size=(500, 500), 
#                                transparent_background=True, 
#                                return_img=True)
# # Make the array C-contiguous and ensure correct format
# img_array = np.ascontiguousarray(img)

# # Save using imageio instead of plt.imsave
# import imageio
# imageio.imwrite('treesandcorner_highlight.png', img_array)
    

In [None]:
latent_coords = [(0, 2, 5)]
  # Coordinates in the 6x6x6 latent space
plotter = visualizer.visualize_latent_blocks(converted_orig, latent_coords, show_axis=False)
plotter.show()

In [None]:
latent_coords = [(2, 2, 5)]  # Coordinates in the 6x6x6 latent space
plotter = visualizer.visualize_isolated_latent_blocks(converted_orig, latent_coords, show_axis=False)
plotter.reset_camera()
plotter.show()
img = plotter.screenshot(window_size=(500, 500), 
                               transparent_background=True, 
                               return_img=True)
# Make the array C-contiguous and ensure correct format
img_array = np.ascontiguousarray(img)

# Save using imageio instead of plt.imsave
import imageio
imageio.imwrite('treesandcorner_225.png', img_array)
    

# Render the reconstructed dataset chunk

In [None]:
# Reconstruct and convert from the style and structure matrices
reconstructed, binary_reconstructed = decode_from_indices(style_indices, struct_indices, fqgan, two_stage=True)
binary_reconstructed = (binary_reconstructed>0.5).float()
binary_reconstructed = block_converter.convert_to_original_blocks(binary_reconstructed).squeeze()
reconstructed = block_converter.convert_to_original_blocks(reconstructed)

coords=[(0, 2, 5)]
# coords=None
fig = visualizer.visualize_chunk(reconstructed, highlight_latents=coords, show_axis=False, wireframe_highlight=True)
fig.show()

# fig.reset_camera()
# fig.show()
# img = fig.screenshot(window_size=(500, 500), 
#                                transparent_background=True, 
#                                return_img=True)
# # Make the array C-contiguous and ensure correct format
# img_array = np.ascontiguousarray(img)

# # Save using imageio instead of plt.imsave
# import imageio
# imageio.imwrite('treesandcorner_recon_2codebook_highlight.png', img_array)
    

In [None]:
latent_coords = [(2, 2, 5)]  # Coordinates in the 6x6x6 latent space
plotter = visualizer.visualize_isolated_latent_blocks(reconstructed, latent_coords, show_axis=False)
plotter.reset_camera()
plotter.show()
img = plotter.screenshot(window_size=(500, 500), 
                               transparent_background=True, 
                               return_img=True)
# Make the array C-contiguous and ensure correct format
img_array = np.ascontiguousarray(img)

# Save using imageio instead of plt.imsave
import imageio
imageio.imwrite('treesandcorner_recon_225.png', img_array)

In [None]:
binary_reconstructed[binary_reconstructed != 5] = 217
fig = visualizer.visualize_chunk(binary_reconstructed, highlight_latents=coords)
fig.show()

## VQ reconstruction (single codebook)

In [None]:
# Reconstruct and convert from the style and structure matrices
vq_reconstructed = vq_decode_from_indices(vq_latent_indices, vqgan)
vq_reconstructed = block_converter.convert_to_original_blocks(vq_reconstructed)
coords=None
fig = visualizer.visualize_chunk(vq_reconstructed, highlight_latents=coords, show_axis=False)
fig.show()

# fig.reset_camera()
# fig.show()
# img = fig.screenshot(window_size=(500, 500), 
#                                transparent_background=True, 
#                                return_img=True)
# # Make the array C-contiguous and ensure correct format
# img_array = np.ascontiguousarray(img)

# # Save using imageio instead of plt.imsave
# import imageio
# imageio.imwrite('treesandcorner_recon_singlecodebook.png', img_array)

# Plots either style or structure codes next to the chunk
NOTE: The slider can't take discrete values, if you slide it to layer 1 and nothing changes, keep moving the slider until you're past like 1.5, then it should change

In [None]:
# Visualize overlay
plotter = visualizer.visualize_latent_space_with_blocks(
    vq_latent_indices,  # Remove batch dimension
    vq_reconstructed,         # Remove batch dimension
    latent_type='style',
)
plotter.show()

In [None]:
# Visualize overlay
plotter = visualizer.visualize_latent_space_with_blocks(
    style_indices,  # Remove batch dimension
    reconstructed,         # Remove batch dimension
    latent_type='style',
)
plotter.show()

# Plot both style and structure next to the chunk
NOTE: This runs very slowly in notebook mode, so I'm rendering it a new window. Check to see if it opened.

In [None]:
plotter = visualizer.visualize_both_codes_with_blocks(
    style_indices,
    struct_indices,
    reconstructed
)
plotter.show(interactive=True)

# Make to change style indices, struct indices, or both
Pass "coords" into this to specify the range of spatial coordinates you want to make modifications to. This will simply loop from 0 to max_code_value + 1 and set all of the values at those coordinates equal to that value

NOTE: It currently only handles one max_code_value, so if your codebooks are different sizes, pass in the smaller value. If you try to go out of bounds, the CUDA kernel crashes and you must restart the kernel

In [None]:
def modify_latent_indices(indices, coords, code_value, max_code_value, mode='all'):
    """
    Helper function to modify latent indices either all to the same value or randomly.
    
    Args:
        indices: Original indices tensor [1,6,6,6]
        coords: List of (x,y,z) coordinates to modify
        code_value: Value to set for 'all' mode, or seed for random mode
        max_code_value: Maximum possible code value
        mode: Either 'all' or 'random'
    
    Returns:
        Modified indices tensor
    """
    modified = indices.clone()
    
    if mode == 'all':
        # Set all specified coordinates to the same value
        for coord in coords:
            modified[0, coord[0], coord[1], coord[2]] = code_value
    elif mode == 'random':
        # Set a different random value for each coordinate
        # Use the code_value as a seed for reproducibility
        torch.manual_seed(code_value)
        for coord in coords:
            modified[0, coord[0], coord[1], coord[2]] = torch.randint(0, max_code_value + 1, (1,)).item()
    
    return modified



def create_latent_modification_gif(fqgan, style_indices, struct_indices, coords, block_converter, 
                                 latent_type='style', max_code_value=31, duration=5.0, out_path='my_animation.gif', 
                                 show_axis=True, transparent_background=False, fps=3, mode='all', wireframe_highlight=True):
    """
    Creates a GIF showing how modifying style or structure codes at specific coordinates affects reconstruction.
    
    Args:
        fqgan: The trained FQGAN model
        style_indices: Original style indices tensor [1,6,6,6]
        struct_indices: Original structure indices tensor [1,6,6,6]
        coords: List of (x,y,z) coordinates to modify
        block_converter: BlockBiomeConverter instance for converting to block IDs
        latent_type: Either 'style' or 'struct' to specify which codes to modify
        max_code_value: Maximum value to try for codes
        duration: Duration of output GIF in seconds
        out_path: Path to save the output GIF
        mode: Either 'all' or 'random' - determines how codes are modified
    """
    chunks = []
    
    # First get base reconstruction
    base_reconstruction, _ = decode_from_indices(style_indices, struct_indices, fqgan, two_stage=True)
    base_reconstruction = block_converter.convert_to_original_blocks(base_reconstruction)
    chunks.append(base_reconstruction)
    
    # For each possible code value
    for code in range(max_code_value + 1):
        if latent_type == 'style':
            modified_style = modify_latent_indices(style_indices, coords, code, max_code_value, mode)
            modified_struct = struct_indices
            reconstruction, _ = decode_from_indices(modified_style, modified_struct, fqgan, two_stage=True)
            
        elif latent_type == 'both':
            modified_style = modify_latent_indices(style_indices, coords, code, max_code_value, mode)
            modified_struct = modify_latent_indices(struct_indices, coords, code, max_code_value, mode)
            reconstruction, _ = decode_from_indices(modified_style, modified_struct, fqgan, two_stage=True)
            
        else:  # struct
            modified_style = style_indices
            modified_struct = modify_latent_indices(struct_indices, coords, code, max_code_value, mode)
            reconstruction, _ = decode_from_indices(modified_style, modified_struct, fqgan, two_stage=True)
        
        # Convert to blocks and add to chunks
        reconstruction = block_converter.convert_to_original_blocks(reconstruction)
        chunks.append(reconstruction)
    
    # Create visualization
    visualizer = MinecraftVisualizerPyVista()
    visualizer.create_voxel_gif(chunks, duration=duration, output_path=out_path, 
                              highlight_latents=coords, show_axis=show_axis, 
                              transparent_background=transparent_background, fps=fps, wireframe_highlight=wireframe_highlight)
    
    return out_path

In [None]:

def create_latent_modification_gif_single_codebook(vqgan, indices, coords, block_converter, max_code_value=31, duration=5.0, out_path='my_animation.gif', show_axis=True, transparent_background=False, fps=3, wireframe_highlight=True, mode='all'):
    chunks = []
    
    # First get base reconstruction
    base_reconstruction = vq_decode_from_indices(indices, vqgan)
    base_reconstruction = block_converter.convert_to_original_blocks(base_reconstruction)
    chunks.append(base_reconstruction)
    
    # For each possible code value
    for code in range(max_code_value + 1):
        # modified_indices = indices.clone()
        modified_indices = modify_latent_indices(indices, coords, code, max_code_value, mode)
        # Modify the structure codes at specified coordinates
        # for coord in coords:
        #     modified_indices[0, coord[0], coord[1], coord[2]] = code
            
        reconstruction = vq_decode_from_indices(modified_indices, vqgan)
        
        # Convert to blocks and add to chunks
        reconstruction = block_converter.convert_to_original_blocks(reconstruction)
        chunks.append(reconstruction)
    
    # Create visualization
    visualizer = MinecraftVisualizerPyVista()
    visualizer.create_voxel_gif(chunks, duration=duration, output_path=out_path, highlight_latents=coords, show_axis=show_axis, transparent_background=transparent_background, fps=fps, wireframe_highlight=wireframe_highlight)
    
    return out_path

# Create style modification gif

In [None]:
#  Pick a name for this sample
sample_name = 'samplename'

# Specify coordinate you want to modify
# coords = [(x, y, z) for x in [0, 1, 2] for y in  range(6) for z in  [0, 1]] 
# coords = [(x, y, z) for x in range(6) for y in range(6) for z in range(6)] 
# coords = [(x, y, z) for x in [4, 5] for y in range(6) for z in [0, 1]] 
# coords = [(x, y, z) for x in [0, 1, 3] for y in [1, 2, 3, 4] for z in [0, 1, 2, 3]]
# coords = [(x, y, z) for x in [3, 4, 5] for y in range(6) for z in [4, 5]] 
# coords = [(x, y, z) for x in [0] for y in range(6) for z in [4, 5]] 
# coords = [(0, 2, 5)]
# coords = [(0, 4, 0)]
coords = [(x, y, z) for x in [0, 1] for y in range(6) for z in [4, 5]] 
# coords = [(x, y, z) for x in [3, 4, 5] for y in range(4) for z in [4, 5]] 
# coords= [(x, y, z) for x in [4, 5] for y in range(6) for z in [0, 1, 2, 3]]



max_style_code_value = fqgan_hparams.style_codebook_size - 1
max_struct_code_value = fqgan_hparams.struct_codebook_size - 1
# Run the gif creation
gif_path = create_latent_modification_gif(
    fqgan,
    style_indices,
    struct_indices,
    coords,
    block_converter,
    latent_type='style',
    max_code_value=max_style_code_value,  # Assuming 32 possible codes (0-31)
    duration=100.0,  # 5 second animation,
    out_path=f'visualizations/{sample_name}_style_change.gif',
    transparent_background=True,
    show_axis=False,
    fps=1,
    # mode='random'
    mode='all'
)

gif_path = create_latent_modification_gif(
    fqgan,
    style_indices,
    struct_indices,
    coords,
    block_converter,
    latent_type='struct',
    max_code_value=max_struct_code_value,  # Assuming 32 possible codes (0-31)
    duration=10.0,  # 5 second animation,
    out_path=f'visualizations/{sample_name}_struct_change.gif',
    transparent_background=True,
    show_axis=False,
    fps=1,
    # mode='random'
    mode='all'
)

# Create structure modification gif

In [None]:
gif_path = create_latent_modification_gif(
    fqgan,
    style_indices,
    struct_indices,
    coords,
    block_converter,
    latent_type='struct',
    max_code_value=31,  # Assuming 32 possible codes (0-31)
    duration=10.0,  # 5 second animation,
    out_path=f'visualizations/{sample_name}_struct_change.gif',
)

# Create both modification gif

In [None]:
gif_path = create_latent_modification_gif(
    fqgan,
    style_indices,
    struct_indices,
    coords,
    block_converter,
    latent_type='both',
    max_code_value=19,  # Assuming 32 possible codes (0-31)
    duration=10.0,  # 5 second animation,
    out_path=f'visualizations/{sample_name}_both_change.gif',
)

# Create single codebook modification gif

In [None]:
# coords = [(x, y, z) for x in [0, 1] for y in range(6) for z in [4, 5]] 
sample_name = 'treesandcorner_asad_vqgan_rand_1_wireframe'
coords = [(x, y, z) for x in [0, 1] for y in range(6) for z in [4, 5]] 
# coords = [(x, y, z) for x in [3, 4, 5] for y in range(4) for z in [4, 5]] 
gif_path = create_latent_modification_gif_single_codebook(
    vqgan,
    vq_latent_indices,
    coords,
    block_converter,
    max_code_value=100,  # Assuming 32 possible codes (0-31)
    duration=20.0,  # 5 second animation,
    out_path=f'visualizations/{sample_name}_vqgan_change.gif',
    transparent_background=True,
    show_axis=False,
    fps=1,
    mode='random'
)