In [1]:
%cd ~/big_vision
import jax
import importlib
import numpy as np
from absl import logging
import jax.numpy as jnp
from jax.experimental import mesh_utils

import big_vision.utils as u
import big_vision.optax as bv_optax
import big_vision.sharding as bv_sharding
import big_vision.input_pipeline as input_pipeline
from big_vision.configs.proj.image_text.siglip_replication import get_config

config = get_config()
model_mod = importlib.import_module(f"big_vision.models.{config.model_name}")
for m in config.get("pp_modules", ["ops_general", "ops_image", "ops_text"]): importlib.import_module(f"big_vision.pp.{m}")
def bytes_in_use_devices(): return [device.memory_stats()['bytes_in_use'] for device in jax.devices()]
def info(s, *a): logging.info("\u001b[33mNOTE\u001b[0m: " + s, *a)
def write_note(note): 
	if jax.process_index() == 0: info("%s", note)

  bkms = self.shell.db.get('bookmarks', {})


/home/austinwang/big_vision




In [2]:
model = model_mod.Model(**config.get("model", {}))
train_ds, ntrain_img = input_pipeline.training(config.input)
batch_size = config.input.batch_size
total_steps = u.steps("total", config, ntrain_img, batch_size)

  from .autonotebook import tqdm as notebook_tqdm


Instructions for updating:
Use `tf.data.Dataset.counter(...)` instead.


Instructions for updating:
Use `tf.data.Dataset.counter(...)` instead.


In [3]:
def init(rng):
	batch = jax.tree_map(lambda x: jnp.zeros(x.shape, x.dtype.as_numpy_dtype),train_ds.element_spec)
	params = model.init(rng, batch["image"], batch["labels"])["params"]
	# Set bias in the head to a low value, such that loss is small initially.
	if "init_head_bias" in config: params["head"]["bias"] = jnp.full_like(params["head"]["bias"],config["init_head_bias"])
	return params

write_note("Inferring parameter shapes...")
rng = jax.random.PRNGKey(u.put_cpu(config.get("seed", 0)))
rng, rng_init = jax.random.split(rng)
params_shape = jax.eval_shape(init, rng_init)

write_note("Inferring optimizer state shapes...")
tx, sched_fns = bv_optax.make(config, params_shape, sched_kw=dict(total_steps=total_steps, batch_size=batch_size, data_size=ntrain_img))
opt_shape = jax.eval_shape(tx.init, params_shape)
sched_fns_cpu = [u.jit_cpu()(sched_fn) for sched_fn in sched_fns]

  batch = jax.tree_map(lambda x: jnp.zeros(x.shape, x.dtype.as_numpy_dtype),train_ds.element_spec)


In [4]:
# config.mesh = [("data",-1)]
# config.mesh = [("data", 2),('fsdp', 2)]
# config.sharding_strategy = [('.*', 'fsdp(axis="data", min_size_to_shard_mb=2)')]

write_note("Setting up mesh...")
config_mesh = config.get("mesh", [("data", jax.device_count())])
sharding_rules = config.get("sharding_rules", [("act_batch", "data")])
mesh_axes, mesh_size = tuple(zip(*config_mesh))
mesh_size = np.array(jax.devices()).reshape(mesh_size).shape
device_mesh = mesh_utils.create_device_mesh(mesh_size)
devices_flat = device_mesh.flatten()

write_note("Creating device mesh...")
mesh = jax.sharding.Mesh(device_mesh, mesh_axes)
print(f"mesh: {mesh}")
repl_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
print(f"repl_sharding: {repl_sharding}")

write_note("Inferring shardings...")
train_state_shape = {"params": params_shape, "opt": opt_shape}
strategy = config.get("sharding_strategy", [(".*", "replicate")])
train_state_sharding = bv_sharding.infer_sharding(train_state_shape, strategy=strategy, mesh=mesh)

mesh: Mesh('data': 4)
repl_sharding: NamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec())


In [5]:
write_note("Transferring train_state to devices...")

print(f"bytes_in_use_devices() before rng_init reshard: {bytes_in_use_devices()}")
# RNG is always replicated
rng_init = u.reshard(rng_init, repl_sharding)
print(f"bytes_in_use_devices() after rng_init reshard: {bytes_in_use_devices()}")

params = jax.jit(init, out_shardings=train_state_sharding["params"])(rng_init)
print(f"bytes_in_use_devices() after init reshard: {bytes_in_use_devices()}")

opt = jax.jit(tx.init, out_shardings=train_state_sharding["opt"])(params)
print(f"bytes_in_use_devices() after tx.init reshard: {bytes_in_use_devices()}")

bytes_in_use_devices() before rng_init reshard: [12800, 12800, 12800, 12800]
bytes_in_use_devices() after rng_init reshard: [13312, 13312, 13312, 13312]
bytes_in_use_devices() after init reshard: [840481792, 840481792, 840481792, 840481792]
bytes_in_use_devices() after tx.init reshard: [2477426688, 2477426688, 2477426688, 2477426688]
