# Imports

In [1]:
from pathlib import Path
from functools import partial
from collections import defaultdict
from collections import OrderedDict

import math
import numpy as np
import matplotlib.pyplot as plt
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import torchvision.utils as vision_utils
import lpips
from torchinfo import summary
from torchvision.datasets import LSUN
from torch.utils.data import DataLoader, random_split
import os
from einops import rearrange

from models3d import BiomeClassifier


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # PyTorch v0.4.0

In [3]:
device

device(type='cuda')

Based on the Factorized Visual Tokenization and Generation paper
https://github.com/showlab/FQGAN/tree/main/tokenizer/models

# Datasets

## Random Minecraft Chunks

In [4]:
class BlockBiomeConverter:
    def __init__(self, block_mappings=None, biome_mappings=None):
        """
        Initialize with pre-computed mappings for both blocks and biomes
        
        Args:
            block_mappings: dict containing 'index_to_block' and 'block_to_index'
            biome_mappings: dict containing 'index_to_biome' and 'biome_to_index'
        """
        self.index_to_block = block_mappings['index_to_block'] if block_mappings else None
        self.block_to_index = block_mappings['block_to_index'] if block_mappings else None
        self.index_to_biome = biome_mappings['index_to_biome'] if biome_mappings else None
        self.biome_to_index = biome_mappings['biome_to_index'] if biome_mappings else None
    
    @classmethod
    def from_dataset(cls, data_path):
        """Create mappings from a dataset file"""
        data = np.load(data_path, allow_pickle=True)
        voxels = data['voxels']
        biomes = data['biomes']
        
        # Create block mappings (blocks are integers)
        unique_blocks = np.unique(voxels)
        block_to_index = {int(block): idx for idx, block in enumerate(unique_blocks)}
        index_to_block = {idx: int(block) for idx, block in enumerate(unique_blocks)}
        
        # Create biome mappings (biomes are strings)
        unique_biomes = np.unique(biomes)
        biome_to_index = {str(biome): idx for idx, biome in enumerate(unique_biomes)}
        index_to_biome = {idx: str(biome) for idx, biome in enumerate(unique_biomes)}
        
        block_mappings = {'index_to_block': index_to_block, 'block_to_index': block_to_index}
        biome_mappings = {'index_to_biome': index_to_biome, 'biome_to_index': biome_to_index}
        
        return cls(block_mappings, biome_mappings)
    
    @classmethod
    def from_arrays(cls, voxels, biomes):
        """Create mappings directly from numpy arrays"""
        # Create block mappings (blocks are integers)
        unique_blocks = np.unique(voxels)
        block_to_index = {int(block): idx for idx, block in enumerate(unique_blocks)}
        index_to_block = {idx: int(block) for idx, block in enumerate(unique_blocks)}
        
        # Create biome mappings (biomes are strings)
        unique_biomes = np.unique(biomes)
        biome_to_index = {str(biome): idx for idx, biome in enumerate(unique_biomes)}
        index_to_biome = {idx: str(biome) for idx, biome in enumerate(unique_biomes)}
        
        block_mappings = {'index_to_block': index_to_block, 'block_to_index': block_to_index}
        biome_mappings = {'index_to_biome': index_to_biome, 'biome_to_index': biome_to_index}
        
        return cls(block_mappings, biome_mappings)
    
    @classmethod
    def load_mappings(cls, path):
        """Load pre-saved mappings"""
        mappings = torch.load(path)
        return cls(mappings['block_mappings'], mappings['biome_mappings'])
    
    def save_mappings(self, path):
        """Save mappings for later use"""
        torch.save({
            'block_mappings': {
                'index_to_block': self.index_to_block,
                'block_to_index': self.block_to_index
            },
            'biome_mappings': {
                'index_to_biome': self.index_to_biome,
                'biome_to_index': self.biome_to_index
            }
        }, path)
    
    def convert_to_original_blocks(self, data):
        """
        Convert from indices back to original block IDs.
        Handles both one-hot encoded and already-indexed data.
        
        Args:
            data: torch.Tensor of either:
                - one-hot encoded blocks [B, C, H, W, D] or [C, H, W, D]
                - indexed blocks [B, H, W, D] or [H, W, D]
        Returns:
            torch.Tensor of original block IDs with shape [B, H, W, D] or [H, W, D]
        """
        # If one-hot encoded (dim == 5 or first dim == num_blocks), convert to indices first
        if len(data.shape) == 5 or (len(data.shape) == 4 and data.shape[0] == len(self.block_to_index)):
            data = torch.argmax(data, dim=1 if len(data.shape) == 5 else 0)
        
        # Now convert indices to original blocks
        if len(data.shape) == 4:  # Batch dimension present
            return torch.tensor([[[[self.index_to_block[int(b)] 
                                for b in row]
                                for row in layer]
                                for layer in slice_]
                                for slice_ in data])
        else:  # No batch dimension
            return torch.tensor([[[self.index_to_block[int(b)] 
                                for b in row]
                                for row in layer]
                                for layer in data])

    def convert_to_original_biomes(self, data):
        """
        Convert from indices back to original biome strings.
        Handles both one-hot encoded and already-indexed data.
        
        Args:
            data: torch.Tensor of either:
                - one-hot encoded biomes [B, C, H, W, D] or [C, H, W, D]
                - indexed biomes [B, H, W, D] or [H, W, D]
        Returns:
            numpy array of original biome strings with shape [B, H, W, D] or [H, W, D]
        """
        # If one-hot encoded (dim == 5 or first dim == num_biomes), convert to indices first
        if len(data.shape) == 5 or (len(data.shape) == 4 and data.shape[0] == len(self.biome_to_index)):
            data = torch.argmax(data, dim=1 if len(data.shape) == 5 else 0)
        
        # Now convert indices to original biomes
        if len(data.shape) == 4:  # Batch dimension present
            return np.array([[[[self.index_to_biome[int(b)] 
                            for b in row]
                            for row in layer]
                            for layer in slice_]
                            for slice_ in data])
        else:  # No batch dimension
            return np.array([[[self.index_to_biome[int(b)] 
                            for b in row]
                            for row in layer]
                            for layer in data])
        
    def get_air_block_index(self):
        """
        Find the one-hot index corresponding to the air block (ID 5).
        Returns:
            int: The index where air blocks are encoded in one-hot format
        """
        # Find the index that maps to block ID 5 (air) in our index_to_block mapping
        for idx, block_id in self.index_to_block.items():
            if block_id == 5:  # Air block ID
                return idx
        raise ValueError("Air block (ID 5) not found in block mappings!")
    
    def get_water_block_index(self):
        """
        Find the one-hot index corresponding to the air block (ID 5).
        Returns:
            int: The index where air blocks are encoded in one-hot format
        """
        # Find the index that maps to block ID 5 (air) in our index_to_block mapping
        for idx, block_id in self.index_to_block.items():
            if block_id == 240 :  # water block ID
                return idx
        raise ValueError("water block (ID 240) not found in block mappings!")
    
    def get_blockid_indices(self, block_ids):
        """
        Find the one-hot index corresponding to the air block (ID 5).
        Returns:
            int: The index where air blocks are encoded in one-hot format
        """
        # Find the index that maps to block ID 5 (air) in our index_to_block mapping
        idxs = []
        for idx, block_id in self.index_to_block.items():
            if block_id in block_ids:  # Air block ID
                idxs.append(idx)
        if len(idxs) == 0:
            raise ValueError("Air block (ID 5) not found in block mappings!")
        return idxs
        

## Dataset and loaders

In [5]:
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import random

def rotate_voxels_90(voxels, k=1):
    """
    Rotate voxels around Y axis by k*90 degrees
    Args:
        voxels: tensor of shape [B, C, H, W, D] or [C, H, W, D]
        k: number of 90 degree rotations (1 = 90°, 2 = 180°, 3 = 270°)
    Returns:
        Rotated voxels
    """
    # Handle both batched and unbatched inputs
    if len(voxels.shape) == 5:  # Batched [B, C, H, W, D]
        # Rotate around Y (height) axis by swapping width and depth dimensions
        return torch.rot90(voxels, k=k, dims=(2, 4))
    elif len(voxels.shape) == 4:  # Unbatched [C, H, W, D]
        return torch.rot90(voxels, k=k, dims=(1, 3))
    elif len(voxels.shape) == 3:  # Unbatched [H, W, D]
        return torch.rot90(voxels, k=k, dims=(0, 2))
    else:
        raise ValueError(f"Unexpected voxel shape: {voxels.shape}")
    
class MinecraftRotationAugmentation:
    def __init__(self, p=0.75):
        """
        Args:
            p: probability of applying rotation (0-1)
        """
        self.p = p

    def __call__(self, x):
        if random.random() < self.p:
            # Randomly choose rotation: 90°, 180°, or 270°
            k = random.randint(1, 3)
            return rotate_voxels_90(x, k)
        return x

class MinecraftDataset(Dataset):
    def __init__(self, data_path, augment=True, rotation_prob=0.75):
        data_path = Path(data_path)
        self.augment = augment
        self.rotation_prob = rotation_prob
        # Try to load processed data first
        # assert processed_data_path.exists() and mappings_path.exists()

        print("Loading pre-processed data...")
        processed_data = torch.load(data_path)
        # Only keep the chunks, discard biome data
        self.processed_chunks = processed_data['chunks']
        # Delete the biomes to free memory
        del processed_data['biomes']
        del processed_data
        
        
        print(f"Loaded {len(self.processed_chunks)} chunks of size {self.processed_chunks.shape[1:]}")
        print(f"Number of unique block types: {self.processed_chunks.shape[1]}")
        print(f'Unique blocks: {torch.unique(torch.argmax(self.processed_chunks, dim=1)).tolist()}')

    def __getitem__(self, idx):
        chunk = self.processed_chunks[idx]
        
        # Apply random rotation augmentation during training
        if self.augment and random.random() < self.rotation_prob:
            # Randomly choose rotation: 90°, 180°, or 270°
            k = random.randint(1, 3)
            chunk = rotate_voxels_90(chunk, k)
            
        return chunk
    
    def __len__(self):
        return len(self.processed_chunks)

def get_minecraft_dataloaders(data_path, batch_size=32, val_split=0.1, num_workers=0, save_val_path=None, augment=True):
    """
    Creates training and validation dataloaders for Minecraft chunks.
    """
    # Create dataset
    dataset = MinecraftDataset(data_path)
    
    # Split into train and validation sets
    val_size = int(val_split * len(dataset))
    train_size = len(dataset) - val_size
    
    # Use a fixed seed for reproducibility
    generator = torch.Generator().manual_seed(42)
    
    train_dataset, val_dataset = torch.utils.data.random_split(
        dataset, 
        [train_size, val_size],
        generator=generator
    )
    val_dataset.dataset.augment = False

    # Save validation data if path provided
    if save_val_path:
        print(f'saving validation dataset to file: {save_val_path}')
        # Extract validation samples
        val_samples = torch.stack([dataset.processed_chunks[i] for i in val_dataset.indices])
        
        # Create directory if it doesn't exist
        os.makedirs(os.path.dirname(save_val_path), exist_ok=True)
        
        # Save validation data
        torch.save({
            'data': val_samples,
            'indices': val_dataset.indices
        }, save_val_path)
        print(f"Saved validation data to {save_val_path}")
    
    # Create dataloaders with memory pinning
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True,
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True,
    )

    
    print(f"\nDataloader details:")
    print(f"Training samples: {len(train_dataset)}")
    print(f"Validation samples: {len(val_dataset)}")
    print(f"Batch size: {batch_size}")
    print(f"Training batches: {len(train_loader)}")
    print(f"Validation batches: {len(val_loader)}")
    
    return train_loader, val_loader


### Render Minecraft data

In [6]:
class MinecraftVisualizer:
    def __init__(self):
        """Initialize the visualizer with the same block color mappings"""
        self.blocks_to_cols = {
            0: (0.5, 0.25, 0.0),    # light brown
            10: 'black', # bedrock
            29: "#006400", # cacutus
            38: "#B8860B",  # clay
            60: "brown",  # dirt
            92: "gold",  # gold ore
            93: "green",  # grass
            115: "brown",  # ladder...?
            119: (.02, .28, .16, 0.8),  # transparent forest green (RGBA) for leaves
            120: (.02, .28, .16, 0.8),  # leaves2
            194: "yellow",  # sand
            217: "gray",  # stone
            240: (0.0, 0.0, 1.0, 0.4),  # water
            227: (0.0, 1.0, 0.0, .3), # tall grass
            237: (0.33, 0.7, 0.33, 0.3), # vine
            40: "#2F4F4F",  # coal ore
            62: "#228B22",  # double plant
            108: "#BEBEBE",  # iron ore
            131: "saddlebrown",  # log1
            132: "saddlebrown",  #log2
            95: "lightgray",  # gravel
            243: "wheat",  # wheat. lmao
            197: "limegreen",  # sapling
            166: "orange",  #pumpkin
            167: "#FF8C00",  # pumpkin stem
            184: "#FFA07A",  # red flower
            195: "tan",  # sandstone
            250: "white",  #wool 
            251: "gold",   #yellow flower
        }

    def visualize_chunk(self, voxels, ax=None):
        """
        Create a 3D visualization of a Minecraft chunk using the original plotting logic.
        
        Args:
            voxels: torch.Tensor [C,H,W,D] (one-hot) or numpy.ndarray [H,W,D] (block IDs)
            ax: Optional matplotlib axis
        """
        # Convert one-hot to block IDs if needed
        if isinstance(voxels, torch.Tensor):
            if voxels.dim() == 4:  # One-hot encoded [C,H,W,D]
                voxels = voxels.detach().cpu()
                voxels = torch.argmax(voxels, dim=0).numpy()
            else:
                voxels = voxels.detach().cpu().numpy()

        # Apply the same transformations as original
        voxels = voxels.transpose(2, 0, 1)
        # Rotate the voxels 90 degrees around the height axis
        voxels = np.rot90(voxels, 1, (0, 1))

        # Create axis if not provided
        if ax is None:
            fig = plt.figure()
            ax = fig.add_subplot(111, projection='3d')

        # Plot non-air blocks
        other_vox = (voxels != 5) & (voxels != -1)
        
        # Plot each block type with its color
        for block_id in np.unique(voxels[other_vox]):
            if block_id not in self.blocks_to_cols:
                # print(f"Unknown block id: {block_id}")
                continue
            ax.voxels(voxels == block_id, facecolors=self.blocks_to_cols[int(block_id)])
            other_vox = other_vox & (voxels != block_id)

        # Plot remaining blocks in red with black edges
        ax.voxels(other_vox, edgecolor="k", facecolor="red")
        
        return ax

In [7]:
def display_minecraft(vis, mc_visualizer, data, win_name="minecraft_display", title="Minecraft Chunks", nrow=4, save_path=None):
    """
    Display or save multiple minecraft chunks.
    
    Args:
        vis: Visdom instance (can be None if only saving)
        data: Tensor of shape [B, 256, 20, 20, 20] or [B, 20, 20, 20]
        win_name: Window name for visdom
        title: Title for the plot
        nrow: Number of images per row
        save_path: If provided, saves the figure to this path
    """
    # Convert to original block IDs for visualization
    # data = mc_dataset.convert_to_original_blocks(data)
    # Convert to one-hot if needed
    if len(data.shape) == 4:  # [B, 20, 20, 20]
        data = F.one_hot(data.long(), num_classes=256).permute(0, 4, 1, 2, 3).float()
    
    # Create figure with subplots
    batch_size = min(data.shape[0], 16)  # Display up to 16 chunks
    ncols = nrow
    nrows = (batch_size + ncols - 1) // ncols
    
    fig = plt.figure(figsize=(4*ncols, 4*nrows))
    fig.suptitle(title)
    
    for i in range(batch_size):
        ax = fig.add_subplot(nrows, ncols, i+1, projection='3d')
        mc_visualizer.visualize_chunk(data[i], ax)
        ax.set_title(f'Chunk {i}')
    
    plt.tight_layout()
    
    # Save if path provided
    if save_path:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        plt.savefig(save_path, bbox_inches='tight', dpi=150)
    
    # Display in visdom if instance provided
    if vis is not None:
        # Convert matplotlib figure to numpy array for visdom
        canvas = fig.canvas
        canvas.draw()
        width, height = canvas.get_width_height()
        img_array = np.frombuffer(canvas.tostring_rgb(), dtype=np.uint8)
        img_array = img_array.reshape(height, width, 3)
        
        vis.image(
            img_array.transpose(2, 0, 1),  # Convert to CHW format
            win=win_name,
            opts=dict(
                title=title,
                caption=f'Batch of {batch_size} chunks'
            )
        )
    
    plt.close(fig)

def save_minecraft(data, mc_visualizer, mc_dataset, save_path, nrow=4, title="Minecraft Chunks"):
    """
    Save multiple minecraft chunks to a file.
    
    Args:
        data: Tensor of shape [B, 256, 20, 20, 20] or [B, 20, 20, 20]
        save_path: Path to save the image
        nrow: Number of images per row
        title: Title for the plot
    """
    # Convert to original block IDs for visualization
    data = mc_dataset.convert_to_original_blocks(data)
    # Convert to one-hot if needed
    if len(data.shape) == 4:  # [B, 20, 20, 20]
        data = F.one_hot(data.long(), num_classes=256).permute(0, 4, 1, 2, 3).float()
    
    # Create figure with subplots
    batch_size = min(data.shape[0], 16)  # Save up to 16 chunks
    ncols = nrow
    nrows = (batch_size + ncols - 1) // ncols
    
    fig = plt.figure(figsize=(4*ncols, 4*nrows))
    fig.suptitle(title)
    
    for i in range(batch_size):
        ax = fig.add_subplot(nrows, ncols, i+1, projection='3d')
        mc_visualizer.visualize_chunk(data[i], ax)
        ax.set_title(f'Chunk {i}')
    
    plt.tight_layout()
    
    # Save figure
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    plt.savefig(save_path, bbox_inches='tight', dpi=150)
    plt.close(fig)

In [8]:
def display_minecraft_pyvista(vis, mc_visualizer, data, win_name="minecraft_display", title="Minecraft Chunks", nrow=4, save_path=None):
    """
    Display or save multiple minecraft chunks using PyVista.
    """
    # Convert to one-hot if needed
    if len(data.shape) == 4:  # [B, 20, 20, 20]
        data = F.one_hot(data.long(), num_classes=256).permute(0, 4, 1, 2, 3).float()
    
    # Create figure with subplots
    batch_size = min(data.shape[0], 16)  # Display up to 16 chunks
    ncols = nrow
    nrows = (batch_size + ncols - 1) // ncols
    
    # Calculate the size of the combined image
    single_size = 400  # Size of each subplot in pixels
    
    # Create a list to store individual chunk images
    chunk_images = []
    
    for i in range(batch_size):
        # Use the visualizer's method to create the plot
        plotter = mc_visualizer.visualize_chunk(data[i])
        
        # Render to image
        img = plotter.screenshot(window_size=(single_size, single_size), 
                               transparent_background=True, 
                               return_img=True)
        chunk_images.append(img)
        plotter.close()
    
    # Combine images into a grid
    grid_rows = []
    for row in range(nrows):
        row_images = chunk_images[row * ncols : (row + 1) * ncols]
        # Pad the last row if needed
        while len(row_images) < ncols:
            row_images.append(np.zeros_like(chunk_images[0]))
        grid_rows.append(np.concatenate(row_images, axis=1))
    
    combined_img = np.concatenate(grid_rows, axis=0)
    
    # Save if path provided
     # Save if path provided
    if save_path:
        # Just save directly without trying to create directories
        plt.imsave(save_path, combined_img)
    
    # Display in visdom if instance provided
    if vis is not None:
        vis.image(
            combined_img.transpose(2, 0, 1),  # Convert to CHW format
            win=win_name,
            opts=dict(
                title=title,
                caption=f'Batch of {batch_size} chunks'
            )
        )

In [9]:
import pyvista as pv

class MinecraftVisualizerPyVista:
    def __init__(self):
        """Initialize with same block color mappings"""
        self.blocks_to_cols = {
            0: (0.5, 0.25, 0.0),    # light brown
            10: 'black', # bedrock
            29: "#006400", # cacutus
            38: "#B8860B",  # clay
            60: "brown",  # dirt
            92: "gold",  # gold ore
            93: "green",  # grass
            115: "brown",  # ladder...?
            119: (.02, .28, .16, 0.9),  # transparent forest green (RGBA) for leaves
            120: (.02, .28, .16, 0.9),  # leaves2
            194: "yellow",  # sand
            217: "gray",  # stone
            240: (0.0, 0.0, 1.0, 0.4),  # water
            227: (0.0, 1.0, 0.0, .3), # tall grass
            237: (0.33, 0.7, 0.33, 0.3), # vine
            40: "#2F4F4F",  # coal ore
            62: "#228B22",  # double plant
            108: "#BEBEBE",  # iron ore
            131: "saddlebrown",  # log1
            132: "saddlebrown",  #log2
            95: "lightgray",  # gravel
            243: "wheat",  # wheat
            197: "limegreen",  # sapling
            166: "orange",  #pumpkin
            167: "#FF8C00",  # pumpkin stem
            184: "#FFA07A",  # red flower
            195: "tan",  # sandstone
            250: "white",  #wool 
            251: "gold",   #yellow flower
        }
        try:
            import panel as pn
            pn.extension('vtk')
            pv.set_jupyter_backend('trame')
        except ImportError:
            print("Please install panel with: pip install panel")
        
    def visualize_chunk(self, voxels, plotter=None):
        """Visualize a single chunk with consistent styling"""
        # Convert to numpy if needed
        if isinstance(voxels, torch.Tensor):
            if voxels.dim() == 4:  # One-hot encoded [C,H,W,D]
                voxels = voxels.detach().cpu()
                voxels = torch.argmax(voxels, dim=0).numpy()
            else:
                voxels = voxels.detach().cpu().numpy()
                
        # Apply the same transformations as original
        voxels = voxels.transpose(2, 0, 1)
        # Rotate the voxels 90 degrees around the height axis
        voxels = np.rot90(voxels, 1, (0, 1))
                
        # Create grid
        grid = pv.ImageData()
        grid.dimensions = np.array(voxels.shape) + 1
        grid.cell_data["values"] = voxels.flatten(order="F")
        
        # Create plotter if not provided
        if plotter is None:
            plotter = pv.Plotter(off_screen=True)
        
        # Remove existing lights
        plotter.remove_all_lights()
        
        # Add the three-point lighting setup
        plotter.add_light(pv.Light(
            position=(1, -1, 1),
            intensity=1.0,
            color='white'
        ))
        
        plotter.add_light(pv.Light(
            position=(-1, 1, 0.5),
            intensity=0.5,
            color='white'
        ))
        
        plotter.add_light(pv.Light(
            position=(-0.5, -0.5, -1),
            intensity=0.3,
            color='white'
        ))
        
        # Plot each block type
        mask = (voxels != 5) & (voxels != -1)
        unique_blocks = np.unique(voxels[mask])
        
        for block_id in unique_blocks:
            threshold = grid.threshold([block_id-0.5, block_id+0.5])
            if block_id in self.blocks_to_cols:
                color = self.blocks_to_cols[int(block_id)]
                opacity = 1.0 if isinstance(color, str) or len(color) == 3 else color[3]
            else:
                color = (1.0, 0.0, 0.0)
                opacity = 0.2
            
            plotter.add_mesh(threshold, 
                        color=color,
                        opacity=opacity,
                        show_edges=True,
                        edge_color='black',
                        line_width=.2,
                        edge_opacity=0.2,
                        lighting=True)
        
        # Add dummy cube for bounds
        outline = pv.Cube(bounds=(0, 24, 0, 24, 0, 24))
        plotter.add_mesh(outline, opacity=0.0)
        
        # Add bounds with consistent settings
        plotter.show_bounds(
            grid='back',
            location='back',
            font_size=8,
            bold=False,
            font_family='arial',
            use_2d=False,
            bounds=[0, 24, 0, 24, 0, 24],
            axes_ranges=[0, 24, 0, 24, 0, 24],
            padding=0.0,
            n_xlabels=2,
            n_ylabels=2,
            n_zlabels=2
        )
        
        # Set camera position and zoom
        plotter.camera_position = 'iso'
        plotter.camera.zoom(1)
        
        return plotter
    
    def visualize_interactive(self, voxels):
        # Convert to numpy if needed
        if isinstance(voxels, torch.Tensor):
            if voxels.dim() == 4:  # One-hot encoded [C,H,W,D]
                voxels = voxels.detach().cpu()
                voxels = torch.argmax(voxels, dim=0).numpy()
            else:
                voxels = voxels.detach().cpu().numpy()
                
        # Apply the same transformations as original
        voxels = voxels.transpose(2, 0, 1)
        # Rotate the voxels 90 degrees around the height axis
        voxels = np.rot90(voxels, 1, (0, 1))
                
        # Create grid
        grid = pv.ImageData()
        grid.dimensions = np.array(voxels.shape) + 1
        grid.cell_data["values"] = voxels.flatten(order="F")
        
        # Create plotter
        plotter = pv.Plotter(notebook=True)
        
        # Remove existing lights
        plotter.remove_all_lights()
        
        # Add custom lights
        # Main light from top-front-right (sun-like)
        # Add a headlight (light from camera position)
        # Key light (main light, 45 degrees from front-right)
        plotter.add_light(pv.Light(
            position=(1, -1, 1),
            intensity=1.0,
            color='white'
        ))
        
        # Fill light (softer light from opposite side)
        plotter.add_light(pv.Light(
            position=(-1, 1, 0.5),
            intensity=0.5,
            color='white'
        ))
        
        # Back light (rim lighting from behind)
        plotter.add_light(pv.Light(
            position=(-0.5, -0.5, -1),
            intensity=0.3,
            color='white'
    ))
        
        # Plot each block type
        mask = (voxels != 5) & (voxels != -1)
        unique_blocks = np.unique(voxels[mask])
        
        for block_id in unique_blocks:
            threshold = grid.threshold([block_id-0.5, block_id+0.5])
            if block_id in self.blocks_to_cols:
                color = self.blocks_to_cols[int(block_id)]
                opacity = 1.0 if isinstance(color, str) or len(color) == 3 else color[3]
            else:
                color = (1.0, 0.0, 0.0)
                opacity = 0.2
            
            plotter.add_mesh(threshold, 
                        color=color,
                        opacity=opacity,
                        show_edges=True,
                        edge_color='black',
                        line_width=.2,   # Thin edges
                        edge_opacity=0.2,
                        lighting=True)
        
        # Add a dummy cube to force the bounds
        outline = pv.Cube(bounds=(0, 24, 0, 24, 0, 24))
        plotter.add_mesh(outline, opacity=0.0)  # Invisible cube to set bounds
        
        # Add clean axes with consistent range
        plotter.show_bounds(
            grid='back',
            location='back',
            # all_edges=True,
            # ticks=None,
            font_size=8,
            bold=False,
            font_family='arial',
            use_2d=False,
            bounds=[0, 24, 0, 24, 0, 24],
            axes_ranges=[0, 24, 0, 24, 0, 24],
            padding=0.0,
            n_xlabels=2,
            n_ylabels=2,
            n_zlabels=2,
            # show_xlabels=False,
            # show_ylabels=False,
            # show_zlabels=False
        )
        
        # Set camera position and zoom
        plotter.camera_position = 'iso'
        plotter.camera.zoom(1)
        
        return plotter

# Model

## Vector Quantizer

### Factorized Nearest Neighbor Quantizer

In [10]:
# def compute_entropy_loss(affinity, loss_type="softmax", temperature=0.01):
#     flat_affinity = affinity.reshape(-1, affinity.shape[-1])
#     flat_affinity /= temperature
#     probs = F.softmax(flat_affinity, dim=-1)
#     log_probs = F.log_softmax(flat_affinity + 1e-5, dim=-1)
#     if loss_type == "softmax":
#         target_probs = probs
#     else:
#         raise ValueError("Entropy loss {} not supported".format(loss_type))
#     avg_probs = torch.mean(target_probs, dim=0)
#     avg_entropy = - torch.sum(avg_probs * torch.log(avg_probs + 1e-5))
#     sample_entropy = - torch.mean(torch.sum(target_probs * log_probs, dim=-1))
#     loss = sample_entropy - avg_entropy
#     return loss

# class FQVectorQuantizer(nn.Module):
#     def __init__(self, n_e, e_dim, beta, entropy_loss_ratio, l2_norm, show_usage):
#         super().__init__()
#         # Same initialization as original
#         self.n_e = n_e
#         self.e_dim = e_dim
#         self.beta = beta
#         self.entropy_loss_ratio = entropy_loss_ratio
#         self.l2_norm = l2_norm
#         self.show_usage = show_usage

#         self.embedding = nn.Embedding(self.n_e, self.e_dim)
#         self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
#         if self.l2_norm:
#             self.embedding.weight.data = F.normalize(self.embedding.weight.data, p=2, dim=-1)
#         if self.show_usage:
#             self.register_buffer("codebook_used", nn.Parameter(torch.zeros(65536)))

#     def forward(self, z):
#         # reshape z -> (batch, height, width, depth, channel) and flatten
#         z = torch.einsum('b c h w d -> b h w d c', z).contiguous()  # Changed permute to handle 3D
#         z_flattened = z.view(-1, self.e_dim)

#         if self.l2_norm:
#             z = F.normalize(z, p=2, dim=-1)
#             z_flattened = F.normalize(z_flattened, p=2, dim=-1)
#             embedding = F.normalize(self.embedding.weight, p=2, dim=-1)
#         else:
#             embedding = self.embedding.weight

#         d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
#             torch.sum(embedding**2, dim=1) - 2 * \
#             torch.einsum('bd,dn->bn', z_flattened, torch.einsum('n d -> d n', embedding))

#         min_encoding_indices = torch.argmin(d, dim=1)
#         z_q = embedding[min_encoding_indices].view(z.shape)

#         # Rest of the function remains the same
#         perplexity = None
#         min_encodings = None
#         vq_loss = None
#         commit_loss = None
#         entropy_loss = None
#         codebook_usage = 0

#         # calculate the losses even if we aren't training, otherwise we get Nones when trying to eval on validation
#         vq_loss = torch.mean((z_q - z.detach()) ** 2)
#         commit_loss = self.beta * torch.mean((z_q.detach() - z) ** 2)
#         entropy_loss = self.entropy_loss_ratio * compute_entropy_loss(-d)

#         if self.show_usage and self.training:
#             cur_len = min_encoding_indices.shape[0]
#             self.codebook_used[:-cur_len] = self.codebook_used[cur_len:].clone()
#             self.codebook_used[-cur_len:] = min_encoding_indices
#             codebook_usage = len(torch.unique(self.codebook_used)) / self.n_e
#         else:
#             codebook_usage = 0

#         # if self.training:
#         #     vq_loss = torch.mean((z_q - z.detach()) ** 2)
#         #     commit_loss = self.beta * torch.mean((z_q.detach() - z) ** 2)
#         #     entropy_loss = self.entropy_loss_ratio * compute_entropy_loss(-d)

#         # preserve gradients
#         z_q = z + (z_q - z).detach()

#         # reshape back to match original input shape
#         z_q = torch.einsum('b h w d c -> b c h w d', z_q)  # Changed permute to handle 3D

#         return z_q, (vq_loss, commit_loss, entropy_loss, codebook_usage), (perplexity, min_encodings, min_encoding_indices)

#     def get_codebook_entry(self, indices, shape=None, channel_first=True):
#         # shape = (batch, channel, height, width, depth) if channel_first else (batch, height, width, depth, channel)
#         if self.l2_norm:
#             embedding = F.normalize(self.embedding.weight, p=2, dim=-1)
#         else:
#             embedding = self.embedding.weight
#         z_q = embedding[indices]

#         if shape is not None:
#             if channel_first:
#                 z_q = z_q.reshape(shape[0], shape[2], shape[3], shape[4], shape[1])
#                 # reshape back to match original input shape
#                 z_q = z_q.permute(0, 4, 1, 2, 3).contiguous()
#             else:
#                 z_q = z_q.view(shape)
#         return z_q

### Factorizer EMA Quantizer

In [11]:
# class FactorizedEMAQuantizer(nn.Module):
#     def __init__(self, structure_codebook_size, style_codebook_size, emb_dim, decay=0.99):
#         super().__init__()
#         self.structure_codebook_size = structure_codebook_size
#         self.style_codebook_size = style_codebook_size
#         self.emb_dim = emb_dim
#         self.decay = decay

#         # Structure codebook
#         self.register_buffer('structure_cluster_size', torch.zeros(structure_codebook_size))
#         self.register_buffer('structure_embedding_avg', torch.zeros(structure_codebook_size, emb_dim))
#         self.register_buffer('structure_embedding', torch.randn(structure_codebook_size, emb_dim))
        
#         # Style codebook
#         self.register_buffer('style_cluster_size', torch.zeros(style_codebook_size))
#         self.register_buffer('style_embedding_avg', torch.zeros(style_codebook_size, emb_dim))
#         self.register_buffer('style_embedding', torch.randn(style_codebook_size, emb_dim))

#     def forward(self, z):
#         # Quantize with structure codebook
#         struct_encoding_indices, struct_encodings = self.quantize(
#             z, 
#             self.structure_embedding,
#             self.structure_cluster_size,
#             self.structure_embedding_avg,
#             self.structure_codebook_size
#         )
        
#         # Quantize with style codebook
#         style_encoding_indices, style_encodings = self.quantize(
#             z,
#             self.style_embedding,
#             self.style_cluster_size,
#             self.style_embedding_avg,
#             self.style_codebook_size
#         )

#         # Calculate disentanglement loss
#         struct_norm = F.normalize(struct_encodings, dim=-1)
#         style_norm = F.normalize(style_encodings, dim=-1)
#         disentangle_loss = torch.mean((struct_norm * style_norm).sum(-1) ** 2)

#         # Average the encodings since they're in the same space
#         z_q = (struct_encodings + style_encodings) / 2

#         return z_q, disentangle_loss, {
#             "struct_indices": struct_encoding_indices,
#             "style_indices": style_encoding_indices
#         }

#     def quantize(self, z, embedding, cluster_size, embedding_avg, n_codes):
#         # Reshape z -> (batch, height, width, depth, channel)
#         z = torch.einsum('b c h w d -> b h w d c', z)
#         z_flattened = z.reshape(-1, self.emb_dim)
        
#         # Distances to embeddings
#         d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
#             torch.sum(embedding ** 2, dim=1) - \
#             2 * torch.einsum('bd,nd->bn', z_flattened, embedding)
        
#         # Find nearest codebook entries
#         encoding_indices = torch.argmin(d, dim=1)
#         encodings = F.one_hot(encoding_indices, n_codes).type_as(z_flattened)
        
#         # EMA update of embeddings
#         if self.training:
#             n_total = encodings.sum(0)
#             cluster_size.data.mul_(self.decay).add_(n_total, alpha=1 - self.decay)
            
#             dw = torch.einsum('bn,bd->nd', encodings, z_flattened)
#             embedding_avg.data.mul_(self.decay).add_(dw, alpha=1 - self.decay)

#             n = cluster_size.sum()
#             cluster_size_balanced = (cluster_size + 1e-5) / (n + n_codes * 1e-5) * n
            
#             embedding.data = embedding_avg / cluster_size_balanced.unsqueeze(1)
        
#         # Quantize z
#         z_q = torch.matmul(encodings, embedding)
#         z_q = z_q.view(z.shape)
        
#         # Reshape back
#         z_q = torch.einsum('b h w d c -> b c h w d', z_q)
        
#         return encoding_indices, z_q

#     def get_codebook_entry(self, indices, shape, codebook="structure"):
#         # Select appropriate codebook
#         embedding = self.structure_embedding if codebook == "structure" else self.style_embedding
        
#         # Get quantized latents
#         z_q = embedding[indices]

#         if shape is not None:
#             z_q = z_q.view(shape)

#         return z_q

In [12]:
# class EMAQuantizer(nn.Module):
#     def __init__(self, codebook_size, emb_dim, decay=0.99, eps=1e-5):
#         super().__init__()
        
#         self.codebook_size = codebook_size
#         self.emb_dim = emb_dim
#         self.decay = decay
#         self.eps = eps

#         # Initialize embeddings with randn, transposed from original
#         embed = torch.randn(emb_dim, codebook_size).t()
#         self.register_buffer("embedding", embed)
#         self.register_buffer("cluster_size", torch.zeros(codebook_size))
#         self.register_buffer("embed_avg", embed.clone())

#     def forward(self, z):
#         # Save input shape and flatten
#         b, c, h, w, d = z.shape  # Now includes depth dimension
#         z_flattened = z.permute(0, 2, 3, 4, 1).reshape(-1, self.emb_dim)
        
#         # Calculate distances
#         dist = (
#             z_flattened.pow(2).sum(1, keepdim=True)
#             - 2 * z_flattened @ self.embedding.t()
#             + self.embedding.pow(2).sum(1, keepdim=True).t()
#         )
        
#         # Get closest encodings
#         _, min_encoding_indices = (-dist).max(1)
#         min_encodings = F.one_hot(min_encoding_indices, self.codebook_size).type(z_flattened.dtype)
        
#         # Get quantized latent vectors
#         z_q = torch.matmul(min_encodings, self.embedding)
        
#         # EMA updates during training
#         if self.training:
#             embed_onehot_sum = min_encodings.sum(0)
#             embed_sum = z_flattened.transpose(0, 1) @ min_encodings
            
#             self.cluster_size.data.mul_(self.decay).add_(
#                 embed_onehot_sum, alpha=1 - self.decay
#             )
#             self.embed_avg.data.mul_(self.decay).add_(
#                 embed_sum.t(), alpha=1 - self.decay
#             )

#             n = self.cluster_size.sum()
#             cluster_size = (
#                 (self.cluster_size + self.eps) / (n + self.codebook_size * self.eps) * n
#             )
            
#             embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
#             self.embedding.data.copy_(embed_normalized)
            
#         # Reshape z_q and apply straight-through estimator
#         z_q = z_q.view(b, h, w, d, c)  # Added depth dimension
#         z_q = z_q.permute(0, 4, 1, 2, 3).contiguous()  # [B, C, H, W, D]
        
#         # Straight-through estimator
#         z_q = z + (z_q - z).detach()
        
#         # Calculate perplexity
#         e_mean = torch.mean(min_encodings, dim=0)
#         perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))

#         return z_q, torch.tensor(0.0, device=z.device), {
#             "perplexity": perplexity,
#             "min_encodings": min_encodings,
#             "min_encoding_indices": min_encoding_indices.view(b, h, w, d),
#             "mean_distance": dist.mean()
#         }

#     def get_codebook_entry(self, indices, shape):
#         min_encodings = F.one_hot(indices, self.codebook_size).type(torch.float)
#         z_q = torch.matmul(min_encodings, self.embedding)

#         if shape is not None:
#             z_q = z_q.view(shape).permute(0, 4, 1, 2, 3).contiguous()

#         return z_q

## Factorized Adapter Head
Processes encoder features before they go into each sub-codebook

In [13]:
# class ResidualAttentionBlock(nn.Module):
#     def __init__(
#             self,
#             d_model,
#             n_head,
#             mlp_ratio=4.0,
#             act_layer=nn.GELU,
#             norm_layer=nn.LayerNorm
#     ):
#         super().__init__()

#         self.ln_1 = norm_layer(d_model)
#         self.attn = nn.MultiheadAttention(d_model, n_head)
#         self.mlp_ratio = mlp_ratio
#         if mlp_ratio > 0:
#             self.ln_2 = norm_layer(d_model)
#             mlp_width = int(d_model * mlp_ratio)
#             self.mlp = nn.Sequential(OrderedDict([
#                 ("c_fc", nn.Linear(d_model, mlp_width)),
#                 ("gelu", act_layer()),
#                 ("c_proj", nn.Linear(mlp_width, d_model))
#             ]))

#     def attention(self, x: torch.Tensor):
#         return self.attn(x, x, x, need_weights=False)[0]

#     def forward(self, x: torch.Tensor):
#         attn_output = self.attention(x=self.ln_1(x))
#         x = x + attn_output
#         if self.mlp_ratio > 0:
#             x = x + self.mlp(self.ln_2(x))
#         return x

In [14]:
# class FactorizedAdapter(nn.Module):
#     def __init__(self, down_factor):
#         super().__init__()

#         # Modified for 3D: grid_size now represents volume size
#         self.grid_size = 24 // down_factor  # volume size // down-sample ratio
#         self.width = 256  # same dim as VQ encoder output
#         self.num_layers = 6
#         self.num_heads = 8

#         scale = self.width ** -0.5
#         # Modified for 3D: positional embedding now handles cubic volume
#         self.positional_embedding = nn.Parameter(scale * torch.randn(self.grid_size ** 3, self.width))
#         self.ln_pre = nn.LayerNorm(self.width)
#         self.transformer = nn.ModuleList([
#             ResidualAttentionBlock(self.width, self.num_heads, mlp_ratio=4.0)
#             for _ in range(self.num_layers)
#         ])
#         self.ln_post = nn.LayerNorm(self.width)

#     def forward(self, x):
#         # Modified for 3D: reshape from 5D to sequence
#         h = x.shape[-1]  # depth dimension
#         x = rearrange(x, 'b c h w d -> b (h w d) c')  # flatten 3D volume to sequence

#         x = x + self.positional_embedding.to(x.dtype)
#         x = self.ln_pre(x)
#         x = x.permute(1, 0, 2)  # NLD -> LND
#         for transformer in self.transformer:
#             x = transformer(x)
#         x = x.permute(1, 0, 2)  # LND -> NLD
#         x = self.ln_post(x)

#         # Modified for 3D: reshape back to 5D
#         x = rearrange(
#             x, 
#             'b (h w d) c -> b c h w d', 
#             h=self.grid_size, 
#             w=self.grid_size, 
#             d=self.grid_size
#         )

#         return x

## Down/Up Sample Layers

In [15]:
# class Downsample(nn.Module):
#     def __init__(self, in_channels):
#         super().__init__()
#         self.conv = torch.nn.Conv3d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)

#     def forward(self, x):
#         pad = (0, 1, 0, 1, 0, 1)  # Padding for all 3 dimensions
#         x = torch.nn.functional.pad(x, pad, mode="constant", value=0) #   padding the right and the bottom with 0s
#         x = self.conv(x)
#         return x


# class Upsample(nn.Module):
#     def __init__(self, in_channels):
#         super().__init__()
#         self.conv = nn.Conv3d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)

#     def forward(self, x):
#         x = F.interpolate(x, scale_factor=2.0, mode="nearest")
#         x = self.conv(x)

#         return x

## Res Block

In [16]:
# def normalize(in_channels):
#     return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)   #   divides the channels into 32 groups, and normalizes each group. More effective for smaller batch size than batch norm

# @torch.jit.script
# def swish(x):
#     return x*torch.sigmoid(x)   #  swish activation function, compiled using torch.jit.script. Smooth, non-linear activation function, works better than ReLu in some cases. swish (x) = x * sigmoid(x)

# class ResBlock(nn.Module):
#     def __init__(self, in_channels, out_channels=None):
#         super(ResBlock, self).__init__()
#         self.in_channels = in_channels
#         self.out_channels = in_channels if out_channels is None else out_channels
#         self.norm1 = normalize(in_channels)
#         self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
#         self.norm2 = normalize(out_channels)
#         self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
#         self.conv_out = nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)

#     def forward(self, x_in):
#         x = x_in
#         x = self.norm1(x)
#         x = swish(x)
#         x = self.conv1(x)
#         x = self.norm2(x)
#         x = swish(x)
#         x = self.conv2(x)
#         if self.in_channels != self.out_channels:
#             x_in = self.conv_out(x_in)

#         return x + x_in

## Attention Block

In [17]:
# class AttnBlock(nn.Module):
#     def __init__(self, in_channels):
#         super().__init__()
#         self.in_channels = in_channels

#         self.norm = normalize(in_channels)
#         # Convert all 2D convolutions to 3D
#         self.q = torch.nn.Conv3d(
#             in_channels,
#             in_channels,
#             kernel_size=1,
#             stride=1,
#             padding=0
#         )
#         self.k = torch.nn.Conv3d(
#             in_channels,
#             in_channels,
#             kernel_size=1,
#             stride=1,
#             padding=0
#         )
#         self.v = torch.nn.Conv3d(
#             in_channels,
#             in_channels,
#             kernel_size=1,
#             stride=1,
#             padding=0
#         )
#         self.proj_out = torch.nn.Conv3d(
#             in_channels,
#             in_channels,
#             kernel_size=1,
#             stride=1,
#             padding=0
#         )

#     def forward(self, x):
#         h_ = x
#         h_ = self.norm(h_)
#         q = self.q(h_)
#         k = self.k(h_)
#         v = self.v(h_)

#         # compute attention
#         b, c, h, w, d = q.shape
#         q = q.reshape(b, c, h*w*d)    # Flatten all spatial dimensions
#         q = q.permute(0, 2, 1)        # b, hwd, c
#         k = k.reshape(b, c, h*w*d)    # b, c, hwd
#         w_ = torch.bmm(q, k)          # b, hwd, hwd    
#         w_ = w_ * (int(c)**(-0.5))    # Scale dot products
#         w_ = F.softmax(w_, dim=2)     # Softmax over spatial positions

#         # attend to values
#         v = v.reshape(b, c, h*w*d)
#         w_ = w_.permute(0, 2, 1)      # b, hwd, hwd (first hwd of k, second of q)
#         h_ = torch.bmm(v, w_)         # b, c, hwd
#         h_ = h_.reshape(b, c, h, w, d) # Restore spatial structure

#         h_ = self.proj_out(h_)

#         return x + h_

## Encoder class 

In [18]:
# def Normalize(in_channels, norm_type='group'):
#     assert norm_type in ['group', 'batch']
#     if norm_type == 'group':
#         return nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
#     elif norm_type == 'batch':
#         return nn.SyncBatchNorm(in_channels)
    
# def nonlinearity(x):
#     # swish
#     return x*torch.sigmoid(x)

In [19]:
# class Encoder(nn.Module):
#     def __init__(self, in_channels, nf, out_channels, ch_mult, num_res_blocks, resolution):
#         super().__init__()
#         self.nf = nf
#         self.num_resolutions = len(ch_mult)
#         self.num_res_blocks = num_res_blocks
#         self.resolution = resolution

#         self.conv_in = nn.Conv3d(in_channels, nf, kernel_size=3, stride=1, padding=1)

#         in_ch_mult = (1,) + tuple(ch_mult)

#         self.conv_blocks = nn.ModuleList()
#         for i_level in range(self.num_resolutions):
#             conv_block = nn.Module()
#             # res & attn
#             res_block = nn.ModuleList()
#             attn_block = nn.ModuleList()
#             block_in = nf * in_ch_mult[i_level]
#             block_out = nf * ch_mult[i_level]
#             for _ in range(self.num_res_blocks):
#                 res_block.append(ResBlock(block_in, block_out))
#                 block_in = block_out
#                 if i_level == self.num_resolutions - 1:
#                     attn_block.append(AttnBlock(block_in))
#             conv_block.res = res_block
#             conv_block.attn = attn_block
#             # downsample
#             if i_level != self.num_resolutions-1:
#                 conv_block.downsample = Downsample(block_in)
#             self.conv_blocks.append(conv_block)

#         # middle
#         self.mid = nn.ModuleList()
#         self.mid.append(ResBlock(block_in, block_in))
#         self.mid.append(AttnBlock(block_in))
#         self.mid.append(ResBlock(block_in, block_in))


#         if self.num_resolutions == 5:
#             down_factor = 16
#         elif self.num_resolutions == 4:
#             down_factor = 8
#         elif self.num_resolutions == 3:
#             down_factor = 4
#         else:
#             raise NotImplementedError
        
#         # semantic head
#         self.style_head = nn.ModuleList()
#         self.style_head.append(FactorizedAdapter(down_factor))

#         # structural details head
#         self.structure_head = nn.ModuleList()
#         self.structure_head.append(FactorizedAdapter(down_factor))

#         # end
#         self.norm_out_style = Normalize(block_in)
#         self.conv_out_style = nn.Conv3d(block_in, out_channels, kernel_size=3, stride=1, padding=1)

#         self.norm_out_struct = Normalize(block_in)
#         self.conv_out_struct = nn.Conv3d(block_in, out_channels, kernel_size=3, stride=1, padding=1)
#         # blocks = []
#         # # Initial convolution - now 3D
#         # blocks.append(
#         #     nn.Conv3d(
#         #         in_channels, 
#         #         nf, 
#         #         kernel_size=3, 
#         #         stride=1, 
#         #         padding=1
#         #     )
#         # )

#         # # Residual and downsampling blocks, with attention on specified resolutions
#         # for i in range(self.num_resolutions):
#         #     block_in_ch = nf * in_ch_mult[i]
#         #     block_out_ch = nf * ch_mult[i]
            
#         #     # Add ResBlocks
#         #     for _ in range(self.num_res_blocks):
#         #         blocks.append(ResBlock(block_in_ch, block_out_ch))
#         #         block_in_ch = block_out_ch
                
#         #         # Add attention if we're at the right resolution
#         #         if curr_res in attn_resolutions:
#         #             blocks.append(AttnBlock(block_in_ch))

#         #     # Add downsampling block if not the last resolution
#         #     if i != self.num_resolutions - 1:
#         #         blocks.append(Downsample(block_in_ch))
#         #         curr_res = curr_res // 2

#         # # Final blocks
#         # blocks.append(ResBlock(block_in_ch, block_in_ch))
#         # blocks.append(AttnBlock(block_in_ch))
#         # blocks.append(ResBlock(block_in_ch, block_in_ch))

#         # # Normalize and convert to latent size
#         # blocks.append(normalize(block_in_ch))
#         # blocks.append(
#         #     nn.Conv3d(
#         #         block_in_ch, 
#         #         out_channels, 
#         #         kernel_size=3, 
#         #         stride=1, 
#         #         padding=1
#         #     )
#         # )


#     def forward(self, x):
#         h = self.conv_in(x)
#         # downsampling
#         for i_level, block in enumerate(self.conv_blocks):
#             for i_block in range(self.num_res_blocks):
#                 h = block.res[i_block](h)
#                 if len(block.attn) > 0:
#                     h = block.attn[i_block](h)
#             if i_level != self.num_resolutions - 1:
#                 h = block.downsample(h)
        
#         # middle
#         for mid_block in self.mid:
#             h = mid_block(h)
#         h_style = h
#         h_struct = h

#         # style head
#         for blk in self.style_head:
#             h_style = blk(h_style)

#         h_style = self.norm_out_style(h_style)
#         h_style = nonlinearity(h_style)
#         h_style = self.conv_out_style(h_style)

#         # structure head
#         for blk in self.structure_head:
#             h_struct = blk(h_struct)

#         h_struct = self.norm_out_struct(h_struct)
#         h_struct = nonlinearity(h_struct)
#         h_struct = self.conv_out_struct(h_struct)

#         return h_style, h_struct
#         # for block in self.blocks:
#         #     x = block(x)
#         # return x

## Generator / Decoderclass

In [20]:
# class Generator(nn.Module):
#     def __init__(self, H, z_channels=256):
#         super().__init__()
#         self.nf = H.nf
#         self.ch_mult = H.ch_mult
#         self.num_resolutions = len(self.ch_mult)
#         self.num_res_blocks = H.res_blocks
#         self.resolution = H.img_size
#         self.attn_resolutions = H.attn_resolutions
#         self.in_channels = H.emb_dim
#         self.out_channels = H.n_channels

#         block_in = self.nf * self.ch_mult[self.num_resolutions-1]


#         # z to block_in
#         # self.conv_in = nn.Conv3d(z_channels * 2, block_in, kernel_size=3, stride=1, padding=1)
#         #TODO: trying addition instead of concat
#         self.conv_in = nn.Conv3d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
#         # middle
#         self.mid = nn.ModuleList()
#         self.mid.append(ResBlock(block_in, block_in))
#         # self.mid.append(AttnBlock(block_in))
#         self.mid.append(ResBlock(block_in, block_in))

#         # upsampling
#         self.conv_blocks = nn.ModuleList()
#         for i_level in reversed(range(self.num_resolutions)):
#             conv_block = nn.Module()
#             # res & attn
#             res_block = nn.ModuleList()
#             attn_block = nn.ModuleList()
#             block_out = self.nf * self.ch_mult[i_level]
#             for _ in range(self.num_res_blocks + 1):
#                 res_block.append(ResBlock(block_in, block_out))
#                 block_in = block_out
#                 if i_level == self.num_resolutions - 1:
#                     attn_block.append(AttnBlock(block_in))
#             conv_block.res = res_block
#             conv_block.attn = attn_block
#             # downsample
#             if i_level != 0:
#                 conv_block.upsample = Upsample(block_in)
#             self.conv_blocks.append(conv_block)

#         # end
#         self.norm_out = Normalize(block_in)
#         self.conv_out = nn.Conv3d(block_in, H.n_channels, kernel_size=3, stride=1, padding=1)

#         # block_in_ch = self.nf * self.ch_mult[-1]
#         # curr_res = self.resolution // 2 ** (self.num_resolutions-1)

#         # print(f'resolution: {self.resolution}, num_resolutions: {self.num_resolutions}, '
#         #       f'num_res_blocks: {self.num_res_blocks}, attn_resolutions: {self.attn_resolutions}, '
#         #       f'in_channels: {self.in_channels}, out_channels: {self.out_channels}, '
#         #       f'block_in_ch: {block_in_ch}, curr_res: {curr_res}')

#         # blocks = []
#         # # Initial conv - now 3D
#         # blocks.append(nn.Conv3d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1))

#         # # Non-local attention block
#         # blocks.append(ResBlock(block_in_ch, block_in_ch))
#         # blocks.append(AttnBlock(block_in_ch))
#         # blocks.append(ResBlock(block_in_ch, block_in_ch))

#         # # Upsampling blocks
#         # for i in reversed(range(self.num_resolutions)):
#         #     block_out_ch = self.nf * self.ch_mult[i]

#         #     for _ in range(self.num_res_blocks):
#         #         blocks.append(ResBlock(block_in_ch, block_out_ch))
#         #         block_in_ch = block_out_ch

#         #         if curr_res in self.attn_resolutions:
#         #             blocks.append(AttnBlock(block_in_ch))

#         #     if i != 0:
#         #         blocks.append(Upsample(block_in_ch))
#         #         curr_res = curr_res * 2

#         # # Final processing
#         # blocks.append(normalize(block_in_ch))
#         # blocks.append(nn.Conv3d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1))

#         # self.blocks = nn.ModuleList(blocks)

#         # # Used for calculating ELBO - fine tuned after training
#         # self.logsigma = nn.Sequential(
#         #     nn.Conv3d(block_in_ch, block_in_ch, kernel_size=3, stride=1, padding=1),
#         #     nn.ReLU(),
#         #     nn.Conv3d(block_in_ch, H.n_channels, kernel_size=1, stride=1, padding=0)
#         # ).cuda()
#     @property
#     def last_layer(self):
#         return self.conv_out.weight
    
#     def forward(self, z):
#         #TODO: trying addition instead of concat
#         B, C, H, W, D = z.shape
#         z_style = z[:, :C//2]  # First half of channels
#         z_struct = z[:, C//2:]  # Second half of channels
        
#         # Add the vectors
#         z_add = z_style + z_struct

#         # z to block_in
#         h = self.conv_in(z_add)

#         # middle
#         for mid_block in self.mid:
#             h = mid_block(h)
        
#         # upsampling
#         for i_level, block in enumerate(self.conv_blocks):
#             for i_block in range(self.num_res_blocks + 1):
#                 h = block.res[i_block](h)
#                 if len(block.attn) > 0:
#                     h = block.attn[i_block](h)
#             if i_level != self.num_resolutions - 1:
#                 h = block.upsample(h)

#         # end
#         h = self.norm_out(h)
#         h = nonlinearity(h)
#         h = self.conv_out(h)

#         return h
#         # for block in self.blocks:
#         #     x = block(x)
#         # return x

#     # def probabilistic(self, x):
#     #     with torch.no_grad():
#     #         for block in self.blocks[:-1]:
#     #             x = block(x)
#     #         mu = self.blocks[-1](x)
#     #     logsigma = self.logsigma(x)
#     #     return mu, logsigma

## Discriminator class (patch based, gives likelyhood each patch is real or fake)

In [21]:
# class PatchGAN3DDiscriminator(nn.Module):
#     """3D PatchGAN discriminator adapted for Minecraft voxel data"""
#     def __init__(self, input_nc, ndf=64, n_layers=3):
#         """
#         Parameters:
#             input_nc (int)  -- number of input channels (block types)
#             ndf (int)       -- number of filters in first conv layer
#             n_layers (int)  -- number of conv layers
#         """
#         super().__init__()
#         norm_layer = nn.BatchNorm3d

        
#         use_bias = norm_layer != nn.BatchNorm3d

#         kw = 4  # kernel size
#         padw = 1
#         sequence = [
#             nn.Conv3d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
#             nn.LeakyReLU(0.2, True)
#         ]
        
#         nf_mult = 1
#         nf_mult_prev = 1
#         for n in range(1, n_layers):
#             nf_mult_prev = nf_mult
#             nf_mult = min(2 ** n, 8)
#             sequence += [
#                 nn.Conv3d(ndf * nf_mult_prev, ndf * nf_mult, 
#                          kernel_size=kw, stride=2, padding=padw, bias=use_bias),
#                 norm_layer(ndf * nf_mult),
#                 nn.LeakyReLU(0.2, True)
#             ]

#         nf_mult_prev = nf_mult
#         nf_mult = min(2 ** n_layers, 8)
#         sequence += [
#             nn.Conv3d(ndf * nf_mult_prev, ndf * nf_mult,
#                      kernel_size=kw, stride=1, padding=padw, bias=use_bias),
#             norm_layer(ndf * nf_mult),
#             nn.LeakyReLU(0.2, True)
#         ]

#         sequence += [
#             nn.Conv3d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)
#         ]

#         self.main = nn.Sequential(*sequence)
#         self.apply(self._init_weights)
    
#     def _init_weights(self, module):    
#         if isinstance(module, nn.Conv3d):
#             nn.init.normal_(module.weight.data, 0.0, 0.02)
#         elif isinstance(module, nn.BatchNorm3d):
#             nn.init.normal_(module.weight.data, 1.0, 0.02)
#             nn.init.constant_(module.bias.data, 0)

#     def forward(self, input):
#         return self.main(input)

In [22]:
# class Discriminator3D(nn.Module):
#     def __init__(self, input_nc=43, ndf=64, n_layers=3):
#         """Simple 3D convolutional discriminator
        
#         Args:
#             input_nc (int): Number of input channels (number of block types)
#             ndf (int): Number of filters in first conv layer
#             n_layers (int): Number of conv layers
#         """
#         super().__init__()
        
#         # Initial convolution
#         layers = [
#             nn.Conv3d(input_nc, ndf, kernel_size=4, stride=2, padding=1),
#             nn.LeakyReLU(0.2, inplace=True)
#         ]
        
#         # Increasing number of filters with each layer
#         current_channels = ndf
#         for i in range(n_layers - 1):
#             next_channels = min(current_channels * 2, 512)
#             layers.extend([
#                 nn.Conv3d(current_channels, next_channels, 
#                          kernel_size=4, stride=2, padding=1),
#                 nn.BatchNorm3d(next_channels),
#                 nn.LeakyReLU(0.2, inplace=True)
#             ])
#             current_channels = next_channels
        
#         # Final layers
#         layers.extend([
#             nn.Conv3d(current_channels, current_channels,
#                      kernel_size=4, stride=1, padding=1),
#             nn.BatchNorm3d(current_channels),
#             nn.LeakyReLU(0.2, inplace=True),
#             nn.Conv3d(current_channels, 1, kernel_size=4, stride=1, padding=1)
#         ])
        
#         self.model = nn.Sequential(*layers)
        
#     def forward(self, x):
#         """
#         Args:
#             x: Input tensor of shape [B, C, D, H, W]
#                 where C is number of block types (one-hot encoded)
#         Returns:
#             Tensor of shape [B, 1, D', H', W'] containing realness scores
#         """
#         return self.model(x)

## Biome "Tower"
Similar to the CLIP vision tower they use

In [23]:
# class BiomeFeatureModel(nn.Module):
#     def __init__(self, biome_classifier_path):
#         super().__init__()
#         self.biome_classifier = BiomeClassifier(
#             num_block_types=43,
#             num_biomes=13,
#             feature_dim=256
#         ).cuda()
        
#         # Load pretrained weights for biome classifier
#         self.biome_classifier.load_state_dict(torch.load(biome_classifier_path))
#         self.biome_classifier.eval()
#         for param in self.biome_classifier.parameters():
#             param.requires_grad = False


#     # @torch.no_grad()
#     def forward(self, inputs, style_features):
#         # Get biome features from real input
#         biome_features = self.biome_classifier.get_intermediate_features(inputs)  # [B, C, 6, 6, 6]
#         B, C, H, W, D = biome_features.shape
#         biome_features = biome_features.permute(0, 2, 3, 4, 1)  # [B, 6, 6, 6, C]
#         biome_features = biome_features.reshape(B, H*W*D, C)    # [B, 216, C]
        
#         # Normalize features
#         # biome_features = F.normalize(biome_features, dim=-1).detach()
#         # style_features = F.normalize(style_features, dim=-1)

#         biome_features = biome_features.detach()

#         # Print feature statistics
#         # print("Biome features mean/std:", biome_features.mean().item(), biome_features.std().item())
#         # print("Style features mean/std:", style_features.mean().item(), style_features.std().item())
        
#         # # Simple cosine similarity loss
#         # loss = 1 - F.cosine_similarity(style_features, biome_features, dim=-1).mean()
#         # or MSE loss
#         loss = F.mse_loss(style_features, biome_features)
#         return loss
#     # @torch.no_grad()
#     # def forward(self, inputs, style_features):
#     #     # Extract biome features from real input using pretrained classifier
#     #     biome_features = self.biome_classifier.get_intermediate_features(inputs)  # [B, C, 6, 6, 6]
#     #     B, C, H, W, D = biome_features.shape
#     #     biome_features = biome_features.permute(0, 2, 3, 4, 1)  # [B, 6, 6, 6, C]
#     #     biome_features = biome_features.reshape(B, H*W*D, C)    # [B, 216, C]
        
#     #     # Normalize both features
#     #     biome_features = F.normalize(biome_features, p=2, dim=-1).detach()  # Fixed target
#     #     style_features = F.normalize(style_features, p=2, dim=-1)  # These will be optimized
        
#     #     # Compute similarity matrix and InfoNCE loss
#     #     loss_mat = torch.bmm(style_features, biome_features.transpose(1, 2))
#     #     loss_mat = loss_mat.exp()
        
#     #     loss_diag = torch.diagonal(loss_mat, dim1=1, dim2=2)  # [B, 216]
#     #     loss_denom = loss_mat.sum(dim=2)  # [B, 216]
        
#     #     loss_InfoNCE = -(loss_diag / loss_denom).log().mean()
        
#     #     return loss_InfoNCE
#     # @torch.no_grad()
#     # def forward(self, inputs, semantic_feat):
#     #     # Get features from biome classifier (B, C, 6, 6, 6)
#     #     real_features = self.biome_classifier.get_intermediate_features(inputs)
        
#     #     # Print shapes for debugging
#     #     print("Before reshape:")
#     #     print("real_features:", real_features.shape)
#     #     print("semantic_feat:", semantic_feat.shape)
        
#     #     # Reshape to (B, N, C) where N = 6*6*6
#     #     B, C, H, W, D = real_features.shape
#     #     real_features = real_features.view(B, C, -1).permute(0, 2, 1)  # B, 216, C
        
#     #     print("After reshape:")
#     #     print("real_features:", real_features.shape)
#     #     print("semantic_feat:", semantic_feat.shape)
        
#     #     # Normalize both feature sets
#     #     real_features = F.normalize(real_features, p=2, dim=-1)
#     #     semantic_feat = F.normalize(semantic_feat, p=2, dim=-1)
        
#     #     # Matrix multiplication should be:
#     #     # (B, 216, C) @ (B, 216, C).transpose(-2, -1) -> (B, 216, 216)
#     #     loss_mat = (semantic_feat @ real_features.detach().mT)  # Use mT and proper dimensions
#     #     loss_diag = loss_mat.diag()
#     #     loss_denom = loss_mat.sum(1)
#     #     loss_InfoNCE = -(loss_diag / loss_denom).log().mean()
        
#     #     return loss_InfoNCE

## Dual Codebook Loss

In [24]:
# class VQLossDualCodebook(nn.Module):
#     def __init__(self, H):
#         super().__init__()
#         # Discriminator parameters
#         if H.disc_type == 'conv':
#             # Use simpler 3D discriminator
#             self.discriminator = Discriminator3D(
#                 input_nc=H.n_channels,  # number of block types
#                 ndf=H.ndf  # base number of filters
#             ).cuda()
#         elif H.disc_type == 'patch':
#             self.discriminator = PatchGAN3DDiscriminator(
#                 input_nc=H.n_channels,
#                 ndf=H.ndf,
#                 n_layers=H.disc_layers
#             ).cuda()
        
#         # Initialize BiomeClassifier for feature extraction
#         self.biome_feature_model = BiomeFeatureModel(H.biome_classifier_path)
        
#         # Loss weights and parameters
#         self.disc_start_step = H.disc_start_step
#         self.disc_weight_max = H.disc_weight_max
#         self.disc_weight_min = 0.0
#         self.disc_adaptive_weight = H.disc_adaptive_weight
#         self.reconstruction_weight = H.reconstruction_weight
#         self.codebook_weight = H.codebook_weight
#         self.biome_weight = H.biome_weight
#         self.disentanglement_ratio = H.disentanglement_ratio
        
#         # Loss functions
#         self.disc_loss = non_saturating_d_loss
#         self.gen_loss = hinge_gen_loss

#     def adopt_weight(self, weight, global_step, threshold=0, value=0.):
#         """Gradually adopt weight after threshold step"""
#         if global_step < threshold:
#             weight = value
#         return weight
    
#     def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
#         """Dynamically adjust discriminator weight to balance with other losses"""
#         if last_layer is not None:
#             nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
#             g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
            
#             d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
#             d_weight = torch.clamp(d_weight, self.disc_weight_min, self.disc_weight_max).detach()
#             return d_weight
#         return 1.0

    
#     def compute_biome_feature_loss(self, inputs, reconstructions):
#         """Compute feature matching loss using BiomeClassifier's intermediate features"""
#         with torch.no_grad():
#             real_features = self.biome_classifier.get_intermediate_features(inputs)
#         fake_features = self.biome_classifier.get_intermediate_features(reconstructions)
        
#         # Normalize features
#         real_features = F.normalize(real_features, p=2, dim=1)
#         fake_features = F.normalize(fake_features, p=2, dim=1)
        
#         # Compute feature matching loss
#         feature_loss = F.mse_loss(fake_features, real_features)
        
#         return feature_loss

#     def forward(self, codebook_loss_style, codebook_loss_struct,
#                 inputs, reconstructions, disentangle_loss, biome_feat,
#                 optimizer_idx, global_step, last_layer=None):
        
#         if optimizer_idx == 0:
#             rec_loss = F.cross_entropy(
#                 reconstructions.contiguous(), 
#                 torch.argmax(inputs, dim=1).contiguous()
#             ) * self.reconstruction_weight
                        
#             # Codebook losses
#             style_loss = sum(codebook_loss_style[:3]) * self.codebook_weight
#             struct_loss = sum(codebook_loss_struct[:3]) * self.codebook_weight
            
#             # Biome feature loss using InfoNCE
#             if biome_feat is not None:
#                 biome_feat_loss = self.biome_feature_model(inputs.contiguous(), biome_feat)
#                 biome_feat_loss = self.biome_weight * biome_feat_loss
#             else:
#                 biome_feat_loss = 0.0

#             # Check if biome loss is being scaled properly
#             # print("Raw biome loss:", biome_feat_loss.item())
#             # print("Scaled biome loss:", (self.biome_weight * biome_feat_loss).item())
            
#             # Disentanglement loss
#             disent_loss = self.disentanglement_ratio * disentangle_loss if disentangle_loss is not None else 0.0
            
#             # Generator adversarial loss with adaptive weight
#             disc_weight = self.adopt_weight(self.disc_weight_max, global_step, 
#                                           threshold=self.disc_start_step, value=0.0)
            
#             logits_fake = self.discriminator(reconstructions.contiguous())
#             g_loss = self.gen_loss(logits_fake)
            
#             if self.disc_adaptive_weight:
#                 null_loss = rec_loss + biome_feat_loss
#                 disc_adaptive_weight = self.calculate_adaptive_weight(null_loss, g_loss, last_layer)
#                 g_loss = g_loss * disc_weight * disc_adaptive_weight
#             else:
#                 g_loss = g_loss * disc_weight
            
#             # Total loss
#             loss = rec_loss + style_loss + struct_loss + biome_feat_loss + disent_loss + g_loss
            
#             return {
#                 'loss': loss,
#                 'rec_loss': rec_loss,
#                 'style_loss': style_loss,
#                 'struct_loss': struct_loss,
#                 'biome_feat_loss': biome_feat_loss,
#                 'disent_loss': disent_loss,
#                 'g_loss': g_loss,
#                 'disc_weight': disc_weight,
#                 'disc_adaptive_weight': disc_adaptive_weight if self.disc_adaptive_weight else 1.0,
#                 'codebook_usage_style': codebook_loss_style[3],
#                 'codebook_usage_struct': codebook_loss_struct[3]
#             }
            
#         # Discriminator update
#         elif optimizer_idx == 1:
#             # Get discriminator predictions
#             logits_real = self.discriminator(inputs.contiguous().detach())
#             logits_fake = self.discriminator(reconstructions.contiguous().detach())

#             # Calculate discriminator loss with weight adoption
#             disc_weight = self.adopt_weight(self.disc_weight_max, global_step, 
#                                           threshold=self.disc_start_step, value=0.0)
#             d_loss = self.disc_loss(logits_real, logits_fake) * disc_weight
            
#             return {
#                 'd_loss': d_loss,
#                 'logits_real': logits_real.mean(),
#                 'logits_fake': logits_fake.mean(),
#                 'disc_weight': disc_weight
#             }
        

# def hinge_d_loss(logits_real, logits_fake):
#     loss_real = torch.mean(F.relu(1. - logits_real))
#     loss_fake = torch.mean(F.relu(1. + logits_fake))

#     d_loss = 0.5 * (loss_real + loss_fake)
#     return d_loss


# def vanilla_d_loss(logits_real, logits_fake):
#     loss_real = torch.mean(F.softplus(-logits_real))
#     loss_fake = torch.mean(F.softplus(logits_fake))
#     d_loss = 0.5 * (loss_real + loss_fake)
#     return d_loss


# def non_saturating_d_loss(logits_real, logits_fake):
#     loss_real = torch.mean(F.binary_cross_entropy_with_logits(torch.ones_like(logits_real),  logits_real))
#     loss_fake = torch.mean(F.binary_cross_entropy_with_logits(torch.zeros_like(logits_fake), logits_fake))
#     d_loss = 0.5 * (loss_real + loss_fake)
#     return d_loss

# def hinge_gen_loss(logits_fake):
#     return -torch.mean(logits_fake)

# def vanilla_d_loss(logits_real, logits_fake):
#     loss_real = torch.mean(F.softplus(-logits_real))
#     loss_fake = torch.mean(F.softplus(logits_fake))
#     d_loss = 0.5 * (loss_real + loss_fake)
#     return d_loss

## FQVAE class - composed of a encoder, quantizer, and decoder (aka generator)

In [25]:

# def _expand_token(token, batch_size: int):
#     return token.unsqueeze(0).expand(batch_size, -1, -1)

# class FeatPredHead(nn.Module):
#     def __init__(self, input_dim=256, down_factor=16):
#         super().__init__()
#         self.grid_size = 24 // down_factor
#         self.width = 256
#         self.num_layers = 3
#         self.num_heads = 8

#         self.upscale = nn.Sequential(
#             nn.Linear(input_dim, self.width),
#             nn.ReLU(),
#             nn.Linear(self.width, self.width)
#         )

#         scale = self.width ** -0.5
#         # Remove class embedding
#         self.positional_embedding = nn.Parameter(scale * torch.randn(self.grid_size ** 3, self.width))
#         self.ln_pre = nn.LayerNorm(self.width)
#         self.transformer = nn.ModuleList([
#             ResidualAttentionBlock(self.width, self.num_heads, mlp_ratio=4.0)
#             for _ in range(self.num_layers)
#         ])
#         self.ln_post = nn.LayerNorm(self.width)

#     def forward(self, x):
#         x = rearrange(x, 'b c h w d -> b (h w d) c')
#         x = self.upscale(x)

#         # No class token addition
#         x = x + self.positional_embedding.to(x.dtype)
#         x = self.ln_pre(x)
#         x = x.permute(1, 0, 2)
#         for layer in self.transformer:
#             x = layer(x)
#         x = x.permute(1, 0, 2)
#         x = self.ln_post(x)
#         return x  # Shape will be [B, 216, C]
    
# class FQModel(nn.Module):
#     def __init__(self, H):
#         super().__init__()
#         # Basic parameters
#         self.in_channels = H.n_channels
#         self.nf = H.nf
#         self.n_blocks = H.res_blocks
#         self.codebook_size = H.codebook_size
#         self.embed_dim = H.emb_dim
#         self.ch_mult = H.ch_mult
#         self.num_resolutions = len(self.ch_mult)
#         self.resolution = H.img_size
#         self.z_channels = H.z_channels
#         self.with_biome_supervision = H.with_biome_supervision
#         self.with_disentanglement = H.with_disentanglement
#         self.disentanglement_ratio = H.disentanglement_ratio
        
#         # Two head encoder
#         self.encoder = Encoder(
#             self.in_channels,
#             self.nf,
#             self.embed_dim,
#             self.ch_mult,
#             self.n_blocks,
#             self.resolution
#         )

#         # Quantizer for style head (semantic)
#         self.quantize_style = FQVectorQuantizer(
#             # self.codebook_size, 
#             20, 
#             self.embed_dim,
#             H.beta, 
#             H.entropy_loss_ratio,
#             H.codebook_l2_norm, 
#             H.codebook_show_usage
#         )
#         self.quant_conv_style = nn.Conv3d(self.z_channels, self.embed_dim, 1)

#         # Quantizer for structural head (visual)
#         self.quantize_struct = FQVectorQuantizer(
#             self.codebook_size, 
#             self.embed_dim,
#             H.beta, 
#             H.entropy_loss_ratio,
#             H.codebook_l2_norm, 
#             H.codebook_show_usage
#         )
#         self.quant_conv_struct = nn.Conv3d(self.z_channels, self.embed_dim, 1)

#         # Pixel decoder
#         input_dim = self.embed_dim * 2  # Combined dimension from both codebooks
#         self.post_quant_conv = nn.Conv3d(input_dim, self.z_channels, 1)
#         self.decoder = Generator(H)

#         # Determine downsampling factor
#         if self.num_resolutions == 5:
#             down_factor = 16
#         elif self.num_resolutions == 4:
#             down_factor = 8
#         elif self.num_resolutions == 3:
#             down_factor = 4
#         else:
#             raise NotImplementedError

#         # Biome prediction head for style representation learning
#         if H.with_biome_supervision:
#             print("Include feature prediction head for biome supervision")
#             self.feat_pred_head = FeatPredHead(input_dim=self.embed_dim, down_factor=down_factor)
#         else:
#             print("NO biome supervision")

#         if H.with_disentanglement:
#             print("Disentangle Ratio: ", H.disentanglement_ratio)
#         else:
#             print("No Disentangle Regularization")

#     def compute_disentangle_loss(self, quant_struct, quant_style):
#         # Reshape from 5D to 2D
#         quant_struct = rearrange(quant_struct, 'b c h w d -> (b h w d) c')
#         quant_style = rearrange(quant_style, 'b c h w d -> (b h w d) c')

#         # Normalize the vectors
#         quant_struct = F.normalize(quant_struct, p=2, dim=-1)
#         quant_style = F.normalize(quant_style, p=2, dim=-1)

#         # Compute dot product and loss
#         dot_product = torch.sum(quant_struct * quant_style, dim=1)
#         loss = torch.mean(dot_product ** 2) * self.disentanglement_ratio

#         return loss

#     def forward(self, input):
#         # Get both style and structure encodings
#         h_style, h_struct = self.encoder(input)
#         h_style = self.quant_conv_style(h_style)
#         h_struct = self.quant_conv_struct(h_struct)

#         # Quantize both paths
#         quant_style, emb_loss_style, _ = self.quantize_style(h_style)
#         # print("quant_style requires grad:", quant_style.requires_grad)

#         quant_struct, emb_loss_struct, _ = self.quantize_struct(h_struct)
        

#         # Biome feature prediction if enabled
#         if self.with_biome_supervision:
#             style_feat = self.feat_pred_head(quant_style)
#             # print("style_feat requires grad:", style_feat.requires_grad)
#             style_feat.retain_grad()  # Add this line

#         else:
#             style_feat = None

#         # Compute disentanglement loss if enabled
#         if self.with_disentanglement:
#             disentangle_loss = self.compute_disentangle_loss(quant_struct, quant_style)
#         else:
#             disentangle_loss = 0

#         # Combine quantized representations and decode
#         quant = torch.cat([quant_struct, quant_style], dim=1)
#         dec = self.decoder(quant)

#         return dec, emb_loss_style, emb_loss_struct, disentangle_loss, style_feat

# Hyperparams

In [26]:
class HparamsBase(dict):
    def __init__(self, dataset):
        self.dataset = dataset

    def __getattr__(self, attr):
        try:
            return self[attr]
        except KeyError:
            return None

    def __setattr__(self, attr, value):
        self[attr] = value

class HparamsFQGAN(HparamsBase):
    def __init__(self, dataset):
        super().__init__(dataset)
        
        if self.dataset == 'minecraft':
            # Existing parameters
            self.batch_size = 8
            self.img_size = 24
            self.n_channels = 42
            self.nf = 64
            self.ndf = 64
            self.res_blocks = 2
            self.latent_shape = [1, 6, 6, 6]
            
            # Structure consistency parameters
            self.with_struct_consistency = False  # Enable structure consistency regularization
            self.struct_consistency_weight = 2.0  # Weight for structure consistency loss
            self.struct_consistency_threshold = 0.72  # Similarity threshold

            # cycle consistency parameters
            self.with_cycle_consistency = True
            self.cycle_consistency_type = 'post_quant_conv'
            self.cycle_consistency_weight = 0.0
            self.disc_gumbel_for_cycle_input = True
            self.cycle_start_step = 0
            
            # New parameters for dual codebook architecture
            self.struct_codebook_size = 32  # Size of each codebook
            self.style_codebook_size = 16  # Size of each codebook
            self.emb_dim = 8  # Embedding dimension
            self.z_channels = 8  # Bottleneck channels
            self.ch_mult = [1, 2, 4]  # Channel multipliers for progressive downsampling
            self.num_resolutions = len(self.ch_mult)
            self.attn_resolutions = [6]  # Resolutions at which to apply attention
            
            # Loss weights and parameters
            self.disc_type = 'conv'
            self.disc_weight_max = 1.0  # Weight for discriminator loss
            self.disc_weight = 0.5
            self.disc_weight_min = 0.0  # Weight for discriminator loss
            self.disc_adaptive_weight = False  # Enable adaptive weighting
            self.reconstruction_weight = 1.0  # Weight for reconstruction loss
            self.codebook_weight = 1.0  # Weight for codebook loss
            self.biome_weight = 1.0  # Weight for biome feature prediction
            self.disentanglement_ratio = 0.5  # Weight for disentanglement loss
            
            # Codebook specific parameters
            self.quantizer_type = 'ema'
            self.beta = 0.5  # Commitment loss coefficient
            # self.entropy_loss_ratio = 0.05  # For codebook entropy regularization
            self.entropy_loss_ratio = 0.2
            self.codebook_l2_norm = True  # Whether to L2 normalize codebook entries
            self.codebook_show_usage = True  # Track codebook usage statistics
            self.ema_decay = 0.99
            
            # Training parameters
            self.lr = 1e-4  # Learning rate
            self.beta1 = 0.9  # Adam beta1
            self.beta2 = 0.95  # Adam beta2
            self.disc_layers = 1  # Number of discriminator layers
            self.train_steps = 15000
            self.disc_start_step = 15000  # Step to start discriminator training

            self.start_step = 0
            
            self.transformer_dim = self.emb_dim  # Make transformer dim match embedding dim
            self.num_heads = 8  # Number of attention heads
            
            # Feature prediction parameters
            self.with_biome_supervision = False  # Enable biome feature prediction
            self.with_disentanglement = True  # Enable disentanglement loss
            
            # Logging parameters (if not already present)
            self.steps_per_log = 150
            self.steps_per_checkpoint = 1000
            self.steps_per_display_output = 500
            self.steps_per_save_output = 500
            self.steps_per_validation = 150
            self.val_samples_to_save = 16
            self.val_samples_to_display = 4
            self.visdom_port = 8097
            
            # Two stage decoder stuff
            self.binary_reconstruction_weight = 3
            self.two_stage_decoder = True
            self.use_dumb_decoder = False
            self.combine_method = 'concat'
            self.detach_binary_recon = True

            # Weighted recon loss:
            self.block_weighting = True
            self.weighted_block_amount = 3.0
            self.weighted_block_indices = []
             
            self.num_biomes = 11  # Number of biome classes
            self.biome_feat_dim = 256  # Dimension of biome features
            self.biome_classifier_path = 'best_biome_classifier_airprocessed.pt'

            self.disc_gumbel = True
            self.gumbel_tau = 1
            self.gumbel_hard = True
            self.disc_argmax_ste = False

            self.padding_mode = 'reflect'

            self.weight_decay = 0.01
        else:
            raise KeyError(f'Defaults not defined for dataset: {self.dataset}')

# Validation helper

In [27]:
def encode_and_quantize(fqgan, terrain_chunks, device='cuda'):
    """Memory-efficient encoding function"""
    fqgan.eval()
    with torch.no_grad():
        # Move input to device
        terrain_chunks = terrain_chunks.to(device)
        
        # Get encodings
        h_style, h_struct = fqgan.encoder(terrain_chunks)
        
        # Process style path
        h_style = fqgan.quant_conv_style(h_style)
        quant_style, _, style_stats = fqgan.quantize_style(h_style)
        style_indices = style_stats[2]  # Get indices from tuple
        style_indices = style_indices.view(
            (h_style.size()[0], h_style.size()[2], h_style.size()[3], h_style.size()[4])
        )
        
        # Clear intermediate tensors
        del h_style, quant_style, style_stats
        
        # Process structure path
        h_struct = fqgan.quant_conv_struct(h_struct)
        quant_struct, _, struct_stats = fqgan.quantize_struct(h_struct)
        struct_indices = struct_stats[2]  # Get indices from tuple
        struct_indices = struct_indices.view(
            (h_struct.size()[0], h_struct.size()[2], h_struct.size()[3], h_struct.size()[4])
        )
        
        # Clear intermediate tensors
        del h_struct, quant_struct, struct_stats
        
        # Move indices to CPU to save GPU memory
        style_indices = style_indices.cpu()
        struct_indices = struct_indices.cpu()
        
        torch.cuda.empty_cache()
        
        return style_indices, struct_indices

# Takes style and structure indices, returns the reconstructed map
def decode_from_indices(style_indices, struct_indices, fqgan, device='cuda', two_stage=False, return_raw=False):
    """Memory-efficient decoding function"""
    with torch.no_grad():
        # Move indices to device only when needed
        style_indices = style_indices.to(device)
        struct_indices = struct_indices.to(device)
        
        # Get quantized vectors
        quant_style = fqgan.quantize_style.get_codebook_entry(
            style_indices.view(-1),
            shape=[1, fqgan.embed_dim, *style_indices.shape[1:]]
        )
        quant_struct = fqgan.quantize_struct.get_codebook_entry(
            struct_indices.view(-1),
            shape=[1, fqgan.embed_dim, *struct_indices.shape[1:]]
        )
        
        # Clear indices from GPU
        del style_indices, struct_indices
        
        # Combine and decode
        quant = torch.cat([quant_struct, quant_style], dim=1)
        # quant = quant_style + quant_struct
        del quant_style, quant_struct
        
        if two_stage:
            decoded, binary_decoded = fqgan.decoder(quant)
        else:
            decoded = fqgan.decoder(quant)
        
        del quant

        if return_raw:
            return decoded
        
        # Convert to block IDs if one-hot encoded
        if decoded.shape[1] > 1:
            decoded = torch.argmax(decoded, dim=1)
        
        # Move result to CPU and clear GPU memory
        result = decoded.squeeze(0).cpu()
        if two_stage:
            binary_result = binary_decoded.squeeze(0).cpu()
            del decoded
            del binary_decoded
            torch.cuda.empty_cache()
            return result, binary_result
        
        del decoded
        torch.cuda.empty_cache()
        
        return result

# def validate_and_save(model, plotted_loss_terms, vq_loss, val_loader, block_converter, visualizer, vis, step, H):
#     """Run validation and save results"""
#     model.eval()
#     val_losses = {term: [] for term in plotted_loss_terms}
#     val_reconstructions = []
#     val_originals = []
    
#     with torch.no_grad():
#         for batch in val_loader:
#             if isinstance(batch, list):
#                 x = batch[0]
#             else:
#                 x = batch
#             x = x.cuda()
            
#             # Forward pass
#             recons, codebook_loss_style, codebook_loss_struct, disentangle_loss, biome_feat = model(x)
            
#             # Calculate losses
#             loss_dict = vq_loss(
#                 codebook_loss_style, codebook_loss_struct,
#                 x, recons, disentangle_loss, biome_feat,
#                 optimizer_idx=0,
#                 global_step=step
#             )
            
#             # Store losses
#             for term in plotted_loss_terms:
#                 if term in loss_dict:
#                     val_losses[term].append(loss_dict[term].item() if torch.is_tensor(loss_dict[term]) else loss_dict[term])
            
#             # Store reconstructions and originals
#             val_reconstructions.append(recons.cpu())
#             val_originals.append(x.cpu())
            
#             # Break only after we have enough samples
#             total_samples = sum(r.size(0) for r in val_reconstructions)
#             if total_samples >= H.val_samples_to_save:
#                 break
    
#     # Concatenate all samples
#     val_reconstructions = torch.cat(val_reconstructions, dim=0)
#     val_originals = torch.cat(val_originals, dim=0)
    
#     # Trim to exact number of samples needed
#     val_reconstructions = val_reconstructions[:H.val_samples_to_save]
#     val_originals = val_originals[:H.val_samples_to_save]
    
#     val_data = {
#         'reconstructions': val_reconstructions,
#         'originals': val_originals,
#         'losses': {k: np.mean(v) if v else 0 for k, v in val_losses.items()},
#         'step': step
#     }
    
#     # Display samples
#     display_samples = min(H.val_samples_to_display, val_reconstructions.size(0))
#     orig_blocks = block_converter.convert_to_original_blocks(val_data['originals'][:display_samples])
#     recon_blocks = block_converter.convert_to_original_blocks(val_data['reconstructions'][:display_samples])
    
#     log_dir = f"../model_logs/{H.log_dir}/val_images"
#     os.makedirs(log_dir, exist_ok=True)
    
#     # Display and save validation visualizations
#     display_minecraft_pyvista(
#         vis, visualizer, orig_blocks,
#         win_name='Val Original Maps',
#         title=f'Val Original Maps step {step}',
#         save_path=f"{log_dir}/val_orig_{step}.png"
#     )
#     display_minecraft_pyvista(
#         vis, visualizer, recon_blocks,
#         win_name='Val Reconstructed Maps',
#         title=f'Val Reconstructed Maps step {step}',
#         save_path=f"{log_dir}/val_recon_{step}.png"
#     )
    
#     # print(f"Batch size: {H.batch_size}")
#     # print(f"Val reconstructions shape: {val_reconstructions.shape}")
#     # print(f"Val originals shape: {val_originals.shape}")

#     # Plot validation losses in Visdom
#     for term in plotted_loss_terms:
#         if term in val_losses and val_losses[term]:  # Only plot if we have values
#             vis.line(
#                 Y=np.array([np.mean(val_losses[term])]),  # Take mean of batch losses
#                 X=np.array([step]),
#                 win=f'val_{term}_plot',
#                 update='append',
#                 opts=dict(title=f'Validation {term} Loss')
#             )
    
#     model.train()
#     return val_losses

def validate_and_save(model, plotted_loss_terms, vq_loss, val_loader, block_converter, visualizer, vis, step, H, visual=False):
    """Run validation and save results"""
    model.eval()
    val_losses = {term: [] for term in plotted_loss_terms}
    val_reconstructions = []
    val_originals = []
    
    with torch.no_grad():
        for batch in val_loader:
            if isinstance(batch, list):
                x = batch[0]
            else:
                x = batch
            x = x.cuda()
            torch.cuda.synchronize() # Add this

            binary_target = (torch.argmax(x, dim=1) != air_idx).float().cuda()
            torch.cuda.synchronize() # Add this
            
            # Forward pass
            h_style_raw, h_struct_raw = model.encoder(x)
            torch.cuda.synchronize() # Add this

            h_style = model.quant_conv_style(h_style_raw)
            torch.cuda.synchronize() # Add this
            h_struct = model.quant_conv_struct(h_struct_raw)
            torch.cuda.synchronize() # Add this

            quant_style, emb_loss_style, indices_style = model.quantize_style(h_style)
            torch.cuda.synchronize() # Add this
            quant_struct, emb_loss_struct, indices_struct = model.quantize_struct(h_struct)
            torch.cuda.synchronize() # Add this

            quant = torch.cat([quant_struct, quant_style], dim=1)
            torch.cuda.synchronize() # Add this
        
            if model.two_stage_decoder:
                dec_logits, binary_out = model.decoder(quant)
            else:
                dec_logits = model.decoder(quant)
            torch.cuda.synchronize() # Add this

            weights_on_device = vq_loss.weights.to('cuda')
            torch.cuda.synchronize() # Add this

            rec_loss = F.cross_entropy(
                dec_logits.contiguous(),
                torch.argmax(x, dim=1).contiguous().long(), 
                weight=weights_on_device 
            ) * vq_loss.reconstruction_weight
            torch.cuda.synchronize() # Add this

            # Binary reconstruction loss (if in two-stage mode)
            binary_recon_loss = 0.0
            if binary_out is not None and binary_target is not None:
                binary_recon_loss = F.binary_cross_entropy(
                    binary_out.squeeze(1),
                    binary_target
                ) * vq_loss.binary_recon_weight
            torch.cuda.synchronize() # Add this
            
            
            val_losses['rec_loss'].append(rec_loss.item() if torch.is_tensor(rec_loss) else rec_loss)
            val_losses['binary_rec_loss'].append(binary_recon_loss.item() if torch.is_tensor(rec_loss) else rec_loss)
            # Store reconstructions and originals
            val_reconstructions.append(dec_logits.cpu())
            val_originals.append(x.cpu())
            
            # Break only after we have enough samples
            total_samples = sum(r.size(0) for r in val_reconstructions)
            if total_samples >= H.val_samples_to_save:
                break
    
    # Concatenate all samples
    val_reconstructions = torch.cat(val_reconstructions, dim=0)
    val_originals = torch.cat(val_originals, dim=0)
    
    # Trim to exact number of samples needed
    val_reconstructions = val_reconstructions[:H.val_samples_to_save]
    val_originals = val_originals[:H.val_samples_to_save]
    
    val_data = {
        'reconstructions': val_reconstructions,
        'originals': val_originals,
        'losses': {k: np.mean(v) if v else 0 for k, v in val_losses.items()},
        'step': step
    }
    
    if visual:
        # Display samples
        display_samples = min(H.val_samples_to_display, val_reconstructions.size(0))
        orig_blocks = block_converter.convert_to_original_blocks(val_data['originals'][:display_samples])
        recon_blocks = block_converter.convert_to_original_blocks(val_data['reconstructions'][:display_samples])
        
        log_dir = f"../model_logs/{H.log_dir}/val_images"
        os.makedirs(log_dir, exist_ok=True)
        
        # Display and save validation visualizations
        display_minecraft_pyvista(
            vis, visualizer, orig_blocks,
            win_name='Val Original Maps',
            title=f'Val Original Maps step {step}',
            save_path=f"{log_dir}/val_orig_{step}.png"
        )
        display_minecraft_pyvista(
            vis, visualizer, recon_blocks,
            win_name='Val Reconstructed Maps',
            title=f'Val Reconstructed Maps step {step}',
            save_path=f"{log_dir}/val_recon_{step}.png"
        )
    
    # print(f"Batch size: {H.batch_size}")
    # print(f"Val reconstructions shape: {val_reconstructions.shape}")
    # print(f"Val originals shape: {val_originals.shape}")

    # Plot validation losses in Visdom
    for term in plotted_loss_terms:
        if term in val_losses and val_losses[term]:  # Only plot if we have values
            vis.line(
                Y=np.array([np.mean(val_losses[term])]),  # Take mean of batch losses
                X=np.array([step]),
                win=f'val_{term}_plot',
                update='append',
                opts=dict(title=f'Validation {term} Loss')
            )
    
    del val_data, val_reconstructions, val_originals
    model.train()
    return val_losses

# Train

In [28]:
# file for running the training of the VQGAN
import torch
import numpy as np
import copy
import time
import random
from log_utils import log, log_stats, save_model, save_stats, save_images, save_maps, \
                            display_images, set_up_visdom, config_log, start_training_log, log_hparams_to_json
import visdom


In [29]:
import gc

gc.collect()

torch.cuda.empty_cache()

In [30]:
H = HparamsFQGAN(dataset='minecraft')
H.log_dir = 'FQGAN_2stagedecoder_logweighted3_32stru16sty_notdumb_001wd2'
H.load_dir = 'FQGAN_2stagedecoder_logweighted3_32stru16sty_notdumb_001wd2'

vis = visdom.Visdom(port=H.visdom_port)
config_log(H.log_dir)
log('---------------------------------')
log(f'Setting up training for VQGAN on {H.dataset}')

start_training_log(H)
log_hparams_to_json(H, H.log_dir)



Setting up a new session...


---------------------------------
Setting up training for VQGAN on minecraft
Using following hparams:
> attn_resolutions: [6]
> batch_size: 8
> beta: 0.5
> beta1: 0.9
> beta2: 0.95
> binary_reconstruction_weight: 3
> biome_classifier_path: best_biome_classifier_airprocessed.pt
> biome_feat_dim: 256
> biome_weight: 1.0
> block_weighting: True
> ch_mult: [1, 2, 4]
> codebook_l2_norm: True
> codebook_show_usage: True
> codebook_weight: 1.0
> combine_method: concat
> cycle_consistency_type: post_quant_conv
> cycle_consistency_weight: 0.0
> cycle_start_step: 0
> dataset: minecraft
> detach_binary_recon: True
> disc_adaptive_weight: False
> disc_argmax_ste: False
> disc_gumbel: True
> disc_gumbel_for_cycle_input: True
> disc_layers: 1
> disc_start_step: 15000
> disc_type: conv
> disc_weight: 0.5
> disc_weight_max: 1.0
> disc_weight_min: 0.0
> disentanglement_ratio: 0.5
> ema_decay: 0.99
> emb_dim: 8
> entropy_loss_ratio: 0.2
> gumbel_hard: True
> gumbel_tau: 1
> img_size: 24
> latent_shape: 

## Load dataset

## Train loop

In [31]:
data_path = '../../text2env/data/24_newdataset_processed_cleaned3.pt'
mappings_path = '../../text2env/data/24_newdataset_mappings3.pt'
# visualizer = MinecraftVisualizer()
visualizer = MinecraftVisualizerPyVista()
train_loader, val_loader = get_minecraft_dataloaders(
    data_path,
    batch_size=H.batch_size,
    num_workers=0,
    val_split=0.1,
    # save_val_path=f'../../text2env/data/{H.log_dir}_valset.pt'
)
block_converter = BlockBiomeConverter.load_mappings(mappings_path)
air_idx = block_converter.get_air_block_index()
water_idx = block_converter.get_water_block_index()

Loading pre-processed data...
Loaded 11119 chunks of size torch.Size([42, 24, 24, 24])
Number of unique block types: 42
Unique blocks: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41]

Dataloader details:
Training samples: 10008
Validation samples: 1111
Batch size: 8
Training batches: 1251
Validation batches: 139


### Block weighting for cross entropy

In [32]:
# Trying just logs first
log_idxs = block_converter.get_blockid_indices([131, 132])
print(f'log indices: {log_idxs}')
H.weighted_block_indices = log_idxs

log indices: [23, 24]


In [33]:
from fq_models import FQModel, VQLossDualCodebook


In [34]:
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

In [None]:
def cycle(iterable):
    while True:
        for x in iterable:
            yield x

# Initialize model and loss
vqgan = FQModel(H).cuda()
vq_loss = VQLossDualCodebook(H).cuda()

train_iterator = cycle(train_loader)
if val_loader is not None:
    val_iterator = cycle(val_loader)

# Setup optimizers - one for generator/encoder, one for discriminator
optimizer_g = torch.optim.Adam(
    vqgan.parameters(), 
    lr=H.lr,
    betas=(H.beta1, H.beta2),
    weight_decay=H.weight_decay
)
optimizer_d = torch.optim.Adam(
    vq_loss.discriminator.parameters(),
    lr=H.lr,
    betas=(H.beta1, H.beta2),
    weight_decay=H.weight_decay
)

print(f'Using weight decay {H.weight_decay}')
# Initialize loss tracking
plotted_loss_terms = [
    'rec_loss', 'binary_rec_loss', 'style_loss', 'struct_loss', 
    'biome_feat_loss', 'disent_loss', 'g_loss', 'd_loss', 'struct_consistency_loss', 'cycle_consistency_loss'
]
loss_arrays = {term: np.array([]) for term in plotted_loss_terms}
codebook_usage = {
    'style': np.array([]),
    'struct': np.array([])
}

# Training loop
for step in range(0, H.train_steps):
    step_start_time = time.time()
    batch = next(train_iterator)
    
    if isinstance(batch, list):
        x = batch[0]
    else:
        x = batch
    x = x.cuda()

    
    
    # Forward pass through model
    # recons, codebook_loss_style, codebook_loss_struct, disentangle_loss, biome_feat = vqgan(x)
    binary_target = (torch.argmax(x, dim=1) != air_idx).float().cuda()
    # Forward pass through model
    if H.two_stage_decoder:
        
        recons, binary_out, codebook_loss_style, codebook_loss_struct, disentangle_loss, biome_feat, struct_consistency, cycle_consistency = vqgan(x)
    else:
        recons, codebook_loss_style, codebook_loss_struct, disentangle_loss, biome_feat, struct_consistency. cycle_consistency = vqgan(x)
        binary_out = None

    
    # Generator/Encoder update
    optimizer_g.zero_grad()
    # Calculate generator losses
    loss_dict = vq_loss(
        codebook_loss_style, codebook_loss_struct,
        x, recons, disentangle_loss, biome_feat,
        optimizer_idx=0,
        global_step=step,
        binary_out=binary_out,
        binary_target=binary_target,
        struct_consistency_loss = struct_consistency,
        cycle_consistency_loss=cycle_consistency
    )
        
        
    loss_dict['loss'].backward()
    torch.nn.utils.clip_grad_norm_(vqgan.parameters(), max_norm=1.0)
    optimizer_g.step()

    
    # if step % 100 == 0:  # Every 100 steps
    #     # Check if style features require grad
    #     print("Style features require grad:", biome_feat.requires_grad)
        
    #     # Check gradient magnitudes after loss.backward()
    #     style_grads = biome_feat.grad
    #     if style_grads is not None:
    #         print("Style features grad magnitude:", style_grads.abs().mean().item())
    #     else:
    #         print("No gradients for style features!")

    # Initialize d_loss_dict
    d_loss_dict = {}

    # Discriminator update
    if step >= H.disc_start_step:
        optimizer_d.zero_grad()
        
        # Calculate discriminator losses
        d_loss_dict = vq_loss(
            codebook_loss_style, codebook_loss_struct,
            x, recons, disentangle_loss, biome_feat,
            optimizer_idx=1,
            global_step=step,
            binary_out=binary_out,
            binary_target=binary_target
        )

        # print(f"Final d_loss: {d_loss_dict}")
        
        d_loss_dict['d_loss'].backward()
        torch.nn.utils.clip_grad_norm_(vq_loss.discriminator.parameters(), max_norm=1.0)
        optimizer_d.step()

    # Logging
    if step % H.steps_per_log == 0:
        print(f"\nStep {step}:")
        # Update loss arrays - handle both tensor and float losses
        for term in plotted_loss_terms:
            if term in loss_dict:
                value = loss_dict[term].item() if torch.is_tensor(loss_dict[term]) else loss_dict[term]
                loss_arrays[term] = np.append(loss_arrays[term], value)
                print(f"{term}: {value:.4f}")
            elif term == 'd_loss':
                # Add 0.0 for d_loss before discriminator starts training
                if step >= H.disc_start_step:
                    value = d_loss_dict['d_loss'].item() if torch.is_tensor(d_loss_dict['d_loss']) else d_loss_dict['d_loss']
                else:
                    value = 0.0
                loss_arrays[term] = np.append(loss_arrays[term], value)
                print(f"{term}: {value:.4f}")

        # Update codebook usage tracking
        style_usage = loss_dict['codebook_usage_style'] if torch.is_tensor(loss_dict['codebook_usage_style']) else loss_dict['codebook_usage_style']
        struct_usage = loss_dict['codebook_usage_struct'] if torch.is_tensor(loss_dict['codebook_usage_struct']) else loss_dict['codebook_usage_struct']
        
        codebook_usage['style'] = np.append(codebook_usage['style'], style_usage)
        codebook_usage['struct'] = np.append(codebook_usage['struct'], struct_usage)
        print(f"Codebook Usage - Style: {style_usage:.2f}%, Structure: {struct_usage:.2f}%")
        # Plot in Visdom
        x_axis = list(range(0, step+1, H.steps_per_log))
        
        # Individual loss plots
        for term in plotted_loss_terms:
            if len(loss_arrays[term]) > 0:
                vis.line(
                    loss_arrays[term],
                    x_axis,
                    win=f'{term}_plot',
                    opts=dict(title=f'{term} Loss')
                )

        # Combined loss plot
        if len(x_axis) > 1:
            vis.line(
                Y=np.column_stack([loss_arrays[term] for term in plotted_loss_terms if len(loss_arrays[term]) > 0]),
                X=np.column_stack([x_axis for _ in plotted_loss_terms if len(loss_arrays[_]) > 0]),
                win='all_losses',
                opts=dict(title='All Losses', legend=[t for t in plotted_loss_terms if len(loss_arrays[t]) > 0])
            )

        # Codebook usage plot
        vis.line(
            Y=np.column_stack([codebook_usage['style'], codebook_usage['struct']]),
            X=np.column_stack([x_axis, x_axis]),
            win='codebook_usage',
            opts=dict(title='Codebook Usage', legend=['Style Codebook', 'Structure Codebook'])
        )

    # Visualization of reconstructions
    if step % H.steps_per_display_output == 0 or step == H.train_steps - 1:
        print("Rendering...")
        # Convert indices back to original block IDs for visualization
        orig_blocks = block_converter.convert_to_original_blocks(x)
        recon_blocks = block_converter.convert_to_original_blocks(recons)
        
        if step % H.steps_per_save_output == 0:
            log_dir = f"../model_logs/{H.log_dir}/images"
            os.makedirs(log_dir, exist_ok=True)

            display_minecraft_pyvista(vis, visualizer, orig_blocks, win_name='Original Maps', title=f'Original Maps step {step}', save_path=f"{log_dir}/orig_{step}.png", nrow=8)
            display_minecraft_pyvista(vis, visualizer, recon_blocks, win_name='Reconstructed Maps', title=f'Reconstructed Maps step {step}', save_path=f"{log_dir}/recon_{step}.png", nrow=8)
        else:
            display_minecraft_pyvista(vis, visualizer, orig_blocks, win_name='Original Maps', title=f'Original Maps step {step}', nrow=8)
            display_minecraft_pyvista(vis, visualizer, recon_blocks, win_name='Reconstructed Maps', title=f'Reconstructed Maps step {step}', nrow=8)
        print("Done Rendering")

    # Run validation
    # if step % H.steps_per_log == 0:
    #     val_losses = validate_and_save(
    #         vqgan, plotted_loss_terms, vq_loss, val_loader, block_converter,
    #         visualizer, vis, step, H, visual=(step % H.steps_per_display_output == 0)
    #     )
        # print(f'Validation losses: {val_losses}')
    # Save checkpoints
    if (step % H.steps_per_checkpoint == 0 and step > 0) or step == H.train_steps - 1:
        save_model(vqgan, 'fqgan', step, H.log_dir)
        save_model(optimizer_g, 'optimizer_g', step, H.log_dir)
        save_model(optimizer_d, 'optimizer_d', step, H.log_dir)

        train_stats = {
            'losses': loss_arrays,
            'codebook_usage': codebook_usage,
            'steps_per_log': H.steps_per_log,
        }
        save_stats(H, train_stats, step)



using padding mode: reflect
With cycle consistency: True type: post_quant_conv, using gumbel: True
Using EMA quantizer
Using TwoStageGenerator
Detaching binary reconstruction from comp graph for final loss
NO biome supervision
Disentangle Ratio:  0.5
With cycle consistency: True weight: 0.0
Using gumbell sampling for disc input
Applying block weighting: amount=3.0 to indices=[23, 24]
Using weight decay 0.01


  attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)


DEBUG: binary_out.squeeze(1) min: 0.06264667212963104, max: 0.9233788847923279

Step 0:
rec_loss: 3.8585
binary_rec_loss: 2.1162
style_loss: 0.0000
struct_loss: 0.0000
biome_feat_loss: 0.0000
disent_loss: 0.0212
g_loss: 0.0619
d_loss: 0.0000
struct_consistency_loss: 0.0000
cycle_consistency_loss: 0.0000
Codebook Usage - Style: 0.12%, Structure: 0.19%
Rendering...
Done Rendering
DEBUG: binary_out.squeeze(1) min: 0.06669817864894867, max: 0.8790794014930725
DEBUG: binary_out.squeeze(1) min: 0.16734473407268524, max: 0.9539996981620789
DEBUG: binary_out.squeeze(1) min: 0.09365185350179672, max: 0.9614930748939514
DEBUG: binary_out.squeeze(1) min: 0.06274798512458801, max: 0.9799028038978577
DEBUG: binary_out.squeeze(1) min: 0.051339272409677505, max: 0.9837033748626709
DEBUG: binary_out.squeeze(1) min: 0.05547556281089783, max: 0.9668311476707458
DEBUG: binary_out.squeeze(1) min: 0.07075825333595276, max: 0.943267822265625
DEBUG: binary_out.squeeze(1) min: 0.06697940826416016, max: 0.9393

In [35]:
import matplotlib.pyplot as plt
import json
import os
import numpy as np

# --- Configuration ---
# These variables (H, loss_arrays, codebook_usage, plotted_loss_terms, vis)
# are assumed to be available from the completed training cell.
# If any are not available, you might need to re-run the training cell
# or load them if they were saved.

# Ensure visdom client is available if you expect to save validation plots
# if 'vis' not in globals():
#     print("Warning: Visdom client 'vis' not found. Validation plots cannot be fetched.")
    # You might need to re-initialize it if the kernel restarted, e.g.:
    # import visdom
    # vis = visdom.Visdom(port=H.visdom_port if 'H' in globals() else 8097)


# --- Directory Setup ---
plots_dir = f"../model_logs/{H.log_dir}/loss_plots"
os.makedirs(plots_dir, exist_ok=True)
print(f"Saving plots to: {os.path.abspath(plots_dir)}")


if plots_dir:
    # --- Calculate X-axis for Training Plots ---
    # Check if loss_arrays is populated
    train_x_axis = np.array([])
    if 'plotted_loss_terms' in globals() and plotted_loss_terms and \
       'loss_arrays' in globals() and plotted_loss_terms[0] in loss_arrays and \
       len(loss_arrays[plotted_loss_terms[0]]) > 0 and \
       'H' in globals() and H.steps_per_log > 0:
        
        num_log_points = len(loss_arrays[plotted_loss_terms[0]])
        train_x_axis = np.arange(num_log_points) * H.steps_per_log
    else:
        print("Warning: Could not generate x-axis for training plots. Necessary variables (H, loss_arrays, plotted_loss_terms) might be missing or empty, or H.steps_per_log is zero.")

    # --- 1. Save Individual Training Loss Plots ---
    if train_x_axis.size > 0 and 'loss_arrays' in globals() and 'plotted_loss_terms' in globals():
        print("\nSaving individual training loss plots...")
        for term in plotted_loss_terms:
            if term in loss_arrays and len(loss_arrays[term]) > 0:
                plt.figure(figsize=(10, 6))
                plt.plot(train_x_axis, loss_arrays[term], label=term)
                plt.xlabel("Training Steps")
                plt.ylabel("Loss")
                plt.title(f"Training Loss: {term}")
                plt.legend()
                plt.grid(True)
                plt.savefig(os.path.join(plots_dir, f"{term}_loss.png"))
                plt.close()
            else:
                print(f"Skipping individual plot for '{term}', no data found.")
        print("Done.")

    # --- 2. Save Combined Training Loss Plot ---
    if train_x_axis.size > 0 and 'loss_arrays' in globals() and 'plotted_loss_terms' in globals():
        print("\nSaving combined training loss plot...")
        plt.figure(figsize=(12, 7))
        plotted_any_combined = False
        for term in plotted_loss_terms:
            if term in loss_arrays and len(loss_arrays[term]) > 0:
                plt.plot(train_x_axis, loss_arrays[term], label=term)
                plotted_any_combined = True
        
        if plotted_any_combined:
            plt.xlabel("Training Steps")
            plt.ylabel("Loss")
            plt.title("All Training Losses")
            plt.legend()
            plt.grid(True)
            plt.savefig(os.path.join(plots_dir, "all_training_losses.png"))
            plt.close()
            print("Done.")
        else:
            print("Skipping combined training loss plot, no data found.")


    # --- 3. Save Codebook Usage Plot ---
    if train_x_axis.size > 0 and 'codebook_usage' in globals() and \
       'style' in codebook_usage and len(codebook_usage['style']) > 0 and \
       'struct' in codebook_usage and len(codebook_usage['struct']) > 0:
        print("\nSaving codebook usage plot...")
        plt.figure(figsize=(10, 6))
        plt.plot(train_x_axis, codebook_usage['style'], label="Style Codebook Usage (%)")
        plt.plot(train_x_axis, codebook_usage['struct'], label="Structure Codebook Usage (%)")
        plt.xlabel("Training Steps")
        plt.ylabel("Usage (%)")
        plt.title("Codebook Usage Over Training")
        plt.legend()
        plt.grid(True)
        plt.savefig(os.path.join(plots_dir, "codebook_usage.png"))
        plt.close()
        print("Done.")
    else:
        print("Skipping codebook usage plot, no data found.")

    # --- 4. Save Validation Loss Plots (fetching from Visdom) ---
    if 'vis' in globals() and 'plotted_loss_terms' in globals():
        print("\nAttempting to save validation loss plots from Visdom...")
        print("Note: This requires the 'validate_and_save' function to have been active during training and successfully plotted to Visdom.")
        saved_val_plots_count = 0
        for term in plotted_loss_terms:
            win_name = f'val_{term}_plot'
            try:
                # Use the visdom environment that was used for plotting.
                # If vis.env was not explicitly set, it defaults to 'main'.
                env = vis.env if hasattr(vis, 'env') else 'main'
                window_data_json = vis.get_window_data(win=win_name, env=env)
                
                if window_data_json:
                    data = json.loads(window_data_json)
                    if 'content' in data and 'data' in data['content'] and data['content']['data']:
                        # Visdom plots can have multiple traces,
                        # but for val_loss, it's usually one.
                        trace = data['content']['data'][0] 
                        x_val = trace.get('x')
                        y_val = trace.get('y')

                        if x_val and y_val:
                            plt.figure(figsize=(10, 6))
                            plt.plot(x_val, y_val, label=f"Validation {term}", marker='o', linestyle='-')
                            plt.xlabel("Training Steps")
                            plt.ylabel("Loss")
                            plt.title(f"Validation Loss: {term}")
                            plt.legend()
                            plt.grid(True)
                            plt.savefig(os.path.join(plots_dir, f"val_{term}_loss.png"))
                            plt.close()
                            # print(f"Successfully saved validation plot for {term}.")
                            saved_val_plots_count += 1
                        else:
                            # print(f"No x/y data in trace for validation plot {term} (window: {win_name}).")
                            pass
                    else:
                        # print(f"Content or data field missing for validation plot {term} in Visdom window {win_name}.")
                        pass
                else:
                    # print(f"No data returned from Visdom for window {win_name}. It might not exist.")
                    pass
            except Exception as e:
                # print(f"Could not retrieve or plot validation data for {term} (window: {win_name}): {e}")
                pass
        
        if saved_val_plots_count > 0:
            print(f"Finished saving {saved_val_plots_count} validation loss plot(s).")
        else:
            print("No validation loss plots were found or saved from Visdom.")
    else:
        print("\nSkipping validation plots: Visdom client 'vis' or 'plotted_loss_terms' not available.")

    print("\nAll plotting attempts complete.")

Saving plots to: c:\Users\TimBits\Documents\minecraft_new_project\model_logs\FQGAN_2stagedecoder_logweighted3_32codes_cycleconsistency_postquant_detachcycle_gumbel_decay1_3\loss_plots

Saving individual training loss plots...
Done.

Saving combined training loss plot...
Done.

Saving codebook usage plot...
Done.

Attempting to save validation loss plots from Visdom...
Note: This requires the 'validate_and_save' function to have been active during training and successfully plotted to Visdom.
No validation loss plots were found or saved from Visdom.

All plotting attempts complete.
