In [58]:
import jax
import numpy as np
from jax.experimental import mesh_utils
import importlib
import big_vision.input_pipeline as input_pipeline
NamedSharding = jax.sharding.NamedSharding
P = jax.sharding.PartitionSpec


In [53]:
%cd ~/big_vision
from big_vision.configs.proj.image_text.siglip_lit_laion400m import get_config
config = get_config()

/home/austinwang/big_vision


# Data

In [None]:
batch_size = config.input.batch_size
train_ds, ntrain_img = input_pipeline.training(config.input)

# device mesh

In [35]:
# config_mesh = [("data", jax.device_count())]
# config_mesh = [("data",jax.device_count()), ("fsdp", 1)]
config_mesh = [("data",2), ("fsdp", 2)]
mesh_axes, mesh_size = tuple(zip(*config_mesh))
mesh_axes,mesh_size

(('data', 'fsdp'), (2, 2))

In [36]:
sharding_rules = [("act_batch", "data")]
sharding_rules

[('act_batch', 'data')]

In [37]:
mesh_size = np.array(jax.devices()).reshape(mesh_size).shape
mesh_size

(2, 2)

In [38]:
device_mesh = mesh_utils.create_device_mesh(mesh_size)
device_mesh

array([[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
        TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0)],
       [TpuDevice(id=2, process_index=0, coords=(0,1,0), core_on_chip=0),
        TpuDevice(id=3, process_index=0, coords=(1,1,0), core_on_chip=0)]],
      dtype=object)

In [39]:
devices_flat = device_mesh.flatten()
devices_flat

array([TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
       TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0),
       TpuDevice(id=2, process_index=0, coords=(0,1,0), core_on_chip=0),
       TpuDevice(id=3, process_index=0, coords=(1,1,0), core_on_chip=0)],
      dtype=object)

# Model

In [None]:

model_mod = importlib.import_module(f"big_vision.models.{config.model_name}")
model = model_mod.Model(**config.get("model", {}))

def init(rng):
batch = jax.tree_map(lambda x: jnp.zeros(x.shape, x.dtype.as_numpy_dtype),
                        train_ds.element_spec)
params = model.init(rng, batch["image"], batch["labels"])["params"]

# Set bias in the head to a low value, such that loss is small initially.
if "init_head_bias" in config:
    params["head"]["bias"] = jnp.full_like(params["head"]["bias"],
                                            config["init_head_bias"])

return params

# This seed makes the Jax part of things (like model init) deterministic.
# However, full training still won't be deterministic, for example due to the
# tf.data pipeline not being deterministic even if we would set TF seed.
# See (internal link) for a fun read on what it takes.
rng = jax.random.PRNGKey(u.put_cpu(config.get("seed", 0)))

write_note("Inferring parameter shapes...")
rng, rng_init = jax.random.split(rng)
params_shape = jax.eval_shape(init, rng_init)

write_note("Inferring optimizer state shapes...")
tx, sched_fns = bv_optax.make(config, params_shape, sched_kw=dict(
    total_steps=total_steps, batch_size=batch_size, data_size=ntrain_img))

# Sharding

In [40]:
mesh = jax.sharding.Mesh(device_mesh, mesh_axes)
mesh

Mesh(device_ids=array([[0, 1],
       [2, 3]]), axis_names=('data', 'fsdp'))

In [60]:
repl_sharding = jax.sharding.NamedSharding(mesh, P())
repl_sharding

NamedSharding(mesh=Mesh('data': 2, 'fsdp': 2), spec=PartitionSpec())