# Imports

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
import matplotlib.pyplot as plt
import matplotlib
import numpy as np
from sampler_utils import retrieve_autoencoder_components_state_dicts, latent_ids_to_onehot3d, get_latent_loaders
from models3d import VQAutoEncoder, Generator
import matplotlib.pyplot as plt
from ipywidgets import interact, fixed
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from concurrent.futures import ThreadPoolExecutor
from PIL import Image
import torch.distributions as dists
from tqdm import tqdm
import gc
from models3d import BiomeClassifier

In [2]:
%matplotlib widget
#matplotlib.use('TkAgg')

# Block Converter

In [3]:
class BlockBiomeConverter:
    def __init__(self, block_mappings=None, biome_mappings=None):
        """
        Initialize with pre-computed mappings for both blocks and biomes
        
        Args:
            block_mappings: dict containing 'index_to_block' and 'block_to_index'
            biome_mappings: dict containing 'index_to_biome' and 'biome_to_index'
        """
        self.index_to_block = block_mappings['index_to_block'] if block_mappings else None
        self.block_to_index = block_mappings['block_to_index'] if block_mappings else None
        self.index_to_biome = biome_mappings['index_to_biome'] if biome_mappings else None
        self.biome_to_index = biome_mappings['biome_to_index'] if biome_mappings else None
    
    @classmethod
    def from_dataset(cls, data_path):
        """Create mappings from a dataset file"""
        data = np.load(data_path, allow_pickle=True)
        voxels = data['voxels']
        biomes = data['biomes']
        
        # Create block mappings (blocks are integers)
        unique_blocks = np.unique(voxels)
        block_to_index = {int(block): idx for idx, block in enumerate(unique_blocks)}
        index_to_block = {idx: int(block) for idx, block in enumerate(unique_blocks)}
        
        # Create biome mappings (biomes are strings)
        unique_biomes = np.unique(biomes)
        biome_to_index = {str(biome): idx for idx, biome in enumerate(unique_biomes)}
        index_to_biome = {idx: str(biome) for idx, biome in enumerate(unique_biomes)}
        
        block_mappings = {'index_to_block': index_to_block, 'block_to_index': block_to_index}
        biome_mappings = {'index_to_biome': index_to_biome, 'biome_to_index': biome_to_index}
        
        return cls(block_mappings, biome_mappings)
    
    @classmethod
    def from_arrays(cls, voxels, biomes):
        """Create mappings directly from numpy arrays"""
        # Create block mappings (blocks are integers)
        unique_blocks = np.unique(voxels)
        block_to_index = {int(block): idx for idx, block in enumerate(unique_blocks)}
        index_to_block = {idx: int(block) for idx, block in enumerate(unique_blocks)}
        
        # Create biome mappings (biomes are strings)
        unique_biomes = np.unique(biomes)
        biome_to_index = {str(biome): idx for idx, biome in enumerate(unique_biomes)}
        index_to_biome = {idx: str(biome) for idx, biome in enumerate(unique_biomes)}
        
        block_mappings = {'index_to_block': index_to_block, 'block_to_index': block_to_index}
        biome_mappings = {'index_to_biome': index_to_biome, 'biome_to_index': biome_to_index}
        
        return cls(block_mappings, biome_mappings)
    
    @classmethod
    def load_mappings(cls, path):
        """Load pre-saved mappings"""
        mappings = torch.load(path)
        return cls(mappings['block_mappings'], mappings['biome_mappings'])
    
    def save_mappings(self, path):
        """Save mappings for later use"""
        torch.save({
            'block_mappings': {
                'index_to_block': self.index_to_block,
                'block_to_index': self.block_to_index
            },
            'biome_mappings': {
                'index_to_biome': self.index_to_biome,
                'biome_to_index': self.biome_to_index
            }
        }, path)
    
    def convert_to_original_blocks(self, data):
        """
        Convert from indices back to original block IDs.
        Handles both one-hot encoded and already-indexed data.
        
        Args:
            data: torch.Tensor of either:
                - one-hot encoded blocks [B, C, H, W, D] or [C, H, W, D]
                - indexed blocks [B, H, W, D] or [H, W, D]
        Returns:
            torch.Tensor of original block IDs with shape [B, H, W, D] or [H, W, D]
        """
        # If one-hot encoded (dim == 5 or first dim == num_blocks), convert to indices first
        if len(data.shape) == 5 or (len(data.shape) == 4 and data.shape[0] == len(self.block_to_index)):
            data = torch.argmax(data, dim=1 if len(data.shape) == 5 else 0)
        
        # Now convert indices to original blocks
        if len(data.shape) == 4:  # Batch dimension present
            return torch.tensor([[[[self.index_to_block[int(b)] 
                                for b in row]
                                for row in layer]
                                for layer in slice_]
                                for slice_ in data])
        else:  # No batch dimension
            return torch.tensor([[[self.index_to_block[int(b)] 
                                for b in row]
                                for row in layer]
                                for layer in data])

    def convert_to_original_biomes(self, data):
        """
        Convert from indices back to original biome strings.
        Handles both one-hot encoded and already-indexed data.
        
        Args:
            data: torch.Tensor of either:
                - one-hot encoded biomes [B, C, H, W, D] or [C, H, W, D]
                - indexed biomes [B, H, W, D] or [H, W, D]
        Returns:
            numpy array of original biome strings with shape [B, H, W, D] or [H, W, D]
        """
        # If one-hot encoded (dim == 5 or first dim == num_biomes), convert to indices first
        if len(data.shape) == 5 or (len(data.shape) == 4 and data.shape[0] == len(self.biome_to_index)):
            data = torch.argmax(data, dim=1 if len(data.shape) == 5 else 0)
        
        # Now convert indices to original biomes
        if len(data.shape) == 4:  # Batch dimension present
            return np.array([[[[self.index_to_biome[int(b)] 
                            for b in row]
                            for row in layer]
                            for layer in slice_]
                            for slice_ in data])
        else:  # No batch dimension
            return np.array([[[self.index_to_biome[int(b)] 
                            for b in row]
                            for row in layer]
                            for layer in data])
        
    def get_air_block_index(self):
        """
        Find the one-hot index corresponding to the air block (ID 5).
        Returns:
            int: The index where air blocks are encoded in one-hot format
        """
        # Find the index that maps to block ID 5 (air) in our index_to_block mapping
        for idx, block_id in self.index_to_block.items():
            if block_id == 5:  # Air block ID
                return idx
        raise ValueError("Air block (ID 5) not found in block mappings!")
    
    def get_blockid_indices(self, block_ids):
        """
        Find the one-hot index corresponding to the air block (ID 5).
        Returns:
            int: The index where air blocks are encoded in one-hot format
        """
        # Find the index that maps to block ID 5 (air) in our index_to_block mapping
        idxs = []
        for idx, block_id in self.index_to_block.items():
            if block_id in block_ids:  # Air block ID
                idxs.append(idx)
        if len(idxs) == 0:
            raise ValueError("Air block (ID 5) not found in block mappings!")
        return idxs

# Minecraft Chunks Dataset

In [4]:
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np

class MinecraftDataset(Dataset):
    def __init__(self, data_path):
        data_path = Path(data_path)

        # Try to load processed data first
        # assert processed_data_path.exists() and mappings_path.exists()

        print("Loading pre-processed data...")
        processed_data = torch.load(data_path)
        # Only keep the chunks, discard biome data
        self.processed_chunks = processed_data['chunks']
        # self.processed_biomes = processed_data['biomes']
        # Delete the biomes to free memory
        # del processed_data['biomes']
        # del processed_data['chunks']
        del processed_data
        
        
        print(f"Loaded {len(self.processed_chunks)} chunks of size {self.processed_chunks.shape[1:]}")
        print(f"Number of unique block types: {self.processed_chunks.shape[1]}")
        print(f'Unique blocks: {torch.unique(torch.argmax(self.processed_chunks, dim=1)).tolist()}')
        #print(f"Loaded {len(self.processed_biomes)} chunks of size {self.processed_biomes.shape[1:]}")
        #print(f"Number of unique biome types: {self.processed_biomes.shape[1]}")
        #print(f'Unique biomes: {torch.unique(torch.argmax(self.processed_biomes, dim=1)).tolist()}')
        

    def __getitem__(self, idx):
        # return self.processed_chunks[idx]
        return self.processed_chunks[idx] #, self.processed_biomes[idx]

    def __len__(self):
        # return len(self.processed_chunks), len(self.processed_biomes)
        return len(self.processed_chunks)

def get_minecraft_dataloaders(data_path, batch_size=32, val_split=0.1, num_workers=4):
    """
    Creates training and validation dataloaders for Minecraft chunks.
    """
    # Create dataset
    dataset = MinecraftDataset(data_path)
    
    # Split into train and validation sets
    val_size = int(val_split * len(dataset))
    train_size = len(dataset) - val_size
    
    # Use a fixed seed for reproducibility
    generator = torch.Generator().manual_seed(42)
    
    train_dataset, val_dataset = torch.utils.data.random_split(
        dataset, 
        [train_size, val_size],
        generator=generator
    )
    
    # Create dataloaders with memory pinning
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True,
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True,
    )

    
    print(f"\nDataloader details:")
    print(f"Training samples: {len(train_dataset)}")
    print(f"Validation samples: {len(val_dataset)}")
    print(f"Batch size: {batch_size}")
    print(f"Training batches: {len(train_loader)}")
    print(f"Validation batches: {len(val_loader)}")
    
    return train_loader, val_loader


# Helper Functions
Loads hyperparameters and model from checkpoint, gives easy function to encode a structure into latent codes, as well as decode those latents back into the reconstruction

# FQGAN Helper Functions

In [5]:
import os
from log_utils import log, load_stats, load_model
import copy
from fq_models import FQModel, HparamsFQGAN


# Loads hparams from hparams.json file in saved model directory
def load_hparams_from_json(log_dir):
    import json
    import os
    json_path = os.path.join(log_dir, 'hparams.json')
    
    if not os.path.exists(json_path):
        raise FileNotFoundError(f"No hparams.json file found in {log_dir}")
    
    with open(json_path, 'r') as f:
        hparams = json.load(f)

    return hparams

# turns loaded hparams json into propery hyperparams object
def dict_to_vcqgan_hparams(hparams_dict, dataset=None):
    # Determine which hyperparameter class to use based on the dataset
    if dataset == None:
        dataset = hparams_dict.get('dataset', 'MNIST')  # Default to MNIST if not specified
    
    vq_hyper = HparamsFQGAN(dataset)
    # Set attributes from the dictionary
    for key, value in hparams_dict.items():
        setattr(vq_hyper, key, value)
    
    return vq_hyper

# Loads fqgan model weights from a given checkpoint file
def load_fqgan_from_checkpoint(H, fqgan):
    fqgan = load_model(fqgan, "fqgan", H.load_step, H.load_dir).cuda()
    fqgan.eval()
    return fqgan

# Takes a chunk or batch of chunks from the dataset, returns the encoded style and structure indices matrices
def encode_and_quantize(fqgan, terrain_chunks, device='cuda'):
    """Memory-efficient encoding function"""
    fqgan.eval()
    with torch.no_grad():
        # Move input to device
        terrain_chunks = terrain_chunks.to(device)
        
        # Get encodings
        h_style, h_struct = fqgan.encoder(terrain_chunks)
        
        # Process style path
        h_style = fqgan.quant_conv_style(h_style)
        quant_style, _, style_stats = fqgan.quantize_style(h_style)
        style_indices = style_stats[2]  # Get indices from tuple
        style_indices = style_indices.view(
            (h_style.size()[0], h_style.size()[2], h_style.size()[3], h_style.size()[4])
        )
        
        # Clear intermediate tensors
        del h_style, quant_style, style_stats
        
        # Process structure path
        h_struct = fqgan.quant_conv_struct(h_struct)
        quant_struct, _, struct_stats = fqgan.quantize_struct(h_struct)
        struct_indices = struct_stats[2]  # Get indices from tuple
        struct_indices = struct_indices.view(
            (h_struct.size()[0], h_struct.size()[2], h_struct.size()[3], h_struct.size()[4])
        )
        
        # Clear intermediate tensors
        del h_struct, quant_struct, struct_stats
        
        # Move indices to CPU to save GPU memory
        style_indices = style_indices.cpu()
        struct_indices = struct_indices.cpu()
        
        torch.cuda.empty_cache()
        
        return style_indices, struct_indices

# Takes style and structure indices, returns the reconstructed map
def decode_from_indices(style_indices, struct_indices, fqgan, device='cuda', two_stage=False):
    """Memory-efficient decoding function"""
    with torch.no_grad():
        # Move indices to device only when needed
        style_indices = style_indices.to(device)
        struct_indices = struct_indices.to(device)
        
        # Get quantized vectors
        quant_style = fqgan.quantize_style.get_codebook_entry(
            style_indices.view(-1),
            shape=[1, fqgan.embed_dim, *style_indices.shape[1:]]
        )
        quant_struct = fqgan.quantize_struct.get_codebook_entry(
            struct_indices.view(-1),
            shape=[1, fqgan.embed_dim, *struct_indices.shape[1:]]
        )
        
        # Clear indices from GPU
        del style_indices, struct_indices
        
        # Combine and decode
        quant = torch.cat([quant_struct, quant_style], dim=1)
        # quant = quant_style + quant_struct
        del quant_style, quant_struct
        
        if two_stage:
            decoded, binary_decoded = fqgan.decoder(quant)
        else:
            decoded = fqgan.decoder(quant)
        
        del quant
        
        # Convert to block IDs if one-hot encoded
        if decoded.shape[1] > 1:
            decoded = torch.argmax(decoded, dim=1)
        
        # Move result to CPU and clear GPU memory
        result = decoded.squeeze(0).cpu()
        if two_stage:
            binary_result = binary_decoded.squeeze(0).cpu()
            del decoded
            del binary_decoded
            torch.cuda.empty_cache()
            return result, binary_result
        
        del decoded
        torch.cuda.empty_cache()
        
        return result

# Load FQ model

In [6]:
# Important if you reload the model, can run into memory issues. There's a memory leak somewhere I haven't been able to fix
import gc
gc.collect()
torch.cuda.empty_cache()

In [7]:
# replace with directory for whatever model you want
# model_path = './FQGAN_2stagedecoder_nobiomemodel_16bothcbook_EMA3'
model_path = "../model_logs/FQGAN_2stagedecoder_logweighted3_32codes_cycleconsistency_STE_postquant_detachcycle_gumbel"

In [8]:
# # I'm manually setting the load step here, if it errors out take a look at what the actual .th file number is
# fqgan_hparams =  dict_to_vcqgan_hparams(load_hparams_from_json(f"{model_path}"), 'minecraft')
# fqgan_hparams.load_step = 10000
# fqgan = FQModel(fqgan_hparams)
# fqgan = load_fqgan_from_checkpoint(fqgan_hparams, fqgan)
# print(f'loaded from: {fqgan_hparams.log_dir}')

# Load dataset

In [9]:
# loads a preprocesed dataset, which already has a mappings file created and everything one-hot encoded nicely. For more memory efficient, could try just loading the validation set file
# data_path = '../../text2env/data/minecraft_biome_newworld_10k_processed_cleaned.pt'
data_path = '../../text2env/data/24_newdataset_processed_cleaned3.pt'
train_loader, val_loader = get_minecraft_dataloaders(
    data_path,
    batch_size=4,
    num_workers=0, # Must be 0 if you're on windows, otherwise it errors
    val_split=0.1
)


Loading pre-processed data...
Loaded 11119 chunks of size torch.Size([42, 24, 24, 24])
Number of unique block types: 42
Unique blocks: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41]

Dataloader details:
Training samples: 10008
Validation samples: 1111
Batch size: 4
Training batches: 2502
Validation batches: 278


In [None]:
# mappings_path = '../../text2env/data/minecraft_biome_newworld_10k_mappings.pt'
mappings_path = '../../text2env/data/24_newdataset_mappings3.pt'

block_converter = BlockBiomeConverter.load_mappings(mappings_path)

# Test the model

In [11]:
# Create an instnance of our visualizer
from visualization_utils import MinecraftVisualizerPyVista
visualizer = MinecraftVisualizerPyVista()


# Structure Exploration

## Get corresponding chunks for each structure code

In [12]:
# # store codes and its corresponding chunks.
# # We don't have a guarantee that air will be index 0, so I have a function to retrieve the correct index corresponding to air blocks
# air_idx = block_converter.get_air_block_index()

# structure_dict = {}
# structure_dict_bin = {}
# style_dict = {}
# for batch in train_loader:
#     for sample_idx in range(len(batch)):
#         sample= batch[sample_idx].unsqueeze(0).cuda()
#         with torch.no_grad():
#             style_indices, struct_indices = encode_and_quantize(fqgan, sample)
#         reconstructed, binary_reconstructed = decode_from_indices(style_indices, struct_indices, fqgan, two_stage=True)
#         for i in range(struct_indices.shape[1]):  
#             for j in range(struct_indices.shape[2]):  
#                 for k in range(struct_indices.shape[3]): 
#                     style_code = style_indices[0, i, j, k].item()  
#                     struct_code = struct_indices[0, i, j, k].item()  
#                     x_start, y_start, z_start = i * 4, j * 4, k * 4
#                     x_end, y_end, z_end = x_start + 4, y_start + 4, z_start + 4
#                     block = reconstructed[x_start:x_end, y_start:y_end, z_start:z_end]
#                     bin_block = binary_reconstructed[x_start:x_end, y_start:y_end, z_start:z_end]
#                     if struct_code not in structure_dict:
#                         structure_dict[struct_code] = []  # Initialize list if not present
#                     if struct_code not in structure_dict_bin:
#                         structure_dict_bin[struct_code] = []
#                     if style_code not in style_dict:
#                         style_dict[style_code] = []  # Initialize list if not present
#                     structure_dict[struct_code].append(block)
#                     style_dict[style_code].append(block)
#                     structure_dict_bin[struct_code].append(bin_block)

# # convert chunks into binary chunks
# binary_structure_dict = {
#     struct_code: [(b != air_idx).to(dtype=torch.int) for b in block_list]  # Convert each block tensor
#     for struct_code, block_list in structure_dict.items()
# }


In [13]:


# # Assuming latent space shape is fixed (e.g., 6x6x6) - get this dynamically if possible
# # Example: Infer from a dummy forward pass or model config if necessary
# latent_depth, latent_height, latent_width = 6, 6, 6 # *** Adjust if different ***
# latent_shape = (latent_depth, latent_height, latent_width)
# num_structure_codes = 20
# # Initialize count tensors for each code
# position_counts = {
#     code: torch.zeros(latent_shape, dtype=torch.long, device='cpu')
#     for code in range(num_structure_codes)
# }
# total_samples_processed = 0

# print("Collecting positional frequencies from train_loader...")
# for batch_idx, batch in enumerate(train_loader):
#     # Limit batches for testing?
#     # if batch_idx > 20: break

#     if torch.cuda.is_available():
#         batch = batch.cuda()

#     # Only need the encoding part
#     style_indices, struct_indices = encode_and_quantize(fqgan, batch) # Get indices for the whole batch

#     # Process each sample in the batch
#     for sample_idx in range(struct_indices.shape[0]): # Iterate through batch dimension
#         struct_indices_sample = struct_indices[sample_idx].cpu() # Get indices for one sample, move to CPU

#         # Iterate through the latent grid dimensions (D, H, W)
#         for i in range(latent_depth):
#             for j in range(latent_height):
#                 for k in range(latent_width):
#                     struct_code = struct_indices_sample[i, j, k].item()
#                     if 0 <= struct_code < num_structure_codes:
#                         position_counts[struct_code][i, j, k] += 1
#                     else:
#                         print(f"Warning: Encountered out-of-bounds code {struct_code} at ({i},{j},{k})")
#         total_samples_processed += 1

#     # Optional: Print progress
#     if (batch_idx + 1) % 50 == 0:
#             print(f"Processed {batch_idx + 1} batches...")



In [14]:
def plot_positional_frequency(position_counts, struct_code, save_dir):
    """
    Generates and saves a 3D scatter plot showing the positional frequency
    of a structure code within the latent grid, using a Viridis colormap and
    Minecraft coordinate conventions (Y=Height is vertical, Y=0 is bottom).

    Args:
        position_counts (dict): Maps struct_code to 6x6x6 count tensor.
        struct_code (int): The structure code to visualize.
        save_dir (str): The directory to save the plot image.
    """
    if struct_code not in position_counts:
        print(f"Error: Code {struct_code} not found in position_counts dictionary.")
        return

    counts_tensor = position_counts[struct_code].cpu() # Ensure it's on CPU
    latent_shape = counts_tensor.shape
    if len(latent_shape) != 3:
        print(f"Error: Count tensor for code {struct_code} is not 3D (shape: {latent_shape}).")
        return

    counts_numpy = counts_tensor.numpy()
    max_count = np.max(counts_numpy)

    if max_count == 0:
        print(f"Info: Code {struct_code} never appeared. Skipping visualization.")
        # Optional: Create an empty plot placeholder if desired
        return # Stop here if code never appeared

    # --- Plotting Setup ---
    # os.makedirs(save_dir, exist_ok=True)
    fig = plt.figure(figsize=(8, 8))
    ax = fig.add_subplot(111, projection='3d')

    # Generate grid indices (I, J, K) corresponding to tensor dimensions
    # Assume I=Depth(Z_mc), J=Height(Y_mc), K=Width(X_mc)
    i_indices, j_indices, k_indices = np.indices(latent_shape)

    # Flatten coordinates and the raw counts (for color mapping)
    i_coords_mc = i_indices.flatten() # Minecraft Z coordinates
    j_coords_mc = j_indices.flatten() # Minecraft Y (Height) coordinates
    k_coords_mc = k_indices.flatten() # Minecraft X coordinates
    frequencies = counts_numpy.flatten()

    # --- Create Scatter Plot (Viridis Colormap, Correct Axis Mapping) ---
    # cmap = plt.get_cmap('viridis')
    cmap = plt.get_cmap('Greys')
    # Plot mapping:
    # Plot X-axis <- Minecraft X data (k_coords_mc)
    # Plot Y-axis <- Minecraft Z data (i_coords_mc)
    # Plot Z-axis <- Minecraft Y (Height) data (j_coords_mc) <<< VERTICAL AXIS
    scatter = ax.scatter(k_coords_mc, i_coords_mc, j_coords_mc, # Correct mapping
                         c=frequencies, cmap=cmap, # Color based on frequency counts
                         s=150, # Adjust size as needed
                         alpha=0.8, # Add some transparency
                         # vmin=0, vmax=max_count, # Optional: Explicitly set color limits
                         edgecolors='grey', linewidth=0.5)

    # --- Add Colorbar ---
    cbar = fig.colorbar(scatter, ax=ax, shrink=0.6, aspect=20, pad=0.1)
    cbar.set_label("Code Occurrence Count")

    # --- Set Labels, Ticks, Limits, Title, and Invert Z-axis ---
    ax.set_title(f"Code {struct_code} Positional Frequency")
    # Label plot axes according to the *Minecraft dimension* plotted on them
    ax.set_xlabel("Z (Latent Dim K)")
    ax.set_ylabel("X (Latent Dim I)")
    ax.set_zlabel("Y (Height, Latent Dim J)") # Vertical axis

    # Set ticks based on the dimension size
    ax.set_xticks(np.arange(latent_shape[2])) # K dimension
    ax.set_yticks(np.arange(latent_shape[0])) # I dimension
    ax.set_zticks(np.arange(latent_shape[1])) # J dimension

    # Set limits for plot axes
    ax.set_xlim(-0.5, latent_shape[2] - 0.5)
    ax.set_ylim(-0.5, latent_shape[0] - 0.5)
    ax.set_zlim(-0.5, latent_shape[1] - 0.5)

    # --- Invert the Z-axis (which represents Y-Height) ---
    # ax.invert_zaxis() # Ensures Y=0 is at the bottom
    ax.invert_yaxis() # Ensures Y=0 is at the bottom

    # Adjust view angle
    ax.view_init(elev=20., azim=-75)
    # plt.show()
    # --- Save and Close ---
    image_path = os.path.join(save_dir, f"pos_freq_code_{struct_code}.png")
    try:
        plt.savefig(image_path, bbox_inches='tight', dpi=150)
    except Exception as e:
        print(f"Error saving positional frequency plot for code {struct_code}: {e}")
    finally:
        plt.close(fig)


def generate_all_positional_frequency_plots(position_counts, output_dir):
    print(f"Generating positional frequency plots for {len(position_counts)} codes...")
    os.makedirs(output_dir, exist_ok=True) # Ensure directory exists

    # Iterate through the codes present in the dictionary
    codes_to_plot = sorted(position_counts.keys())

    for i, code in enumerate(codes_to_plot):
        # Call the plotting function for the current code
        plot_positional_frequency(position_counts, code, output_dir)

        # Optional: Print progress
        if (i + 1) % 10 == 0 or (i + 1) == len(codes_to_plot):
            print(f"Generated plot {i+1}/{len(codes_to_plot)} (Code {code})")

    print(f"Finished generating positional frequency plots. Saved to: {output_dir}")

## Visualize some of the chunks in a grid

In [15]:
import random
def visualize_structure_grid(structure_dict, struct_code, num_chunks=64, grid_size=8, spacing=2, save_path=None, code_type="Struct"):
    """
    Visualize a grid of 4x4x4 chunks for a specific structure code.
    
    Parameters:
    - structure_dict: Dictionary containing chunks for each structure code
    - struct_code: The structure code to visualize
    - num_chunks: Number of chunks to display (default: 64)
    - grid_size: Number of chunks per row/column (default: 8)
    - spacing: Number of blocks spacing between chunks (default: 2)
    
    Returns:
    - PyVista plotter with interactive visualization
    """
    import pyvista as pv
    import numpy as np
    import torch
    
    # Check if the structure code exists
    if struct_code not in structure_dict:
        print(f"Structure code {struct_code} not found in dictionary")
        return None
    
    # Randomly select num_chunks from the list
    available_chunks = structure_dict[struct_code]
    actual_chunks = len(available_chunks)
    
    if actual_chunks == 0:
        print(f"No chunks found for structure code {struct_code}")
        return None
    
    if actual_chunks < num_chunks:
        print(f"Only {actual_chunks} chunks available for structure code {struct_code}")
        chunks = available_chunks
    else:
        # Randomly select num_chunks from the list
        chunks = random.sample(available_chunks, num_chunks)
    
    # Setup block colors mapping (simplified from the visualizer)
    blocks_to_cols = {
            0: (0.5, 0.25, 0.0),    # light brown
            10: 'black', # bedrock
            29: "#006400", # cacutus
            38: "#B8860B",  # clay
            60: "brown",  # dirt
            92: "gold",  # gold ore
            93: "green",  # grass
            115: "brown",  # ladder...?
            119: (.02, .28, .16, 0.9),  # transparent forest green (RGBA) for leaves
            120: (.02, .28, .16, 0.9),  # leaves2
            194: "yellow",  # sand
            217: "gray",  # stone
            240: (0.0, 0.0, 1.0, 0.4),  # water
            227: (0.0, 1.0, 0.0, .3), # tall grass
            237: (0.33, 0.7, 0.33, 0.3), # vine
            40: "#2F4F4F",  # coal ore
            62: "#228B22",  # double plant
            108: "#BEBEBE",  # iron ore
            131: "saddlebrown",  # log1
            132: "saddlebrown",  #log2
            95: "lightgray",  # gravel
            243: "wheat",  # wheat
            197: "limegreen",  # sapling
            166: "orange",  #pumpkin
            167: "#FF8C00",  # pumpkin stem
            184: "#FFA07A",  # red flower
            195: "tan",  # sandstone
            250: "white",  #wool 
            251: "gold",   #yellow flower
        }
    
    # Calculate grid dimensions
    chunk_size = 4
    grid_dim = grid_size * chunk_size + (grid_size - 1) * spacing
    
    # Create plotter
    plotter = pv.Plotter(notebook=True)
    
    # Remove existing lights and add custom lighting
    plotter.remove_all_lights()
    plotter.add_light(pv.Light(position=(1, -1, 1), intensity=1.0, color='white'))
    plotter.add_light(pv.Light(position=(-1, 1, 0.5), intensity=0.5, color='white'))
    plotter.add_light(pv.Light(position=(-0.5, -0.5, -1), intensity=0.3, color='white'))
    plotter.add_title(f"{code_type} Code {struct_code} - Pattern Visualization", font_size=16)
    # Place each chunk in the grid
    for i in range(min(actual_chunks, num_chunks)):
        # Calculate grid position
        row = i // grid_size
        col = i % grid_size
        
        # Calculate offset in the grid
        x_offset = col * (chunk_size + spacing)
        z_offset = row * (chunk_size + spacing)
        
        # Get current chunk
        chunk = chunks[i]
        chunk = block_converter.convert_to_original_blocks(chunk)
        # Convert to numpy if needed
        if isinstance(chunk, torch.Tensor):
            chunk = chunk.detach().cpu().numpy()
        
        # Apply the same transformations as original visualizer
        chunk = chunk.transpose(2, 0, 1)
        chunk = np.rot90(chunk, 1, (0, 1))
        
        # Convert encoded blocks to original block IDs
        
        # Create grid for this chunk
        grid = pv.ImageData()
        grid.dimensions = np.array(chunk.shape) + 1
        grid.origin = (x_offset, z_offset, 0)  # Position in the overall grid
        grid.spacing = (1, 1, 1)  # Unit spacing
        grid.cell_data["values"] = chunk.flatten(order="F")
        
        # Plot each block type in the chunk
        mask = (chunk != 5) & (chunk != -1)
        unique_blocks = np.unique(chunk[mask])
        
        for block_id in unique_blocks:
            # Skip air blocks (0)
            if block_id == 0:
                continue
                
            threshold = grid.threshold([block_id-0.5, block_id+0.5])
            
            # Get color for this block type
            if block_id in blocks_to_cols:
                color = blocks_to_cols[int(block_id)]
                opacity = 1.0 if isinstance(color, str) or len(color) == 3 else color[3]
            else:
                # Default for unknown blocks
                color = (0.7, 0.7, 0.7)  # Gray
                opacity = 1.0
            
            # Add mesh for this block type
            plotter.add_mesh(threshold, 
                           color=color,
                           opacity=opacity,
                           show_edges=True,
                           edge_color='black',
                           line_width=0.5,
                           edge_opacity=0.3,
                           lighting=True)
    
    # Add a dummy cube to set overall bounds
    total_size = grid_size * chunk_size + (grid_size - 1) * spacing
    # outline = pv.Cube(bounds=(0, total_size, 0, chunk_size, 0, total_size))
    # plotter.add_mesh(outline, opacity=0.0)
    
    # Set camera position and bounds
    plotter.camera_position = 'iso'
    plotter.camera.zoom(1)  # Zoom out to see the whole grid
    
    # Add grid lines or axes
    # plotter.show_bounds(
    #     grid='back',
    #     location='back', 
    #     font_size=10,
    #     bounds=[0, total_size, 0,total_size, 0, chunk_size],
    #     axes_ranges=[0, total_size, 0, total_size, 0, chunk_size]
    # )
    # Save image if path provided
    if save_path:
        plotter.screenshot(save_path)
        plotter.close()
        print(f"Saved visualization for structure code {struct_code} to {save_path}")
    return plotter

# Usage example:
# plotter = visualize_structure_grid(binary_structure_dict, 4)
# plotter.show()

In [16]:
# Loop to visualize and save all structure codes
def visualize_all_structure_codes(structure_dict, output_dir="structure_visualizations", num_chunks=64, code_type="Struct"):
    import os
    import math
    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)
    # vis_out_dir = os.path.join(output_dir, f"{code_type}_visualizations")
    # os.makedirs(vis_out_dir, exist_ok=True)
    
    # Get all unique structure codes
    struct_codes = sorted(structure_dict.keys())
    
    print(f"Found {len(struct_codes)} {code_type} codes to visualize")
    
    # Loop through each structure code
    for struct_code in struct_codes:
        # Define output file path
        output_path = os.path.join(output_dir, f"{code_type}_code_{struct_code}.png")
        
        # Skip if file already exists (optional, can be removed)
        if os.path.exists(output_path):
            print(f"Skipping {code_type} code {struct_code} - file already exists")
            continue
        
        print(f"Visualizing {code_type} code {struct_code}...")
        
        # Generate and save visualization
        try:
            visualize_structure_grid(
                structure_dict, 
                struct_code, 
                num_chunks=num_chunks, 
                grid_size=int(math.sqrt(num_chunks)),
                save_path=output_path,
                code_type=code_type
            )
        except Exception as e:
            print(f"Error visualizing {code_type} code {struct_code}: {str(e)}")
    
    print(f"Visualization complete. Images saved to {output_dir}/")


In [17]:
import pyvista as pv
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize
from matplotlib.cm import get_cmap

# Ensure pyvista is set up for notebook display
# pv.set_jupyter_backend('trame') # or 'ipygany', 'panel', etc. depending on your setup

def visualize_probability_volume_pyvista(prob_matrix, title="Probability Volume", cmap_name='viridis', fixed_opacity=0.5):
    """
    Visualizes a 4x4x4 probability matrix using PyVista, where voxel color
    maps to probability and opacity is fixed for visibility.

    Args:
        prob_matrix (np.ndarray): A 4x4x4 numpy array of probabilities (0.0 to 1.0).
        title (str): The title for the plot.
        cmap_name (str): Name of the matplotlib colormap to use (e.g., 'viridis', 'coolwarm', 'jet').
        fixed_opacity (float): The fixed opacity value (0.0 to 1.0) for visible blocks.
    """
    if not isinstance(prob_matrix, np.ndarray) or prob_matrix.shape != (4, 4, 4):
        print(f"Error: Input must be a 4x4x4 NumPy array. Got shape {prob_matrix.shape}")
        return

    plotter = pv.Plotter(notebook=True)
    plotter.add_title(title, font_size=12)

    # Set up colormap and normalization
    cmap = get_cmap(cmap_name)
    norm = Normalize(vmin=0.0, vmax=1.0) # Probabilities range from 0 to 1

    min_probability_threshold = 0.01 # Don't render blocks with near-zero probability

    # Keep track if any blocks were added
    blocks_added = False
    for x in range(4):
        for y in range(4):
            for z in range(4):
                probability = prob_matrix[x, y, z]

                # Only add a cube if the probability is above the threshold
                if probability >= min_probability_threshold:
                    # Get the corresponding color from the colormap
                    # Use matplotlib cmap directly to get RGBA, then take RGB
                    color = cmap(norm(probability))

                    # Create a cube for this voxel position
                    cube = pv.Cube(bounds=(x, x + 1, y, y + 1, z, z + 1))

                    # Add the cube mesh to the plotter
                    plotter.add_mesh(
                        cube,
                        color=color[:3], # Pass RGB tuple
                        opacity=fixed_opacity,
                        show_edges=True,
                        edge_color='grey',
                        line_width=1
                    )
                    blocks_added = True

    # Only add scalar bar if we actually plotted something
    if blocks_added:
        # Add a scalar bar manually configured
        plotter.add_scalar_bar(
            title="Probability",
            # cmap=cmap_name, # Pass the colormap name
            # Define the limits for the scalar bar manually
            # This requires creating a dummy actor or setting clim directly
            # For simplicity, we'll rely on cmap name and default limits (0-1)
            # which works since our data is normalized 0-1.
            n_labels=6, # Number of labels on the colorbar (e.g., 0.0, 0.2, ..., 1.0)
            fmt="%.2f" # Format labels to 2 decimal places
        )
        # Note: If the range wasn't 0-1, you might need plotter.update_scalar_bar_range([min, max])
        # or pass clim=[min, max] if supported by your PyVista version.

    # Add axes bounds for context
    plotter.show_bounds(
        bounds=[0, 4, 0, 4, 0, 4],
        grid='front',
        location='outer',
        xlabel='X',
        ylabel='Y',
        zlabel='Z',
        ticks='inside',
        minor_ticks=False,
        n_xlabels=5,
        n_ylabels=5,
        n_zlabels=5,
        fmt='%0.0f'
    )

    # Set camera view
    plotter.camera_position = 'iso'
    # plotter.camera.zoom(1.5)

    # Show the interactive plot
    plotter.show()

# Style Exploration

In [1]:
# air_idx = block_converter.get_air_block_index()

# style_dict = {}
# for batch in train_loader:
#     for sample_idx in range(len(batch)):
#         sample= batch[sample_idx].unsqueeze(0).cuda()
#         with torch.no_grad():
#             style_indices, struct_indices = encode_and_quantize(fqgan, sample)
#         reconstructed, binary_reconstructed = decode_from_indices(style_indices, struct_indices, fqgan, two_stage=True)
#         for i in range(style_indices.shape[1]):  
#             for j in range(style_indices.shape[2]):  
#                 for k in range(style_indices.shape[3]): 
#                     style_code = style_indices[0, i, j, k].item()  
#                     x_start, y_start, z_start = i * 4, j * 4, k * 4
#                     x_end, y_end, z_end = x_start + 4, y_start + 4, z_start + 4
#                     block = reconstructed[x_start:x_end, y_start:y_end, z_start:z_end]
#                     if style_code not in style_dict:
#                         style_dict[style_code] = []  # Initialize list if not present
#                     style_dict[style_code].append(block)



In [2]:
# # Store the style codes and corresponding chunks
# style_dict = {}
# for batch in train_loader:
#     for sample_idx in range(len(batch)):
#         sample= batch[sample_idx].unsqueeze(0).cuda()
#         with torch.no_grad():
#             style_indices, struct_indices = encode_and_quantize(fqgan, sample)
#         reconstructed, binary_reconstructed = decode_from_indices(style_indices, struct_indices, fqgan, two_stage=True)
#         for i in range(style_indices.shape[1]):  
#             for j in range(style_indices.shape[2]):  
#                 for k in range(style_indices.shape[3]): 
#                     style_code = style_indices[0, i, j, k].item()  
#                     x_start, y_start, z_start = i * 4, j * 4, k * 4
#                     x_end, y_end, z_end = x_start + 4, y_start + 4, z_start + 4
#                     block = reconstructed[x_start:x_end, y_start:y_end, z_start:z_end]
#                     if style_code not in style_dict:
#                         style_dict[style_code] = []  # Initialize list if not present
#                     style_dict[style_code].append(block)


In [4]:
# # Count the frequency of different block types
# from collections import defaultdict
# block_type_frequencies = defaultdict(lambda: defaultdict(int))
# block_total_counts = defaultdict(int)
# for style_code, block_list in style_dict.items():
#     for block_matrix in block_list:
#         flattened_blocks = block_matrix.flatten()
#         block_total_counts[style_code] += len(flattened_blocks)

#         # Count the frequency of each block type
#         for block_type in flattened_blocks:
#             block_type_frequencies[style_code][block_type.item()] += 1
# sorted_block_frequency_dict = {}
# for style_code in sorted(block_type_frequencies.keys()):  # Sort by style code
#     block_counts = block_type_frequencies[style_code]
#     total_blocks = block_total_counts[style_code] if block_total_counts[style_code] > 0 else 1  # Avoid division by zero

#     sorted_block_frequency_dict[style_code] = {
#         block_type: block_counts[block_type] / total_blocks
#         for block_type in sorted(block_counts.keys())  # Sort by block type
#     }
# print(sorted_block_frequency_dict)

In [6]:
# # check blocks that are most commonly used
# threshold = 0.001
# filtered_block_frequency_dict = {
#     style_code: {
#         block_type: freq for block_type, freq in block_counts.items() if freq >= threshold
#     }
#     for style_code, block_counts in sorted_block_frequency_dict.items()
# }
# common_block_index = {
#     style_code: [block_type for block_type, freq in block_counts.items() if freq >= threshold]
#     for style_code, block_counts in sorted_block_frequency_dict.items()
# }

# common_block_list = []
# for i in common_block_index.values():
#     for j in i:
#         if j not in common_block_list:
#             common_block_list.append(j)
# common_block_list.sort()
# print(common_block_list)

In [7]:
block_index_to_name = {
    0:  "AIR",
    1:  "BONE_BLOCK",
    2:  "BROWN_MUSHROOM",
    3:  "BROWN_MUSHROOM_BLOCK",
    4:  "CACTUS",
    5:  "CHEST",
    6:  "CLAY",
    7:  "COAL_ORE",
    8:  "COBBLESTONE",
    9:  "DEADBUSH",
   10:  "DIRT",
   11:  "DOUBLE_PLANT",
   12:  "EMERALD_ORE",
   13:  "FLOWING_LAVA",
   14:  "FLOWING_WATER",
   15:  "GOLD_ORE",
   16:  "GRASS",
   17:  "GRAVEL",
   18:  "IRON_ORE",
   19:  "LAPIS_ORE",
   20:  "LAVA",
   21:  "LEAVES",
   22:  "LEAVES2",
   23:  "LOG",
   24:  "LOG2",
   25:  "MOB_SPAWNER",
   26:  "MONSTER_EGG",
   27:  "MOSSY_COBBLESTONE",
   28:  "PUMPKIN",
   29:  "RED_FLOWER",
   30:  "RED_MUSHROOM_BLOCK",
   31:  "REEDS",
   32:  "SAND",
   33:  "SANDSTONE",
   34:  "SNOW_LAYER",
   35:  "STONE",
   36:  "STONE_SLAB",
   37:  "TALLGRASS",
   38:  "VINE",
   39:  "WATER",
   40:  "WATERLILY",
   41:  "YELLOW_FLOWER",
}

# Map indices to block names

In [None]:
block_index_to_name = {}
for index in sorted(block_converter.index_to_block.keys()):
    block_id = block_converter.get_block_id_from_index(index)
    block_name = block_converter.get_block_name_from_index(index)
    print(f"{index:5d} | {block_id:7d} | {block_name}")
    block_index_to_name[index] = block_name

# Codebook metrics

In [25]:
import torch
import itertools

def calculate_average_frequency_maps(binary_structure_dict):
    """
    Calculates the average frequency map for each structure code.

    Args:
        binary_structure_dict (dict): Dictionary mapping structure codes (int)
                                      to lists of 4x4x4 binary tensors (torch.Tensor).

    Returns:
        dict: Dictionary mapping structure codes (int) to their 4x4x4
              average frequency map (float tensor). Returns empty dict if input is empty.
              Codes with no associated chunks are skipped.
    """
    if not binary_structure_dict:
        return {}

    avg_freq_maps = {}
    for code, chunk_list in binary_structure_dict.items():
        if not chunk_list:
            print(f"Warning: Code {code} has no associated chunks. Skipping.")
            continue
        # Stack tensors along a new dimension (dim=0) and calculate the mean
        stacked_chunks = torch.stack(chunk_list).float() # Ensure float for mean calculation
        avg_freq_maps[code] = torch.mean(stacked_chunks, dim=0)
    return avg_freq_maps


def calculate_average_block_frequency_maps(
    style_chunks_dict: dict,
    num_block_types: int
) -> dict:
    """
    Calculates the average block‐type frequency map for each style code,
    starting from raw 3D chunks of block‐ID ints.

    Args:
        style_chunks_dict (dict):
            Mapping style_code (int) → list of torch.Tensor of shape (H, W, D),
            where each entry is an integer in [0 .. num_block_types-1].
        num_block_types (int):
            Total number C of distinct block types / channels.

    Returns:
        dict:
            Mapping style_code (int) → 1D torch.float tensor of length C,
            giving the average frequency of each block type (0…1).
    """
    if not style_chunks_dict:
        return {}

    avg_block_freq_maps = {}
    for code, chunks in style_chunks_dict.items():
        if not chunks:
            print(f"Warning: Code {code} has no associated chunks. Skipping.")
            continue

        # Stack raw index tensors → shape (N, H, W, D)
        stacked_idx = torch.stack(chunks).long()
        N, H, W, D = stacked_idx.shape

        # One‐hot encode → shape (N, H, W, D, C)
        onehot = F.one_hot(stacked_idx, num_classes=num_block_types).float()

        # Sum over batch + spatial dims → (C,)
        counts = onehot.sum(dim=(0, 1, 2, 3))

        # Normalize by total voxels (N * H*W*D)
        avg_block_freq_maps[code] = counts / (N * H * W * D)

    return avg_block_freq_maps

def calculate_sharpness_mad(avg_freq_maps):
    """
    Calculates sharpness using Mean Absolute Deviation from 0.5 for each code's avg freq map.

    Args:
        avg_freq_maps (dict): Dictionary mapping codes to avg freq maps (4x4x4 float tensors).

    Returns:
        tuple: (dict mapping code to sharpness score, float overall average sharpness)
               Returns ({}, 0.0) if input is empty.
    """
    if not avg_freq_maps:
        return {}, 0.0

    sharpness_scores = {}
    total_sharpness = 0.0
    for code, freq_map in avg_freq_maps.items():
        # Calculate |F_i[x,y,z] - 0.5| for all voxels and average
        mad = torch.mean(torch.abs(freq_map - 0.5))
        sharpness_scores[code] = mad.item()
        total_sharpness += sharpness_scores[code]

    average_sharpness = total_sharpness / len(sharpness_scores) if sharpness_scores else 0.0
    return sharpness_scores, average_sharpness

def calculate_sharpness_entropy(avg_freq_maps, epsilon=1e-9):
    """
    Calculates sharpness using binary entropy for each voxel in the avg freq map.

    Args:
        avg_freq_maps (dict): Dictionary mapping codes to avg freq maps (4x4x4 float tensors).
        epsilon (float): Small value to avoid log(0).

    Returns:
        tuple: (dict mapping code to sharpness entropy score, float overall average sharpness entropy)
               Returns ({}, 0.0) if input is empty.
    """
    if not avg_freq_maps:
        return {}, 0.0

    entropy_scores = {}
    total_entropy = 0.0
    for code, freq_map in avg_freq_maps.items():
        # Clamp values to avoid log(0)
        p = torch.clamp(freq_map, epsilon, 1.0 - epsilon)
        # Calculate binary entropy: -p*log2(p) - (1-p)*log2(1-p)
        voxel_entropies = -p * torch.log2(p) - (1.0 - p) * torch.log2(1.0 - p)
        avg_entropy = torch.mean(voxel_entropies)
        entropy_scores[code] = avg_entropy.item()
        total_entropy += entropy_scores[code]

    average_entropy = total_entropy / len(entropy_scores) if entropy_scores else 0.0
    return entropy_scores, average_entropy

def calculate_consistency_variance(binary_structure_dict):
    """
    Calculates consistency using the average voxel variance across chunks for each code.

    Args:
        binary_structure_dict (dict): Dictionary mapping codes to lists of 4x4x4 binary tensors.

    Returns:
        tuple: (dict mapping code to consistency variance score, float overall average consistency variance)
               Returns ({}, 0.0) if input is empty. Codes with < 2 chunks have variance 0.
    """
    if not binary_structure_dict:
        return {}, 0.0

    variance_scores = {}
    total_variance = 0.0
    num_valid_codes = 0
    for code, chunk_list in binary_structure_dict.items():
        if len(chunk_list) < 2: # Variance requires at least 2 samples
             # Assign 0 variance, but maybe handle this differently? (e.g. skip?)
            variance_scores[code] = 0.0
            # print(f"Warning: Code {code} has < 2 chunks. Assigning variance 0.") # Optional warning
            continue # Skip adding to total_variance if we want average over codes with enough data

        stacked_chunks = torch.stack(chunk_list).float()
        # Calculate variance across the chunks (dim=0) for each voxel
        voxel_variances = torch.var(stacked_chunks, dim=0, unbiased=False) # Use population variance
        avg_variance = torch.mean(voxel_variances)
        variance_scores[code] = avg_variance.item()
        total_variance += variance_scores[code]
        num_valid_codes += 1 # Only count codes where variance could be computed

    average_variance = total_variance / num_valid_codes if num_valid_codes > 0 else 0.0
    return variance_scores, average_variance


def calculate_consistency_entropy(binary_structure_dict, epsilon=1e-9):
    """
    Calculates consistency using the average entropy of voxel distributions across chunks for each code.
    This relies on calculating the probability p(voxel=1) at each position first.

    Args:
        binary_structure_dict (dict): Dictionary mapping codes to lists of 4x4x4 binary tensors.
        epsilon (float): Small value to avoid log(0).

    Returns:
        tuple: (dict mapping code to consistency entropy score, float overall average consistency entropy)
               Returns ({}, 0.0) if input is empty. Codes with no chunks are skipped.
    """
    # This metric is essentially the same as sharpness entropy applied to the avg freq maps
    # because the entropy H(p) depends only on the mean frequency p at that voxel.
    # H(p) = -p*log2(p) - (1-p)*log2(1-p), where p = mean(voxel_values_for_code)
    # So we can reuse calculate_sharpness_entropy with the average frequency maps.
    avg_freq_maps = calculate_average_frequency_maps(binary_structure_dict)
    return calculate_sharpness_entropy(avg_freq_maps, epsilon)


def calculate_uniqueness_metrics(avg_freq_maps, distance_metric='mae'):
    """
    Calculates uniqueness metrics (average and minimum pairwise distance) between avg freq maps.

    Args:
        avg_freq_maps (dict): Dictionary mapping codes to avg freq maps (4x4x4 float tensors).
        distance_metric (str): 'mae' (Mean Absolute Error) or 'mse' (Mean Squared Error).

    Returns:
        tuple: (float average pairwise distance, float minimum pairwise distance)
               Returns (0.0, 0.0) if less than 2 codes exist.
    """
    codes = list(avg_freq_maps.keys())
    if len(codes) < 2:
        return 0.0, 0.0 # Cannot compare pairs if less than 2 codes

    total_distance = 0.0
    min_distance = float('inf')
    num_pairs = 0

    for code1, code2 in itertools.combinations(codes, 2):
        map1 = avg_freq_maps[code1]
        map2 = avg_freq_maps[code2]

        if distance_metric == 'mae':
            distance = torch.mean(torch.abs(map1 - map2)).item()
        elif distance_metric == 'mse':
            distance = torch.mean((map1 - map2)**2).item()
        else:
            raise ValueError("Unsupported distance_metric. Choose 'mae' or 'mse'.")

        total_distance += distance
        min_distance = min(min_distance, distance)
        num_pairs += 1

    average_distance = total_distance / num_pairs if num_pairs > 0 else 0.0
    # Handle case where min_distance wasn't updated (e.g., only one code pair and it was identical)
    if min_distance == float('inf'):
        min_distance = 0.0


    return average_distance, min_distance



## Code level metrics

In [26]:
def find_most_similar_codes(avg_freq_maps, distance_metric='mae', tolerance=1e-5):
    """
    Finds the pair(s) of codes with the minimum pairwise distance between their avg freq maps.

    Args:
        avg_freq_maps (dict): Dictionary mapping codes to avg freq maps (4x4x4 float tensors).
        distance_metric (str): 'mae' (Mean Absolute Error) or 'mse' (Mean Squared Error).
        tolerance (float): Tolerance for considering distances equal to the minimum.

    Returns:
        tuple: (float minimum_distance, list of tuples containing the most similar code pairs [(code1, code2), ...])
               Returns (inf, []) if less than 2 codes exist.
    """
    codes = list(avg_freq_maps.keys())
    if len(codes) < 2:
        return float('inf'), []

    min_distance = float('inf')
    similar_pairs = []

    # First pass to find the minimum distance
    for code1, code2 in itertools.combinations(codes, 2):
        map1 = avg_freq_maps[code1]
        map2 = avg_freq_maps[code2]

        if distance_metric == 'mae':
            distance = torch.mean(torch.abs(map1 - map2)).item()
        elif distance_metric == 'mse':
            distance = torch.mean((map1 - map2)**2).item()
        else:
            raise ValueError("Unsupported distance_metric. Choose 'mae' or 'mse'.")
        min_distance = min(min_distance, distance)

    # Handle case where min_distance wasn't updated (e.g., only one code pair and it was identical)
    if min_distance == float('inf'):
        # This case should ideally not happen if len(codes) >= 2,
        # unless maybe all maps are identical? Check one pair.
         if len(codes) >= 2:
             code1, code2 = codes[0], codes[1]
             map1 = avg_freq_maps[code1]
             map2 = avg_freq_maps[code2]
             if distance_metric == 'mae': distance = torch.mean(torch.abs(map1 - map2)).item()
             else: distance = torch.mean((map1 - map2)**2).item()
             min_distance = distance
         else: # Should not be reachable due to initial check
             return float('inf'), []


    # Second pass to collect all pairs at (or very close to) the minimum distance
    for code1, code2 in itertools.combinations(codes, 2):
        map1 = avg_freq_maps[code1]
        map2 = avg_freq_maps[code2]

        if distance_metric == 'mae':
            distance = torch.mean(torch.abs(map1 - map2)).item()
        elif distance_metric == 'mse':
            distance = torch.mean((map1 - map2)**2).item()
        else: # Should not happen if first pass succeeded
             raise ValueError("Unsupported distance_metric.")


        if abs(distance - min_distance) < tolerance:
            similar_pairs.append(tuple(sorted((code1, code2)))) # Store sorted pairs

    # Deduplicate pairs (if any floating point issues caused near duplicates)
    similar_pairs = sorted(list(set(similar_pairs)))


    return min_distance, similar_pairs

def analyze_per_code_metrics(
    code_sharpness_mad,
    code_sharpness_entropy,
    code_consistency_var,
    avg_freq_maps, # Needed to get the list of codes that actually appeared
    num_chunks_per_code, # Add a dict mapping code -> number of chunks
    sharp_mad_threshold=0.3, # Example threshold: Lower than this might be "blurry"
    sharp_entropy_threshold=0.5, # Example threshold: Higher than this might be "blurry"
    cons_var_threshold=0.15, # Example threshold: Higher than this might be "inconsistent"
    min_chunks_threshold=10 # Example threshold: Codes used less than this might be unreliable
    ):
    """
    Prints sharpness and consistency metrics for each code and flags potential issues.

    Args:
        code_sharpness_mad (dict): Code -> Sharpness (MAD) score.
        code_sharpness_entropy (dict): Code -> Sharpness (Entropy) score.
                                       (Also used for Consistency Entropy).
        code_consistency_var (dict): Code -> Consistency (Variance) score.
        avg_freq_maps (dict): Code -> Average Frequency Map. Used to get active codes.
        num_chunks_per_code(dict): Code -> integer count of chunks assigned.
        sharp_mad_threshold (float): Threshold below which MAD is flagged low.
        sharp_entropy_threshold (float): Threshold above which Entropy is flagged high.
        cons_var_threshold (float): Threshold above which Variance is flagged high.
        min_chunks_threshold (int): Threshold below which code usage is flagged low.
    """
    print("\n--- Per-Code Analysis ---")
    codes = sorted(list(avg_freq_maps.keys()))

    if not codes:
        print("No codes found in avg_freq_maps.")
        return

    print(f"{'Code':<6} {'NumChunks':<10} {'Sharp(MAD)':<12} {'Sharp(Entr)':<12} {'Cons(Var)':<12} {'Flags':<20}")
    print("-" * 70)

    for code in codes:
        sharp_mad = code_sharpness_mad.get(code, float('nan'))
        sharp_entropy = code_sharpness_entropy.get(code, float('nan'))
        cons_var = code_consistency_var.get(code, float('nan')) # Variance might be missing if < 2 chunks
        n_chunks = num_chunks_per_code.get(code, 0)


        flags = []
        if n_chunks < min_chunks_threshold:
             flags.append(f"LOW_USAGE({n_chunks})")
        if sharp_mad < sharp_mad_threshold:
            flags.append("LOW_SHARPNESS_MAD")
        if sharp_entropy > sharp_entropy_threshold:
            flags.append("HIGH_SHARPNESS_ENT")
        if cons_var > cons_var_threshold:
            flags.append("HIGH_CONSIST_VAR")
        # Note: Consistency Entropy is same as Sharpness Entropy here

        # Handle cases where variance wasn't computed (e.g., < 2 chunks)
        cons_var_str = f"{cons_var:<12.4f}" if not torch.isnan(torch.tensor(cons_var)) else f"{'N/A':<12}" # Use torch.isnan for tensor check compatibility

        print(f"{code:<6} {n_chunks:<10} {sharp_mad:<12.4f} {sharp_entropy:<12.4f} {cons_var_str} {', '.join(flags)}")

    print("-" * 70)

## Code visualizations

## 4x4 frequency heatmap

In [27]:
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors

import matplotlib.pyplot as plt
import numpy as np
import os
import torch # Assuming input is torch tensor

def plot_frequency_heatmap(prob_matrix_tensor, struct_code, save_dir):
    """
    Generates and saves a 3D frequency heatmap visualization for a structure code,
    applying Minecraft-specific orientation adjustments. Uses the 'viridis' colormap
    where color represents probability and ensures Y (Height) axis is vertical and correctly oriented.

    Args:
        prob_matrix_tensor (torch.Tensor): The 4x4x4 frequency map (PyTorch tensor).
        struct_code (int): The structure code ID.
        save_dir (str): The directory to save the heatmap image.
    """
    # --- Input Validation and Conversion ---
    if not isinstance(prob_matrix_tensor, torch.Tensor) or prob_matrix_tensor.shape != (4, 4, 4):
        print(f"Error: Input for code {struct_code} is not a 4x4x4 torch tensor.")
        return

    prob_matrix_np = prob_matrix_tensor.cpu().numpy()

    # --- Orientation Adjustment ---
    # Apply the same sequence: Transpose Z <-> X, then rotate around Z.
    prob_matrix_transposed = prob_matrix_np.transpose(2, 0, 1) # (X,Y,Z) -> (Z,X,Y)
    prob_matrix_oriented = np.rot90(prob_matrix_transposed, k=1, axes=(1, 2)) # (Z,X,Y) -> (Z,Y_mc,X_mc)

    # --- Plotting Setup ---
    os.makedirs(save_dir, exist_ok=True)
    fig = plt.figure(figsize=(9, 8)) # Adjusted size slightly for colorbar
    ax = fig.add_subplot(111, projection='3d')

    # Generate grid indices corresponding to the oriented matrix axes (Z_mc, Y_mc, X_mc)
    z_indices, y_indices, x_indices = np.indices(prob_matrix_oriented.shape)

    # Flatten coordinates and probability values
    x_coords_mc = x_indices.flatten() # Minecraft X coordinates
    y_coords_mc = y_indices.flatten() # Minecraft Y (Height) coordinates
    z_coords_mc = z_indices.flatten() # Minecraft Z coordinates
    probabilities = prob_matrix_oriented.flatten()

    # --- Create Scatter Plot (Viridis Colormap) ---
    # Plot mapping:
    # Plot X-axis <- Minecraft X data (x_coords_mc)
    # Plot Y-axis <- Minecraft Z data (z_coords_mc)
    # Plot Z-axis <- Minecraft Y (Height) data (y_coords_mc) <<< VERTICAL AXIS
    scatter = ax.scatter(x_coords_mc, z_coords_mc, y_coords_mc,
                         c=probabilities, # Color based on probability values
                         cmap='viridis',  # Use the viridis colormap
                         vmax=1,
                         vmin=0,
                         s=300,
                         edgecolors='k', linewidth=0.5)

    # --- Add Colorbar ---
    cbar = fig.colorbar(scatter, ax=ax, shrink=0.6, aspect=10)
    cbar.set_label('Probability')

    # --- Set Labels, Ticks, Limits, Title, and INVERT Z-AXIS ---
    ax.set_title(f"Structure Code {struct_code} Frequency Map (Color = Probability)") # Updated title
    ax.set_xlabel("X")
    ax.set_ylabel("Z")
    ax.set_zlabel("Y (Height)")

    ax.set_xticks(np.arange(4))
    ax.set_yticks(np.arange(4)) # Corresponds to Z data range
    ax.set_zticks(np.arange(4)) # Corresponds to Y (Height) data range

    ax.set_xlim(-0.5, 3.5)
    ax.set_ylim(-0.5, 3.5)
    ax.set_zlim(-0.5, 3.5)
    # --- Invert the Z-axis (which represents Y-Height) ---
    ax.invert_zaxis()
    # --- Optionally invert X too depending on preferred view ---
    # ax.invert_xaxis()

    # Adjust view angle if desired
    ax.view_init(elev=25., azim=-125)

    # --- Save and Close ---
    image_path = os.path.join(save_dir, f"heatmap_code_{struct_code}.png")
    try:
        plt.savefig(image_path, bbox_inches='tight', dpi=150)
    except Exception as e:
        print(f"Error saving heatmap for code {struct_code}: {e}")
    finally:
        plt.close(fig)

## block count frequency heatmap (style codes)

In [28]:
def plot_block_frequency_heatmap(
    style_code: int,
    block_freqs,                    # torch.Tensor or array of shape (C,)
    save_dir: str,
    block_converter,                # object with index_to_block(idx) -> blockID
    block_index_to_name: dict,      # maps Minecraft blockID → block name
    cmap: str = "Greys"
):
    """
    Plot and save a 1C heatmap of blocktype frequencies for one style code.

    Args:
        style_code (int): Identifier of the style code.
        block_freqs (Tensor or ndarray): 1D of length C, values in [0,1].
        save_dir (str): Directory where the PNG will be written.
        block_converter: Object with method index_to_block(idx) returning
                         the Minecraft blockID for channel idx.
        block_index_to_name (dict): Mapping from blockID to block name.
        cmap (str): Matplotlib colormap name.
    """
    # ensure numpy array on CPU
    if isinstance(block_freqs, torch.Tensor):
        # print('converting to numpy')
        freqs = block_freqs.detach().cpu().numpy()
    else:
        print('already numpy')
        freqs = np.array(block_freqs)

    # print(freqs)
    # print(freqs.shape)
    C = freqs.shape[0]
    heatmap = freqs.reshape(C, 1)   # shape (1, C)

    # Map each channel idx → Minecraft blockID → block name
    block_ids = [i for i in range(C)]
    print(block_ids)
    names = [block_index_to_name.get(bid, str(bid)) for bid in block_ids]

    # ensure output dir exists
    os.makedirs(save_dir, exist_ok=True)
    image_path = os.path.join(save_dir, f"block_heatmap_code_{style_code}.png")

    # plot
    plt.figure(figsize=(3, max(8, C * 0.2)))
    im = plt.imshow(
        heatmap,
        aspect='auto',
        cmap=cmap,
        vmin=0.0, vmax=1.0         # fixed scale across all codes
    )
     # ---- Bigger, labeled colorbar ----
    cbar = plt.colorbar(
        im,
        orientation='vertical',
        fraction=0.12,    # thicker bar
        pad=0.05          # more space from the heatmap
    )
    cbar.ax.tick_params(labelsize=8)               # larger tick labels
    cbar.set_label('Frequency', rotation=270, labelpad=15)

    plt.yticks(np.arange(C), names, fontsize=6)
    # only one column, no X-ticks needed
    plt.xticks([])

    plt.title(f"Style Code {style_code} Block Frequencies", pad=10)
    plt.tight_layout()
    plt.savefig(image_path, dpi=300)
    plt.close()

    return image_path

In [29]:
def calculate_positional_frequencies(fqgan):
    latent_depth, latent_height, latent_width = 6, 6, 6 # *** Adjust if different ***
    latent_shape = (latent_depth, latent_height, latent_width)
    num_structure_codes = fqgan.struct_codebook_size
    num_style_codes = fqgan.style_codebook_size
    # Initialize count tensors for each code
    struct_position_counts = {
        code: torch.zeros(latent_shape, dtype=torch.long, device='cpu')
        for code in range(num_structure_codes)
    }
    style_position_counts = {
        code: torch.zeros(latent_shape, dtype=torch.long, device='cpu')
        for code in range(num_style_codes)
    }
    total_samples_processed = 0

    print("Collecting positional frequencies from train_loader...")
    for batch_idx, batch in enumerate(train_loader):
        # Limit batches for testing?
        # if batch_idx > 20: break

        if torch.cuda.is_available():
            batch = batch.cuda()

        # Only need the encoding part
        style_indices, struct_indices = encode_and_quantize(fqgan, batch) # Get indices for the whole batch

        # Process each sample in the batch
        for sample_idx in range(struct_indices.shape[0]): # Iterate through batch dimension
            struct_indices_sample = struct_indices[sample_idx].cpu() # Get indices for one sample, move to CPU

            # Iterate through the latent grid dimensions (D, H, W)
            for i in range(latent_depth):
                for j in range(latent_height):
                    for k in range(latent_width):
                        struct_code = struct_indices_sample[i, j, k].item()
                        if 0 <= struct_code < num_structure_codes:
                            position_counts[struct_code][i, j, k] += 1
                        else:
                            print(f"Warning: Encountered out-of-bounds code {struct_code} at ({i},{j},{k})")
            total_samples_processed += 1

        # Optional: Print progress
        if (batch_idx + 1) % 50 == 0:
                print(f"Processed {batch_idx + 1} batches...")
    return position_counts

def plot_positional_frequency(position_counts, struct_code, save_dir):
    """
    Generates and saves a 3D scatter plot showing the positional frequency
    of a structure code within the latent grid, using a Viridis colormap and
    Minecraft coordinate conventions (Y=Height is vertical, Y=0 is bottom).

    Args:
        position_counts (dict): Maps struct_code to 6x6x6 count tensor.
        struct_code (int): The structure code to visualize.
        save_dir (str): The directory to save the plot image.
    """
    if struct_code not in position_counts:
        print(f"Error: Code {struct_code} not found in position_counts dictionary.")
        return

    counts_tensor = position_counts[struct_code].cpu() # Ensure it's on CPU
    latent_shape = counts_tensor.shape
    if len(latent_shape) != 3:
        print(f"Error: Count tensor for code {struct_code} is not 3D (shape: {latent_shape}).")
        return

    counts_numpy = counts_tensor.numpy()
    max_count = np.max(counts_numpy)

    if max_count == 0:
        print(f"Info: Code {struct_code} never appeared. Skipping visualization.")
        # Optional: Create an empty plot placeholder if desired
        return # Stop here if code never appeared

    # --- Plotting Setup ---
    # os.makedirs(save_dir, exist_ok=True)
    fig = plt.figure(figsize=(8, 8))
    ax = fig.add_subplot(111, projection='3d')

    # Generate grid indices (I, J, K) corresponding to tensor dimensions
    # Assume I=Depth(Z_mc), J=Height(Y_mc), K=Width(X_mc)
    i_indices, j_indices, k_indices = np.indices(latent_shape)

    # Flatten coordinates and the raw counts (for color mapping)
    i_coords_mc = i_indices.flatten() # Minecraft Z coordinates
    j_coords_mc = j_indices.flatten() # Minecraft Y (Height) coordinates
    k_coords_mc = k_indices.flatten() # Minecraft X coordinates
    frequencies = counts_numpy.flatten()

    # --- Create Scatter Plot (Viridis Colormap, Correct Axis Mapping) ---
    # cmap = plt.get_cmap('viridis')
    cmap = plt.get_cmap('Greys')
    # Plot mapping:
    # Plot X-axis <- Minecraft X data (k_coords_mc)
    # Plot Y-axis <- Minecraft Z data (i_coords_mc)
    # Plot Z-axis <- Minecraft Y (Height) data (j_coords_mc) <<< VERTICAL AXIS
    scatter = ax.scatter(k_coords_mc, i_coords_mc, j_coords_mc, # Correct mapping
                         c=frequencies, cmap=cmap, # Color based on frequency counts
                         s=150, # Adjust size as needed
                         alpha=0.8, # Add some transparency
                         # vmin=0, vmax=max_count, # Optional: Explicitly set color limits
                         edgecolors='grey', linewidth=0.5)

    # --- Add Colorbar ---
    cbar = fig.colorbar(scatter, ax=ax, shrink=0.6, aspect=20, pad=0.1)
    cbar.set_label("Code Occurrence Count")

    # --- Set Labels, Ticks, Limits, Title, and Invert Z-axis ---
    ax.set_title(f"Code {struct_code} Positional Frequency")
    # Label plot axes according to the *Minecraft dimension* plotted on them
    ax.set_xlabel("Z (Latent Dim K)")
    ax.set_ylabel("X (Latent Dim I)")
    ax.set_zlabel("Y (Height, Latent Dim J)") # Vertical axis

    # Set ticks based on the dimension size
    ax.set_xticks(np.arange(latent_shape[2])) # K dimension
    ax.set_yticks(np.arange(latent_shape[0])) # I dimension
    ax.set_zticks(np.arange(latent_shape[1])) # J dimension

    # Set limits for plot axes
    ax.set_xlim(-0.5, latent_shape[2] - 0.5)
    ax.set_ylim(-0.5, latent_shape[0] - 0.5)
    ax.set_zlim(-0.5, latent_shape[1] - 0.5)

    # --- Invert the Z-axis (which represents Y-Height) ---
    # ax.invert_zaxis() # Ensures Y=0 is at the bottom
    ax.invert_yaxis() # Ensures Y=0 is at the bottom

    # Adjust view angle
    ax.view_init(elev=20., azim=-75)
    # plt.show()
    # --- Save and Close ---
    image_path = os.path.join(save_dir, f"pos_freq_code_{struct_code}.png")
    try:
        plt.savefig(image_path, bbox_inches='tight', dpi=150)
    except Exception as e:
        print(f"Error saving positional frequency plot for code {struct_code}: {e}")
    finally:
        plt.close(fig)


def generate_all_positional_frequency_plots(position_counts, output_dir):
    print(f"Generating positional frequency plots for {len(position_counts)} codes...")
    os.makedirs(output_dir, exist_ok=True) # Ensure directory exists

    # Iterate through the codes present in the dictionary
    codes_to_plot = sorted(position_counts.keys())

    for i, code in enumerate(codes_to_plot):
        # Call the plotting function for the current code
        plot_positional_frequency(position_counts, code, output_dir)

        # Optional: Print progress
        if (i + 1) % 10 == 0 or (i + 1) == len(codes_to_plot):
            print(f"Generated plot {i+1}/{len(codes_to_plot)} (Code {code})")

    print(f"Finished generating positional frequency plots. Saved to: {output_dir}")

In [30]:
# # --- Example Usage ---
# # Assuming binary_structure_dict is populated as shown in your notebook snippet

# # 1. Calculate average frequency maps (needed for several metrics)
# avg_freq_maps = calculate_average_frequency_maps(binary_structure_dict)

# # 2. Calculate Sharpness (MAD)
# code_sharpness_mad, model_avg_sharpness_mad = calculate_sharpness_mad(avg_freq_maps)
# print(f"Model Average Sharpness (MAD): {model_avg_sharpness_mad}")

# # 3. Calculate Sharpness (Entropy)
# code_sharpness_entropy, model_avg_sharpness_entropy = calculate_sharpness_entropy(avg_freq_maps)
# print(f"Model Average Sharpness (Entropy): {model_avg_sharpness_entropy}") # Lower is better

# # 4. Calculate Consistency (Variance)
# code_consistency_var, model_avg_consistency_var = calculate_consistency_variance(binary_structure_dict)
# print(f"Model Average Consistency (Variance): {model_avg_consistency_var}") # Lower is better

# # 5. Calculate Consistency (Entropy) - Note: This reuses sharpness entropy calculation
# code_consistency_entropy, model_avg_consistency_entropy = calculate_consistency_entropy(binary_structure_dict)
# print(f"Model Average Consistency (Entropy): {model_avg_consistency_entropy}") # Lower is better

# # 6. Calculate Uniqueness
# model_avg_pairwise_dist, model_min_pairwise_dist = calculate_uniqueness_metrics(avg_freq_maps, distance_metric='mae')
# print(f"Model Average Pairwise Distance (MAE): {model_avg_pairwise_dist}") # Higher is better
# print(f"Model Minimum Pairwise Distance (MAE): {model_min_pairwise_dist}") # Higher is better

In [31]:
# # avg_freq_maps = calculate_average_frequency_maps(binary_structure_dict)
# # code_sharpness_mad, model_avg_sharpness_mad = calculate_sharpness_mad(avg_freq_maps)
# # code_sharpness_entropy, model_avg_sharpness_entropy = calculate_sharpness_entropy(avg_freq_maps)
# # code_consistency_var, model_avg_consistency_var = calculate_consistency_variance(binary_structure_dict)
# # model_avg_pairwise_dist, model_min_pairwise_dist = calculate_uniqueness_metrics(avg_freq_maps) # Needed for context below

# # You also need the number of chunks per code:
# num_chunks_per_code = {code: len(chunks) for code, chunks in binary_structure_dict.items()}

# # 1. Print the per-code metrics and flags
# analyze_per_code_metrics(
#     code_sharpness_mad,
#     code_sharpness_entropy,
#     code_consistency_var,
#     avg_freq_maps,
#     num_chunks_per_code
# )

# # 2. Find and print the most similar code pairs (redundant codes)
# min_dist, similar_pairs = find_most_similar_codes(avg_freq_maps, distance_metric='mae')
# print(f"\n--- Code Uniqueness Analysis ---")
# print(f"Minimum Pairwise Distance (MAE): {min_dist:.6f}")
# if similar_pairs:
#     print("Most Similar Code Pairs (Potential Redundancy):")
#     for pair in similar_pairs:
#         print(f"  - Codes {pair[0]} and {pair[1]}")
# else:
#     print("No highly similar code pairs found.")


# Run evaluation sweep

## Sweep params

In [32]:
# --- Configuration ---
CHECKPOINT_STEPS = [5000, 10000, 15000, 20000, 24999] # Example list of checkpoint steps to evaluate
MODEL_BASE_PATH = "../model_logs/FQGAN_2stagedecoder_logweighted3_32codes_cycleconsistency_STE_postquant_detachcycle_gumbel" # *** CHANGE THIS ***
# MODEL_HPARAMS_FILE = "/path/to/your/hparams.json" # *** CHANGE THIS ***
OUTPUT_DIR = "results/interpretability_analysis17"
# Thresholds for analyze_per_code_metrics (can be adjusted)
SHARP_MAD_THRESHOLD = 0.3
SHARP_ENTROPY_THRESHOLD = 0.5
CONS_VAR_THRESHOLD = 0.15
MIN_CHUNKS_THRESHOLD = 10
DISTANCE_METRIC = 'mae' # For uniqueness metrics
os.makedirs(OUTPUT_DIR, exist_ok=True)

In [33]:
all_metrics_data = {} # Store metrics for all checkpoints: {step: metrics_dict}

### Load Model

In [34]:
def load_model_for_step(step, base_path):
    """Loads the FQGAN model for a specific checkpoint step."""
    print(f"\n--- Loading Checkpoint Step: {step} ---")
    try:
        fqgan_hparams = dict_to_vcqgan_hparams(load_hparams_from_json(f"{base_path}"), 'minecraft')
        fqgan_hparams.load_step = step # Set the step to load
        fqgan = FQModel(fqgan_hparams)
        fqgan = load_fqgan_from_checkpoint(fqgan_hparams, fqgan) 
        print(f'Loaded model for step {step}')
        fqgan.eval() # Set to evaluation mode
        if torch.cuda.is_available():
             fqgan.cuda()
        return fqgan
    except Exception as e:
        print(f"Error loading model for step {step}: {e}")
        # Depending on the error, you might want to skip this step or stop
        return None # Indicate failure

### Run dataset through model, save dictionary of codes and corresponding chunks
TODO: add handling for the binary reconstruction & style codes

In [35]:
@torch.no_grad() # Ensure no gradients are computed
def collect_data_for_model(model, loader, block_converter):
    """
    Collects structure dictionary data (original and binary) for a given model.
    Returns original structure dict, binary structure dict, and chunk counts per code.
    """
    print("Collecting data from train_loader...")
    latent_depth, latent_height, latent_width = 6, 6, 6 # *** Adjust if different ***
    latent_shape = (latent_depth, latent_height, latent_width)
    num_structure_codes = fqgan.struct_codebook_size
    num_style_codes = fqgan.style_codebook_size

     # Initialize count tensors for each code
    struct_position_counts = {
        code: torch.zeros(latent_shape, dtype=torch.long, device='cpu')
        for code in range(num_structure_codes)
    }
    style_position_counts = {
        code: torch.zeros(latent_shape, dtype=torch.long, device='cpu')
        for code in range(num_style_codes)
    }
    total_samples_processed = 0

    air_idx = block_converter.get_air_block_index()
    structure_dict_original = {} # Store original chunks here
    style_dict_original = {}
    # style_dict = {} # Keep this if you plan to analyze style later
    i = 0
    # --- (Loop through loader, encode, decode as before) ---
    for batch_idx, batch in enumerate(loader):
        # Limit number of batches?
        # if batch_idx > 20: break

        if torch.cuda.is_available():
            batch = batch.cuda()

        for sample_idx in range(len(batch)):
            sample = batch[sample_idx].unsqueeze(0)
            try:
                style_indices, struct_indices = encode_and_quantize(model, sample)
                reconstructed, _ = decode_from_indices(style_indices, struct_indices, model, two_stage=True)

                struct_indices_cpu = struct_indices.squeeze(0).cpu()
                style_indices_cpu = style_indices.squeeze(0).cpu() # If needed for style
                reconstructed_cpu = reconstructed.cpu()

                for i in range(struct_indices_cpu.shape[0]): # D
                    for j in range(struct_indices_cpu.shape[1]): # H
                        for k in range(struct_indices_cpu.shape[2]): # W
                            # style_code = style_indices_cpu[i, j, k].item() # If needed
                            struct_code = struct_indices_cpu[i, j, k].item()
                            style_code  = style_indices_cpu[i, j, k].item()
                            x_start, y_start, z_start = i * 4, j * 4, k * 4
                            x_end, y_end, z_end = x_start + 4, y_start + 4, z_start + 4
                            block_chunk = reconstructed_cpu[x_start:x_end, y_start:y_end, z_start:z_end]

                            # Store the original chunk tensor
                            if struct_code not in structure_dict_original:
                                structure_dict_original[struct_code] = []
                            if style_code not in style_dict_original:
                                style_dict_original[style_code] = []

                            structure_dict_original[struct_code].append(block_chunk)
                            style_dict_original[style_code].append(block_chunk)
                            struct_position_counts[struct_code][i, j, k] += 1
                            style_position_counts[style_code][i, j, k] += 1
                            # --- (Style dictionary population if needed) ---

            except Exception as e:
                 print(f"Error processing sample {sample_idx} in batch {batch_idx}: {e}")
                 continue

    print(f"Finished data collection. Found {len(structure_dict_original)} unique structure codes and {len(style_dict_original)} unique style codes.")

    # Convert structure chunks to binary
    print("Converting structure chunks to binary...")
    binary_structure_dict = {}
    num_chunks_per_code = {}
    for struct_code, block_list in structure_dict_original.items(): # Iterate original dict
         num_chunks_per_code[struct_code] = len(block_list) # Get count first
         if block_list:
             try:
                 # Convert to binary for metrics
                 binary_chunks = [(b.cpu() != air_idx).to(dtype=torch.int) for b in block_list]
                 binary_structure_dict[struct_code] = binary_chunks
             except Exception as e:
                 print(f"Error converting chunks to binary for code {struct_code}: {e}")
                 # Ensure binary dict entry exists even if conversion failed? Or skip?
                 # binary_structure_dict[struct_code] = [] # Or handle differently
         # else: # No need for else, count is already 0 from len() above

    # Clean up GPU memory if possible
    # del style_dict # Keep structure_dict_original
    if torch.cuda.is_available():
       torch.cuda.empty_cache()

    print(f"Finished binary conversion. Processed {len(binary_structure_dict)} codes.")
    # Return the original dictionary as well
    return structure_dict_original, style_dict_original, binary_structure_dict, num_chunks_per_code, struct_position_counts, style_position_counts



In [36]:
def save_checkpoint_report(step, metrics, num_chunks_per_code, similar_pairs, min_dist, output_dir):
    """Saves a text report summarizing metrics for a single checkpoint."""
    report_path = os.path.join(output_dir, f"checkpoint_{step}_report.txt")
    print(f"Saving report for step {step} to {report_path}")

    with open(report_path, 'w') as f:
        f.write(f"--- Analysis Report for Checkpoint Step: {step} ---\n\n")

        f.write("--- Model-Level Metrics ---\n")
        f.write(f"Avg Sharpness (MAD):      {metrics['avg_sharpness_mad']:.4f} (Higher is better)\n")
        f.write(f"Avg Sharpness (Entropy):  {metrics['avg_sharpness_entropy']:.4f} (Lower is better)\n")
        f.write(f"Avg Consistency (Var):    {metrics['avg_consistency_var']:.4f} (Lower is better)\n")
        # Consistency Entropy is same as Sharpness Entropy in our current impl.
        # f.write(f"Avg Consistency (Entropy):{metrics['avg_consistency_entropy']:.4f} (Lower is better)\n")
        f.write(f"Avg Pairwise Dist ({DISTANCE_METRIC.upper()}): {metrics['avg_pairwise_dist']:.4f} (Higher is better)\n")
        f.write(f"Min Pairwise Dist ({DISTANCE_METRIC.upper()}): {min_dist:.6f} (Higher is better)\n\n")


        f.write("--- Code Uniqueness Analysis ---\n")
        f.write(f"Minimum Pairwise Distance ({DISTANCE_METRIC.upper()}): {min_dist:.6f}\n")
        if similar_pairs:
            f.write("Most Similar Code Pairs (Potential Redundancy):\n")
            for pair in similar_pairs:
                f.write(f"  - Codes {pair[0]} and {pair[1]}\n")
        else:
            f.write("No highly similar code pairs found.\n")
        f.write("\n")


        f.write("--- Per-Code Analysis ---\n")
        codes = sorted(list(metrics['code_sharpness_mad'].keys())) # Use keys from one of the dicts
        if not codes:
            f.write("No codes found with sufficient data for per-code analysis.\n")
        else:
            header = f"{'Code':<6} {'NumChunks':<10} {'Sharp(MAD)':<12} {'Sharp(Entr)':<12} {'Cons(Var)':<12} {'Flags':<20}\n"
            f.write(header)
            f.write("-" * (len(header) - 1) + "\n") # Adjust separator length

            for code in codes:
                sharp_mad = metrics['code_sharpness_mad'].get(code, float('nan'))
                sharp_entropy = metrics['code_sharpness_entropy'].get(code, float('nan'))
                cons_var = metrics['code_consistency_var'].get(code, float('nan'))
                n_chunks = num_chunks_per_code.get(code, 0)

                flags = []
                if n_chunks < MIN_CHUNKS_THRESHOLD:
                     flags.append(f"LOW_USAGE({n_chunks})")
                # Use np.isnan for checking because values are Python floats now
                if not np.isnan(sharp_mad) and sharp_mad < SHARP_MAD_THRESHOLD:
                    flags.append("LOW_SHARPNESS_MAD")
                if not np.isnan(sharp_entropy) and sharp_entropy > SHARP_ENTROPY_THRESHOLD:
                    flags.append("HIGH_SHARPNESS_ENT")
                if not np.isnan(cons_var) and cons_var > CONS_VAR_THRESHOLD:
                    flags.append("HIGH_CONSIST_VAR")

                cons_var_str = f"{cons_var:<12.4f}" if not np.isnan(cons_var) else f"{'N/A':<12}"

                f.write(f"{code:<6} {n_chunks:<10} {sharp_mad:<12.4f} {sharp_entropy:<12.4f} {cons_var_str} {', '.join(flags)}\n")
            f.write("-" * (len(header) -1 ) + "\n")

## Main Evaluation Loop

In [37]:
model_name_file = os.path.join(OUTPUT_DIR, "model_name.txt")
with open(model_name_file, "w", encoding="utf-8") as f:
    f.write(MODEL_BASE_PATH)
for step in CHECKPOINT_STEPS:
    # 1. Load Model
    fqgan = load_model_for_step(step, MODEL_BASE_PATH)
    fqgan.to('cuda')
    if fqgan is None:
        print(f"Skipping step {step} due to loading error.")
        continue

    # 2. Collect Data
    # Ensure train_loader is correctly defined/passed
    structure_dict_original, style_dict_original, binary_structure_dict, num_chunks_per_code, struct_position_counts, style_position_counts = collect_data_for_model(fqgan, train_loader, block_converter)

    if not binary_structure_dict: # Check binary dict as it's needed for metrics
         print(f"Skipping metrics/plotting for step {step} as no data was collected/converted.")
         del fqgan, structure_dict_original # Clean up original dict too
         if torch.cuda.is_available(): torch.cuda.empty_cache()
         continue

    # 3. Calculate Metrics
    print(f"Calculating metrics for step {step}...")
    metrics = {}
    avg_freq_maps = calculate_average_frequency_maps(binary_structure_dict)
    block_freq_maps = calculate_average_block_frequency_maps(style_dict_original, fqgan.in_channels)
    # pos_freq_maps = calculate_positional_frequencies(fqgan)

    if not avg_freq_maps:
         print(f"Skipping metrics for step {step} as no average frequency maps could be calculated (no valid codes?).")
         del fqgan, binary_structure_dict, num_chunks_per_code
         if torch.cuda.is_available(): torch.cuda.empty_cache()
         continue
    
    
    # --- Visualization Section ---
    print(f"Generating visualizations for step {step}...")
    struct_vis_dir = os.path.join(OUTPUT_DIR, f"checkpoint_{step}_visualizations/Struct_visualizations") # Combined dir
    style_vis_dir = os.path.join(OUTPUT_DIR, f"checkpoint_{step}_visualizations/Style_visualizations")
    os.makedirs(struct_vis_dir, exist_ok=True)
    os.makedirs(style_vis_dir, exist_ok=True)
    visualize_all_structure_codes(structure_dict_original, struct_vis_dir)
    visualize_all_structure_codes(style_dict_original, style_vis_dir, code_type="Style")
    generate_all_positional_frequency_plots(struct_position_counts, struct_vis_dir)
    generate_all_positional_frequency_plots(style_position_counts, style_vis_dir)
    for code, freq_map_tensor in avg_freq_maps.items():
        # --- Heatmap Generation ---
        plot_frequency_heatmap(freq_map_tensor, code, struct_vis_dir) # Save to combined dir
            
    for code, block_freq_map_tensor in block_freq_maps.items():
        # --- Heatmap Generation ---
        plot_block_frequency_heatmap(code, block_freq_map_tensor, style_vis_dir, block_converter, block_index_to_name) # Save to combined dir
        # --- Grid Visualization ---
        # if code in structure_dict_original and structure_dict_original[code]: # Check if code exists and has chunks
        #     grid_save_path = os.path.join(vis_dir, f"grid_vis_code_{code}.png")
        #     try:
        #         # Call the grid visualization function
        #         # Make sure block_converter is accessible here
        #         visualize_structure_grid(
        #             structure_dict=structure_dict_original,
        #             struct_code=code,
        #             save_path=grid_save_path,
        #             # block_converter=block_converter, # Pass if not global
        #             # You might need to adjust num_chunks, grid_size if defaults aren't desired
        #             num_chunks=16, # Example: visualize fewer chunks?
        #             grid_size=4  # Example: smaller grid?
        #         )
        #     except ImportError as e:
        #          print(f"PyVista import error, skipping grid visualization: {e}")
        #          # Optional: break the inner loop if PyVista isn't installed
        #          # break
        #     except Exception as e:
        #          print(f"Error generating grid visualization for code {code} at step {step}: {e}")
        #          # Continue to next code
        #          continue
        # else:
        #     print(f"Skipping grid visualization for code {code}: No original chunks found.")


    # print(f"Visualizations saved to: {vis_dir}")
    # --- End Visualization Section ---
    
    metrics['code_sharpness_mad'], metrics['avg_sharpness_mad'] = calculate_sharpness_mad(avg_freq_maps)
    metrics['code_sharpness_entropy'], metrics['avg_sharpness_entropy'] = calculate_sharpness_entropy(avg_freq_maps)
    metrics['code_consistency_var'], metrics['avg_consistency_var'] = calculate_consistency_variance(binary_structure_dict)
    # Note: Cons. Entropy = Sharp. Entropy
    # metrics['code_consistency_entropy'], metrics['avg_consistency_entropy'] = calculate_consistency_entropy(binary_structure_dict)
    metrics['avg_pairwise_dist'], metrics['min_pairwise_dist'] = calculate_uniqueness_metrics(avg_freq_maps, distance_metric=DISTANCE_METRIC)

    # Find most similar pairs (for reporting) - use the calculated min distance
    min_dist_actual, similar_pairs = find_most_similar_codes(avg_freq_maps, distance_metric=DISTANCE_METRIC)
    # Sanity check: min_dist_actual should be very close to metrics['min_pairwise_dist']
    if not np.isclose(min_dist_actual, metrics['min_pairwise_dist']):
         print(f"Warning: Mismatch in minimum distance calculation for step {step}: {min_dist_actual} vs {metrics['min_pairwise_dist']}")
         # Use the one from find_most_similar_codes for consistency in the report
         min_dist_report = min_dist_actual
    else:
         min_dist_report = metrics['min_pairwise_dist'] # Use the one stored in metrics


    # Add num_chunks for reference if needed later, but it's mainly for the report
    metrics['num_chunks_per_code'] = num_chunks_per_code
    metrics['most_similar_pairs'] = similar_pairs
    # Add avg_freq_maps if you want to visualize them later
    # metrics['avg_freq_maps'] = avg_freq_maps # Can consume a lot of memory!

    all_metrics_data[step] = metrics
    print("Metrics calculated.")

    # 4. Save Checkpoint Report
    save_checkpoint_report(step, metrics, num_chunks_per_code, similar_pairs, min_dist_report, OUTPUT_DIR)

    # 5. Clean up GPU memory before next loop iteration
    del fqgan, binary_structure_dict, num_chunks_per_code, avg_freq_maps, metrics
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

print("\n--- Finished evaluating all checkpoints ---")

# --- Generate Overall Summary Plots ---
if all_metrics_data:
    print("Generating overall summary plots...")
    steps = sorted(all_metrics_data.keys())
    num_plots = 5 # Adjust if you add/remove metrics
    fig, axes = plt.subplots(num_plots, 1, figsize=(10, 5 * num_plots), sharex=True)
    fig.suptitle('Model Interpretability Metrics vs. Checkpoint Step', fontsize=16)

    plot_configs = [
        {'key': 'avg_sharpness_mad', 'label': 'Avg Sharpness (MAD)', 'goal': 'max'},
        {'key': 'avg_sharpness_entropy', 'label': 'Avg Sharpness (Entropy)', 'goal': 'min'},
        {'key': 'avg_consistency_var', 'label': 'Avg Consistency (Var)', 'goal': 'min'},
        {'key': 'avg_pairwise_dist', 'label': f'Avg Pairwise Dist ({DISTANCE_METRIC.upper()})', 'goal': 'max'},
        {'key': 'min_pairwise_dist', 'label': f'Min Pairwise Dist ({DISTANCE_METRIC.upper()})', 'goal': 'max'}
    ]

    for i, config in enumerate(plot_configs):
        metric_key = config['key']
        metric_label = config['label']
        goal = config['goal']

        values = [all_metrics_data[step][metric_key] for step in steps]
        axes[i].plot(steps, values, marker='o', linestyle='-')
        axes[i].set_ylabel(metric_label)
        axes[i].grid(True)

        # Find and mark the best value
        if values:
            if goal == 'max':
                best_val = np.max(values)
                best_idx = np.argmax(values)
            else: # goal == 'min'
                best_val = np.min(values)
                best_idx = np.argmin(values)
            best_step = steps[best_idx]
            axes[i].axhline(best_val, color='r', linestyle='--', label=f'Best: {best_val:.4f} at step {best_step}')
            axes[i].plot(best_step, best_val, 'r*', markersize=10) # Mark best point
            axes[i].legend()
            axes[i].set_title(f"{metric_label} (Best is {'Higher' if goal == 'max' else 'Lower'})")
        else:
             axes[i].set_title(metric_label)


    axes[-1].set_xlabel("Checkpoint Step")
    plt.tight_layout(rect=[0, 0.03, 1, 0.97]) # Adjust layout to prevent title overlap
    plot_path = os.path.join(OUTPUT_DIR, "overall_metrics_evolution.png")
    plt.savefig(plot_path)
    print(f"Overall summary plot saved to {plot_path}")
    plt.close(fig) # Close the figure to free memory
else:
    print("No metrics data collected, skipping overall plots.")

print("Analysis complete.")


--- Loading Checkpoint Step: 5000 ---
using padding mode: reflect
With cycle consistency: True type: post_quant_conv, using gumbel: True
Using EMA quantizer
Using SlightlyLessDumbTwoStageGenerator
Detaching binary reconstruction from comp graph for final loss
NO biome supervision
Disentangle Ratio:  0.5
Loading fqgan_5000.th
Loaded model for step 5000
Collecting data from train_loader...


  attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)


Finished data collection. Found 32 unique structure codes and 32 unique style codes.
Converting structure chunks to binary...
Finished binary conversion. Processed 32 codes.
Calculating metrics for step 5000...
Generating visualizations for step 5000...
Found 32 Struct codes to visualize
Visualizing Struct code 0...
Saved visualization for structure code 0 to results/interpretability_analysis17\checkpoint_5000_visualizations/Struct_visualizations\Struct_code_0.png
Visualizing Struct code 1...
Saved visualization for structure code 1 to results/interpretability_analysis17\checkpoint_5000_visualizations/Struct_visualizations\Struct_code_1.png
Visualizing Struct code 2...
Saved visualization for structure code 2 to results/interpretability_analysis17\checkpoint_5000_visualizations/Struct_visualizations\Struct_code_2.png
Visualizing Struct code 3...
Saved visualization for structure code 3 to results/interpretability_analysis17\checkpoint_5000_visualizations/Struct_visualizations\Struct_co

In [38]:
block_freq_maps[0]

tensor([5.7483e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        3.2005e-07, 2.6068e-03, 0.0000e+00, 0.0000e+00, 1.2660e-01, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.0216e-01, 2.6398e-02,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 3.2571e-03, 9.5694e-05, 1.4940e-03,
        5.6123e-03, 0.0000e+00, 0.0000e+00, 3.2005e-07, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 1.3116e-03, 2.2051e-04, 2.3875e-04, 7.4630e-02,
        0.0000e+00, 2.2051e-03, 6.4009e-07, 7.8341e-02, 0.0000e+00, 0.0000e+00])

In [39]:
plot_block_frequency_heatmap(0, block_freq_maps[0], style_vis_dir, block_converter, block_index_to_name) # Save to combined dir

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41]


'results/interpretability_analysis17\\checkpoint_10000_visualizations/Style_visualizations\\block_heatmap_code_0.png'



In [40]:
plot_block_frequency_heatmap(block_freq_map_tensor, code, struct_vis_dir, block_converter, block_index_to_name) # Save to combined dir

already numpy


IndexError: tuple index out of range

# Biome Exploration

In [18]:
# check for biome data
for batch in train_loader:
    voxels, biomes = batch
    biome_value = block_converter.convert_to_original_biomes(biomes)
    print(biome_value[2])
    break

[[['forest' 'forest' 'forest' ... 'plains' 'plains' 'plains']
  ['forest' 'forest' 'forest' ... 'plains' 'plains' 'plains']
  ['forest' 'forest' 'forest' ... 'plains' 'plains' 'plains']
  ...
  ['forest' 'forest' 'forest' ... 'plains' 'plains' 'plains']
  ['forest' 'forest' 'forest' ... 'plains' 'plains' 'plains']
  ['forest' 'forest' 'forest' ... 'plains' 'plains' 'plains']]

 [['forest' 'forest' 'forest' ... 'plains' 'plains' 'plains']
  ['forest' 'forest' 'forest' ... 'plains' 'plains' 'plains']
  ['forest' 'forest' 'forest' ... 'plains' 'plains' 'plains']
  ...
  ['forest' 'forest' 'forest' ... 'plains' 'plains' 'plains']
  ['forest' 'forest' 'forest' ... 'plains' 'plains' 'plains']
  ['forest' 'forest' 'forest' ... 'plains' 'plains' 'plains']]

 [['forest' 'forest' 'forest' ... 'plains' 'plains' 'plains']
  ['forest' 'forest' 'forest' ... 'plains' 'plains' 'plains']
  ['forest' 'forest' 'forest' ... 'plains' 'plains' 'plains']
  ...
  ['forest' 'forest' 'forest' ... 'plains' 'plai

In [19]:
# store style code and its corresponding biome chunks and block chunks 
biome_dict = {}

for batch in train_loader:
    voxels, biomes = batch  # Now unpack both
    for sample_idx in range(len(voxels)):
        sample = voxels[sample_idx].unsqueeze(0).cuda()
        
        with torch.no_grad():
            style_indices, struct_indices = encode_and_quantize(fqgan, sample)
            reconstructed, binary_reconstructed = decode_from_indices(
                style_indices, struct_indices, fqgan, two_stage=True
            )

        biome_tensor = biomes[sample_idx]
        biome_names = block_converter.convert_to_original_biomes(biome_tensor)

        for i in range(style_indices.shape[1]):  
            for j in range(style_indices.shape[2]):  
                for k in range(style_indices.shape[3]):  
                    style_code = style_indices[0, i, j, k].item()

                    x_start, y_start, z_start = i * 4, j * 4, k * 4
                    x_end, y_end, z_end = x_start + 4, y_start + 4, z_start + 4

                    block = reconstructed[x_start:x_end, y_start:y_end, z_start:z_end]
                    biome_chunk = biome_names[x_start:x_end, y_start:y_end, z_start:z_end]

                    # Store both block and corresponding biome values
                    if style_code not in biome_dict:
                        biome_dict[style_code] = []

                    biome_dict[style_code].append({
                        "block": block,
                        "biome": biome_chunk
                    })

In [28]:
# example for the first sample in index 0.
print(biome_dict[0][0]["biome"])

[[['swampland' 'swampland' 'swampland' 'swampland']
  ['swampland' 'swampland' 'swampland' 'swampland']
  ['swampland' 'swampland' 'swampland' 'swampland']
  ['swampland' 'swampland' 'swampland' 'swampland']]

 [['swampland' 'swampland' 'swampland' 'swampland']
  ['swampland' 'swampland' 'swampland' 'swampland']
  ['swampland' 'swampland' 'swampland' 'swampland']
  ['swampland' 'swampland' 'swampland' 'swampland']]

 [['swampland' 'swampland' 'swampland' 'swampland']
  ['swampland' 'swampland' 'swampland' 'swampland']
  ['swampland' 'swampland' 'swampland' 'swampland']
  ['swampland' 'swampland' 'swampland' 'swampland']]

 [['swampland' 'swampland' 'swampland' 'swampland']
  ['swampland' 'swampland' 'swampland' 'swampland']
  ['swampland' 'swampland' 'swampland' 'swampland']
  ['swampland' 'swampland' 'swampland' 'swampland']]]


In [30]:
# find the frequency of biome for each code index
from collections import defaultdict

biome_frequencies = defaultdict(lambda: defaultdict(int))
biome_total_counts = defaultdict(int)


for style_code, entries in biome_dict.items():
    for entry in entries:
        biome_matrix = entry["biome"]  
        flattened_biomes = biome_matrix.flatten()

        # Update total biome count for the style code
        biome_total_counts[style_code] += len(flattened_biomes)

        # Count occurrences of each biome type
        for biome in flattened_biomes:
            biome_frequencies[style_code][biome] += 1

sorted_biome_frequency_dict = {}
for style_code in sorted(biome_frequencies.keys()):
    biome_counts = biome_frequencies[style_code]
    total_biomes = biome_total_counts[style_code] if biome_total_counts[style_code] > 0 else 1

    sorted_biome_frequency_dict[style_code] = {
        biome: biome_counts[biome] / total_biomes
        for biome in sorted(biome_counts.keys())  # Sort by biome name
    }

print(sorted_biome_frequency_dict[0])


{np.str_('beaches'): 0.012664664858581944, np.str_('birch_forest'): 0.04071086977991507, np.str_('cave'): 0.3224620340193419, np.str_('desert'): 0.042630535360211504, np.str_('extreme_hills'): 0.15213599000235506, np.str_('forest'): 0.14629722481786206, np.str_('mutated_extreme_hills'): 0.00047291291565056863, np.str_('ocean'): 0.02534025039694297, np.str_('plains'): 0.12334860329253747, np.str_('river'): 0.02441863998602153, np.str_('savanna'): 0.029020994294657033, np.str_('swampland'): 0.04432703922328327, np.str_('taiga'): 0.03617024105263958}


In [34]:
print(sorted_biome_frequency_dict)

{0: {np.str_('beaches'): 0.012664664858581944, np.str_('birch_forest'): 0.04071086977991507, np.str_('cave'): 0.3224620340193419, np.str_('desert'): 0.042630535360211504, np.str_('extreme_hills'): 0.15213599000235506, np.str_('forest'): 0.14629722481786206, np.str_('mutated_extreme_hills'): 0.00047291291565056863, np.str_('ocean'): 0.02534025039694297, np.str_('plains'): 0.12334860329253747, np.str_('river'): 0.02441863998602153, np.str_('savanna'): 0.029020994294657033, np.str_('swampland'): 0.04432703922328327, np.str_('taiga'): 0.03617024105263958}, 1: {np.str_('beaches'): 0.029120701688594434, np.str_('birch_forest'): 0.032540730439492904, np.str_('cave'): 0.1486817296373845, np.str_('desert'): 0.014762180068127129, np.str_('extreme_hills'): 0.10217897434294822, np.str_('forest'): 0.12550913049366125, np.str_('mutated_extreme_hills'): 0.0002714928591518485, np.str_('ocean'): 0.2171942873214788, np.str_('plains'): 0.1211893340416888, np.str_('river'): 0.07409541444003459, np.str_('s

In [18]:
# Biome Mapping
print(block_converter.index_to_biome)
Biome_mapping = {0: 'beaches', 1: 'birch_forest', 2: 'cave', 3: 'desert', 4: 'extreme_hills', 5: 'forest', 6: 'mutated_extreme_hills', 7: 'ocean', 8: 'plains', 9: 'river', 10: 'savanna', 11: 'swampland', 12: 'taiga'}

{0: 'beaches', 1: 'birch_forest', 2: 'cave', 3: 'desert', 4: 'extreme_hills', 5: 'forest', 6: 'mutated_extreme_hills', 7: 'ocean', 8: 'plains', 9: 'river', 10: 'savanna', 11: 'swampland', 12: 'taiga'}


In [37]:
# plot histogram of biome frequency. 
import matplotlib.pyplot as plt
import os

histogram_dir = os.path.join(os.getcwd(), "Biome_histograms_V1")
os.makedirs(histogram_dir, exist_ok=True)

# Assume your biome frequency dictionary is called `sorted_biome_frequency_dict`
for style_code, biome_data in sorted_biome_frequency_dict.items():
    biome_names = list(biome_data.keys())         # Biome name labels (strings)
    frequencies = list(biome_data.values())       # Corresponding frequency ratios

    # Sort biome names and frequencies alphabetically
    sorted_indices = sorted(range(len(biome_names)), key=lambda k: biome_names[k])
    biome_names = [biome_names[i] for i in sorted_indices]
    frequencies = [frequencies[i] for i in sorted_indices]

    # Plot histogram
    plt.figure(figsize=(12, 6))
    plt.bar(biome_names, frequencies, color="seagreen", alpha=0.75)
    plt.xlabel("Biome Name")
    plt.ylabel("Frequency Ratio")
    plt.title(f"Biome Frequency for Style Code {style_code}")
    plt.xticks(rotation=90, fontsize=7)
    plt.grid(axis="y", linestyle="--", alpha=0.6)

    # Save the histogram image
    save_path = os.path.join(histogram_dir, f"biome_histogram_style_code_{style_code}.png")
    plt.savefig(save_path, bbox_inches="tight")
    plt.close()

print(f"Biome histograms saved in directory: {histogram_dir}")

Biome histograms saved in directory: /root/autodl-tmp/minecraft_diffusion-master/Dual_codebook_2.0/Biome_histograms_V1
