In [None]:

import torch
import numpy as np
import random

# Set random seed for reproducibility

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(42)



import torch
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
import matplotlib.pyplot as plt
import matplotlib
from matplotlib.widgets import Slider  # Add this import
import numpy as np
from diffusion_models3d import Transformer, AbsorbingDiffusion, Block, CausalSelfAttention
from sampler_utils import retrieve_autoencoder_components_state_dicts, latent_ids_to_onehot3d, get_latent_loaders
from models3d import VQAutoEncoder, Generator
import matplotlib.pyplot as plt
from ipywidgets import interact, fixed
# from visualization_utils import MinecraftVisualizer
from data_utils import MinecraftVisualizer, get_minecraft_dataloaders, BlockConverter, MinecraftDataset
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from concurrent.futures import ThreadPoolExecutor
from PIL import Image
import torch.distributions as dists
from tqdm import tqdm
import gc
from scipy.stats import entropy
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.colors as mcolors

get_ipython().run_line_magic('matplotlib', 'widget')

In [None]:
blocks_to_cols = {
            0: (0.5, 0.25, 0.0),    # light brown
            10: 'black', # bedrock
            29: "#006400", # cacutus
            38: "#B8860B",  # clay
            60: "brown",  # dirt
            92: "gold",  # gold ore
            93: "green",  # grass
            115: "brown",  # ladder...?
            119: (.02, .28, .16, 0.8),  # transparent forest green (RGBA) for leaves
            120: (.02, .28, .16, 0.8),  # leaves2
            194: "yellow",  # sand
            217: "gray",  # stone
            240: (0.0, 0.0, 1.0, 0.4),  # water
            227: (0.0, 1.0, 0.0, .3), # tall grass
            237: (0.33, 0.7, 0.33, 0.3), # vine
            40: "#2F4F4F",  # coal ore
            62: "#228B22",  # double plant
            108: "#BEBEBE",  # iron ore
            131: "saddlebrown",  # log1
            132: "saddlebrown",  #log2
            95: "lightgray",  # gravel
            243: "wheat",  # wheat. lmao
            197: "limegreen",  # sapling
            166: "orange",  #pumpkin
            167: "#FF8C00",  # pumpkin stem
            184: "#FFA07A",  # red flower
            195: "tan",  # sandstone
            250: "white",  #wool 
            251: "gold",   #yellow flower
        }




def draw_latent_cuboid(fig, latent_coords, size=4):
    """
    Draw a transparent cuboid around the specified latent coordinates.
    
    Args:
        fig: matplotlib figure to draw on
        latent_coords: list of tuples, each containing (d,h,w) coordinates
        size: size of each latent cell in final space (default 4 for 6->24 upscaling)
    """
    def cuboid_data(o, sizes):
        l, w, h = sizes
        x = [[o[0], o[0] + l, o[0] + l, o[0], o[0]],  
             [o[0], o[0] + l, o[0] + l, o[0], o[0]],  
             [o[0], o[0] + l, o[0] + l, o[0], o[0]],  
             [o[0], o[0] + l, o[0] + l, o[0], o[0]]]  
        y = [[o[1], o[1], o[1] + w, o[1] + w, o[1]],  
             [o[1], o[1], o[1] + w, o[1] + w, o[1]],  
             [o[1], o[1], o[1], o[1], o[1]],          
             [o[1] + w, o[1] + w, o[1] + w, o[1] + w, o[1] + w]]
        z = [[o[2], o[2], o[2], o[2], o[2]],          
             [o[2] + h, o[2] + h, o[2] + h, o[2] + h, o[2] + h],
             [o[2], o[2], o[2] + h, o[2] + h, o[2]],  
             [o[2], o[2], o[2] + h, o[2] + h, o[2]]]  
        return np.array(x), np.array(y), np.array(z)

    ax = fig.gca()
    
    # Convert coordinates to numpy array for easier manipulation
    coords = np.array(latent_coords)
    
    # Find min and max for each dimension
    d_min, h_min, w_min = coords.min(axis=0)
    d_max, h_max, w_max = coords.max(axis=0)
    
    # Calculate origin and sizes
    origin = np.array([abs(5 - d_max)*size, w_min*size, h_min*size])
    sizes = (
        abs(d_max - d_min + 1) * size,  # length
        (w_max - w_min + 1) * size,     # width
        (h_max - h_min + 1) * size      # height
    )
    
    # Create and draw single cuboid
    X, Y, Z = cuboid_data(origin, sizes)
    ax.plot_surface(X, Y, Z, color='red', alpha=0.1)
    
    # Plot edges
    for i in range(4):
        ax.plot(X[i], Y[i], Z[i], color='red', linewidth=1)
    for i in range(4):
        ax.plot([X[0][i], X[1][i]], [Y[0][i], Y[1][i]], [Z[0][i], Z[1][i]], 
               color='red', linewidth=2)
    
    return fig

def visualize_chunk(voxels, figsize=(10, 10), elev=20, azim=45, highlight_latents=None):
    """
    Optimized version of the 3D visualization of a Minecraft chunk.
    """
    import matplotlib.pyplot as plt
    import numpy as np
    # print("voxels: ", voxels)
    # Convert one-hot to block IDs if needed
    if isinstance(voxels, torch.Tensor):
        if voxels.dim() == 4:  # One-hot encoded [C,H,W,D]
            voxels = voxels.detach().cpu()
            voxels = torch.argmax(voxels, dim=0).numpy()
        else:
            voxels = voxels.detach().cpu().numpy()
    # print("voxels: ", voxels)
    # Apply the same transformations as original
    voxels = voxels.transpose(2, 0, 1) # Moves axes from [D,H,W] to [W,D,H]
    voxels = np.rot90(voxels, 1, (0, 1))  # Rotate 90 degrees around height axis
    # print([block_id for block_id in np.unique(voxels) if block_id not in blocks_to_cols])
    # Create figure and 3D axis
    fig = plt.figure(figsize=figsize)
    ax = fig.add_subplot(111, projection='3d')
    # depth, height, width = voxels.shape
    # # Set the aspect ratio to match the data dimensions
    # ax.set_box_aspect((depth, height, width))
    # Generate a single boolean mask for each block type
    # print("unique: ", np.unique(voxels))
    block_masks = {int(block_id): (voxels == block_id) for block_id in np.unique(voxels) if int(block_id) in blocks_to_cols}
    # print("masks: ", block_masks)
    # Plot all block types with their respective colors
    for block_id, mask in block_masks.items():
        ax.voxels(mask, facecolors=blocks_to_cols[int(block_id)])
    
    # Plot remaining blocks in red with black edges
    # other_vox = (voxels != 5) & (voxels != -1) & (~np.any(np.stack(list(block_masks.values())), axis=0))
    other_vox = (voxels == 5) | (~np.isin(voxels, list(blocks_to_cols.keys()))) # Directly check for air blocks
    ax.voxels(other_vox, edgecolor="k", facecolors=(1, 0, 0, 0.5))
    
    # Set default view angle
    ax.view_init(elev=elev, azim=azim)

    if highlight_latents is not None:
        fig = draw_latent_cuboid(fig, highlight_latents)
    
    return fig





In [None]:
# ## Block Converter: Converts between block IDs and indices, needed for visualization
block_converter = BlockConverter.load_mappings('block_mappings.pt')


In [None]:
# # Minecraft Chunks Dataset


from torch.utils.data import DataLoader, random_split, Dataset

class MinecraftDataset(Dataset):
    def __init__(self, data_path, converter):
        # Load data and convert to int16 to save memory
        self.chunks = torch.from_numpy(np.load(data_path)).to(torch.int16)
        
        # Load pre-saved mappings
        self.converter = converter
        self.num_block_types = len(self.converter.block_to_index)
        
        # Convert blocks to indices once at initialization
        for old_block, new_idx in self.converter.block_to_index.items():
            self.chunks[self.chunks == old_block] = new_idx
            
        # Store air block index
        self.air_idx = self.converter.block_to_index[5]
        self.target_size = 24

        # Pad if needed
        pad_size = self.target_size - self.chunks.size(-1)
        if pad_size > 0:
            self.chunks = F.pad(self.chunks, 
                              (0, pad_size, 0, pad_size, 0, pad_size), 
                              value=self.air_idx)

        # Convert to one-hot [N, C, H, W, D]
        self.processed_chunks = F.one_hot(
            self.chunks.long(), 
            num_classes=self.num_block_types
        ).permute(0, 4, 1, 2, 3).float()
        
        # Free up memory by deleting original chunks
        del self.chunks
        
        print(f"Loaded {len(self.processed_chunks)} chunks of size {self.processed_chunks.shape[1:]}")
        print(f"Number of unique block types: {self.num_block_types}")

    def __getitem__(self, idx):
        return self.processed_chunks[idx]

    def convert_to_original_blocks(self, data):
        """Convert from indices back to original block IDs"""
        return self.converter.convert_to_original_blocks(data)

    def __len__(self):
        return len(self.processed_chunks)



def get_minecraft_dataloader(data_path, converter, batch_size=32, num_workers=0):
    """
    Creates a single dataloader for exploring the entire Minecraft chunks dataset.
    """
    # Create dataset
    dataset = MinecraftDataset(data_path, converter)
    
    # Create dataloader with memory pinning
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=False,  # No need to shuffle for exploration
        num_workers=num_workers,
        pin_memory=True,
    )


    
    print(f"\nDataloader details:")
    print(f"Total samples: {len(dataset)}")
    print(f"Batch size: {batch_size}")
    print(f"Number of batches: {len(dataloader)}")
    
    return dataloader




In [None]:
# # Load VQGAN Model


import os
from log_utils import log, load_stats, load_model
import copy
from hyperparams import HparamsVQGAN

# Loads hparams from hparams.json file in saved model directory
def load_hparams_from_json(log_dir):
    import json
    import os
    json_path = os.path.join(log_dir, 'hparams.json')
    
    if not os.path.exists(json_path):
        raise FileNotFoundError(f"No hparams.json file found in {log_dir}")
    
    with open(json_path, 'r') as f:
        hparams = json.load(f)

    return hparams

# turns loaded hparams json into propery hyperparams object
def dict_to_vcqgan_hparams(hparams_dict, dataset=None):
    # Determine which hyperparameter class to use based on the dataset
    if dataset == None:
        dataset = hparams_dict.get('dataset', 'MNIST')  # Default to MNIST if not specified
    
    vq_hyper = HparamsVQGAN(dataset)
    # Set attributes from the dictionary
    for key, value in hparams_dict.items():
        setattr(vq_hyper, key, value)
    
    return vq_hyper


def load_vqgan_from_checkpoint(H, vqgan):
    vqgan = load_model(vqgan, "vqgan", H.load_step, H.load_dir).cuda()
    vqgan.eval()
    return vqgan


def encode_and_quantize(vqgan, terrain_chunks):
    vqgan.eval()
    with torch.no_grad():
        encoded = vqgan.ae.encoder(terrain_chunks)
        quantized, _, quant_stats = vqgan.ae.quantize(encoded)
        print(f'zq shape: {quantized.size()}')
        latent_indices = quant_stats["min_encoding_indices"]
        print(f'latent_indices size: {latent_indices.size()}')
        latent_indices = latent_indices.view((encoded.size()[0], encoded.size()[2], encoded.size()[3]))
        print(f'latent_indices viewed size: {latent_indices.size()}')

    return quantized, latent_indices

In [None]:
model_path = '../model_logs/minecraft39ch_ce_3'
vqgan_hparams =  dict_to_vcqgan_hparams(load_hparams_from_json(f"{model_path}"), 'maps')
vqgan = VQAutoEncoder(vqgan_hparams)
vqgan = load_vqgan_from_checkpoint(vqgan_hparams, vqgan)
print(f'loaded from: {vqgan_hparams.log_dir}')
# This takes a while

train_loader= get_minecraft_dataloader(
        '../datasets/minecraft_chunks.npy',
        block_converter,
        batch_size=vqgan_hparams.batch_size,
        num_workers=0,
    )

all_blocks_index = []
for batch in train_loader:
    batch_indices = torch.argmax(batch, dim=1)
    all_blocks_index.append(batch_indices)
all_blocks_index = torch.cat(all_blocks_index, dim=0) 


from sampler_utils import generate_latents_from_loader3d
latents = generate_latents_from_loader3d(vqgan_hparams, vqgan, train_loader)

**extract all corresponding chunks from each code**

In [None]:
latents.shape

code_chunks_dict = {i: [] for i in range(vqgan_hparams.codebook_size)}

chunk_size = 4
reshaped_latents = latents.view(-1, 6, 6, 6)
N, D, H, W = reshaped_latents.shape

# extract all corresponding chunks from each code. 
for n in range(N):
    for d in range(D):
        for h in range(H):
            for w in range(W):
                code_idx = reshaped_latents[n, d, h, w].item()  # Get the code index
                
                # Compute the starting position in the block index tensor
                d_start, h_start, w_start = d * chunk_size, h * chunk_size, w * chunk_size
                
                # Ensure indices do not exceed dimensions of all_blocks_index
                if (d_start + chunk_size <= all_blocks_index.shape[1] and
                    h_start + chunk_size <= all_blocks_index.shape[2] and
                    w_start + chunk_size <= all_blocks_index.shape[3]):
                    
                    # Extract the 4x4x4 chunk from block indices
                    block_chunk = all_blocks_index[n, 
                                                   d_start:d_start + chunk_size, 
                                                   h_start:h_start + chunk_size, 
                                                   w_start:w_start + chunk_size]
                    
                    # Store the extracted chunk in the corresponding code index entry
                    code_chunks_dict[code_idx].append(block_chunk.cpu().numpy())

print("Code_chunks_dict is already finished. Now turn to have binary chunks...")

**Convert into Binary Code**

In [None]:
def create_binary_chunk_code_dict (code_chunks_dict):
    binary_code_chunks_dict = {i: [] for i in range(vqgan_hparams.codebook_size)}
    for i in range(vqgan_hparams.codebook_size):
        if code_chunks_dict[i]:
            for chunk in code_chunks_dict[i]:
                binary_chunk = []
                for element in chunk:
                    binary_chunk.append(np.where(element != 0, 1, 0).tolist())
                binary_code_chunks_dict[i].append(binary_chunk)
    print("Binary_code_chunks_dict is already finished. Now turn to have average_code_chunks_dict...")
    return binary_code_chunks_dict


**Calculate Entropy**

In [None]:
def calculate_structural_score(binary_code_chunks_dict):
    structural_score_value_dict = {i: [] for i in range(vqgan_hparams.codebook_size)}
    for idx in range(vqgan_hparams.codebook_size):
        if binary_code_chunks_dict[idx]:
            array = np.array(binary_code_chunks_dict[idx])
            num, d, h, w = array.shape
            entropy_values = np.zeros((d,h,w))
            for i in range(d):
                for j in range(h):
                    for k in range(w):
                        values = array[:,i,j,k]
                        p_1 = np.mean(values)
                        p_0 = 1-p_1
                        if p_1 == 0 or p_0 == 0:
                            entropy_values[i,j,k] = 0.0
                        else:
                            entropy_values[i,j,k] = entropy([p_0,p_1],base = 2)


            average_entropy = np.mean(entropy_values)
            structural_score_value_dict[idx] = average_entropy
        else:
            structural_score_value_dict[idx] = None
    print("Finish calculating structural scores...")
    return calculate_structural_score

**Draw Structural Heatmaps**


In [None]:
from matplotlib.colors import BoundaryNorm
# create graph bar for heatmap
counter = 0
levels = [0, 0.2, 0.4, 0.6, 0.8, 1]
#colors = ['#f0f0f0', '#d6d6d6', '#a3a3a3', '#636363', '#2c2c2c']
colors = ['#FFFFFF',  # White for 0-0.2
          '#FFDD00',  # Bright Yellow for 0.2-0.4
          '#FF6600',  # Bright Orange for 0.4-0.6
          '#CC0000',  # Red for 0.6-0.8
          '#2c2c2c']  # Black for 0.8-1.0
cmap = mcolors.ListedColormap(colors)
boundaries = levels
norm = BoundaryNorm(boundaries, cmap.N)

# draw structural code heatmap
def print_structural_code_chunks_heatmap(score_value_dict, binary_code_chunks_dict):
    for i in range(vqgan_hparams.codebook_size):
        if score_value_dict[i] is not None and score_value_dict[i] < 0.5:
            num, _, _, _ = np.array(binary_code_chunks_dict[i]).shape
            freq_values = np.zeros((4,4,4), dtype = float)
            '''
            for chunk in binary_code_chunks_dict[i]:
                freq_values += chunk
            freq_values = freq_values/num
            '''
            chunks_array = np.array(binary_code_chunks_dict[i])
            freq_values = np.mean(chunks_array, axis = 0)

            #classified_freq = np.vectorize(get_color)(freq_values.flatten())

            
            x,y,z = np.indices((4,4,4))
            x = x.flatten()
            y = y.flatten()
            z = z.flatten()
            #freq_values = frequency.flatten()
            #freq_values = np.full_like(x, frequency, dtype=float)
            #print(f"freq_values {i}: ", freq_values)
            #freq_values = freq_values.flatten()
            #freq_values = classified_freq.flatten()
            #classified_freq = np.vectorize(get_color)(freq_values)
            # print("freq_values: ", freq_values)
            #normalized_freq = freq_values / np.max(freq_values)

            #norm = mcolors.Normalize(vmin=np.min(freq_values), vmax = np.max(freq_values))
            # print("Norm: ", norm)
            fig = plt.figure(figsize = (8,8))
            ax = fig.add_subplot(111,projection = '3d')
            #cmap = plt.cm.hot_r
            sc = ax.scatter(x,y,z,c=freq_values, cmap = cmap, norm = norm, s = 500, edgecolors = 'k')

            cbar = plt.colorbar(sc, ax = ax, shrink = 0.5)
            cbar.set_ticks([0, 0.2, 0.4, 0.6, 0.8, 1.0])
            cbar.set_ticklabels(['0', '0.2', '0.4', '0.6', '0.8', '1.0']) 
            cbar.set_label("Frequency of 1s")
            ax.set_xlabel('X')
            ax.set_ylabel("Y")
            ax.set_zlabel("Z")
            ax.set_title(f"3D Frequency Heatmap of 1s in 4x4x4 Vectors (Index: {i}, Structural Score: {score_value_dict[i]})")
            filename = f"heatmap_{i}.png"
            plt.savefig(filename, dpi=300)
            print(f"The heatmap of code index {i} is saved")
            plt.close(fig)
            #counter += 1
            #if(counter == 5):
                #break

**Low Entropy Codes (High Structural Score)**

In [None]:
# find out all code indices with a score smaller than 0.5
def low_entropy_structural_codes(score_value_dict):
    code_indices_with_low_entropy = []
    for i in range(vqgan_hparams.codebook_size):
        if score_value_dict[i] is not None and score_value_dict[i] < 0.5:
            code_indices_with_low_entropy.append(i)
    return code_indices_with_low_entropy

**Visualization Testing**

In [None]:
import os
for code_idx in code_indices_with_low_entropy:
    if code_chunks_dict[code_idx]:
        num_chunks = len(code_chunks_dict[code_idx])
        if num_chunks <= 20:
            selected_indices = range(num_chunks)
        else:
            selected_indices = random.sample(range(num_chunks), 20)
        save_dir = f'visualize_chunks_code_{code_idx}'
        os.makedirs(save_dir, exist_ok = True)

        for i, idx in enumerate(selected_indices):
            first_chunk = code_chunks_dict[code_idx][idx]
            if isinstance(first_chunk, torch.Tensor):
                first_chunk = first_chunk.detach().cpu().numpy()

            # convert to original ids
            first_chunk_tensor = torch.tensor(first_chunk, dtype=torch.int64)
            if first_chunk_tensor.dim() == 3:  # Shape: (4, 4, 4)
                first_chunk_tensor = first_chunk_tensor.unsqueeze(0) 
            original_blocks = block_converter.convert_to_original_blocks(first_chunk_tensor)
            original_blocks = original_blocks.squeeze(0).cpu().numpy()

            # original_blocks = block_converter.convert_to_original_blocks(first_chunk)
            # print("Original block chunks: ", original_blocks)


            fig = visualize_chunk(
                original_blocks, 
                figsize=(10, 10), 
                elev=20,  # Elevation angle for the 3D plot
                azim=45   # Azimuth angle for the 3D plot
            )
            #routation
            # first_chunk = first_chunk.transpose(2, 0, 1) # Moves axes from [D,H,W] to [W,D,H]
            # first_chunk = np.rot90(first_chunk, 1, (0, 1))
            # print("First chunk: ", first_chunk)




            filename = os.path.join(save_dir, f"visualized_chunk_code_{code_idx}_{idx}.png")
            fig.savefig(filename, dpi=300)
            print(f"Saved: {filename}")
            plt.close(fig)
        print(f"All visualizations saved in: {save_dir}")
    else:
        print(f"No chunks available for code index {code_idx}")

**Structural Dictionary**

In [None]:
# a result dictionary where the first element is structural score and the second is the number of chunks
def get_strucutural_result(binary_code_chunks_dict, score_value_dict, code_indices_with_low_entropy):
    structure_result_dict = {
        code_idx: [float(score_value_dict[code_idx]), len(binary_code_chunks_dict[code_idx])]
        for code_idx in code_indices_with_low_entropy
        if code_idx in binary_code_chunks_dict and code_idx in score_value_dict
    }
    return structure_result_dict

**Visualize Structural Histogram**

In [None]:
sorted_items = sorted(structure_result_dict.items(), key=lambda x: x[1][1], reverse=True)
sorted_indices = [item[0] for item in sorted_items]  # Sorted code indices by number of chunks
sorted_scores = [item[1][0] for item in sorted_items]  # Corresponding structural scores
sorted_chunk_counts = [item[1][1] for item in sorted_items]

bar_width = 0.8
x_positions = np.arange(len(sorted_indices))
plt.figure(figsize=(18, 10))
plt.bar(x_positions, sorted_scores, color='navy', alpha=0.9, edgecolor='black', linewidth = 0.6, width = bar_width)
for x, y, chunk_count in zip(x_positions, sorted_scores, sorted_chunk_counts):
    plt.text(x, y + 0.02, str(chunk_count), ha='center', fontsize=10, fontweight='bold', rotation=90)
plt.xlabel("Code Index (Sorted by Number of Chunks)", fontsize = 12)
plt.ylabel("Structural Score", fontsize = 12)
plt.title("Histogram of Structural Scores for VQGAN Codebook Entries", fontsize = 14)
plt.xticks(x_positions, sorted_indices, rotation=90, fontsize=10)
plt.grid(axis='y', linestyle='--', alpha=0.6)
plt.savefig("structural_scores_histogram.png", dpi=300, bbox_inches="tight")
plt.close()


**Explore Style Information**

In [5]:
from collections import Counter
# explore the style information
# count the frequency of each type of blocks
def find_num_blocks():
    num_blocks = {i: None for i in range(vqgan_hparams.codebook_size)}
    for i, chunks in code_chunks_dict.items():
        temp_dic = {a: 0 for a in range(39)}
        total_num = 0
        if code_chunks_dict[i]:
            for chunk in code_chunks_dict[i]:
                flattened_chunk = chunk.flatten()
                block_count = Counter(flattened_chunk)
                block_count = {int(k):v for k,v in block_count.items()}
                for index, value in block_count.items():
                    temp_dic[index] += value
                    total_num += value
        if total_num > 0:
            for idx in temp_dic.keys():
                temp_dic[idx] /= total_num
        num_blocks[i] = temp_dic
    return num_blocks  

**Calculate Style Score**

In [None]:
# calculate the style score based on 
def calculate_style_score(): 
    style_scores = {}
    for i in range(39):
        style_scores[i] = 0
    block_counters = []
    final_style_score = {i: 0 for i in range(vqgan_hparams.codebook_size)}
    for i, chunks in code_chunks_dict.items():
    # for i in range(512):
        single_score = 0
        if code_chunks_dict[i]:
            chunk_scores = []
            for chunk in code_chunks_dict[i]:
                flattened_chunk = chunk.flatten()
                block_count = Counter(flattened_chunk)
                block_count = {int(k):v for k,v in block_count.items()}

                style_scores = {i: 0 for i in range(39)}  
                for key in block_count.keys():
                    style_scores[key] = block_count[key]

                # print("Block_count: ", block_count)
                
                # fi = list(block_count.values())
                fi = list(style_scores.values())
                max_F = max(fi)
                # num_F = len(fi)
                total_types_of_blocks = 39

                # if len(fi) == 1:
                #    style_score = ((max_F ** 2)/num_F) ** 0.5
                #else:
                #    style_score = (sum((max_F - f)**2 for f in fi)/num_F)** 0.5
                style_score =  (sum((max_F - f) **2 for f in fi)/(total_types_of_blocks - 1)) ** 0.5
                single_score += style_score

                #chunk_scores.append(style_score)
                #block_counters.append([block_count, style_score])
            # print("Chunk_scores: ", chunk_scores)
            #sorted_block_counters = sorted(block_counters, key = lambda x: x[1])
            single_score = single_score / len(code_chunks_dict[i])
            final_style_score[i] = single_score
            #print("Score: ", single_score)
            #break
    return final_style_score

def find_high_style_codes(final_style_score):
    high_style_score_code_index = {}
    for i in final_style_score.keys():
        if final_style_score[i] > 45:
            high_style_score_code_index[i] = [final_style_score[i], len(code_chunks_dict[i])]
    sorted_high_style_score_code_index = dict(sorted(
        high_style_score_code_index.items(),
        key=lambda x: x[1][1],  # Sort by len(code_chunks_dict[i]), which is stored at index 1
        reverse=True  # Optional: Sort in descending order
    ))
    return sorted_high_style_score_code_index
# print("High Style Score Codes: ", high_style_score_code_index)

**Visualize Style Chunks**

In [6]:
def print_style_code_chunks(sorted_high_style_score_code_index, code_chunks_dict):
    import os
    # print("Converted block chunks: ", code_chunks_dict[code_idx][0])
    for code_idx in sorted_high_style_score_code_index.keys():
        if code_chunks_dict[code_idx]:
            num_chunks = len(code_chunks_dict[code_idx])
            if num_chunks <= 10:
                selected_indices = range(num_chunks)
            else:
                selected_indices = random.sample(range(num_chunks), 10)
            save_dir = f'visualize_chunks_code_{code_idx}'
            os.makedirs(save_dir, exist_ok = True)

            for i, idx in enumerate(selected_indices):
                first_chunk = code_chunks_dict[code_idx][idx]
                if isinstance(first_chunk, torch.Tensor):
                    first_chunk = first_chunk.detach().cpu().numpy()

                # convert to original ids
                first_chunk_tensor = torch.tensor(first_chunk, dtype=torch.int64)
                if first_chunk_tensor.dim() == 3:  # Shape: (4, 4, 4)
                    first_chunk_tensor = first_chunk_tensor.unsqueeze(0) 
                original_blocks = block_converter.convert_to_original_blocks(first_chunk_tensor)
                original_blocks = original_blocks.squeeze(0).cpu().numpy()

                # original_blocks = block_converter.convert_to_original_blocks(first_chunk)
                # print("Original block chunks: ", original_blocks)


                fig = visualize_chunk(
                    original_blocks, 
                    figsize=(10, 10), 
                    elev=20,  # Elevation angle for the 3D plot
                    azim=45   # Azimuth angle for the 3D plot
                )
                #routation
                # first_chunk = first_chunk.transpose(2, 0, 1) # Moves axes from [D,H,W] to [W,D,H]
                # first_chunk = np.rot90(first_chunk, 1, (0, 1))
                # print("First chunk: ", first_chunk)

                filename = os.path.join(save_dir, f"visualized_chunk_code_{code_idx}_{idx}.png")
                fig.savefig(filename, dpi=300)
                print(f"Saved: {filename}")
                plt.close(fig)
            print(f"All visualizations saved in: {save_dir}")
        else:
            print(f"No chunks available for code index {code_idx}")


def plot_style_histogram(sorted_high_style_score_code_index):
    sorted_indices = list(sorted_high_style_score_code_index.keys())
    sorted_scores = [val[0] for val in sorted_high_style_score_code_index.values()]
    sorted_chunk_counts = [val[1] for val in sorted_high_style_score_code_index.values()]

    bar_width = 1.5
    x_positions = np.arange(len(sorted_indices))
    plt.figure(figsize=(60, 20))
    plt.bar(x_positions, sorted_scores, color='navy', alpha=0.9, edgecolor='black', linewidth = 0.6, width = bar_width)
    for x, y, chunk_count in zip(x_positions, sorted_scores, sorted_chunk_counts):
        plt.text(x, y + 2.5, str(chunk_count), ha='center', fontsize=16, fontweight='bold', rotation=90)
    plt.xlabel("Code Index (Sorted by Number of Chunks)", fontsize = 20)
    plt.ylabel("Style Score", fontsize = 20)
    plt.title("Histogram of Style Scores for VQGAN Codebook Entries", fontsize = 22)
    plt.xticks(x_positions, sorted_indices, rotation=90, fontsize=16)
    plt.grid(axis='y', linestyle='--', alpha=0.6)
    plt.savefig("style_scores_histogram.png", dpi=300, bbox_inches="tight")
    plt.close() 

**Find Most Common Block Types**

In [None]:
blocks_to_types = {
    0: "air",
    1: "None", 2: "None", 3: "cacutus", 4: "None", 5: "clay", 6: "coal ore", 7: "None", 8: "None", 9: "dirt", 10: "double plant", 11: "None", 12: "None", 13: "gold ore", 14: "grass", 15: "gravel", 16: "iron ore", 17: "None", 18: "leaves", 19: "leaves", 20: "log1", 21: "log2", 22: "None", 23: "None", 24: "None", 25: "pumpkin stem", 26: "red flower", 27: "None", 28: "None", 29: "sand", 30: "sandstone", 31: "None", 32: "stone", 33: "None", 34: "tall grass", 35: "vine", 36: "water", 37: "None", 38: "yellow flower"
}

def find_most_common_block_types(sorted_high_style_score_code_index, code_chunks_dict):
    total_block_counter = {}
    for code_idx in sorted_high_style_score_code_index.keys():
        code_block_count = Counter()
        total_number_of_blocks = 0
        for chunk in code_chunks_dict[code_idx]:
            flattened_chunk = chunk.flatten()
            chunk_block_count = Counter(flattened_chunk)
            chunk_block_count = {int(k): v for k, v in chunk_block_count.items()}
            total_number_of_blocks += sum(chunk_block_count.values())
            code_block_count.update(chunk_block_count)
        if total_number_of_blocks > 0:
            code_block_percentage = {k: v / total_number_of_blocks for k, v in code_block_count.items()}
            code_block_percentage = dict(sorted(code_block_percentage.items(), key=lambda item: item[1], reverse=True))
        # code_block_count = dict(sorted(code_block_count.items(), key = lambda item: item[1], reverse = True))
        total_block_counter[code_idx] = code_block_percentage
    return total_block_counter

def print_most_common_type_and_frequency(total_block_counter):
    for key, sub_dict in total_block_counter.items():
        top_3_items = sorted(sub_dict.items(), key=lambda item: item[1], reverse=True)[:3]
        print(f"Code index: {key}")
        for sub_key, value in top_3_items:
            if sub_key in blocks_to_types.keys():
                print(f"{sub_key} ({blocks_to_types[sub_key]}): {value}")
            else:
                print(f"{sub_key}: {value}")
        print()

**Visualization of Blocks Types in Histogram**

In [None]:
def block_freq_histogram(blocks_to_code_indices, num_blocks):
    os.makedirs('Block_Freq_Histogram', exist_ok=True)
    for block_id, code_indices in blocks_to_code_indices.items():
        if not code_indices:  # Skip if no code indices for this block
            continue
        block_name = blocks_to_types.get(block_id, f"Unknown")
        frequencies = [num_blocks[code_index][block_id] for code_index in code_indices]
        plt.figure(figsize=(12, 6))
        plt.bar(code_indices, frequencies, color='skyblue', edgecolor='black')
        plt.xlabel('Code Index', fontsize=12)
        plt.ylabel('Frequency', fontsize=12)
        plt.xticks(ticks=range(len(code_indices)), labels=code_indices, rotation=90, fontsize=8)
        plt.title(f'Histogram of Block {block_id}: "{block_name}" Frequencies Across Code Indices', fontsize=14)
        plt.grid(axis='y', linestyle='--', alpha=0.7)

        # Save the plot
        filename = f'Block_Freq_Histogram/block_{block_id}_{block_name}.png'
        plt.tight_layout()
        plt.savefig(filename, dpi=300, bbox_inches='tight')
        plt.close()
        print(f'Histogram saved for block "{block_name}" (Block ID: {block_id}).')




def block_type_histogram(total_block_counter):
    os.makedirs("Block_Type_Histogram", exist_ok = True)
    for code_idx, block_freq in total_block_counter.items():
        filtered_blocks = {block_id: freq for block_id, freq in block_freq.items() if freq > 0}
        if not filtered_blocks:
            continue
        block_names = [blocks_to_types.get(block_id, f"Unknown({block_id})") for block_id in filtered_blocks.keys()]
        frequencies = list(filtered_blocks.values())
        plt.figure(figsize=(12, 6))
        plt.bar(block_names, frequencies, color='skyblue', edgecolor='black')
        plt.xlabel('Block Type', fontsize=12)
        plt.ylabel('Frequency', fontsize=12)
        plt.title(f'Block Types Histogram for Code Index {code_idx}', fontsize=14)


        # Add grid for better readability
        plt.grid(axis='y', linestyle='--', alpha=0.7)

        # Save the plot to the directory
        filename = f'Block_Type_Histogram/code_index_{code_idx}.png'
        plt.tight_layout()
        plt.savefig(filename, dpi=300, bbox_inches='tight')
        plt.close()

        print(f'Histogram saved for code index {code_idx}.')

**Find The Max Frequency Block of Each Code**

In [None]:
def find_max_frequency_code_index(num_blocks):
    # Initialize the dictionary to store the code index with the highest frequency for each block ID
    blocks_to_code_indices = {block_id: [] for block_id in range(39)}
    for code_index, block_freqs in num_blocks.items():
        if block_freqs:
            bid = None
            freq = None
            # Find the maximum frequency for the current code index
            for id, f in block_freqs.items():
                if bid is None:
                    bid = id
                    freq = f 
                else:
                    if f > freq:
                        bid = id
                        freq = f
            blocks_to_code_indices[bid].append(code_index)
    return blocks_to_code_indices



final_style_score = calculate_style_score()
sorted_high_style_score_code_index = find_high_style_codes(final_style_score)
num_blocks = find_num_blocks()
blocks_to_code_indices = find_max_frequency_code_index(num_blocks)
total_block_counter = find_most_common_block_types(sorted_high_style_score_code_index, code_chunks_dict)

**Explore Spatial Relationship**

In [None]:
def find_the_spatial_relation(reshaped_latents):
    N, D, H, W = reshaped_latents.shape
    code_spatial_frequency = {i: np.zeros((6,6,6), dtype = int) for i in range (vqgan_hparams.codebook_size)}
    for n in range(N):
        for d in range(D):
            for h in range(H):
                for w in range(W):
                    code_idx = reshaped_latents[n, d, h, w].item()
                    code_spatial_frequency[code_idx][d,h,w] += 1
    return code_spatial_frequency

os.makedirs('Spatial_Heatmaps', exist_ok=True)
def plot_spatial_frequency(code_spatial_frequency):
    for code_index, count_array in code_spatial_frequency.items():
        total_num_chunks = np.sum(count_array)
        freq_values = count_array / total_num_chunks if total_num_chunks > 0 else count_array
        x,y,z = np.indices((6,6,6))
        x,y,z = x.flatten(), y.flatten(), z.flatten()
        freq_values_flat = freq_values.flatten()

        fig = plt.figure(figsize = (10,10))
        ax = fig.add_subplot(111,projection = '3d')
        #cmap = plt.cm.hot_r
        sc = ax.scatter(x,y,z,c=freq_values_flat, cmap = cmap, norm = norm, s = 500, edgecolors = 'k')


        cbar = plt.colorbar(sc, ax = ax, shrink = 0.5)
        cbar.set_ticks([0, 0.2, 0.4, 0.6, 0.8, 1.0])
        cbar.set_ticklabels(['0', '0.2', '0.4', '0.6', '0.8', '1.0']) 
        cbar.set_label("Frequency in Each Position")
        ax.set_xlabel('X')
        ax.set_ylabel("Y")
        ax.set_zlabel("Z")
        ax.set_title(f"3D Heatmap of Frequency in Each Position (Index: {code_index}). Total Number of Chunks: {total_num_chunks}")
        filename = f"Spatial_Heatmaps/spatial_heatmap_{code_index}.png"
        plt.savefig(filename, dpi=300)
        print(f"The spatial heatmap of code index {code_index} is saved")
        plt.close(fig)
        
def find_spatial_dominance(code_spatial_frequency):
    code_spatial_dominance = {i: 0 for i in range (vqgan_hparams.codebook_size)}
    for code_index, count_array in code_spatial_frequency.items():
        total_num_chunks = np.sum(count_array)
        freq_values = np.max(count_array / total_num_chunks) if total_num_chunks > 0 else 0
        code_spatial_dominance[code_index] = [freq_values, total_num_chunks]
    return code_spatial_dominance

**Plot Spatial Heatmaps**

In [None]:
def plot_spatial_dominance_histogram(code_spatial_dominance):
    # Extract code indices, maximum frequencies, and total chunk counts
    code_indices = list(code_spatial_dominance.keys())
    freq_values = [code_spatial_dominance[i][0] for i in code_indices]
    total_num_chunks = [code_spatial_dominance[i][1] for i in code_indices]

    # Combine data and sort by total_num_chunks in descending order
    combined_data = list(zip(code_indices, freq_values, total_num_chunks))
    sorted_data = sorted(combined_data, key=lambda x: x[2], reverse=True)  # Sort by total_num_chunks
    sorted_data = [entry for entry in sorted_data if entry[2] > 10]
    # Unzip sorted data
    sorted_code_indices, sorted_freq_values, sorted_total_num_chunks = zip(*sorted_data)

    # Plotting the histogram
    plt.figure(figsize=(30, 10))
    plt.bar(range(len(sorted_code_indices)), sorted_freq_values, color='skyblue', edgecolor='black', width=0.6)

    plt.xticks(ticks=np.arange(len(sorted_code_indices)),  # Positions of the bars
               labels=sorted_code_indices,  # Actual code indices as labels
               rotation=90, fontsize=8)
    # Axis labels and title
    plt.xlabel('Code Index (Sorted by Total Chunks)', fontsize=12)
    plt.ylabel('Maximum Frequency', fontsize=12)
    plt.title('Histogram of Maximum Frequency per Code Index (Sorted by Total Chunks)', fontsize=14)
    plt.grid(axis='y', linestyle='--', alpha=0.7)

    # Display the plot
    plt.tight_layout()
    plt.savefig("spatial_histogram.png", dpi=300, bbox_inches="tight")
    plt.close


**Block ids to Original Block Types**

In [None]:
block_converter_to_original = {}
# block_converter_to_original.keys = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38]
original_block_ids = [5, 26, 27, 29, 35, 38, 40, 41, 56, 60, 62, 83, 84, 92, 93, 95, 108, 118, 119, 120, 131, 132, 138, 139, 140, 166, 184, 187, 192, 194, 195, 204, 217, 222, 227, 237, 240, 241, 251]
for i in range(39):
    block_converter_to_original[i] = original_block_ids[i]

**Block Frequency Dictionary**

In [None]:
#find the percentage frequency of each block type for a specific code.
def find_blocks_frequency(code_chunks_dict):
    total_block_counter = {}
    for code_idx in range(vqgan_hparams.codebook_size):
        code_block_count = Counter()
        total_number_of_blocks = 0
        for chunk in code_chunks_dict[code_idx]:
            flattened_chunk = chunk.flatten()
            chunk_block_count = Counter(flattened_chunk)
            chunk_block_count = {int(k): v for k, v in chunk_block_count.items()}
            total_number_of_blocks += sum(chunk_block_count.values())
            code_block_count.update(chunk_block_count)
        if total_number_of_blocks > 0:
            code_block_percentage = {k: v / total_number_of_blocks for k, v in code_block_count.items()}
            code_block_percentage = dict(sorted(code_block_percentage.items(), key=lambda item: item[1], reverse=True))
        # code_block_count = dict(sorted(code_block_count.items(), key = lambda item: item[1], reverse = True))
            total_block_counter[code_idx] = code_block_percentage
        else:
            total_block_counter[code_idx] = None
    return total_block_counter

**Identify Specific Biome**

In [None]:
# without requiring key to have the max frequency
# For a code, if it's corresponding chunk includes block types in the following category,
# then it is likely that the chunk represents this biome.
def find_styles(temp_arr):
    most_common_5_block_types = []
    for i in temp_arr:
        most_common_5_block_types.append(i[0])
    desert = [29, 194, 195] # key = 194 sand
    ocean = [240, 194, 60, 95] # key = 240 water
    grass_land = [10, 60, 93, 119, 120, 217, 227, 237, 60, 131, 132, 95, 243, 197, 166, 167, 184, 250, 251]
    # key = 60 dirt
    ore = [10, 92, 217, 40, 108, 95, 195] # key = 217 stone
    result = []
    # desert
    if 194 in most_common_5_block_types:
        common_blocks = len([elements for elements in most_common_5_block_types if elements in desert])
        if len(most_common_5_block_types) == 1:
            result.append("desert")
        elif common_blocks >= 2 and "desert" not in result:
            result.append("desert")
    # ocean
    if 240 in most_common_5_block_types:
        common_blocks = len([elements for elements in most_common_5_block_types if elements in ocean])
        if len(most_common_5_block_types) == 1:
            result.append("ocean")
        elif common_blocks >= 2 and "ocean" not in result:
            result.append("ocean")
    # grass_land
    if 60 in most_common_5_block_types:
        common_blocks = len([elements for elements in most_common_5_block_types if elements in ocean])
        if len(most_common_5_block_types) == 1:
            result.append("grass_land")
        elif common_blocks >= 2 and "grass_land" not in result:
            result.append("grass_land")
    # ore
    if 217 in most_common_5_block_types:
        common_blocks = len([elements for elements in most_common_5_block_types if elements in ore])
        if len(most_common_5_block_types) == 1:
            result.append("ore")
        elif common_blocks >= 2 and "ore" not in result:
            result.append("ore")
    return result


def find_five_most_common_block_types(total_block_counter):
    dict1 = {}
    for key, sub_dict in total_block_counter.items():
        if sub_dict is not None:
            sub_dict[0] = 0
            sorted_blocks = sorted(sub_dict.items(), key=lambda item: item[1], reverse=True)[:5]
            dict1[key] = sorted_blocks
        else:
            dict1[key] = None
    return dict1

#print out the corresponding biome information.
import copy

block_frequency = find_blocks_frequency(code_chunks_dict)
top_5_items = find_five_most_common_block_types(block_frequency)

with open('Style_Information.txt', 'w') as file:
    for i in top_5_items.keys():
        if top_5_items[i] is not None:
            temp = copy.deepcopy(top_5_items[i])
            for j in range(len(temp)):
                temp[j] = list(temp[j])
                temp[j][0] = block_converter_to_original[temp[j][0]]
            result = find_styles(temp)
            file.write(f"Code idx {i}: {result} \n\n")

            temp2 = copy.deepcopy(top_5_items[i])
            for j in range(len(temp2)):
                temp2[j] = list(temp2[j])
                temp2[j][0] = blocks_to_types[temp2[j][0]]
            file.write(f"Frequencies = {temp2} \n\n")
            file.write(f"Length = {len(code_chunks_dict[i])} \n\n")
        else:
            file.write(f"Code idx {i}: None \n\n")