In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import os
from pathlib import Path
from visualization_utils import MinecraftVisualizerPyVista

# Biome and terrain visualizer

In [2]:
blocks_to_cols = {
            0: (0.5, 0.25, 0.0),    # light brown
            10: 'black', # bedrock
            29: "#006400", # cacutus
            38: "#B8860B",  # clay
            60: "brown",  # dirt
            92: "gold",  # gold ore
            93: "green",  # grass
            115: "brown",  # ladder...?
            119: (.02, .28, .16, 0.8),  # transparent forest green (RGBA) for leaves
            120: (.02, .28, .16, 0.8),  # leaves2
            194: "yellow",  # sand
            217: "gray",  # stone
            240: (0.0, 0.0, 1.0, 0.4),  # water
            227: (0.0, 1.0, 0.0, .3), # tall grass
            237: (0.33, 0.7, 0.33, 0.3), # vine
            40: "#2F4F4F",  # coal ore
            62: "#228B22",  # double plant
            108: "#BEBEBE",  # iron ore
            131: "saddlebrown",  # log1
            132: "saddlebrown",  #log2
            95: "lightgray",  # gravel
            243: "wheat",  # wheat. lmao
            197: "limegreen",  # sapling
            166: "orange",  #pumpkin
            167: "#FF8C00",  # pumpkin stem
            184: "#FFA07A",  # red flower
            195: "tan",  # sandstone
            250: "white",  #wool 
            251: "gold",   #yellow flower
        }

def draw_latent_cuboid(fig, latent_coords, size=4):
    """
    Draw a transparent cuboid around the specified latent coordinates.
    
    Args:
        fig: matplotlib figure to draw on
        latent_coords: list of tuples, each containing (d,h,w) coordinates
        size: size of each latent cell in final space (default 4 for 6->24 upscaling)
    """
    def cuboid_data(o, sizes):
        l, w, h = sizes
        x = [[o[0], o[0] + l, o[0] + l, o[0], o[0]],  
             [o[0], o[0] + l, o[0] + l, o[0], o[0]],  
             [o[0], o[0] + l, o[0] + l, o[0], o[0]],  
             [o[0], o[0] + l, o[0] + l, o[0], o[0]]]  
        y = [[o[1], o[1], o[1] + w, o[1] + w, o[1]],  
             [o[1], o[1], o[1] + w, o[1] + w, o[1]],  
             [o[1], o[1], o[1], o[1], o[1]],          
             [o[1] + w, o[1] + w, o[1] + w, o[1] + w, o[1] + w]]
        z = [[o[2], o[2], o[2], o[2], o[2]],          
             [o[2] + h, o[2] + h, o[2] + h, o[2] + h, o[2] + h],
             [o[2], o[2], o[2] + h, o[2] + h, o[2]],  
             [o[2], o[2], o[2] + h, o[2] + h, o[2]]]  
        return np.array(x), np.array(y), np.array(z)

    ax = fig.gca()
    
    # Convert coordinates to numpy array for easier manipulation
    coords = np.array(latent_coords)
    
    # Find min and max for each dimension
    d_min, h_min, w_min = coords.min(axis=0)
    d_max, h_max, w_max = coords.max(axis=0)
    
    # Calculate origin and sizes
    origin = np.array([abs(5 - d_max)*size, w_min*size, h_min*size])
    sizes = (
        abs(d_max - d_min + 1) * size,  # length
        (w_max - w_min + 1) * size,     # width
        (h_max - h_min + 1) * size      # height
    )
    
    # Create and draw single cuboid
    X, Y, Z = cuboid_data(origin, sizes)
    ax.plot_surface(X, Y, Z, color='red', alpha=0.1)
    
    # Plot edges
    for i in range(4):
        ax.plot(X[i], Y[i], Z[i], color='red', linewidth=1)
    for i in range(4):
        ax.plot([X[0][i], X[1][i]], [Y[0][i], Y[1][i]], [Z[0][i], Z[1][i]], 
               color='red', linewidth=2)
    
    return fig

def visualize_chunk(voxels, figsize=(10, 10), elev=20, azim=45, highlight_latents=None):
    """
    Optimized version of the 3D visualization of a Minecraft chunk.
    """
    import matplotlib.pyplot as plt
    import numpy as np

    # Convert one-hot to block IDs if needed
    if isinstance(voxels, torch.Tensor):
        if voxels.dim() == 4:  # One-hot encoded [C,H,W,D]
            voxels = voxels.detach().cpu()
            voxels = torch.argmax(voxels, dim=0).numpy()
        else:
            voxels = voxels.detach().cpu().numpy()

    # Apply the same transformations as original
    voxels = voxels.transpose(2, 0, 1) # Moves axes from [D,H,W] to [W,D,H]
    voxels = np.rot90(voxels, 1, (0, 1))  # Rotate 90 degrees around height axis
    # print([block_id for block_id in np.unique(voxels) if block_id not in blocks_to_cols])
    # Create figure and 3D axis
    fig = plt.figure(figsize=figsize)
    ax = fig.add_subplot(111, projection='3d')

    # Generate a single boolean mask for each block type
    block_masks = {block_id: (voxels == block_id) for block_id in np.unique(voxels) if block_id in blocks_to_cols}
    
    # Plot all block types with their respective colors
    for block_id, mask in block_masks.items():
        ax.voxels(mask, facecolors=blocks_to_cols[int(block_id)])
    
    # Plot remaining blocks in red with black edges
    other_vox = (voxels != 5) & (voxels != -1) & (~np.any(np.stack(list(block_masks.values())), axis=0))
    ax.voxels(other_vox, edgecolor="k", facecolor="red")
    
    # Set default view angle
    ax.view_init(elev=elev, azim=azim)

    if highlight_latents is not None:
        fig = draw_latent_cuboid(fig, highlight_latents)
    
    return fig

def visualize_chunk_with_biomes(voxels, biomes, figsize=(10, 10), elev=20, azim=45):
    """
    3D visualization of a Minecraft chunk with biome overlay.
    
    Args:
        voxels: numpy array of block IDs
        biomes: numpy array of biome IDs (same shape as voxels)
        figsize: tuple for figure size
        elev, azim: viewing angle parameters
    """
    # Convert tensors to numpy if needed
    if isinstance(voxels, torch.Tensor):
        if voxels.dim() == 4:  # One-hot encoded
            voxels = voxels.detach().cpu()
            voxels = torch.argmax(voxels, dim=0).numpy()
        else:
            voxels = voxels.detach().cpu().numpy()
    if isinstance(biomes, torch.Tensor):
        biomes = biomes.detach().cpu().numpy()

    # Apply the same transformations to both arrays
    voxels = voxels.transpose(2, 0, 1)
    voxels = np.rot90(voxels, 1, (0, 1))
    biomes = biomes.transpose(2, 0, 1)
    biomes = np.rot90(biomes, 1, (0, 1))

    # Create figure and 3D axis
    fig = plt.figure(figsize=figsize)
    ax = fig.add_subplot(111, projection='3d')

    # First plot the regular blocks
    block_masks = {block_id: (voxels == block_id) 
                  for block_id in np.unique(voxels) 
                  if block_id in blocks_to_cols}
    
    for block_id, mask in block_masks.items():
        ax.voxels(mask, facecolors=blocks_to_cols[int(block_id)])

    # Plot remaining blocks in red with black edges
    other_vox = (voxels != 5) & (voxels != -1) & (~np.any(np.stack(list(block_masks.values())), axis=0))
    ax.voxels(other_vox, edgecolor="k", facecolor="red")

    # Create a colormap for biomes
    unique_biomes = np.unique(biomes)
    biome_colors = plt.cm.rainbow(np.linspace(0, 1, len(unique_biomes)))
    biome_color_map = dict(zip(unique_biomes, biome_colors))

    # Plot biome overlay with transparency
    for biome in unique_biomes:
        biome_mask = (biomes == biome)
        if biome_mask.any():  # Only plot if biome exists in chunk
            color = biome_color_map[biome]
            color = (*color[:3], 0.2)  # Set alpha to 0.2 for transparency
            ax.voxels(biome_mask, facecolors=color, edgecolor=None)

    # Set default view angle
    ax.view_init(elev=elev, azim=azim)

    # Add a legend for biomes
    from matplotlib.patches import Patch
    legend_elements = [Patch(facecolor=biome_color_map[biome], 
                           alpha=0.2,
                           label=f'Biome {biome}')
                      for biome in unique_biomes]
    ax.legend(handles=legend_elements, bbox_to_anchor=(1.05, 1), loc='upper left')

    return fig

# Biome dataset

## Converters
Save mappings from block types and biome strings to ordinal integers, so that we can retrieve the original values from what we generate

In [3]:
class BlockBiomeConverter:
    def __init__(self, block_mappings=None, biome_mappings=None):
        """
        Initialize with pre-computed mappings for both blocks and biomes
        
        Args:
            block_mappings: dict containing 'index_to_block' and 'block_to_index'
            biome_mappings: dict containing 'index_to_biome' and 'biome_to_index'
        """
        self.index_to_block = block_mappings['index_to_block'] if block_mappings else None
        self.block_to_index = block_mappings['block_to_index'] if block_mappings else None
        self.index_to_biome = biome_mappings['index_to_biome'] if biome_mappings else None
        self.biome_to_index = biome_mappings['biome_to_index'] if biome_mappings else None
    
    @classmethod
    def from_dataset(cls, data_path):
        """Create mappings from a dataset file"""
        data = np.load(data_path, allow_pickle=True)
        voxels = data['voxels']
        biomes = data['biomes']
        
        # Create block mappings (blocks are integers)
        unique_blocks = np.unique(voxels)
        block_to_index = {int(block): idx for idx, block in enumerate(unique_blocks)}
        index_to_block = {idx: int(block) for idx, block in enumerate(unique_blocks)}
        
        # Create biome mappings (biomes are strings)
        unique_biomes = np.unique(biomes)
        biome_to_index = {str(biome): idx for idx, biome in enumerate(unique_biomes)}
        index_to_biome = {idx: str(biome) for idx, biome in enumerate(unique_biomes)}
        
        block_mappings = {'index_to_block': index_to_block, 'block_to_index': block_to_index}
        biome_mappings = {'index_to_biome': index_to_biome, 'biome_to_index': biome_to_index}
        
        return cls(block_mappings, biome_mappings)
    
    @classmethod
    def from_arrays(cls, voxels, biomes):
        """Create mappings directly from numpy arrays"""
        # Create block mappings (blocks are integers)
        unique_blocks = np.unique(voxels)
        block_to_index = {int(block): idx for idx, block in enumerate(unique_blocks)}
        index_to_block = {idx: int(block) for idx, block in enumerate(unique_blocks)}
        
        # Create biome mappings (biomes are strings)
        unique_biomes = np.unique(biomes)
        biome_to_index = {str(biome): idx for idx, biome in enumerate(unique_biomes)}
        index_to_biome = {idx: str(biome) for idx, biome in enumerate(unique_biomes)}
        
        block_mappings = {'index_to_block': index_to_block, 'block_to_index': block_to_index}
        biome_mappings = {'index_to_biome': index_to_biome, 'biome_to_index': biome_to_index}
        
        return cls(block_mappings, biome_mappings)
    
    @classmethod
    def load_mappings(cls, path):
        """Load pre-saved mappings"""
        mappings = torch.load(path)
        return cls(mappings['block_mappings'], mappings['biome_mappings'])
    
    def save_mappings(self, path):
        """Save mappings for later use"""
        torch.save({
            'block_mappings': {
                'index_to_block': self.index_to_block,
                'block_to_index': self.block_to_index
            },
            'biome_mappings': {
                'index_to_biome': self.index_to_biome,
                'biome_to_index': self.biome_to_index
            }
        }, path)
    
    def convert_to_original_blocks(self, data):
        """
        Convert from indices back to original block IDs.
        Handles both one-hot encoded and already-indexed data.
        
        Args:
            data: torch.Tensor of either:
                - one-hot encoded blocks [B, C, H, W, D] or [C, H, W, D]
                - indexed blocks [B, H, W, D] or [H, W, D]
        Returns:
            torch.Tensor of original block IDs with shape [B, H, W, D] or [H, W, D]
        """
        # If one-hot encoded (dim == 5 or first dim == num_blocks), convert to indices first
        if len(data.shape) == 5 or (len(data.shape) == 4 and data.shape[0] == len(self.block_to_index)):
            data = torch.argmax(data, dim=1 if len(data.shape) == 5 else 0)
        
        # Now convert indices to original blocks
        if len(data.shape) == 4:  # Batch dimension present
            return torch.tensor([[[[self.index_to_block[int(b)] 
                                for b in row]
                                for row in layer]
                                for layer in slice_]
                                for slice_ in data])
        else:  # No batch dimension
            return torch.tensor([[[self.index_to_block[int(b)] 
                                for b in row]
                                for row in layer]
                                for layer in data])

    def convert_to_original_biomes(self, data):
        """
        Convert from indices back to original biome strings.
        Handles both one-hot encoded and already-indexed data.
        
        Args:
            data: torch.Tensor of either:
                - one-hot encoded biomes [B, C, H, W, D] or [C, H, W, D]
                - indexed biomes [B, H, W, D] or [H, W, D]
        Returns:
            numpy array of original biome strings with shape [B, H, W, D] or [H, W, D]
        """
        # If one-hot encoded (dim == 5 or first dim == num_biomes), convert to indices first
        if len(data.shape) == 5 or (len(data.shape) == 4 and data.shape[0] == len(self.biome_to_index)):
            data = torch.argmax(data, dim=1 if len(data.shape) == 5 else 0)
        
        # Now convert indices to original biomes
        if len(data.shape) == 4:  # Batch dimension present
            return np.array([[[[self.index_to_biome[int(b)] 
                            for b in row]
                            for row in layer]
                            for layer in slice_]
                            for slice_ in data])
        else:  # No batch dimension
            return np.array([[[self.index_to_biome[int(b)] 
                            for b in row]
                            for row in layer]
                            for layer in data])

# Biome Mapping

In [4]:
biome_mapping = {
    'ocean': 'ocean',
    'deep_ocean': 'ocean',
    'deep_warm_ocean': 'ocean',

    'desert': 'desert',
    'desert_hills': 'desert',
    'mutated_desert': 'desert',

    'beaches': 'beaches',
    'stone_beach': 'beaches',

    'cave': 'cave',

    'extreme_hills': 'extreme_hills',
    'extreme_hills_with_trees': 'extreme_hills',
    'mutated_extreme_hills_with_trees': 'extreme_hills',
    'mutated_extreme_hills': 'extreme_hills',

    'forest': 'forest',
    'forest_hills': 'forest',
    'mutated_forest': 'forest',
    'mutated_roofed_forest': 'forest',
    'roofed_forest': 'forest',

    'birch_forest': 'birch_forest',
    'birch_forest_hills': 'birch_forest',
    'mutated_birch_forest': 'birch_forest',
    'mutated_birch_forest_hills': 'birch_forest',

    'plains': 'plains',
    'mutated_plains': 'plains',

    'river': 'river',

    'savanna': 'savanna',
    'savanna_rock': 'savanna',

    'swampland': 'swampland',
    'mutated_swampland': 'swampland',

    'taiga': 'taiga',
    'taiga_hills': 'taiga',
    'mutated_taiga': 'taiga',
    'redwood_taiga': 'taiga',
    # ... add more mappings as needed
}

## Dataset Class

In [5]:
# biome_mapping = {
#     'ocean': 'ocean',
#     'deep_ocean': 'ocean',
#     'deep_warm_ocean': 'ocean',
#     'desert': 'desert',
#     'desert_hills': 'desert',
#     'beaches': 'beaches',
#     'stone_beach': 'beaches',
#     'cave': 'cave',
#     'extreme_hills': 'extreme_hills',
#     'extreme_hills_with_trees': 'extreme_hills',
#     'mutated_extreme_hills_with_trees': 'extreme_hills',
#     'forest': 'forest',
#     'forest_hills': 'forest',
#     'plains': 'plains',
#     'mutated_plains': 'plains',
#     'river': 'river',
#     'savanna': 'savanna',
#     'savanna_rock': 'savanna',
#     'swampland': 'swampland',
#     'taiga': 'taiga',
#     'taiga_hills': 'taiga',
#     # ... add more mappings as needed
# }

class MinecraftTerrainDataset(Dataset):
    def __init__(self, processed_chunks, processed_biomes, indices, converter):
        """
        Internal constructor used by create_train_val_datasets
        
        Args:
            processed_chunks: Pre-processed chunk data
            processed_biomes: Pre-processed biome data
            indices: Indices for this split
            converter: BlockBiomeConverter instance
        """
        self.processed_chunks = processed_chunks
        self.processed_biomes = processed_biomes
        self.indices = indices
        self.converter = converter
        
    def __len__(self):
        return len(self.indices)
    
    def __getitem__(self, idx):
        actual_idx = self.indices[idx]
        return self.processed_chunks[actual_idx], self.processed_biomes[actual_idx]
    
    @classmethod
    def condense_biomes(cls, voxels, biomes, biome_mapping):
        """
        Combine similar biomes according to mapping.
        
        Args:
            voxels: numpy array of block IDs
            biomes: numpy array of biome strings
            biome_mapping: dict mapping original biome names to new consolidated names
        
        Returns:
            voxels, biomes (unchanged voxels, consolidated biomes)
        """
        print("Condensing biomes...")
        print("Original unique biomes:", np.unique(biomes))
        
        # Create a copy to modify
        new_biomes = biomes.copy()
        
        # Apply the mapping
        for old_biome, new_biome in biome_mapping.items():
            new_biomes[biomes == old_biome] = new_biome
        
        print("Condensed unique biomes:", np.unique(new_biomes))
        return voxels, new_biomes
    
    @classmethod
    def preprocess_air_biomes(cls, voxels, biomes, air_id=5, bubble_size=2):
        """
        Preprocesses biome labels by creating an "air" biome label for air blocks
        beyond a bubble around the terrain.
        
        Args:
            voxels: numpy array of block IDs [H, W, D]
            biomes: numpy array of biome strings [H, W, D]
            air_id: ID representing air blocks in voxels array
            bubble_size: Number of blocks to extend the bubble past terrain
        
        Returns:
            processed_biomes: numpy array with updated biome labels
        """
        from scipy import ndimage
        print("Processing air biomes...")
        print("Original unique biomes:", np.unique(biomes))
        
        # Create output array
        processed_biomes = biomes.copy()
        
        # Skip if this chunk contains cave biome
        if 'cave' in np.unique(biomes):
            return processed_biomes
            
        # Create binary mask of air blocks
        air_mask = (voxels == air_id)
        
        # Label connected components of air
        labeled_array, num_features = ndimage.label(air_mask)
        
        # Find components that touch the top of the chunk
        top_labels = set(np.unique(labeled_array[0, :, :]))  # Labels present in top layer
        if 0 in top_labels:  # Remove background label
            top_labels.remove(0)
            
        if len(top_labels) > 0:
            # Find the largest component that touches the top
            component_sizes = [(label, np.sum(labeled_array == label)) 
                            for label in top_labels]
            sky_label = max(component_sizes, key=lambda x: x[1])[0]
            
            # Create mask for sky component
            sky_mask = (labeled_array == sky_label)
            
            # Create terrain mask (inverse of sky mask)
            terrain_mask = ~sky_mask
            
            # Dilate the terrain mask to create the bubble
            kernel = np.ones((bubble_size*2 + 1, bubble_size*2 + 1, bubble_size*2 + 1))
            bubble = ndimage.binary_dilation(terrain_mask, structure=kernel)
            
            # Set biome to "air" for sky blocks outside the bubble
            processed_biomes[sky_mask & ~bubble] = "air"
        
        print("Processed unique biomes:", np.unique(processed_biomes))
        return processed_biomes
    
    @classmethod
    def remove_underground_chunks(cls, voxels, biomes, min_air_blocks=1152):
        """
        Remove chunks that are underground (have too few air blocks),
        except for cave biomes.
        
        Args:
            voxels: numpy array of block IDs
            biomes: numpy array of biome strings
            min_air_blocks: minimum number of air blocks required (default 24x24x2=1152)
        
        Returns:
            filtered voxels, filtered biomes
        """
        print("Removing underground chunks...")
        print(f"Initial chunks: {len(voxels)}")
        
        # Count air blocks in each chunk
        air_counts = np.sum(voxels == 5, axis=(1, 2, 3))  # 5 is air block ID
        
        # Create mask for valid chunks
        valid_chunks = []
        for i in range(len(voxels)):
            unique_biomes = set(np.unique(biomes[i]))
            if air_counts[i] >= min_air_blocks or 'cave' in unique_biomes:
                valid_chunks.append(i)
        
        valid_chunks = np.array(valid_chunks)
        
        # Filter arrays
        filtered_voxels = voxels[valid_chunks]
        filtered_biomes = biomes[valid_chunks]
        
        print(f"Remaining chunks: {len(filtered_voxels)}")
        return filtered_voxels, filtered_biomes
    
    @classmethod
    def create_train_val_datasets(cls, data_path, train_ratio=0.8, biome_mapping=None):
        """
        Creates both training and validation datasets.
        
        Args:
            data_path: Path to the .npz file containing voxels and biomes
            train_ratio: Fraction of data to use for training
            biome_mapping: Optional dict to condense similar biomes
        """
        data_path = Path(data_path)
        mappings_path = data_path.parent / f"{data_path.stem}_mappings3.pt"
        processed_data_path = data_path.parent / f"{data_path.stem}_processed_cleaned3.pt"

        # Try to load processed data first
        if processed_data_path.exists() and mappings_path.exists():
            print("Loading pre-processed data and mappings...")
            processed_data = torch.load(processed_data_path)
            processed_chunks = processed_data['chunks']
            processed_biomes = processed_data['biomes']
            converter = BlockBiomeConverter.load_mappings(mappings_path)
        else:
            # Load and clean raw data
            print("Processing data for the first time...")
            data = np.load(data_path, allow_pickle=True)
            # voxels = data['voxels']
            voxels = data['chunks']
            biomes = data['biomes']
            
            # Clean the data
            if biome_mapping is not None:
                print("Condensing biomes...")
                voxels, biomes = cls.condense_biomes(voxels, biomes, biome_mapping)
            
            print("Removing underground chunks...")
            voxels, biomes = cls.remove_underground_chunks(voxels, biomes)
            
            # print("Processing air biomes...")
            # # Process each chunk individually
            # processed_biomes = []
            # for i in range(len(voxels)):
            #     if i % 100 == 0:
            #         print(f"Processing chunk {i}/{len(voxels)}")
            #     chunk_biomes = cls.preprocess_air_biomes(voxels[i], biomes[i], bubble_size=1)
            #     processed_biomes.append(chunk_biomes)
            # biomes = np.stack(processed_biomes)
            # Create mappings with cleaned data
            print("Creating new block/biome mappings...")
            converter = BlockBiomeConverter.from_arrays(voxels, biomes)
            converter.save_mappings(mappings_path)
            
            # Get dimensions
            num_blocks = len(converter.block_to_index)
            num_biomes = len(converter.biome_to_index)
            
                ## Process in batches
            num_samples = len(voxels)
            batch_size = 100
            num_batches = (num_samples + batch_size - 1) // batch_size
            
            processed_chunks_list = []
            processed_biomes_list = []
            
            print(f"Processing {num_samples} samples in {num_batches} batches...")
            for i in range(num_batches):
                start_idx = i * batch_size
                end_idx = min((i + 1) * batch_size, num_samples)
                print(f"Processing batch {i+1}/{num_batches}")
                
                # Convert batch to indices
                batch_chunks = torch.tensor([[[[converter.block_to_index[int(b)] 
                                for b in row]
                                for row in layer]
                                for layer in slice_]
                                for slice_ in voxels[start_idx:end_idx]])
                
                batch_biomes = torch.tensor([[[[converter.biome_to_index[str(b)] 
                                for b in row]
                                for row in layer]
                                for layer in slice_]
                                for slice_ in biomes[start_idx:end_idx]])
                
                # One-hot encode batch
                processed_batch_chunks = F.one_hot(
                    batch_chunks.long(), 
                    num_classes=num_blocks
                ).permute(0, 4, 1, 2, 3).float()
                
                processed_batch_biomes = F.one_hot(
                    batch_biomes.long(),
                    num_classes=num_biomes
                ).permute(0, 4, 1, 2, 3).float()
                
                processed_chunks_list.append(processed_batch_chunks)
                processed_biomes_list.append(processed_batch_biomes)
                
                # Clear some memory
                del batch_chunks, batch_biomes
                del processed_batch_chunks, processed_batch_biomes
                torch.cuda.empty_cache() if torch.cuda.is_available() else None
            
            del voxels
            del biomes
            # Concatenate all batches
            processed_chunks = torch.cat(processed_chunks_list, dim=0)
            processed_biomes = torch.cat(processed_biomes_list, dim=0)
                
            # Save processed data
            print("Saving processed data...")
            torch.save({
                'chunks': processed_chunks,
                'biomes': processed_biomes
            }, processed_data_path)
        
        # Create train/val split
        num_samples = len(processed_chunks)
        indices = torch.randperm(num_samples)
        split_idx = int(num_samples * train_ratio)
        
        train_indices = indices[:split_idx]
        val_indices = indices[split_idx:]
        
        # Create datasets
        train_dataset = cls(processed_chunks, processed_biomes, train_indices, converter)
        val_dataset = cls(processed_chunks, processed_biomes, val_indices, converter)
        
        print(f"\nDataset details:")
        print(f"Total chunks: {len(processed_chunks)}")
        print(f"Number of unique block types: {len(converter.block_to_index)}")
        print(f"Number of unique biome types: {len(converter.biome_to_index)}")
        print(f"Training samples: {len(train_indices)}")
        print(f"Validation samples: {len(val_indices)}")
        print(f"Chunk shape: {processed_chunks.shape}")
        print(f"Biome shape: {processed_biomes.shape}")
        
        return train_dataset, val_dataset

In [6]:
def get_minecraft_dataloaders(data_path, batch_size=32, train_ratio=0.8, num_workers=0, biome_mapping=None):
    """
    Creates training and validation dataloaders for Minecraft chunks.
    """
    # Create both datasets at once
    train_dataset, val_dataset = MinecraftTerrainDataset.create_train_val_datasets(
        data_path, 
        train_ratio=train_ratio,
        biome_mapping=biome_mapping
    )
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True
    )
    
    # Store converter in the loaders
    train_loader.converter = train_dataset.converter
    val_loader.converter = val_dataset.converter
    
    print(f"\nDataloader details:")
    print(f"Batch size: {batch_size}")
    print(f"Training batches: {len(train_loader)}")
    print(f"Validation batches: {len(val_loader)}")
    
    return train_loader, val_loader

# Biome classifier

In [7]:
def normalize(in_channels):
    return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)   #   divides the channels into 32 groups, and normalizes each group. More effective for smaller batch size than batch norm

@torch.jit.script
def swish(x):
    return x*torch.sigmoid(x)   #  swish activation function, compiled using torch.jit.script. Smooth, non-linear activation function, works better than ReLu in some cases. swish (x) = x * sigmoid(x)

class ResBlock(nn.Module):
    def __init__(self, in_channels, out_channels=None):
        super(ResBlock, self).__init__()
        self.in_channels = in_channels
        self.out_channels = in_channels if out_channels is None else out_channels
        self.norm1 = normalize(in_channels)
        self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.norm2 = normalize(out_channels)
        self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.conv_out = nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)

    def forward(self, x_in):
        x = x_in
        x = self.norm1(x)
        x = swish(x)
        x = self.conv1(x)
        x = self.norm2(x)
        x = swish(x)
        x = self.conv2(x)
        if self.in_channels != self.out_channels:
            x_in = self.conv_out(x_in)

        return x + x_in
    
class AttnBlock(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.in_channels = in_channels

        self.norm = normalize(in_channels)
        # Convert all 2D convolutions to 3D
        self.q = torch.nn.Conv3d(
            in_channels,
            in_channels,
            kernel_size=1,
            stride=1,
            padding=0
        )
        self.k = torch.nn.Conv3d(
            in_channels,
            in_channels,
            kernel_size=1,
            stride=1,
            padding=0
        )
        self.v = torch.nn.Conv3d(
            in_channels,
            in_channels,
            kernel_size=1,
            stride=1,
            padding=0
        )
        self.proj_out = torch.nn.Conv3d(
            in_channels,
            in_channels,
            kernel_size=1,
            stride=1,
            padding=0
        )

    def forward(self, x):
        h_ = x
        h_ = self.norm(h_)
        q = self.q(h_)
        k = self.k(h_)
        v = self.v(h_)

        # compute attention
        b, c, h, w, d = q.shape
        q = q.reshape(b, c, h*w*d)    # Flatten all spatial dimensions
        q = q.permute(0, 2, 1)        # b, hwd, c
        k = k.reshape(b, c, h*w*d)    # b, c, hwd
        w_ = torch.bmm(q, k)          # b, hwd, hwd    
        w_ = w_ * (int(c)**(-0.5))    # Scale dot products
        w_ = F.softmax(w_, dim=2)     # Softmax over spatial positions

        # attend to values
        v = v.reshape(b, c, h*w*d)
        w_ = w_.permute(0, 2, 1)      # b, hwd, hwd (first hwd of k, second of q)
        h_ = torch.bmm(v, w_)         # b, c, hwd
        h_ = h_.reshape(b, c, h, w, d) # Restore spatial structure

        h_ = self.proj_out(h_)

        return x + h_
    


In [8]:
class BiomeClassifier(nn.Module):
    def __init__(self, num_block_types, num_biomes, feature_dim=256):
        super(BiomeClassifier, self).__init__()
        
        # Initial embedding layer for one-hot encoded blocks
        self.block_proj = nn.Conv3d(num_block_types, 64, kernel_size=1)
        
        # Encoder layers that will downsample to 6x6x6 spatial dimensions
        self.encoder = nn.Sequential(
            # Layer 1: 24x24x24 -> 12x12x12
            nn.Conv3d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm3d(128),
            nn.ReLU(),
            
            # Layer 2: 12x12x12 -> 6x6x6
            nn.Conv3d(128, feature_dim, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm3d(feature_dim),
            nn.ReLU(),
        )
        
        # Self-attention block at 6x6x6 resolution
        self.attention = AttnBlock(feature_dim)
        
        # Additional convolution after attention for feature extraction
        self.feature_conv = nn.Conv3d(feature_dim, feature_dim, kernel_size=3, padding=1)
        
        # Biome prediction head
        self.biome_head = nn.Conv3d(feature_dim, num_biomes, kernel_size=1)
        
        # Upsampling layer to get back to original resolution
        self.upsample = nn.Upsample(size=(24, 24, 24), mode='trilinear', align_corners=False)
        
    def get_intermediate_features(self, x):
        """Get the features after attention for representation learning"""
        x = self.block_proj(x)
        x = self.encoder(x)
        x = self.attention(x)
        return self.feature_conv(x)
        
    def forward(self, x, return_features=False):
        # x shape: (batch_size, num_blocks, 24, 24, 24)
        
        # Initial projection and encoding
        x = self.block_proj(x)
        x = self.encoder(x)  # (batch_size, feature_dim, 6, 6, 6)
        
        # Apply self-attention
        x = self.attention(x)
        
        # Get features for representation learning
        features = self.feature_conv(x)
        
        # Predict biomes and upsample
        biome_logits = self.biome_head(features)  # (batch_size, num_biomes, 6, 6, 6)
        biome_logits = self.upsample(biome_logits)  # (batch_size, num_biomes, 24, 24, 24)
        
        if return_features:
            return biome_logits, features
        return biome_logits

# Combine raw npz files

In [9]:
# from data_utils import combine_chunk_files
# input_dir = "../../text2env/data/Voxels/voxels"
# output_file = "../../text2env/data/24_newdataset.npz"
# result = combine_chunk_files(input_dir, output_file)

In [10]:
# Create datasets
train_loader, val_loader = get_minecraft_dataloaders(
    data_path='../../text2env/data/24_newdataset.npz',
    batch_size=32,
    num_workers=0,
    biome_mapping=biome_mapping
)

Processing data for the first time...
Condensing biomes...
Condensing biomes...
Original unique biomes: ['beaches' 'birch_forest' 'birch_forest_hills' 'cave' 'deep_ocean'
 'desert' 'desert_hills' 'extreme_hills' 'extreme_hills_with_trees'
 'forest' 'forest_hills' 'mutated_birch_forest'
 'mutated_birch_forest_hills' 'mutated_desert' 'mutated_extreme_hills'
 'mutated_extreme_hills_with_trees' 'mutated_forest' 'mutated_plains'
 'mutated_roofed_forest' 'mutated_swampland' 'mutated_taiga' 'ocean'
 'plains' 'river' 'roofed_forest' 'savanna' 'savanna_rock' 'stone_beach'
 'swampland' 'taiga' 'taiga_hills']
Condensed unique biomes: ['beaches' 'birch_forest' 'cave' 'desert' 'extreme_hills' 'forest'
 'mutated_extreme_hills' 'ocean' 'plains' 'river' 'savanna' 'swampland'
 'taiga']
Removing underground chunks...
Removing underground chunks...
Initial chunks: 11120
Remaining chunks: 11119
Creating new block/biome mappings...
Processing 11119 samples in 112 batches...
Processing batch 1/112
Processin

In [118]:
# # Load the processed data
# processed_data_path = Path('../../text2env/data/minecraft_biome_newjava_2500_processed_cleaned.pt')
# data = torch.load(processed_data_path)
# processed_biomes = data['biomes']

# # Load the mappings to get biome names
# mappings_path = Path('../../text2env/data/minecraft_biome_newjava_2500_mappings.pt')
# converter = BlockBiomeConverter.load_mappings(mappings_path)

# # Print all unique biomes
# print("Current biome types:")
# for idx, biome in converter.index_to_biome.items():
#     print(f"{idx}: {biome}")

Current biome types:
0: beaches
1: cave
2: desert
3: extreme_hills
4: forest
5: ocean
6: plains
7: river
8: savanna
9: taiga


In [11]:
# Initialize model and move to device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = BiomeClassifier(
    num_block_types=len(train_loader.converter.block_to_index),
    num_biomes=len(train_loader.converter.biome_to_index)
).to(device)

# Training parameters
learning_rate = 1e-3
num_epochs = 125
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

best_val_loss = float('inf')

for epoch in range(num_epochs):
    # Training phase
    model.train()
    train_loss = 0
    train_batches = 0
    
    progress_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}')
    for terrain, biomes in progress_bar:
        terrain = terrain.to(device)
        biomes = biomes.to(device)
        
        # Forward pass
        logits = model(terrain)  # (batch_size, num_biomes, 24, 24, 24)
        
        # Convert target from one-hot to indices
        target = biomes.argmax(dim=1)  # (batch_size, 24, 24, 24)
        
        # Compute loss - no reshaping needed!
        loss = criterion(logits, target)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        train_batches += 1
        
        # Update progress bar
        progress_bar.set_postfix({'train_loss': train_loss / train_batches})
    
    avg_train_loss = train_loss / train_batches
    
    # Validation phase
    model.eval()
    val_loss = 0
    val_batches = 0
    
    with torch.no_grad():
        for terrain, biomes in val_loader:
            terrain = terrain.to(device)
            biomes = biomes.to(device)
            
            logits = model(terrain)
            
            B, C, H, W, D = logits.shape
            loss = criterion(
                logits.reshape(B * H * W * D, C),
                biomes.argmax(dim=1).reshape(B * H * W * D)
            )
            
            val_loss += loss.item()
            val_batches += 1
    
    avg_val_loss = val_loss / val_batches
    
    print(f'Epoch {epoch+1}/{num_epochs}:')
    print(f'  Train Loss: {avg_train_loss:.4f}')
    print(f'  Val Loss: {avg_val_loss:.4f}')
    
    # Save best model
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        torch.save(model.state_dict(), 'best_biome_classifier_airprocessed.pt')

Epoch 1/125: 100%|██████████| 278/278 [00:16<00:00, 16.76it/s, train_loss=0.813]


Epoch 1/125:
  Train Loss: 0.8125
  Val Loss: 2.8253


Epoch 2/125: 100%|██████████| 278/278 [00:16<00:00, 16.89it/s, train_loss=0.574]


Epoch 2/125:
  Train Loss: 0.5740
  Val Loss: 2.8011


Epoch 3/125: 100%|██████████| 278/278 [00:16<00:00, 17.31it/s, train_loss=0.472]


Epoch 3/125:
  Train Loss: 0.4721
  Val Loss: 2.8334


Epoch 4/125: 100%|██████████| 278/278 [00:17<00:00, 16.32it/s, train_loss=0.43] 


Epoch 4/125:
  Train Loss: 0.4295
  Val Loss: 2.8581


Epoch 5/125: 100%|██████████| 278/278 [00:16<00:00, 16.35it/s, train_loss=0.4]  


Epoch 5/125:
  Train Loss: 0.3997
  Val Loss: 2.8520


Epoch 6/125: 100%|██████████| 278/278 [00:17<00:00, 15.84it/s, train_loss=0.378]


Epoch 6/125:
  Train Loss: 0.3778
  Val Loss: 2.8300


Epoch 7/125: 100%|██████████| 278/278 [00:17<00:00, 15.63it/s, train_loss=0.37] 


Epoch 7/125:
  Train Loss: 0.3697
  Val Loss: 2.8241


Epoch 8/125: 100%|██████████| 278/278 [00:17<00:00, 15.56it/s, train_loss=0.346]


Epoch 8/125:
  Train Loss: 0.3464
  Val Loss: 2.8632


Epoch 9/125: 100%|██████████| 278/278 [00:17<00:00, 16.15it/s, train_loss=0.388]


Epoch 9/125:
  Train Loss: 0.3879
  Val Loss: 2.8490


Epoch 10/125: 100%|██████████| 278/278 [00:17<00:00, 15.59it/s, train_loss=0.306]


Epoch 10/125:
  Train Loss: 0.3065
  Val Loss: 2.9173


Epoch 11/125: 100%|██████████| 278/278 [00:18<00:00, 15.36it/s, train_loss=0.29] 


Epoch 11/125:
  Train Loss: 0.2899
  Val Loss: 2.8915


Epoch 12/125: 100%|██████████| 278/278 [00:17<00:00, 15.76it/s, train_loss=0.291]


Epoch 12/125:
  Train Loss: 0.2909
  Val Loss: 2.8909


Epoch 13/125: 100%|██████████| 278/278 [00:18<00:00, 15.21it/s, train_loss=0.259]


Epoch 13/125:
  Train Loss: 0.2594
  Val Loss: 2.9301


Epoch 14/125: 100%|██████████| 278/278 [00:17<00:00, 15.49it/s, train_loss=0.245]


Epoch 14/125:
  Train Loss: 0.2450
  Val Loss: 2.9306


Epoch 15/125: 100%|██████████| 278/278 [00:16<00:00, 16.41it/s, train_loss=0.218]


Epoch 15/125:
  Train Loss: 0.2180
  Val Loss: 2.9876


Epoch 16/125: 100%|██████████| 278/278 [00:17<00:00, 16.24it/s, train_loss=0.215]


Epoch 16/125:
  Train Loss: 0.2147
  Val Loss: 2.9934


Epoch 17/125: 100%|██████████| 278/278 [00:17<00:00, 15.50it/s, train_loss=0.425]


Epoch 17/125:
  Train Loss: 0.4255
  Val Loss: 2.9379


Epoch 18/125: 100%|██████████| 278/278 [00:17<00:00, 15.94it/s, train_loss=0.207]


Epoch 18/125:
  Train Loss: 0.2067
  Val Loss: 2.9724


Epoch 19/125: 100%|██████████| 278/278 [00:17<00:00, 15.75it/s, train_loss=0.195]


Epoch 19/125:
  Train Loss: 0.1949
  Val Loss: 3.0028


Epoch 20/125: 100%|██████████| 278/278 [00:18<00:00, 15.12it/s, train_loss=0.185]


Epoch 20/125:
  Train Loss: 0.1851
  Val Loss: 2.9969


Epoch 21/125: 100%|██████████| 278/278 [00:17<00:00, 15.52it/s, train_loss=0.186]


Epoch 21/125:
  Train Loss: 0.1857
  Val Loss: 2.9846


Epoch 22/125: 100%|██████████| 278/278 [00:17<00:00, 15.49it/s, train_loss=0.183]


Epoch 22/125:
  Train Loss: 0.1827
  Val Loss: 2.9920


Epoch 23/125: 100%|██████████| 278/278 [00:17<00:00, 15.92it/s, train_loss=0.155]


Epoch 23/125:
  Train Loss: 0.1553
  Val Loss: 3.0025


Epoch 24/125: 100%|██████████| 278/278 [00:17<00:00, 16.09it/s, train_loss=0.142]


Epoch 24/125:
  Train Loss: 0.1417
  Val Loss: 3.0502


Epoch 25/125: 100%|██████████| 278/278 [00:17<00:00, 16.28it/s, train_loss=0.139]


Epoch 25/125:
  Train Loss: 0.1391
  Val Loss: 3.0282


Epoch 26/125: 100%|██████████| 278/278 [00:17<00:00, 15.88it/s, train_loss=0.14] 


Epoch 26/125:
  Train Loss: 0.1402
  Val Loss: 3.0845


Epoch 27/125: 100%|██████████| 278/278 [00:17<00:00, 16.04it/s, train_loss=0.227]


Epoch 27/125:
  Train Loss: 0.2266
  Val Loss: 3.0213


Epoch 28/125: 100%|██████████| 278/278 [00:17<00:00, 16.33it/s, train_loss=0.163]


Epoch 28/125:
  Train Loss: 0.1630
  Val Loss: 3.0525


Epoch 29/125: 100%|██████████| 278/278 [00:17<00:00, 15.83it/s, train_loss=0.139]


Epoch 29/125:
  Train Loss: 0.1392
  Val Loss: 3.0350


Epoch 30/125: 100%|██████████| 278/278 [00:18<00:00, 15.09it/s, train_loss=0.132]


Epoch 30/125:
  Train Loss: 0.1325
  Val Loss: 3.0634


Epoch 31/125: 100%|██████████| 278/278 [00:17<00:00, 15.61it/s, train_loss=0.125]


Epoch 31/125:
  Train Loss: 0.1246
  Val Loss: 3.0380


Epoch 32/125: 100%|██████████| 278/278 [00:17<00:00, 15.61it/s, train_loss=0.125]


Epoch 32/125:
  Train Loss: 0.1253
  Val Loss: 3.0627


Epoch 33/125: 100%|██████████| 278/278 [00:17<00:00, 15.62it/s, train_loss=0.124]


Epoch 33/125:
  Train Loss: 0.1242
  Val Loss: 3.0736


Epoch 34/125: 100%|██████████| 278/278 [00:17<00:00, 16.09it/s, train_loss=0.128]


Epoch 34/125:
  Train Loss: 0.1279
  Val Loss: 3.0748


Epoch 35/125: 100%|██████████| 278/278 [00:18<00:00, 15.29it/s, train_loss=0.124]


Epoch 35/125:
  Train Loss: 0.1236
  Val Loss: 3.0803


Epoch 36/125: 100%|██████████| 278/278 [00:17<00:00, 15.84it/s, train_loss=0.153]


Epoch 36/125:
  Train Loss: 0.1529
  Val Loss: 3.1406


Epoch 37/125: 100%|██████████| 278/278 [00:17<00:00, 15.50it/s, train_loss=0.128]


Epoch 37/125:
  Train Loss: 0.1277
  Val Loss: 3.0719


Epoch 38/125: 100%|██████████| 278/278 [00:17<00:00, 16.35it/s, train_loss=0.119]


Epoch 38/125:
  Train Loss: 0.1194
  Val Loss: 3.0925


Epoch 39/125: 100%|██████████| 278/278 [00:16<00:00, 16.46it/s, train_loss=0.219]


Epoch 39/125:
  Train Loss: 0.2190
  Val Loss: 3.1054


Epoch 40/125: 100%|██████████| 278/278 [00:18<00:00, 15.26it/s, train_loss=0.147]


Epoch 40/125:
  Train Loss: 0.1472
  Val Loss: 3.1266


Epoch 41/125: 100%|██████████| 278/278 [00:18<00:00, 14.91it/s, train_loss=0.12] 


Epoch 41/125:
  Train Loss: 0.1195
  Val Loss: 3.1585


Epoch 42/125: 100%|██████████| 278/278 [00:18<00:00, 14.75it/s, train_loss=0.107]


Epoch 42/125:
  Train Loss: 0.1067
  Val Loss: 3.1447


Epoch 43/125: 100%|██████████| 278/278 [00:18<00:00, 15.30it/s, train_loss=0.105]


Epoch 43/125:
  Train Loss: 0.1049
  Val Loss: 3.1709


Epoch 44/125: 100%|██████████| 278/278 [00:17<00:00, 15.48it/s, train_loss=0.102]


Epoch 44/125:
  Train Loss: 0.1016
  Val Loss: 3.1620


Epoch 45/125: 100%|██████████| 278/278 [00:18<00:00, 15.41it/s, train_loss=0.0984]


Epoch 45/125:
  Train Loss: 0.0984
  Val Loss: 3.1396


Epoch 46/125: 100%|██████████| 278/278 [00:17<00:00, 15.55it/s, train_loss=0.101] 


Epoch 46/125:
  Train Loss: 0.1005
  Val Loss: 3.1750


Epoch 47/125: 100%|██████████| 278/278 [00:18<00:00, 15.12it/s, train_loss=0.104]


Epoch 47/125:
  Train Loss: 0.1043
  Val Loss: 3.1886


Epoch 48/125: 100%|██████████| 278/278 [00:18<00:00, 14.64it/s, train_loss=0.1]  


Epoch 48/125:
  Train Loss: 0.1001
  Val Loss: 3.1887


Epoch 49/125: 100%|██████████| 278/278 [00:18<00:00, 14.87it/s, train_loss=0.102]


Epoch 49/125:
  Train Loss: 0.1022
  Val Loss: 3.1572


Epoch 50/125: 100%|██████████| 278/278 [00:17<00:00, 16.35it/s, train_loss=0.1]  


Epoch 50/125:
  Train Loss: 0.1005
  Val Loss: 3.1755


Epoch 51/125: 100%|██████████| 278/278 [00:17<00:00, 16.35it/s, train_loss=0.0962]


Epoch 51/125:
  Train Loss: 0.0962
  Val Loss: 3.2158


Epoch 52/125: 100%|██████████| 278/278 [00:17<00:00, 15.97it/s, train_loss=0.222]


Epoch 52/125:
  Train Loss: 0.2222
  Val Loss: 3.3892


Epoch 53/125: 100%|██████████| 278/278 [00:16<00:00, 16.39it/s, train_loss=0.163]


Epoch 53/125:
  Train Loss: 0.1632
  Val Loss: 3.1636


Epoch 54/125: 100%|██████████| 278/278 [00:17<00:00, 16.30it/s, train_loss=0.108]


Epoch 54/125:
  Train Loss: 0.1075
  Val Loss: 3.2136


Epoch 55/125: 100%|██████████| 278/278 [00:17<00:00, 16.13it/s, train_loss=0.0914]


Epoch 55/125:
  Train Loss: 0.0914
  Val Loss: 3.2313


Epoch 56/125: 100%|██████████| 278/278 [00:17<00:00, 16.11it/s, train_loss=0.0858]


Epoch 56/125:
  Train Loss: 0.0858
  Val Loss: 3.2525


Epoch 57/125: 100%|██████████| 278/278 [00:17<00:00, 16.08it/s, train_loss=0.0843]


Epoch 57/125:
  Train Loss: 0.0843
  Val Loss: 3.2549


Epoch 58/125: 100%|██████████| 278/278 [00:17<00:00, 16.30it/s, train_loss=0.081] 


Epoch 58/125:
  Train Loss: 0.0810
  Val Loss: 3.2479


Epoch 59/125: 100%|██████████| 278/278 [00:17<00:00, 15.88it/s, train_loss=0.081] 


Epoch 59/125:
  Train Loss: 0.0810
  Val Loss: 3.2375


Epoch 60/125: 100%|██████████| 278/278 [00:17<00:00, 16.25it/s, train_loss=0.0816]


Epoch 60/125:
  Train Loss: 0.0816
  Val Loss: 3.2468


Epoch 61/125: 100%|██████████| 278/278 [00:18<00:00, 15.07it/s, train_loss=0.0837]


Epoch 61/125:
  Train Loss: 0.0837
  Val Loss: 3.2385


Epoch 62/125: 100%|██████████| 278/278 [00:18<00:00, 15.00it/s, train_loss=0.0851]


Epoch 62/125:
  Train Loss: 0.0851
  Val Loss: 3.2227


Epoch 63/125: 100%|██████████| 278/278 [00:18<00:00, 15.17it/s, train_loss=0.083] 


Epoch 63/125:
  Train Loss: 0.0830
  Val Loss: 3.2456


Epoch 64/125: 100%|██████████| 278/278 [00:18<00:00, 15.34it/s, train_loss=0.0827]


Epoch 64/125:
  Train Loss: 0.0827
  Val Loss: 3.2338


Epoch 65/125: 100%|██████████| 278/278 [00:18<00:00, 14.84it/s, train_loss=0.0811]


Epoch 65/125:
  Train Loss: 0.0811
  Val Loss: 3.2479


Epoch 66/125: 100%|██████████| 278/278 [00:18<00:00, 15.18it/s, train_loss=0.0815]


Epoch 66/125:
  Train Loss: 0.0815
  Val Loss: 3.2512


Epoch 67/125: 100%|██████████| 278/278 [00:18<00:00, 15.31it/s, train_loss=0.0934]


Epoch 67/125:
  Train Loss: 0.0934
  Val Loss: 3.2947


Epoch 68/125: 100%|██████████| 278/278 [00:18<00:00, 15.35it/s, train_loss=0.21] 


Epoch 68/125:
  Train Loss: 0.2102
  Val Loss: 3.4128


Epoch 69/125: 100%|██████████| 278/278 [00:18<00:00, 15.08it/s, train_loss=0.106]


Epoch 69/125:
  Train Loss: 0.1062
  Val Loss: 3.3045


Epoch 70/125: 100%|██████████| 278/278 [00:17<00:00, 15.51it/s, train_loss=0.0855]


Epoch 70/125:
  Train Loss: 0.0855
  Val Loss: 3.3248


Epoch 71/125: 100%|██████████| 278/278 [00:18<00:00, 15.07it/s, train_loss=0.082] 


Epoch 71/125:
  Train Loss: 0.0820
  Val Loss: 3.3062


Epoch 72/125: 100%|██████████| 278/278 [00:18<00:00, 14.85it/s, train_loss=0.0813]


Epoch 72/125:
  Train Loss: 0.0813
  Val Loss: 3.3210


Epoch 73/125: 100%|██████████| 278/278 [00:18<00:00, 15.37it/s, train_loss=0.0731]


Epoch 73/125:
  Train Loss: 0.0731
  Val Loss: 3.3172


Epoch 74/125: 100%|██████████| 278/278 [00:18<00:00, 15.17it/s, train_loss=0.07]  


Epoch 74/125:
  Train Loss: 0.0700
  Val Loss: 3.3396


Epoch 75/125: 100%|██████████| 278/278 [00:18<00:00, 14.75it/s, train_loss=0.0752]


Epoch 75/125:
  Train Loss: 0.0752
  Val Loss: 3.3284


Epoch 76/125: 100%|██████████| 278/278 [00:18<00:00, 15.16it/s, train_loss=0.371]


Epoch 76/125:
  Train Loss: 0.3710
  Val Loss: 3.7334


Epoch 77/125: 100%|██████████| 278/278 [00:18<00:00, 15.43it/s, train_loss=0.122]


Epoch 77/125:
  Train Loss: 0.1219
  Val Loss: 3.5063


Epoch 78/125: 100%|██████████| 278/278 [00:18<00:00, 14.83it/s, train_loss=0.0861]


Epoch 78/125:
  Train Loss: 0.0861
  Val Loss: 3.5344


Epoch 79/125: 100%|██████████| 278/278 [00:18<00:00, 15.24it/s, train_loss=0.0774]


Epoch 79/125:
  Train Loss: 0.0774
  Val Loss: 3.5421


Epoch 80/125: 100%|██████████| 278/278 [00:18<00:00, 15.28it/s, train_loss=0.0733]


Epoch 80/125:
  Train Loss: 0.0733
  Val Loss: 3.5317


Epoch 81/125: 100%|██████████| 278/278 [00:19<00:00, 14.44it/s, train_loss=0.0712]


Epoch 81/125:
  Train Loss: 0.0712
  Val Loss: 3.6108


Epoch 82/125: 100%|██████████| 278/278 [00:18<00:00, 14.84it/s, train_loss=0.0786]


Epoch 82/125:
  Train Loss: 0.0786
  Val Loss: 3.4772


Epoch 83/125: 100%|██████████| 278/278 [00:18<00:00, 15.17it/s, train_loss=0.0724]


Epoch 83/125:
  Train Loss: 0.0724
  Val Loss: 3.4969


Epoch 84/125: 100%|██████████| 278/278 [00:18<00:00, 15.36it/s, train_loss=0.0704]


Epoch 84/125:
  Train Loss: 0.0704
  Val Loss: 3.4982


Epoch 85/125: 100%|██████████| 278/278 [00:18<00:00, 14.91it/s, train_loss=0.0677]


Epoch 85/125:
  Train Loss: 0.0677
  Val Loss: 3.5013


Epoch 86/125: 100%|██████████| 278/278 [00:18<00:00, 14.94it/s, train_loss=0.0677]


Epoch 86/125:
  Train Loss: 0.0677
  Val Loss: 3.4892


Epoch 87/125: 100%|██████████| 278/278 [00:18<00:00, 15.15it/s, train_loss=0.0698]


Epoch 87/125:
  Train Loss: 0.0698
  Val Loss: 3.4872


Epoch 88/125: 100%|██████████| 278/278 [00:18<00:00, 15.38it/s, train_loss=0.0692]


Epoch 88/125:
  Train Loss: 0.0692
  Val Loss: 3.5712


Epoch 89/125: 100%|██████████| 278/278 [00:18<00:00, 14.71it/s, train_loss=0.0852]


Epoch 89/125:
  Train Loss: 0.0852
  Val Loss: 3.8680


Epoch 90/125: 100%|██████████| 278/278 [00:19<00:00, 14.18it/s, train_loss=0.101]


Epoch 90/125:
  Train Loss: 0.1015
  Val Loss: 3.5088


Epoch 91/125: 100%|██████████| 278/278 [00:18<00:00, 15.15it/s, train_loss=0.0981]


Epoch 91/125:
  Train Loss: 0.0981
  Val Loss: 3.4406


Epoch 92/125: 100%|██████████| 278/278 [00:18<00:00, 15.30it/s, train_loss=0.0861]


Epoch 92/125:
  Train Loss: 0.0861
  Val Loss: 3.4431


Epoch 93/125: 100%|██████████| 278/278 [00:18<00:00, 15.16it/s, train_loss=0.0753]


Epoch 93/125:
  Train Loss: 0.0753
  Val Loss: 3.4577


Epoch 94/125: 100%|██████████| 278/278 [00:18<00:00, 14.91it/s, train_loss=0.0672]


Epoch 94/125:
  Train Loss: 0.0672
  Val Loss: 3.3967


Epoch 95/125: 100%|██████████| 278/278 [00:18<00:00, 15.25it/s, train_loss=0.0646]


Epoch 95/125:
  Train Loss: 0.0646
  Val Loss: 3.4263


Epoch 96/125: 100%|██████████| 278/278 [00:18<00:00, 15.30it/s, train_loss=0.0647]


Epoch 96/125:
  Train Loss: 0.0647
  Val Loss: 3.4249


Epoch 97/125: 100%|██████████| 278/278 [00:18<00:00, 15.29it/s, train_loss=0.0664]


Epoch 97/125:
  Train Loss: 0.0664
  Val Loss: 3.3932


Epoch 98/125: 100%|██████████| 278/278 [00:18<00:00, 14.81it/s, train_loss=0.0672]


Epoch 98/125:
  Train Loss: 0.0672
  Val Loss: 3.3708


Epoch 99/125: 100%|██████████| 278/278 [00:18<00:00, 15.00it/s, train_loss=0.0688]


Epoch 99/125:
  Train Loss: 0.0688
  Val Loss: 3.3950


Epoch 100/125: 100%|██████████| 278/278 [00:17<00:00, 15.49it/s, train_loss=0.0694]


Epoch 100/125:
  Train Loss: 0.0694
  Val Loss: 3.3841


Epoch 101/125: 100%|██████████| 278/278 [00:17<00:00, 16.34it/s, train_loss=0.0696]


Epoch 101/125:
  Train Loss: 0.0696
  Val Loss: 3.3583


Epoch 102/125: 100%|██████████| 278/278 [00:17<00:00, 15.75it/s, train_loss=0.072] 


Epoch 102/125:
  Train Loss: 0.0720
  Val Loss: 3.3951


Epoch 103/125: 100%|██████████| 278/278 [00:17<00:00, 16.01it/s, train_loss=0.0753]


Epoch 103/125:
  Train Loss: 0.0753
  Val Loss: 3.3765


Epoch 104/125: 100%|██████████| 278/278 [00:17<00:00, 15.64it/s, train_loss=0.0776]


Epoch 104/125:
  Train Loss: 0.0776
  Val Loss: 3.3700


Epoch 105/125: 100%|██████████| 278/278 [00:17<00:00, 16.04it/s, train_loss=0.0715]


Epoch 105/125:
  Train Loss: 0.0715
  Val Loss: 3.3907


Epoch 106/125: 100%|██████████| 278/278 [00:17<00:00, 16.16it/s, train_loss=0.0737]


Epoch 106/125:
  Train Loss: 0.0737
  Val Loss: 3.3496


Epoch 107/125: 100%|██████████| 278/278 [00:17<00:00, 16.26it/s, train_loss=0.0672]


Epoch 107/125:
  Train Loss: 0.0672
  Val Loss: 3.3446


Epoch 108/125: 100%|██████████| 278/278 [00:16<00:00, 16.37it/s, train_loss=0.0679]


Epoch 108/125:
  Train Loss: 0.0679
  Val Loss: 3.3514


Epoch 109/125: 100%|██████████| 278/278 [00:17<00:00, 15.86it/s, train_loss=0.0637]


Epoch 109/125:
  Train Loss: 0.0637
  Val Loss: 3.3629


Epoch 110/125: 100%|██████████| 278/278 [00:16<00:00, 16.46it/s, train_loss=0.0647]


Epoch 110/125:
  Train Loss: 0.0647
  Val Loss: 3.3633


Epoch 111/125: 100%|██████████| 278/278 [00:18<00:00, 15.36it/s, train_loss=0.0655]


Epoch 111/125:
  Train Loss: 0.0655
  Val Loss: 3.3554


Epoch 112/125: 100%|██████████| 278/278 [00:17<00:00, 15.71it/s, train_loss=0.0655]


Epoch 112/125:
  Train Loss: 0.0655
  Val Loss: 3.3601


Epoch 113/125: 100%|██████████| 278/278 [00:17<00:00, 15.55it/s, train_loss=0.0651]


Epoch 113/125:
  Train Loss: 0.0651
  Val Loss: 3.3587


Epoch 114/125: 100%|██████████| 278/278 [00:18<00:00, 15.02it/s, train_loss=0.17] 


Epoch 114/125:
  Train Loss: 0.1697
  Val Loss: 3.5544


Epoch 115/125: 100%|██████████| 278/278 [00:17<00:00, 15.67it/s, train_loss=0.111]


Epoch 115/125:
  Train Loss: 0.1111
  Val Loss: 3.3820


Epoch 116/125: 100%|██████████| 278/278 [00:17<00:00, 16.33it/s, train_loss=0.0751]


Epoch 116/125:
  Train Loss: 0.0751
  Val Loss: 3.4012


Epoch 117/125: 100%|██████████| 278/278 [00:17<00:00, 16.23it/s, train_loss=0.064] 


Epoch 117/125:
  Train Loss: 0.0640
  Val Loss: 3.4125


Epoch 118/125: 100%|██████████| 278/278 [00:16<00:00, 16.39it/s, train_loss=0.0611]


Epoch 118/125:
  Train Loss: 0.0611
  Val Loss: 3.4004


Epoch 119/125: 100%|██████████| 278/278 [00:17<00:00, 16.01it/s, train_loss=0.0576]


Epoch 119/125:
  Train Loss: 0.0576
  Val Loss: 3.4346


Epoch 120/125: 100%|██████████| 278/278 [00:17<00:00, 15.62it/s, train_loss=0.0564]


Epoch 120/125:
  Train Loss: 0.0564
  Val Loss: 3.4247


Epoch 121/125: 100%|██████████| 278/278 [00:17<00:00, 16.18it/s, train_loss=0.0565]


Epoch 121/125:
  Train Loss: 0.0565
  Val Loss: 3.4636


Epoch 122/125: 100%|██████████| 278/278 [00:17<00:00, 16.15it/s, train_loss=0.0581]


Epoch 122/125:
  Train Loss: 0.0581
  Val Loss: 3.4416


Epoch 123/125: 100%|██████████| 278/278 [00:17<00:00, 16.22it/s, train_loss=0.0638]


Epoch 123/125:
  Train Loss: 0.0638
  Val Loss: 3.4403


Epoch 124/125: 100%|██████████| 278/278 [00:17<00:00, 15.86it/s, train_loss=0.0616]


Epoch 124/125:
  Train Loss: 0.0616
  Val Loss: 3.4118


Epoch 125/125: 100%|██████████| 278/278 [00:18<00:00, 15.35it/s, train_loss=0.0626]


Epoch 125/125:
  Train Loss: 0.0626
  Val Loss: 3.4063


# Visualize predictions


In [10]:
model = BiomeClassifier(
    num_block_types=len(train_loader.converter.block_to_index),
    num_biomes=len(train_loader.converter.biome_to_index)
).to('cuda')
        
        # Load pretrained weights for biome classifier
model.load_state_dict(torch.load('best_biome_classifier_airprocessed.pt'))
model.eval()

BiomeClassifier(
  (block_proj): Conv3d(43, 64, kernel_size=(1, 1, 1), stride=(1, 1, 1))
  (encoder): Sequential(
    (0): Conv3d(64, 128, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1))
    (1): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Conv3d(128, 256, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1))
    (4): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
  )
  (attention): AttnBlock(
    (norm): GroupNorm(32, 256, eps=1e-06, affine=True)
    (q): Conv3d(256, 256, kernel_size=(1, 1, 1), stride=(1, 1, 1))
    (k): Conv3d(256, 256, kernel_size=(1, 1, 1), stride=(1, 1, 1))
    (v): Conv3d(256, 256, kernel_size=(1, 1, 1), stride=(1, 1, 1))
    (proj_out): Conv3d(256, 256, kernel_size=(1, 1, 1), stride=(1, 1, 1))
  )
  (feature_conv): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (biome_head): Conv3d(256, 14, kernel_size=

In [13]:
# Get 5 random samples
visualizer = MinecraftVisualizerPyVista()
num_samples = 5
all_batches = list(val_loader)
device = 'cuda'

for i in range(num_samples):
    # Get a random batch
    random_idx = np.random.randint(0, len(all_batches))
    terrain_batch, biome_batch = all_batches[random_idx]
    
    # Move to device for prediction
    terrain_batch = terrain_batch.to(device)
    
    # Get model predictions
    model.eval()
    with torch.no_grad():
        pred_biomes = model(terrain_batch)
    
    # Take first sample from batch
    terrain = terrain_batch[0]
    true_biomes = biome_batch[0]
    pred_biomes = pred_biomes[0]
    
    # Convert everything back to original format
    original_terrain = val_loader.converter.convert_to_original_blocks(terrain)
    original_true_biomes = val_loader.converter.convert_to_original_biomes(true_biomes)
    original_pred_biomes = val_loader.converter.convert_to_original_biomes(pred_biomes)

    
    plotter = visualizer.visualize_chunk_with_biomes(original_terrain, original_true_biomes)
    plotter.show()

    plotter = visualizer.visualize_chunk_with_biomes(original_terrain, original_pred_biomes)
    plotter.show()
    
    

Widget(value='<iframe src="http://localhost:58071/index.html?ui=P_0x1943050a520_0&reconnect=auto" class="pyvis…

Widget(value='<iframe src="http://localhost:58071/index.html?ui=P_0x194307266a0_1&reconnect=auto" class="pyvis…

Widget(value='<iframe src="http://localhost:58071/index.html?ui=P_0x19c8cc9eb20_2&reconnect=auto" class="pyvis…

Widget(value='<iframe src="http://localhost:58071/index.html?ui=P_0x19ca02d6ac0_3&reconnect=auto" class="pyvis…

Widget(value='<iframe src="http://localhost:58071/index.html?ui=P_0x19cab9e5cd0_4&reconnect=auto" class="pyvis…

Widget(value='<iframe src="http://localhost:58071/index.html?ui=P_0x19cabdbcaf0_5&reconnect=auto" class="pyvis…

Widget(value='<iframe src="http://localhost:58071/index.html?ui=P_0x19cabe84cd0_6&reconnect=auto" class="pyvis…

Widget(value='<iframe src="http://localhost:58071/index.html?ui=P_0x19cb2f4c1f0_7&reconnect=auto" class="pyvis…

Widget(value='<iframe src="http://localhost:58071/index.html?ui=P_0x19cb33fccd0_8&reconnect=auto" class="pyvis…

Widget(value='<iframe src="http://localhost:58071/index.html?ui=P_0x19cc4de98e0_9&reconnect=auto" class="pyvis…

In [11]:
# Get first batch
terrain_batch, biome_batch = next(iter(train_loader))
visualizer = MinecraftVisualizerPyVista()
num_samples = 5
device = 'cuda'

for i in range(num_samples):
    # Take sample i from batch and move to device
    terrain = terrain_batch[i:i+1].to(device)  # Keep batch dimension
    biomes = biome_batch[i]
    
    # Get model predictions
    model.eval()
    with torch.no_grad():
        pred_biomes = model(terrain)[0]  # Remove batch dimension
    
    # Convert everything back to original format
    original_terrain = train_loader.dataset.converter.convert_to_original_blocks(terrain_batch[i])
    original_true_biomes = train_loader.dataset.converter.convert_to_original_biomes(biomes)
    original_pred_biomes = train_loader.dataset.converter.convert_to_original_biomes(pred_biomes)
    
    # Visualize true and predicted
    print(f"\nSample {i+1}")
    plotter = visualizer.visualize_chunk_with_biomes(original_terrain, original_true_biomes)
    plotter.show()
    
    plotter = visualizer.visualize_chunk_with_biomes(original_terrain, original_pred_biomes)
    plotter.show()


Sample 1


Widget(value='<iframe src="http://localhost:58557/index.html?ui=P_0x1bb7f17ffd0_0&reconnect=auto" class="pyvis…

Widget(value='<iframe src="http://localhost:58557/index.html?ui=P_0x1bb83ed0670_1&reconnect=auto" class="pyvis…


Sample 2


Widget(value='<iframe src="http://localhost:58557/index.html?ui=P_0x1bb93333940_2&reconnect=auto" class="pyvis…

Widget(value='<iframe src="http://localhost:58557/index.html?ui=P_0x1bb94d63ee0_3&reconnect=auto" class="pyvis…


Sample 3


Widget(value='<iframe src="http://localhost:58557/index.html?ui=P_0x1bb94eb21f0_4&reconnect=auto" class="pyvis…

Widget(value='<iframe src="http://localhost:58557/index.html?ui=P_0x1bb9558a490_5&reconnect=auto" class="pyvis…


Sample 4


Widget(value='<iframe src="http://localhost:58557/index.html?ui=P_0x1bba7e65a30_6&reconnect=auto" class="pyvis…

Widget(value='<iframe src="http://localhost:58557/index.html?ui=P_0x1bba7f47910_7&reconnect=auto" class="pyvis…


Sample 5


Widget(value='<iframe src="http://localhost:58557/index.html?ui=P_0x1bbae3944f0_8&reconnect=auto" class="pyvis…

Widget(value='<iframe src="http://localhost:58557/index.html?ui=P_0x1bbae6e7fd0_9&reconnect=auto" class="pyvis…