In [1]:
import os
from collections.abc import Sequence
from absl import app
from absl import flags
from absl import logging
import jax
import jax.numpy as jnp
import optax
from flax.training import train_state
from flax.training import orbax_utils
import orbax.checkpoint
from recurrentgemma import jax as recurrentgemma
from datasets import DatasetDict
from safe.tokenizer import SAFETokenizer
from time import perf_counter
import json
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Define flags
# DEBUG_MODE = flags.DEFINE_bool("debug", False, "Debug mode.")
# SAVE_PATH = flags.DEFINE_string("save_path", "saves/recurrent_gemma_model.ckpt", "Path to save model checkpoints.")
# KEY = flags.DEFINE_integer("key", 1241312, "Key to use for randomization.")
DEBUG_MODE = False
SAVE_PATH = "saves/recurrent_gemma_model.ckpt"
KEY = 1241312

In [3]:
# Training parameters
# EPOCHS = flags.DEFINE_integer("epochs", 10, "Number of training epochs.")
# STEPS_PER_EPOCH = flags.DEFINE_integer("steps_per_epoch", 2500, "Number of steps per epoch.")
# BATCH_SIZE = flags.DEFINE_integer("batch_size", 64, "Batch size for training.")
# SEQ_LENGTH = flags.DEFINE_integer("seq_length", 69, "Sequence length for training.")
# LEARNING_RATE = flags.DEFINE_float("learning_rate", 1e-3, "Initial learning rate.")
EPOCHS = 10
STEPS_PER_EPOCH = 2500
BATCH_SIZE = 64
SEQ_LENGTH = 69
LEARNING_RATE = 1e-3

In [4]:
# def create_train_state(rng, config):
#     """Creates initial `TrainState`."""
#     model = recurrentgemma.Griffin(config)
#     params = model.init(rng, jnp.ones((1, SEQ_LENGTH.value), dtype=jnp.int32))
#     tx = optax.adam(learning_rate=LEARNING_RATE.value)
#     return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)

def create_train_state(rng, config):
    """Creates initial `TrainState`."""
    model = recurrentgemma.Griffin(config)
    dummy_input = jnp.ones((1, SEQ_LENGTH), dtype=jnp.int32)
    dummy_segment_pos = jnp.zeros((1, SEQ_LENGTH), dtype=jnp.int32)
    params = model.init(rng, dummy_input, segment_pos=dummy_segment_pos)
    tx = optax.adam(learning_rate=LEARNING_RATE)
    return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)



In [5]:
# def count_parameters(params):
#     """Count the number of trainable parameters in the model."""
#     return sum(jnp.prod(p.shape).item() for p in jax.tree_util.tree_leaves(params))

# def count_parameters(params):
#     """Count the number of trainable parameters in the model."""
#     return sum(jnp.prod(p.shape).item() for p in jax.tree_util.tree_leaves(params))

def count_parameters(params):
    """Count the number of trainable parameters in the model."""
    total_params = 0
    for param in jax.tree_util.tree_leaves(params):
        if isinstance(param, jnp.ndarray):
            total_params += jnp.prod(param.shape).item()
        elif isinstance(param, tuple):
            total_params += sum(jnp.prod(p.shape).item() for p in param if isinstance(p, jnp.ndarray))
    return total_params


In [6]:
# def load_dataset():
#     """Load and preprocess the dataset."""
#     dataset = DatasetDict.load_from_disk('../../Datasets/MOSES/datasets')
#     tokenizer = SAFETokenizer.from_pretrained("./tokenizer.json")

#     def tokenize_function(examples):
#         return {"input_ids": tokenizer.encode(examples["SAFE"], ids_only=True)}

#     tokenized_dataset = dataset.map(tokenize_function, batched=False, remove_columns=dataset['train'].column_names)
#     return tokenized_dataset, tokenizer

def load_dataset():
    """Load and preprocess the dataset."""
    dataset = DatasetDict.load_from_disk('../../Datasets/MOSES/datasets')
    tokenizer = SAFETokenizer.from_pretrained("./tokenizer.json")

    def tokenize_function(examples):
        return {"input_ids": tokenizer.encode(examples["SAFE"], ids_only=True)}

    tokenized_dataset = dataset.map(tokenize_function, batched=False, remove_columns=dataset['train'].column_names)
    return tokenized_dataset, tokenizer

In [7]:
def pad_sequences(sequences, max_len):
    """Pad sequences to the same length."""
    return [seq + [0] * (max_len - len(seq)) for seq in sequences]

In [8]:
def get_batch(dataset, batch_size, seq_length):
    """Get a random batch from the dataset."""
    idx = np.random.randint(0, len(dataset), batch_size)
    batch = [dataset[i] for i in idx]
    
    # Truncate or pad sequences
    batch = [seq[:seq_length] for seq in batch]
    batch = pad_sequences(batch, seq_length)
    
    return jnp.array(batch)

In [9]:
@jax.jit
def train_step(state, batch):
    """Perform a single training step."""
    def loss_fn(params):
        logits = state.apply_fn(params, batch)
        targets = jnp.roll(batch, -1, axis=-1)
        mask = jnp.where(targets != 0, 1, 0)
        loss = optax.softmax_cross_entropy_with_integer_labels(logits, targets) * mask
        return jnp.sum(loss) / jnp.sum(mask)

    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(state.params)
    return state.apply_gradients(grads=grads), loss

In [10]:
# def train_epoch(state, dataset, steps):
#     """Train for a single epoch."""
#     batch_size, seq_length = BATCH_SIZE.value, SEQ_LENGTH.value
    
#     epoch_loss = 0
#     for _ in range(steps):
#         batch = get_batch(dataset, batch_size, seq_length)
#         state, loss = train_step(state, batch)
#         epoch_loss += loss
    
#     return state, epoch_loss / steps

def train_epoch(state, dataset, steps):
    """Train for a single epoch."""
    epoch_loss = 0
    for _ in range(steps):
        batch = get_batch(dataset, BATCH_SIZE, SEQ_LENGTH)
        state, loss = train_step(state, batch)
        epoch_loss += loss
    
    return state, epoch_loss / steps

In [11]:
def generate_molecule(state, tokenizer, max_length=100, temperature=0.8, top_k=None):
    """Generate a single molecule."""
    rng = jax.random.PRNGKey(0)
    start_token = jnp.array(tokenizer.encode("[START]", ids_only=True))
    input_ids = start_token[None, :]
    
    for _ in range(max_length):
        logits = state.apply_fn(state.params, input_ids)
        next_token_logits = logits[:, -1, :] / temperature
        
        if top_k is not None:
            top_k_logits, top_k_indices = jax.lax.top_k(next_token_logits, top_k)
            next_token_logits = jnp.where(
                jnp.expand_dims(jnp.arange(next_token_logits.shape[-1]), 0) == top_k_indices,
                top_k_logits,
                float('-inf')
            )
        
        next_token = jax.random.categorical(rng, next_token_logits)
        input_ids = jnp.concatenate([input_ids, next_token[:, None]], axis=-1)
        
        if next_token.item() == tokenizer.encode("[END]", ids_only=True)[0]:
            break
    
    return tokenizer.decode(input_ids[0])

In [12]:
def de_novo_generation(state, tokenizer, num_molecules=10):
    """Perform de novo generation of molecules."""
    print(f"Generating {num_molecules} new molecules:")
    for i in range(num_molecules):
        molecule = generate_molecule(state, tokenizer)
        print(f"Molecule {i+1}: {molecule}")

main

In [13]:
config = recurrentgemma.GriffinConfig(
    vocab_size=1180,
    width=128,
    mlp_expanded_width=3 * 128,
    lru_width=256,
    num_heads=2,
    block_types=(
        recurrentgemma.TemporalBlockType.RECURRENT,
        recurrentgemma.TemporalBlockType.ATTENTION,
    ),
    embeddings_scale_by_sqrt_dim=True,
    attention_window_size=2048,
    logits_soft_cap=30.0,
)

In [14]:
# Load dataset and tokenizer
dataset, tokenizer = load_dataset()
train_data = dataset['train']

In [15]:
# # Initialize model and training state
# rng = jax.random.PRNGKey(KEY.value)
# state = create_train_state(rng, config)

# Initialize model and training state
rng = jax.random.PRNGKey(KEY)
state = create_train_state(rng, config)

In [17]:
# # Count and print the number of trainable parameters
# num_params = count_parameters(state.params)
# logging.info(f"Number of trainable parameters: {num_params:,}")

In [18]:
# # Training loop
# losses = []
# t0_start = perf_counter()
# for epoch in range(EPOCHS.value):
#     logging.info(f"Starting epoch {epoch + 1}/{EPOCHS.value}")
#     t1_start = perf_counter()
    
#     state, epoch_loss = train_epoch(state, train_data, STEPS_PER_EPOCH.value)
#     losses.append(float(epoch_loss))
    
#     t1_stop = perf_counter()
#     logging.info(f"Epoch {epoch + 1} completed. Loss: {epoch_loss:.5f}. Time: {t1_stop - t1_start:.2f} sec")

#     # Save checkpoint
#     ckpt = {
#         "model": state.params,
#         "config": config,
#         "epoch": epoch + 1,
#     }
#     orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
#     save_args = orbax_utils.save_args_from_target(ckpt)
#     orbax_checkpointer.save(f"{SAVE_PATH.value}_{epoch + 1}", ckpt, save_args=save_args)

# t0_stop = perf_counter()
# logging.info(f"Training completed in {t0_stop - t0_start:.2f} seconds")

# Training loop
losses = []
t0_start = perf_counter()
for epoch in range(EPOCHS):
    logging.info(f"Starting epoch {epoch + 1}/{EPOCHS}")
    t1_start = perf_counter()
    
    state, epoch_loss = train_epoch(state, train_data, STEPS_PER_EPOCH)
    losses.append(float(epoch_loss))
    
    t1_stop = perf_counter()
    logging.info(f"Epoch {epoch + 1} completed. Loss: {epoch_loss:.5f}. Time: {t1_stop - t1_start:.2f} sec")

    # Save checkpoint
    ckpt = {
        "model": state.params,
        "config": config,
        "epoch": epoch + 1,
    }
    orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
    save_args = orbax_utils.save_args_from_target(ckpt)
    orbax_checkpointer.save(f"{SAVE_PATH}_{epoch + 1}", ckpt, save_args=save_args)

t0_stop = perf_counter()
logging.info(f"Training completed in {t0_stop - t0_start:.2f} seconds")

TypeError: Wrong key type: '535065' of type '<class 'numpy.int64'>'. Expected one of int, slice, range, str or Iterable.

In [None]:
# Save loss history
with open('recurrent_gemma_loss_history.json', 'w') as f:
    json.dump(losses, f)