# Test loading the dataset and collating it into batches for the model

## 1. load a batch

In [1]:
from datasets import DatasetDict

In [2]:
dd = DatasetDict.load_from_disk("../data/dataset/metal_site_dataset")

In [3]:
ds = dd["train"]

In [4]:
for batch in ds.iter(batch_size=2):
    break

## 2. Collate a batch

In [5]:
import logging
logger = logging.getLogger(__name__)
from typing import Dict, List, Optional
import torch

class AtomicSystemBatchCollator:
    """Processes batches of atomic systems with optional masking and noise.

    Handles HuggingFace dataset batches (dict of lists) and converts to dict of tensors.
    Optionally applies atom masking and position noise during training.

    Args:
        tokenizer: Tokenizer instance providing mask token and vocabulary
        mask_rate: Fraction of atoms to mask during training, if None no masking
        noise_rate: Fraction of positions to add noise, if None no noise
        zero_noise_in_loss_rate: Additional positions to include in loss but not noise 
        noise_scale: Standard deviation of gaussian noise (Angstroms)
        already_tokenized: If True, input is already tokenized, otherwise tokenizes
        return_original_positions: If True, returns original positions before noise

    Input batch format:
        {
            'atoms': List[List[int]], # Atomic number tokens
            'atom_types': List[List[int]], # ATOM/HETATM tokens  
            'positions': List[List[float]], # xyz coordinates
            'id': List[str] # Optional system identifiers
        }

    Output format:
        {
            'atoms': [n_atoms_total] atom tokens
            'atom_types': [n_atoms_total] record tokens
            'positions': [n_atoms_total, 3] coordinates
            'batch': [n_atoms_total] batch indices
            'mask_mask': [n_atoms] mask of which atoms are masked
            'noise_mask': [n_atoms] mask of which atoms are noised, if noising
            'denoise_vectors': [n_atoms, 3] vectors required to denoise positions, if noising, for computing loss, this vector is scaled.
            'noise_loss_mask': [n_atoms] mask of which atoms are used for loss, if noising
            'id': [batch_size] original system IDs if provided
            any other fields in input batch
        }
    """

    def __init__(
            self,
            tokenizer,
            mask_rate: Optional[float] = None,
            noise_rate: Optional[float] = None, 
            zero_noise_in_loss_rate: Optional[float] = None,
            noise_scale: float = 1.0,
            already_tokenized: bool = True,
            return_original_positions: bool = False
        ):
        self.atom_mask_token = tokenizer.atom_mask_token
        self.type_mask_token = tokenizer.type_mask_token
        self.mask_rate = mask_rate
        self.noise_rate = noise_rate
        self.zero_noise_in_loss_rate = zero_noise_in_loss_rate
        self.noise_scale = noise_scale
        self.already_tokenized = already_tokenized
        self.return_original_positions = return_original_positions


    def __call__(self, batch: Dict[str, List]) -> Dict[str, torch.Tensor]:
        """Process a batch of atomic systems.

        Args:
            batch: HuggingFace format batch dictionary

        Returns:
            Dictionary of processed tensors with optional masking/noise
        """
        # Track batch sizes for creating index tensor
        batch_sizes = [len(atoms) for atoms in batch['atoms']]
        total_atoms = sum(batch_sizes)
        logger.debug(f"Processing batch with {len(batch_sizes)} systems, {total_atoms} total atoms")

        # Create batch index tensor
        batch_idx = torch.repeat_interleave(
            torch.arange(len(batch_sizes)), 
            torch.tensor(batch_sizes)
        )

        # tokenize if necessary
        if not self.already_tokenized:
            batch['atoms'] = [self.tokenizer.encode(x) for x in batch['atoms']]
            batch['atom_types'] = [self.tokenizer.encode(x) for x in batch['atom_types']]

        # Concatenate and convert to tensors
        output = {
            'atoms': torch.cat([torch.tensor(x) for x in batch['atoms']]),
            'atom_types': torch.cat([torch.tensor(x) for x in batch['atom_types']]),
            'positions': torch.cat([torch.tensor(x) for x in batch['positions']]),
            'batch': batch_idx
        }

        # Apply masking
        if self.mask_rate and self.mask_rate > 0:
            n_mask = int(total_atoms * self.mask_rate)
            mask_idx = torch.randperm(total_atoms)[:n_mask]
            
            output['atoms'][mask_idx] = self.atom_mask_token
            output['atom_types'][mask_idx] = self.type_mask_token
            output['mask_mask'] = torch.zeros(total_atoms, dtype=torch.bool)
            output['mask_mask'][mask_idx] = True
            
            logger.debug(f"Masked {n_mask} atoms")

        # Apply coordinate noise
        if self.noise_rate and self.noise_rate > 0:
            n_noise = int(total_atoms * self.noise_rate)
            randperm = torch.randperm(total_atoms)
            noise_idx = randperm[:n_noise]
            
            # Additional positions for loss but no noise
            if self.zero_noise_in_loss_rate:
                n_zero_noise = int(total_atoms * self.zero_noise_in_loss_rate)
                zero_noise_idx = randperm[n_noise:n_noise+n_zero_noise]
                noise_loss_idx = torch.cat([noise_idx, zero_noise_idx])
            else:
                noise_loss_idx = noise_idx


            if self.return_original_positions:
                output['original_positions'] = output['positions'].clone()

            noise_vectors = torch.zeros_like(output['positions'])
            noise_vectors[noise_idx] = torch.randn(n_noise, 3)
            # move the atoms by noise times scale
            # return vectors will not be scaled so model can be trained with low activations
            output['positions'] = output['positions'] + noise_vectors * self.noise_scale
            denoise_loss_vectors = noise_vectors * -1
            output['denoise_vectors'] = denoise_loss_vectors
            output['noise_loss_mask'] = torch.zeros(total_atoms, dtype=torch.bool)
            output['noise_loss_mask'][noise_loss_idx] = True
            output['noise_mask'] = torch.zeros(total_atoms, dtype=torch.bool)
            output['noise_mask'][noise_idx] = True

            logger.debug(f"Added noise to {n_noise} positions, tracking loss on {len(noise_loss_idx)} positions")
        
        # pass through other keys
        for key, value in batch.items():
            if key not in output:
                output[key] = value

        return output

In [6]:
logging.basicConfig(level=logging.DEBUG)

In [7]:
from metalsitenn.atom_vocabulary import AtomTokenizer

In [23]:
tokenizer = AtomTokenizer(
    metal_known=False,
    aggregate_uncommon=True,
    keep_hydrogen=False,
    allow_unknown=True)

In [24]:
collator = AtomicSystemBatchCollator(tokenizer=tokenizer, mask_rate=0.15, noise_rate=0.15, zero_noise_in_loss_rate=0.05, already_tokenized=True)

In [25]:
out_batch = collator(batch)

DEBUG:__main__:Processing batch with 2 systems, 434 total atoms
DEBUG:__main__:Masked 65 atoms
DEBUG:__main__:Added noise to 65 positions, tracking loss on 86 positions


In [26]:
atoms_out = out_batch['atoms']
atom_types_out = out_batch['atom_types']
positions_out = out_batch['positions']
batch_out = out_batch['batch']
mask_mask_out = out_batch['mask_mask']
noise_mask_out = out_batch['noise_mask']
denoise_vectors_out = out_batch['denoise_vectors']
noise_loss_mask_out = out_batch['noise_loss_mask']
id_out = out_batch['id']

In [27]:
denoise_vectors_out[noise_mask_out]

tensor([[-0.6371, -0.3284, -0.4553],
        [ 2.3044,  0.0064,  0.1806],
        [ 1.3374,  0.9880,  0.3346],
        [-0.8094, -0.3219, -0.7245],
        [ 0.2485, -1.5461,  0.8510],
        [ 0.4743, -0.1667,  1.4650],
        [-0.3384,  0.9913, -0.0240],
        [ 0.0853, -2.9400,  0.7277],
        [-1.2542, -0.7127,  0.8286],
        [ 0.5052,  2.0674, -2.2834],
        [ 0.4355,  1.2461, -0.1633],
        [ 0.6142, -2.0832, -2.0023],
        [-1.2316, -0.1770,  0.4541],
        [-1.1371, -0.4642,  0.6588],
        [-1.5451,  0.3310,  1.2291],
        [ 0.3362,  0.6862,  1.7811],
        [-1.3093,  1.2934,  1.8894],
        [ 0.4330,  1.0125,  0.2559],
        [-0.0461, -0.1601, -0.4453],
        [ 1.0123,  1.9930, -2.7784],
        [-1.5821,  0.0900, -2.1105],
        [ 2.3626, -0.3481,  1.5511],
        [-0.5329, -0.3074,  2.0833],
        [-1.6107,  0.2612, -0.1267],
        [-0.4153, -0.6762, -0.7446],
        [-0.2976, -0.8066,  0.5422],
        [ 0.6019,  0.7598,  0.3473],
 

In [28]:
denoise_vectors_out[noise_loss_mask_out].shape

torch.Size([86, 3])

In [29]:
atoms_out.shape

torch.Size([434])

In [30]:
56/286

0.1958041958041958

In [31]:
atoms_out[mask_mask_out]

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

In [32]:
atoms_out[~mask_mask_out]

tensor([ 9,  8,  8,  3,  8,  8,  3,  8,  3,  7,  7,  3,  3,  7,  3,  7,  3,  9,
         8,  8,  8,  3,  3,  8,  3,  8,  3,  7,  3,  7,  3,  3,  7,  7,  3,  7,
         3,  9,  8,  8,  3,  3,  3,  8,  3,  8,  3,  7,  3,  7,  3,  7,  7,  3,
         7,  9,  8,  8,  8,  3,  3,  8,  3,  8,  3,  3,  7,  3,  7,  3,  7,  7,
         3,  7,  3,  9,  8,  8,  8,  3,  3,  8,  3,  8,  3,  7,  3,  7,  3,  3,
         8,  7,  3,  7,  7,  3,  9,  8,  8,  8,  3,  3,  8,  3,  8,  3,  8,  3,
         7,  3,  7,  3,  3,  7,  7,  3,  7,  3,  8,  8,  3,  3,  8,  8,  3,  8,
         3,  7,  3,  8,  7,  3,  8,  3,  3,  9,  8,  8,  8,  3,  3,  8,  3,  8,
         3,  7,  3,  8,  7,  3,  7,  3,  3,  7,  3,  3,  7,  3,  3,  8,  3,  3,
         3,  7,  3,  3,  8,  3,  3,  3,  7,  3,  7,  7, 12, 12,  8,  8,  8,  8,
         8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  7,  3,  3,
         8,  3,  3,  7,  3,  7,  7,  3,  3,  8,  3,  7,  3,  3,  8,  3,  3,  3,
         7,  3,  3,  3,  3,  7,  3,  3, 

In [33]:
tokenizer.decode(atoms=atoms_out[mask_mask_out], atom_types=atom_types_out[mask_mask_out])

{'atoms': ['<MASK>',
  '<MASK>',
  '<MASK>',
  '<MASK>',
  '<MASK>',
  '<MASK>',
  '<MASK>',
  '<MASK>',
  '<MASK>',
  '<MASK>',
  '<MASK>',
  '<MASK>',
  '<MASK>',
  '<MASK>',
  '<MASK>',
  '<MASK>',
  '<MASK>',
  '<MASK>',
  '<MASK>',
  '<MASK>',
  '<MASK>',
  '<MASK>',
  '<MASK>',
  '<MASK>',
  '<MASK>',
  '<MASK>',
  '<MASK>',
  '<MASK>',
  '<MASK>',
  '<MASK>',
  '<MASK>',
  '<MASK>',
  '<MASK>',
  '<MASK>',
  '<MASK>',
  '<MASK>',
  '<MASK>',
  '<MASK>',
  '<MASK>',
  '<MASK>',
  '<MASK>',
  '<MASK>',
  '<MASK>',
  '<MASK>',
  '<MASK>',
  '<MASK>',
  '<MASK>',
  '<MASK>',
  '<MASK>',
  '<MASK>',
  '<MASK>',
  '<MASK>',
  '<MASK>',
  '<MASK>',
  '<MASK>',
  '<MASK>',
  '<MASK>',
  '<MASK>',
  '<MASK>',
  '<MASK>',
  '<MASK>',
  '<MASK>',
  '<MASK>',
  '<MASK>',
  '<MASK>'],
 'atom_types': ['<MASK>',
  '<MASK>',
  '<MASK>',
  '<MASK>',
  '<MASK>',
  '<MASK>',
  '<MASK>',
  '<MASK>',
  '<MASK>',
  '<MASK>',
  '<MASK>',
  '<MASK>',
  '<MASK>',
  '<MASK>',
  '<MASK>',
  '<MASK>',
  '<

In [34]:
atoms_out

tensor([ 9,  8,  0,  8,  0,  3,  8,  0,  8,  3,  8,  3,  7,  0,  7,  3,  3,  0,
         7,  3,  7,  3,  9,  8,  8,  8,  3,  3,  8,  3,  8,  3,  0,  0,  7,  3,
         7,  3,  3,  7,  7,  3,  7,  3,  9,  0,  8,  8,  3,  3,  0,  3,  8,  3,
         8,  3,  7,  3,  7,  0,  3,  7,  7,  3,  7,  0,  9,  8,  8,  8,  3,  3,
         8,  3,  8,  3,  0,  3,  7,  3,  7,  0,  3,  7,  7,  3,  7,  3,  9,  8,
         8,  8,  3,  3,  0,  0,  8,  3,  8,  3,  7,  3,  7,  3,  3,  8,  7,  3,
         7,  7,  3,  9,  8,  8,  8,  3,  3,  8,  3,  8,  3,  8,  3,  7,  3,  7,
         3,  3,  7,  7,  3,  7,  3,  0,  8,  0,  8,  3,  3,  8,  0,  8,  3,  8,
         3,  7,  3,  8,  7,  3,  8,  3,  3,  9,  8,  8,  8,  3,  3,  8,  3,  0,
         0,  8,  3,  7,  3,  8,  7,  3,  7,  3,  3,  7,  3,  3,  0,  7,  3,  3,
         8,  0,  3,  3,  3,  0,  7,  3,  3,  8,  3,  3,  3,  7,  3,  7,  7, 12,
        12,  0,  8,  8,  0,  8,  8,  8,  8,  8,  8,  8,  8,  8,  0,  0,  8,  0,
         8,  8,  8,  8,  8,  0,  8,  8, 

In [22]:
tokenizer.decode(**outs)

{'atoms': ['C', 'C', 'CL', 'O'],
 'atom_types': ['ATOM', 'ATOM', 'HETATM', 'ATOM']}