# Imports

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
import matplotlib.pyplot as plt
import matplotlib
import numpy as np
from sampler_utils import retrieve_autoencoder_components_state_dicts, latent_ids_to_onehot3d, get_latent_loaders
from models3d import VQAutoEncoder, Generator
import matplotlib.pyplot as plt
from ipywidgets import interact, fixed
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from concurrent.futures import ThreadPoolExecutor
from PIL import Image
import torch.distributions as dists
from tqdm import tqdm
import gc
from models3d import BiomeClassifier

In [2]:
%matplotlib widget
#matplotlib.use('TkAgg')

# Block Converter

In [3]:
class BlockBiomeConverter:
    def __init__(self, block_mappings=None, biome_mappings=None):
        """
        Initialize with pre-computed mappings for both blocks and biomes
        
        Args:
            block_mappings: dict containing 'index_to_block' and 'block_to_index'
            biome_mappings: dict containing 'index_to_biome' and 'biome_to_index'
        """
        self.index_to_block = block_mappings['index_to_block'] if block_mappings else None
        self.block_to_index = block_mappings['block_to_index'] if block_mappings else None
        self.index_to_biome = biome_mappings['index_to_biome'] if biome_mappings else None
        self.biome_to_index = biome_mappings['biome_to_index'] if biome_mappings else None
    
    @classmethod
    def from_dataset(cls, data_path):
        """Create mappings from a dataset file"""
        data = np.load(data_path, allow_pickle=True)
        voxels = data['voxels']
        biomes = data['biomes']
        
        # Create block mappings (blocks are integers)
        unique_blocks = np.unique(voxels)
        block_to_index = {int(block): idx for idx, block in enumerate(unique_blocks)}
        index_to_block = {idx: int(block) for idx, block in enumerate(unique_blocks)}
        
        # Create biome mappings (biomes are strings)
        unique_biomes = np.unique(biomes)
        biome_to_index = {str(biome): idx for idx, biome in enumerate(unique_biomes)}
        index_to_biome = {idx: str(biome) for idx, biome in enumerate(unique_biomes)}
        
        block_mappings = {'index_to_block': index_to_block, 'block_to_index': block_to_index}
        biome_mappings = {'index_to_biome': index_to_biome, 'biome_to_index': biome_to_index}
        
        return cls(block_mappings, biome_mappings)
    
    @classmethod
    def from_arrays(cls, voxels, biomes):
        """Create mappings directly from numpy arrays"""
        # Create block mappings (blocks are integers)
        unique_blocks = np.unique(voxels)
        block_to_index = {int(block): idx for idx, block in enumerate(unique_blocks)}
        index_to_block = {idx: int(block) for idx, block in enumerate(unique_blocks)}
        
        # Create biome mappings (biomes are strings)
        unique_biomes = np.unique(biomes)
        biome_to_index = {str(biome): idx for idx, biome in enumerate(unique_biomes)}
        index_to_biome = {idx: str(biome) for idx, biome in enumerate(unique_biomes)}
        
        block_mappings = {'index_to_block': index_to_block, 'block_to_index': block_to_index}
        biome_mappings = {'index_to_biome': index_to_biome, 'biome_to_index': biome_to_index}
        
        return cls(block_mappings, biome_mappings)
    
    @classmethod
    def load_mappings(cls, path):
        """Load pre-saved mappings"""
        mappings = torch.load(path)
        return cls(mappings['block_mappings'], mappings['biome_mappings'])
    
    def save_mappings(self, path):
        """Save mappings for later use"""
        torch.save({
            'block_mappings': {
                'index_to_block': self.index_to_block,
                'block_to_index': self.block_to_index
            },
            'biome_mappings': {
                'index_to_biome': self.index_to_biome,
                'biome_to_index': self.biome_to_index
            }
        }, path)
    
    def convert_to_original_blocks(self, data):
        """
        Convert from indices back to original block IDs.
        Handles both one-hot encoded and already-indexed data.
        
        Args:
            data: torch.Tensor of either:
                - one-hot encoded blocks [B, C, H, W, D] or [C, H, W, D]
                - indexed blocks [B, H, W, D] or [H, W, D]
        Returns:
            torch.Tensor of original block IDs with shape [B, H, W, D] or [H, W, D]
        """
        # If one-hot encoded (dim == 5 or first dim == num_blocks), convert to indices first
        if len(data.shape) == 5 or (len(data.shape) == 4 and data.shape[0] == len(self.block_to_index)):
            data = torch.argmax(data, dim=1 if len(data.shape) == 5 else 0)
        
        # Now convert indices to original blocks
        if len(data.shape) == 4:  # Batch dimension present
            return torch.tensor([[[[self.index_to_block[int(b)] 
                                for b in row]
                                for row in layer]
                                for layer in slice_]
                                for slice_ in data])
        else:  # No batch dimension
            return torch.tensor([[[self.index_to_block[int(b)] 
                                for b in row]
                                for row in layer]
                                for layer in data])

    def convert_to_original_biomes(self, data):
        """
        Convert from indices back to original biome strings.
        Handles both one-hot encoded and already-indexed data.
        
        Args:
            data: torch.Tensor of either:
                - one-hot encoded biomes [B, C, H, W, D] or [C, H, W, D]
                - indexed biomes [B, H, W, D] or [H, W, D]
        Returns:
            numpy array of original biome strings with shape [B, H, W, D] or [H, W, D]
        """
        # If one-hot encoded (dim == 5 or first dim == num_biomes), convert to indices first
        if len(data.shape) == 5 or (len(data.shape) == 4 and data.shape[0] == len(self.biome_to_index)):
            data = torch.argmax(data, dim=1 if len(data.shape) == 5 else 0)
        
        # Now convert indices to original biomes
        if len(data.shape) == 4:  # Batch dimension present
            return np.array([[[[self.index_to_biome[int(b)] 
                            for b in row]
                            for row in layer]
                            for layer in slice_]
                            for slice_ in data])
        else:  # No batch dimension
            return np.array([[[self.index_to_biome[int(b)] 
                            for b in row]
                            for row in layer]
                            for layer in data])
        
    def get_air_block_index(self):
        """
        Find the one-hot index corresponding to the air block (ID 5).
        Returns:
            int: The index where air blocks are encoded in one-hot format
        """
        # Find the index that maps to block ID 5 (air) in our index_to_block mapping
        for idx, block_id in self.index_to_block.items():
            if block_id == 5:  # Air block ID
                return idx
        raise ValueError("Air block (ID 5) not found in block mappings!")
    
    def get_blockid_indices(self, block_ids):
        """
        Find the one-hot index corresponding to the air block (ID 5).
        Returns:
            int: The index where air blocks are encoded in one-hot format
        """
        # Find the index that maps to block ID 5 (air) in our index_to_block mapping
        idxs = []
        for idx, block_id in self.index_to_block.items():
            if block_id in block_ids:  # Air block ID
                idxs.append(idx)
        if len(idxs) == 0:
            raise ValueError("Air block (ID 5) not found in block mappings!")
        return idxs

# Minecraft Chunks Dataset

In [4]:
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np

class MinecraftDataset(Dataset):
    def __init__(self, data_path):
        data_path = Path(data_path)

        # Try to load processed data first
        # assert processed_data_path.exists() and mappings_path.exists()

        print("Loading pre-processed data...")
        processed_data = torch.load(data_path)
        # Only keep the chunks, discard biome data
        self.processed_chunks = processed_data['chunks']
        # self.processed_biomes = processed_data['biomes']
        # Delete the biomes to free memory
        # del processed_data['biomes']
        # del processed_data['chunks']
        del processed_data
        
        
        print(f"Loaded {len(self.processed_chunks)} chunks of size {self.processed_chunks.shape[1:]}")
        print(f"Number of unique block types: {self.processed_chunks.shape[1]}")
        print(f'Unique blocks: {torch.unique(torch.argmax(self.processed_chunks, dim=1)).tolist()}')
        #print(f"Loaded {len(self.processed_biomes)} chunks of size {self.processed_biomes.shape[1:]}")
        #print(f"Number of unique biome types: {self.processed_biomes.shape[1]}")
        #print(f'Unique biomes: {torch.unique(torch.argmax(self.processed_biomes, dim=1)).tolist()}')
        

    def __getitem__(self, idx):
        # return self.processed_chunks[idx]
        return self.processed_chunks[idx] #, self.processed_biomes[idx]

    def __len__(self):
        # return len(self.processed_chunks), len(self.processed_biomes)
        return len(self.processed_chunks)

def get_minecraft_dataloaders(data_path, batch_size=32, val_split=0.1, num_workers=4):
    """
    Creates training and validation dataloaders for Minecraft chunks.
    """
    # Create dataset
    dataset = MinecraftDataset(data_path)
    
    # Split into train and validation sets
    val_size = int(val_split * len(dataset))
    train_size = len(dataset) - val_size
    
    # Use a fixed seed for reproducibility
    generator = torch.Generator().manual_seed(42)
    
    train_dataset, val_dataset = torch.utils.data.random_split(
        dataset, 
        [train_size, val_size],
        generator=generator
    )
    
    # Create dataloaders with memory pinning
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True,
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True,
    )

    
    print(f"\nDataloader details:")
    print(f"Training samples: {len(train_dataset)}")
    print(f"Validation samples: {len(val_dataset)}")
    print(f"Batch size: {batch_size}")
    print(f"Training batches: {len(train_loader)}")
    print(f"Validation batches: {len(val_loader)}")
    
    return train_loader, val_loader


# Helper Functions
Loads hyperparameters and model from checkpoint, gives easy function to encode a structure into latent codes, as well as decode those latents back into the reconstruction

## FQGAN Helper Functions

In [5]:
import os
from log_utils import log, load_stats, load_model
import copy
from fq_models import FQModel, HparamsFQGAN


# Loads hparams from hparams.json file in saved model directory
def load_hparams_from_json(log_dir):
    import json
    import os
    json_path = os.path.join(log_dir, 'hparams.json')
    
    if not os.path.exists(json_path):
        raise FileNotFoundError(f"No hparams.json file found in {log_dir}")
    
    with open(json_path, 'r') as f:
        hparams = json.load(f)

    return hparams

# turns loaded hparams json into propery hyperparams object
def dict_to_vcqgan_hparams(hparams_dict, dataset=None):
    # Determine which hyperparameter class to use based on the dataset
    if dataset == None:
        dataset = hparams_dict.get('dataset', 'MNIST')  # Default to MNIST if not specified
    
    vq_hyper = HparamsFQGAN(dataset)
    # Set attributes from the dictionary
    for key, value in hparams_dict.items():
        setattr(vq_hyper, key, value)
    
    return vq_hyper

# Loads fqgan model weights from a given checkpoint file
def load_fqgan_from_checkpoint(H, fqgan):
    fqgan = load_model(fqgan, "fqgan", H.load_step, H.load_dir).cuda()
    fqgan.eval()
    return fqgan

# Takes a chunk or batch of chunks from the dataset, returns the encoded style and structure indices matrices
def encode_and_quantize(fqgan, terrain_chunks, device='cuda'):
    """Memory-efficient encoding function"""
    fqgan.eval()
    with torch.no_grad():
        # Move input to device
        terrain_chunks = terrain_chunks.to(device)
        
        # Get encodings
        h_style, h_struct = fqgan.encoder(terrain_chunks)
        
        # Process style path
        h_style = fqgan.quant_conv_style(h_style)
        quant_style, _, style_stats = fqgan.quantize_style(h_style)
        style_indices = style_stats[2]  # Get indices from tuple
        style_indices = style_indices.view(
            (h_style.size()[0], h_style.size()[2], h_style.size()[3], h_style.size()[4])
        )
        
        # Clear intermediate tensors
        del h_style, quant_style, style_stats
        
        # Process structure path
        h_struct = fqgan.quant_conv_struct(h_struct)
        quant_struct, _, struct_stats = fqgan.quantize_struct(h_struct)
        struct_indices = struct_stats[2]  # Get indices from tuple
        struct_indices = struct_indices.view(
            (h_struct.size()[0], h_struct.size()[2], h_struct.size()[3], h_struct.size()[4])
        )
        
        # Clear intermediate tensors
        del h_struct, quant_struct, struct_stats
        
        # Move indices to CPU to save GPU memory
        style_indices = style_indices.cpu()
        struct_indices = struct_indices.cpu()
        
        torch.cuda.empty_cache()
        
        return style_indices, struct_indices

# Takes style and structure indices, returns the reconstructed map
def decode_from_indices(style_indices, struct_indices, fqgan, device='cuda', two_stage=False):
    """Memory-efficient decoding function"""
    with torch.no_grad():
        # Move indices to device only when needed
        style_indices = style_indices.to(device)
        struct_indices = struct_indices.to(device)
        
        # Get quantized vectors
        quant_style = fqgan.quantize_style.get_codebook_entry(
            style_indices.view(-1),
            shape=[1, fqgan.embed_dim, *style_indices.shape[1:]]
        )
        quant_struct = fqgan.quantize_struct.get_codebook_entry(
            struct_indices.view(-1),
            shape=[1, fqgan.embed_dim, *struct_indices.shape[1:]]
        )
        
        # Clear indices from GPU
        del style_indices, struct_indices
        
        # Combine and decode
        quant = torch.cat([quant_struct, quant_style], dim=1)
        # quant = quant_style + quant_struct
        del quant_style, quant_struct
        
        if two_stage:
            decoded, binary_decoded = fqgan.decoder(quant)
        else:
            decoded = fqgan.decoder(quant)
        
        del quant
        
        # Convert to block IDs if one-hot encoded
        if decoded.shape[1] > 1:
            decoded = torch.argmax(decoded, dim=1)
        
        # Move result to CPU and clear GPU memory
        result = decoded.squeeze(0).cpu()
        if two_stage:
            binary_result = binary_decoded.squeeze(0).cpu()
            del decoded
            del binary_decoded
            torch.cuda.empty_cache()
            return result, binary_result
        
        del decoded
        torch.cuda.empty_cache()
        
        return result

# Load FQ model

In [6]:
# Important if you reload the model, can run into memory issues. There's a memory leak somewhere I haven't been able to fix
import gc
gc.collect()
torch.cuda.empty_cache()

model_path = "../model_logs/FQGAN_2stagedecoder_nobiomemodel_EMA_newdata_logweighted3_40both"
load_step = 10000

fqgan_hparams =  dict_to_vcqgan_hparams(load_hparams_from_json(f"{model_path}"), 'minecraft')
fqgan_hparams.load_step = load_step
fqgan_hparams.padding_mode = 'zeros'
fqgan = FQModel(fqgan_hparams)
fqgan = load_fqgan_from_checkpoint(fqgan_hparams, fqgan)
print(f'loaded from: {fqgan_hparams.log_dir}')

using padding mode: zeros
Using EMA quantizer
Using TwoStageGenerator
NO biome supervision
Disentangle Ratio:  0.5
Loading fqgan_10000.th
loaded from: FQGAN_2stagedecoder_nobiomemodel_EMA_newdata_logweighted3_40both


# Load dataset

In [7]:
# loads a preprocesed dataset, which already has a mappings file created and everything one-hot encoded nicely. For more memory efficient, could try just loading the validation set file
# data_path = '../../text2env/data/minecraft_biome_newworld_10k_processed_cleaned.pt'
data_path = '../../text2env/data/24_newdataset_processed_cleaned3.pt'
train_loader, val_loader = get_minecraft_dataloaders(
    data_path,
    batch_size=4,
    num_workers=0, # Must be 0 if you're on windows, otherwise it errors
    val_split=0.1
)

Loading pre-processed data...
Loaded 11119 chunks of size torch.Size([42, 24, 24, 24])
Number of unique block types: 42
Unique blocks: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41]

Dataloader details:
Training samples: 10008
Validation samples: 1111
Batch size: 4
Training batches: 2502
Validation batches: 278


In [8]:
# mappings_path = '../../text2env/data/minecraft_biome_newworld_10k_mappings.pt'
mappings_path = '../../text2env/data/24_newdataset_mappings3.pt'
block_converter = BlockBiomeConverter.load_mappings(mappings_path)

In [9]:
# Create an instnance of our visualizer
from visualization_utils import MinecraftVisualizerPyVista
visualizer = MinecraftVisualizerPyVista()


# Gradcam

In [11]:
from pytorch_grad_cam import GradCAM
# from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
# from pytorch_grad_cam.utils.image import show_cam_on_image

In [13]:
batch = next(iter(train_loader))
sample = batch[0].unsqueeze(0).cuda()

In [53]:
# torch.save(sample, 'gradcam_emptywater.pt')
sample = torch.load('gradcam_emptywater.pt')

In [54]:
with torch.no_grad():
    style_indices, struct_indices = encode_and_quantize(fqgan, sample)
reconstructed, binary_reconstructed = decode_from_indices(style_indices, struct_indices, fqgan, two_stage=True)

In [55]:
converted_orig = block_converter.convert_to_original_blocks(sample.squeeze())
plotter = visualizer.visualize_chunk(converted_orig, interactive=True, show_axis=False)
plotter.show()

Widget(value='<iframe src="http://localhost:53416/index.html?ui=P_0x2079de27fa0_19&reconnect=auto" class="pyvi…

In [56]:
binary_reconstructed = (binary_reconstructed>0.5).float()
binary_reconstructed = block_converter.convert_to_original_blocks(binary_reconstructed).squeeze()
reconstructed = block_converter.convert_to_original_blocks(reconstructed)
fig = visualizer.visualize_chunk(reconstructed, interactive=True, show_axis=False)
fig.show()

Widget(value='<iframe src="http://localhost:53416/index.html?ui=P_0x2078da67820_20&reconnect=auto" class="pyvi…

In [None]:
target_layers = [fqgan.encoder.conv_out_struct]


In [57]:
plotter = visualizer.visualize_both_codes_with_blocks(
    style_indices,
    struct_indices,
    reconstructed
)
plotter.show(interactive=True)

In [153]:
np.unique(struct_indices, return_counts=True)

(array([ 0,  2,  3,  4,  5,  7,  8, 10, 16, 21, 22, 24, 25, 29, 31, 32, 33,
        34, 35, 37, 38, 39], dtype=int64),
 array([ 4,  1, 27,  4,  1,  2,  4, 15, 32,  3, 14,  5, 13,  2,  1,  1,  1,
        30,  7,  3, 27, 19], dtype=int64))

# o3 implementation

In [18]:
class EncoderWrapper(torch.nn.Module):
    """
    Returns pre-quantisation latents z (B,C,D,H,W) so CAM can back-prop
    through the encoder only.  Feel free to add EMA swap logic if needed.
    """
    def __init__(self, fqgan):
        super().__init__()
        self.enc = fqgan.encoder  # adapt to your attr names
    def forward(self, x):
        z = self.enc(x)           # no quantisation here!
        return z

In [19]:
class CodeTarget:
    def __init__(self, code_idx, pos, codebook):
        self.k  = code_idx        # integer index in the codebook
        self.pos = pos            # (d,h,w) tuple in latent grid
        self.E = codebook         # nn.Embedding weight or tensor

    def __call__(self, z_pre_q):
        d,h,w = self.pos
        zvec  = z_pre_q[0, :, d, h, w]      # (C,)
        return -((zvec - self.E[self.k]) ** 2).sum()  # maximise similarity

In [20]:
from pytorch_grad_cam.grad_cam import GradCAM
class GradCAM3D(GradCAM):
    def get_cam_weights(self, input_tensor, target_layers, targets):
        grads = self.activations_and_grads.gradients[-1]  # B,C,D,H,W
        return grads.mean(dim=(2,3,4), keepdim=True)      # ⟨D,H,W⟩ avg
    def get_loss(self, output, targets):
        return sum([t(output) for t in targets])

In [24]:
k = 19
pos = (0, 2, 0)
struct_embedding = fqgan.quantize_style.embedding
with GradCAM3D(model=EncoderWrapper(fqgan), target_layers=[fqgan.encoder.conv_out_struct]) as cam:
    saliency = cam(input_tensor=sample, targets=[CodeTarget(k, pos, struct_embedding)])[0]
    # saliency is (D,H,W) – align to original 24³ with interpolate
    saliency = F.interpolate(saliency[None,None], size=(24,24,24),
                             mode="trilinear", align_corners=False)[0,0]


TypeError: get_cam_weights() takes 4 positional arguments but 6 were given

# o3, recreate rl paper

In [18]:
import torch, torch.nn.functional as F
from pytorch_grad_cam import GradCAM
# from pytorch_grad_cam.utils.model_targets import BaseTarget

# ---- 0.  thin wrapper so GradCAM sees just the decoder ----
class DecoderWrapper(torch.nn.Module):
    def __init__(self, fqgan):                # your full model
        super().__init__()
        self.dec = fqgan.decoder              # reference to decoder
    def forward(self, masked_codes):          # (B,C,D,H,W)
        return self.dec(masked_codes)

# ---- 1. custom Grad-CAM target identical to paper’s scalar ----
class DecoderL2Target():
    def __call__(self, decoder_feature):
        # score = sum of squares over everything
        return (decoder_feature**2).sum()

# ---- 2. run one explanation for code k on one chunk ----
def explain_struct_code(fqgan, chunk, k, position_indices=None):
    fqgan.eval()
    with torch.no_grad():
        h_style, h_struct  = fqgan.encoder(chunk)          # (1,C,D,H,W)
        h_struct = fqgan.quant_conv_struct(h_struct)
        quant_struct, _, struct_stats = fqgan.quantize_struct(h_struct)
        struct_indices = struct_stats[2]  # Get indices from tuple
        struct_indices = struct_indices.view(
            (h_struct.size()[0], h_struct.size()[2], h_struct.size()[3], h_struct.size()[4])
        )

        # Process style path
        h_style = fqgan.quant_conv_style(h_style)
        quant_style, _, style_stats = fqgan.quantize_style(h_style)
        style_indices = style_stats[2]  # Get indices from tuple
        style_indices = style_indices.view(
            (h_style.size()[0], h_style.size()[2], h_style.size()[3], h_style.size()[4])
        )
        print(f'h_struct size: {h_struct.shape}')
        print(f'quant_struct size: {quant_struct.shape}')
        print(f'struct_indices size: {struct_indices.shape}')

    # --- NEW: mask logic can target a single position or k-codes ---
    if position_indices is not None:
        # position_indices is a (d, h, w) tuple
        d_idx, h_idx, w_idx = position_indices
        mask = torch.zeros_like(struct_indices, dtype=torch.float32)
        # assume batch dim first; here batch==1 typically
        mask[:, d_idx, h_idx, w_idx] = 1.0
    else:
        mask = (struct_indices == k).float()
    z_q_masked = quant_struct * mask.unsqueeze(1)           # keep dims (C,D,H,W)

    quant = torch.cat([z_q_masked, quant_style], dim=1)

    print(f'z_q_masked size: {z_q_masked.shape}')
    print(f'z_q_masked[None] size: {z_q_masked.shape}')
    print(f'quant shape: {quant.shape}')

    # --- Grad-CAM on the decoder's *first* conv ---
    cam = GradCAM(model=DecoderWrapper(fqgan),
                  target_layers=[fqgan.decoder.struct_conv_in])
    heat = cam(quant, targets=[DecoderL2Target()])[0]  # (d',h',w')

    print(f'heat shape: {torch.from_numpy(heat[None, None]).shape}')
    # print(f'heat type: {heat}')
    heat = F.interpolate(torch.from_numpy(heat[None, None]), size=chunk.shape[-3:],
                         mode="trilinear", align_corners=False)[0,0]
    return heat, mask    # heat-map + where code k actually occurred

# ---- 2. run one explanation for code k on one chunk ----
def explain_style_code(fqgan, chunk, k, position_indices=None):
    fqgan.eval()
    with torch.no_grad():
        h_style, h_struct  = fqgan.encoder(chunk)          # (1,C,D,H,W)
        h_struct = fqgan.quant_conv_struct(h_struct)
        quant_struct, _, struct_stats = fqgan.quantize_struct(h_struct)
        struct_indices = struct_stats[2]  # Get indices from tuple
        struct_indices = struct_indices.view(
            (h_struct.size()[0], h_struct.size()[2], h_struct.size()[3], h_struct.size()[4])
        )

        # Process style path
        h_style = fqgan.quant_conv_style(h_style)
        quant_style, _, style_stats = fqgan.quantize_style(h_style)
        style_indices = style_stats[2]  # Get indices from tuple
        style_indices = style_indices.view(
            (h_style.size()[0], h_style.size()[2], h_style.size()[3], h_style.size()[4])
        )
        print(f'h_style size: {h_style.shape}')
        print(f'quant_style size: {quant_style.shape}')
        print(f'struct_indices size: {style_indices.shape}')

    # --- NEW: mask logic can target a single position or k-codes ---
    if position_indices is not None:
        d_idx, h_idx, w_idx = position_indices
        mask = torch.zeros_like(style_indices, dtype=torch.float32)
        mask[:, d_idx, h_idx, w_idx] = 1.0
    else:
        mask = (style_indices == k).float()
    z_q_masked = quant_style * mask.unsqueeze(1)           # keep dims (C,D,H,W)

    quant = torch.cat([quant_struct, z_q_masked], dim=1)

    print(f'z_q_masked size: {z_q_masked.shape}')
    print(f'z_q_masked[None] size: {z_q_masked.shape}')
    print(f'quant shape: {quant.shape}')

    # --- Grad-CAM on the decoder's *first* conv ---
    cam = GradCAM(model=DecoderWrapper(fqgan),
                  target_layers=[fqgan.decoder.initial_conv])
    heat = cam(quant, targets=[DecoderL2Target()])[0]  # (d',h',w')

    print(f'heat shape: {torch.from_numpy(heat[None, None]).shape}')
    # print(f'heat type: {heat}')
    heat = F.interpolate(torch.from_numpy(heat[None, None]), size=chunk.shape[-3:],
                         mode="trilinear", align_corners=False)[0,0]
    return heat, mask    # heat-map + where code k actually occurred

In [59]:
np.unique(struct_indices, return_counts=True)

(array([ 0,  2,  3,  7, 10, 14, 19, 22, 23, 29, 32, 34, 35, 37, 38, 39],
       dtype=int64),
 array([ 5,  3, 30, 16, 18,  1,  4, 32,  5, 10, 13, 28,  6, 19, 24,  2],
       dtype=int64))

In [20]:
np.unique(style_indices, return_counts=True)

(array([ 0,  1,  2,  3,  4,  5,  7,  8,  9, 10, 12, 13, 14, 16, 17, 18, 19,
        20, 21, 23, 24, 25, 26, 27, 29, 30, 31, 32, 34, 36, 37, 38, 39],
       dtype=int64),
 array([ 1,  9,  1, 16,  2,  9, 21,  2, 13,  8,  2,  4,  5,  1,  5,  1, 13,
         3,  4,  7,  6,  9,  9, 13,  1,  5, 24,  5,  6,  4,  4,  2,  1],
       dtype=int64))

In [190]:
heat, mask = explain_struct_code(fqgan, sample, 33)
np.unique(heat)

h_struct size: torch.Size([1, 32, 6, 6, 6])
quant_struct size: torch.Size([1, 32, 6, 6, 6])
struct_indices size: torch.Size([1, 6, 6, 6])
z_q_masked size: torch.Size([1, 32, 6, 6, 6])
z_q_masked[None] size: torch.Size([1, 32, 6, 6, 6])
quant shape: torch.Size([1, 64, 6, 6, 6])
heat shape: torch.Size([1, 1, 6, 6, 6])


array([0.], dtype=float32)

In [133]:
style_heat, _ = explain_style_code(fqgan, sample, 18)
np.unique(style_heat)

h_style size: torch.Size([1, 32, 6, 6, 6])
quant_style size: torch.Size([1, 32, 6, 6, 6])
struct_indices size: torch.Size([1, 6, 6, 6])
z_q_masked size: torch.Size([1, 32, 6, 6, 6])
z_q_masked[None] size: torch.Size([1, 32, 6, 6, 6])
quant shape: torch.Size([1, 64, 6, 6, 6])
heat shape: torch.Size([1, 1, 6, 6, 6])


array([0.0000000e+00, 1.9840769e-28, 5.9522309e-28, ..., 9.2149466e-01,
       9.3437725e-01, 9.4256920e-01], dtype=float32)

In [21]:
import pyvista as pv
def visualize_heatmap(heatmap, plotter=None, interactive=False, show_axis=True):
    """
    Visualize a 3D heatmap in a 24×24×24 grid. 
    Every “voxel” is drawn black, with its opacity equal to the heatmap value (0→transparent, 1→opaque).
    """
    # Convert to numpy if needed
    if isinstance(heatmap, torch.Tensor):
        heatmap = heatmap.detach().cpu().numpy()
    # If there's an extra leading channel dimension, drop it
    if heatmap.ndim == 4 and heatmap.shape[0] == 1:
        heatmap = heatmap[0]

    # Re‐orient to match original plotting conventions
    # (depth, height, width) → (x, y, z) ordering, then rotate
    heatmap = heatmap.transpose(2, 0, 1)
    heatmap = np.rot90(heatmap, 1, (0, 1))

    # Build a uniform vtkImageData volume
    grid = pv.ImageData()
    grid.dimensions = np.array(heatmap.shape) + 1
    grid.cell_data["values"] = heatmap.flatten(order="F")

    # Create or reuse a PyVista plotter
    if plotter is None:
        plotter = pv.Plotter(notebook=True) if interactive else pv.Plotter(off_screen=True)

    # Volume‐render: all voxels black, opacity from 0→1
    plotter.add_volume(
        grid,
        # scalars="values",
        # color="black",
        opacity="linear",  # linearly map min→0, max→1
        shade=False        # disable additional shading
    )

    # Optionally show axes/bounds
    if show_axis:
        plotter.show_bounds(
            grid='back',
            location='back',
            font_size=8,
            bold=False,
            font_family='arial',
            use_2d=False,
            bounds=[0, 24, 0, 24, 0, 24],
            axes_ranges=[0, 24, 0, 24, 0, 24],
            padding=0.0,
            n_xlabels=2,
            n_ylabels=2,
            n_zlabels=2
        )

    # Standard camera setup
    plotter.camera_position = 'iso'
    plotter.camera.zoom(1)

    return plotter

In [65]:
heat, mask = explain_struct_code(fqgan, sample, 3)
np.unique(heat)

plotter = visualize_heatmap(heat, interactive=True)
plotter.show()

h_struct size: torch.Size([1, 32, 6, 6, 6])
quant_struct size: torch.Size([1, 32, 6, 6, 6])
struct_indices size: torch.Size([1, 6, 6, 6])
z_q_masked size: torch.Size([1, 32, 6, 6, 6])
z_q_masked[None] size: torch.Size([1, 32, 6, 6, 6])
quant shape: torch.Size([1, 64, 6, 6, 6])
heat shape: torch.Size([1, 1, 6, 6, 6])


Widget(value='<iframe src="http://localhost:53416/index.html?ui=P_0x2080c241fd0_27&reconnect=auto" class="pyvi…

In [40]:
locs = torch.where(style_indices == 27)

In [41]:
locs

(tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
 tensor([0, 0, 1, 1, 1, 2, 2, 3, 4, 4, 4, 4, 5]),
 tensor([1, 1, 1, 2, 5, 1, 4, 2, 2, 2, 2, 3, 2]),
 tensor([4, 5, 5, 4, 1, 5, 2, 4, 3, 4, 5, 1, 4]))

In [None]:
style_indices[0, ]

In [51]:
style_heat, _ = explain_style_code(fqgan, sample, 27, (2, 1, 5))
np.unique(style_heat)

plotter = visualize_heatmap(style_heat, interactive=True)
plotter.show()

h_style size: torch.Size([1, 32, 6, 6, 6])
quant_style size: torch.Size([1, 32, 6, 6, 6])
struct_indices size: torch.Size([1, 6, 6, 6])
z_q_masked size: torch.Size([1, 32, 6, 6, 6])
z_q_masked[None] size: torch.Size([1, 32, 6, 6, 6])
quant shape: torch.Size([1, 64, 6, 6, 6])
heat shape: torch.Size([1, 1, 6, 6, 6])


Widget(value='<iframe src="http://localhost:53416/index.html?ui=P_0x20798922be0_18&reconnect=auto" class="pyvi…

In [47]:
plotter = visualizer.visualize_both_codes_with_blocks(
    style_indices,
    struct_indices,
    reconstructed
)
plotter.show(interactive=True)

In [None]:
# Get encodings
        h_style, h_struct = fqgan.encoder(terrain_chunks)
        
        # Process style path
        h_style = fqgan.quant_conv_style(h_style)
        quant_style, _, style_stats = fqgan.quantize_style(h_style)
        style_indices = style_stats[2]  # Get indices from tuple
        style_indices = style_indices.view(
            (h_style.size()[0], h_style.size()[2], h_style.size()[3], h_style.size()[4])
        )

In [None]:
# Takes a chunk or batch of chunks from the dataset, returns the encoded style and structure indices matrices
def encode_and_quantize(fqgan, terrain_chunks, device='cuda'):
    """Memory-efficient encoding function"""
    fqgan.eval()
    with torch.no_grad():
        # Move input to device
        terrain_chunks = terrain_chunks.to(device)
        
        # Get encodings
        h_style, h_struct = fqgan.encoder(terrain_chunks)
        
        # Process style path
        h_style = fqgan.quant_conv_style(h_style)
        quant_style, _, style_stats = fqgan.quantize_style(h_style)
        style_indices = style_stats[2]  # Get indices from tuple
        style_indices = style_indices.view(
            (h_style.size()[0], h_style.size()[2], h_style.size()[3], h_style.size()[4])
        )
        
        # Clear intermediate tensors
        del h_style, quant_style, style_stats
        
        # Process structure path
        h_struct = fqgan.quant_conv_struct(h_struct)
        quant_struct, _, struct_stats = fqgan.quantize_struct(h_struct)
        struct_indices = struct_stats[2]  # Get indices from tuple
        struct_indices = struct_indices.view(
            (h_struct.size()[0], h_struct.size()[2], h_struct.size()[3], h_struct.size()[4])
        )
        
        # Clear intermediate tensors
        del h_struct, quant_struct, struct_stats
        
        # Move indices to CPU to save GPU memory
        style_indices = style_indices.cpu()
        struct_indices = struct_indices.cpu()
        
        torch.cuda.empty_cache()
        
        return style_indices, struct_indices

# Takes style and structure indices, returns the reconstructed map
def decode_from_indices(style_indices, struct_indices, fqgan, device='cuda', two_stage=False):
    """Memory-efficient decoding function"""
    with torch.no_grad():
        # Move indices to device only when needed
        style_indices = style_indices.to(device)
        struct_indices = struct_indices.to(device)
        
        # Get quantized vectors
        quant_style = fqgan.quantize_style.get_codebook_entry(
            style_indices.view(-1),
            shape=[1, fqgan.embed_dim, *style_indices.shape[1:]]
        )
        quant_struct = fqgan.quantize_struct.get_codebook_entry(
            struct_indices.view(-1),
            shape=[1, fqgan.embed_dim, *struct_indices.shape[1:]]
        )
        
        # Clear indices from GPU
        del style_indices, struct_indices
        
        # Combine and decode
        quant = torch.cat([quant_struct, quant_style], dim=1)
        # quant = quant_style + quant_struct
        del quant_style, quant_struct
        
        if two_stage:
            decoded, binary_decoded = fqgan.decoder(quant)
        else:
            decoded = fqgan.decoder(quant)
        
        del quant
        
        # Convert to block IDs if one-hot encoded
        if decoded.shape[1] > 1:
            decoded = torch.argmax(decoded, dim=1)
        
        # Move result to CPU and clear GPU memory
        result = decoded.squeeze(0).cpu()
        if two_stage:
            binary_result = binary_decoded.squeeze(0).cpu()
            del decoded
            del binary_decoded
            torch.cuda.empty_cache()
            return result, binary_result
        
        del decoded
        torch.cuda.empty_cache()
        
        return result

## docs ass example

In [None]:
model = resnet50(pretrained=True)
target_layers = [model.layer4[-1]]
input_tensor = # Create an input tensor image for your model..
# Note: input_tensor can be a batch tensor with several images!

# We have to specify the target we want to generate the CAM for.
targets = [ClassifierOutputTarget(281)]

# Construct the CAM object once, and then re-use it on many images.
with GradCAM(model=model, target_layers=target_layers) as cam:
  # You can also pass aug_smooth=True and eigen_smooth=True, to apply smoothing.
  grayscale_cam = cam(input_tensor=input_tensor, targets=targets)
  # In this example grayscale_cam has only one image in the batch:
  grayscale_cam = grayscale_cam[0, :]
  visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)
  # You can also get the model outputs without having to redo inference
  model_outputs = cam.outputs

## chatgpt ass response

In [None]:
input_volume = ...  # shape (C, 24, 24, 24), one-hot or multi-channel
# Hook to capture the last conv layer output and gradient
feature_maps = None
grads = None
def forward_hook(module, inp, out):
    nonlocal feature_maps
    feature_maps = out.detach()
def backward_hook(module, grad_in, grad_out):
    # grad_out is a tuple (because output could be tuple), take first if needed
    nonlocal grads
    grads = grad_out[0].detach()

layer = model.encoder.last_conv  # for example
handle_f = layer.register_forward_hook(forward_hook)
handle_b = layer.register_backward_hook(backward_hook)

# Forward pass through encoder & quantizer
z = model.encoder(input_volume.unsqueeze(0))        # shape (1, latent_C, D_lat, H_lat, W_lat)
quantized, code_indices = model.quantize(z)         # quantized is shape (1, latent_dim, D_lat, H_lat, W_lat)
# Suppose we want to analyze code at position (d0,h0,w0) in the latent
d0,h0,w0 =  ...  # indices of the latent position of interest
code_k = code_indices[0, d0, h0, w0].item()         # the index of the codebook selected

# Define target score = negative distance between z and its selected code vector
z_vec = z[0,:, d0, h0, w0]                          # the pre-quantization vector at that position
codebook_vec = model.codebook.embedding[code_k]     # the corresponding codebook embedding vector
score = -torch.norm(z_vec - codebook_vec, p=2)**2

# Backpropagate to get gradients at last conv layer
model.zero_grad()
score.backward()

# Compute weights: average grad over spatial dims for each channel
# grads shape: (1, channels, D_lat_feat, H_lat_feat, W_lat_feat)
weights = grads.mean(dim=[2,3,4], keepdim=True)  # average gradient over depth, height, width

# Weight the captured feature maps
# feature_maps shape: (1, channels, D_lat_feat, H_lat_feat, W_lat_feat)
cam = (weights * feature_maps).sum(dim=1)         # sum over channels -> shape (1, D_lat_feat, H_lat_feat, W_lat_feat)
cam = torch.relu(cam)                             # apply ReLU
cam = torch.nn.functional.interpolate(cam, size=(24,24,24), mode='trilinear', align_corners=False)
saliency_volume = cam.squeeze(0).cpu().numpy()    # now a 24x24x24 numpy array of importance values