In [1]:
%cd ~/my_repo/big_vision
import jax
import importlib
import numpy as np
from absl import logging
import jax.numpy as jnp
from jax.experimental import mesh_utils

import big_vision.utils as u
import big_vision.optax as bv_optax
import big_vision.sharding as bv_sharding
import big_vision.input_pipeline as input_pipeline
from big_vision.configs.proj.image_text.siglip_replication import get_config

config = get_config()
model_mod = importlib.import_module(f"big_vision.models.{config.model_name}")
for m in config.get("pp_modules", ["ops_general", "ops_image", "ops_text"]): importlib.import_module(f"big_vision.pp.{m}")
def bytes_in_use_devices(): return [device.memory_stats()['bytes_in_use'] for device in jax.devices()]
def info(s, *a): logging.info("\u001b[33mNOTE\u001b[0m: " + s, *a)
def write_note(note): 
	if jax.process_index() == 0: info("%s", note)

  bkms = self.shell.db.get('bookmarks', {})


/home/austinwang/my_repo/big_vision




In [2]:
config.model.image['scan'] = True
config.model.text['scan'] = True
config.model.image['dtype_mm'] = "float32"
model = model_mod.Model(**config.get("model", {}))
train_ds, ntrain_img = input_pipeline.training(config.input)
batch_size = config.input.batch_size
total_steps = u.steps("total", config, ntrain_img, batch_size)

  from .autonotebook import tqdm as notebook_tqdm


Instructions for updating:
Use `tf.data.Dataset.counter(...)` instead.


Instructions for updating:
Use `tf.data.Dataset.counter(...)` instead.


In [3]:
def init(rng):
	batch = jax.tree_map(lambda x: jnp.zeros(x.shape, x.dtype.as_numpy_dtype),train_ds.element_spec)
	params = model.init(rng, batch["image"], batch["labels"])["params"]
	# Set bias in the head to a low value, such that loss is small initially.
	if "init_head_bias" in config: params["head"]["bias"] = jnp.full_like(params["head"]["bias"],config["init_head_bias"])
	return params

write_note("Inferring parameter shapes...")
rng = jax.random.PRNGKey(u.put_cpu(config.get("seed", 0)))
rng, rng_init = jax.random.split(rng)
params_shape = jax.eval_shape(init, rng_init)

write_note("Inferring optimizer state shapes...")
tx, sched_fns = bv_optax.make(config, params_shape, sched_kw=dict(total_steps=total_steps, batch_size=batch_size, data_size=ntrain_img))
opt_shape = jax.eval_shape(tx.init, params_shape)
sched_fns_cpu = [u.jit_cpu()(sched_fn) for sched_fn in sched_fns]

  batch = jax.tree_map(lambda x: jnp.zeros(x.shape, x.dtype.as_numpy_dtype),train_ds.element_spec)


In [4]:
config.mesh = [("data",-1)]
# config.mesh = [("data", 2),('fsdp', 2)]
# config.sharding_strategy = [('.*', 'replicate')]
config.sharding_strategy = [('.*', 'fsdp(axis="data", min_size_to_shard_mb=1)')]

write_note("Setting up mesh...")
config_mesh = config.get("mesh", [("data", jax.device_count())])
sharding_rules = config.get("sharding_rules", [("act_batch", "data")])
mesh_axes, mesh_size = tuple(zip(*config_mesh))
mesh_size = np.array(jax.devices()).reshape(mesh_size).shape
device_mesh = mesh_utils.create_device_mesh(mesh_size)
devices_flat = device_mesh.flatten()

write_note("Creating device mesh...")
mesh = jax.sharding.Mesh(device_mesh, mesh_axes)
print(f"mesh: {mesh}")
repl_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
# print(f"repl_sharding: {repl_sharding}")

write_note("Inferring shardings...")
train_state_shape = {"params": params_shape, "opt": opt_shape}
strategy = config.get("sharding_strategy", [(".*", "replicate")])
print(f"sharding_strategy: {strategy}")
train_state_sharding = bv_sharding.infer_sharding(train_state_shape, strategy=strategy, mesh=mesh)
print(f"train_state_sharding: {train_state_sharding}")

mesh: Mesh('data': 4)
sharding_strategy: [('.*', 'fsdp(axis="data", min_size_to_shard_mb=1)')]
train_state_sharding: {'opt': (MaskedState(inner_state=EmptyState()), MaskedState(inner_state=ScaleByAdamState(count=NamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec()), mu={'b': NamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec(None,)), 'img': {'MAPHead_0': {'LayerNorm_0': {'bias': NamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec(None,)), 'scale': NamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec(None,))}, 'MlpBlock_0': {'Dense_0': {'bias': NamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec(None,)), 'kernel': NamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec(None, 'data'))}, 'Dense_1': {'bias': NamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec(None,)), 'kernel': NamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec('data', None))}}, 'MultiHeadDotProductAttention_0': {'key': {'bias': NamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec(None, None)), 'ker

In [5]:
jax.debug.visualize_array_sharding(rng_init)

In [6]:
write_note("Transferring train_state to devices...")

print(f"bytes_in_use_devices() before rng_init reshard: {bytes_in_use_devices()}")
# RNG is always replicated
rng_init = u.reshard(rng_init, repl_sharding)
print(f"bytes_in_use_devices() after rng_init reshard: {bytes_in_use_devices()}")
jax.debug.visualize_array_sharding(rng_init)

bytes_in_use_devices() before rng_init reshard: [12800, 12800, 12800, 12800]
bytes_in_use_devices() after rng_init reshard: [13312, 13312, 13312, 13312]


In [7]:
params = jax.jit(init, out_shardings=train_state_sharding["params"])(rng_init)
print(f"bytes_in_use_devices() after init reshard: {bytes_in_use_devices()}")

opt = jax.jit(tx.init, out_shardings=train_state_sharding["opt"])(params)
print(f"bytes_in_use_devices() after tx.init reshard: {bytes_in_use_devices()}")

rng, rng_loop = jax.random.split(rng, 2)
rng_loop = u.reshard(rng_loop, repl_sharding)
print(f"bytes_in_use_devices() after rng_loop reshard: {bytes_in_use_devices()}")
del rng  # not used anymore, so delete it.

train_state = {"params": params, "opt": opt}
del params, opt  # Delete to avoid memory leak or accidental reuse.

bytes_in_use_devices() after init reshard: [236578816, 236578816, 236578816, 236578816]
bytes_in_use_devices() after tx.init reshard: [688075264, 688075264, 688075264, 688075264]
bytes_in_use_devices() after rng_loop reshard: [688075776, 688075776, 688075776, 688075776]


In [8]:
# get tree of keys in train_state['params'] with values becoming shapes
print(f"img keys: {train_state['params']['img']['Transformer'].keys()}")
print(f"txt keys: {train_state['params']['txt']['Encoder_0'].keys()}")
img_shapes = jax.tree_map(lambda x: x.shape, train_state['params']['img'])
txt_shapes = jax.tree_map(lambda x: x.shape, train_state['params']['txt'])
# print(f"img_shapes: {img_shapes}")
# print(f"txt_shapes: {txt_shapes}")
# calculate total number of elements in train_state['params']
total_img_elements = sum(np.prod(shape) for shape in jax.tree_leaves(img_shapes))
total_txt_elements = sum(np.prod(shape) for shape in jax.tree_leaves(txt_shapes))
print(f"total_img_elements: {total_img_elements}")
print(f"total_txt_elements: {total_txt_elements}")

img keys: dict_keys(['encoder_norm', 'encoderblock_0', 'encoderblock_1', 'encoderblock_10', 'encoderblock_11', 'encoderblock_2', 'encoderblock_3', 'encoderblock_4', 'encoderblock_5', 'encoderblock_6', 'encoderblock_7', 'encoderblock_8', 'encoderblock_9'])
txt keys: dict_keys(['encoder_norm', 'encoderblock_0', 'encoderblock_1', 'encoderblock_10', 'encoderblock_11', 'encoderblock_2', 'encoderblock_3', 'encoderblock_4', 'encoderblock_5', 'encoderblock_6', 'encoderblock_7', 'encoderblock_8', 'encoderblock_9'])
total_img_elements: 249838
total_txt_elements: 265009


  img_shapes = jax.tree_map(lambda x: x.shape, train_state['params']['img'])
  txt_shapes = jax.tree_map(lambda x: x.shape, train_state['params']['txt'])
  total_img_elements = sum(np.prod(shape) for shape in jax.tree_leaves(img_shapes))
  total_txt_elements = sum(np.prod(shape) for shape in jax.tree_leaves(txt_shapes))


In [9]:
n_prefetch = config.get("prefetch_to_device", 1)
train_iter = input_pipeline.start_global(train_ds, devices_flat, n_prefetch)

In [10]:
# for batch in train_iter:
#     # print which process has which batch
#     logging.info(f"process {jax.process_index()} has batch {batch['labels']}")

# Result Section

### default: replicate sharding strategy, without gradient checkpointing, float32

In [11]:
print(f"Mesh('data': 4), \nNamedSharding(mesh=Mesh('data': 4), \nspec=PartitionSpec()), [('.*', 'replicate')]:")
print(f"param memory: {840481792-13312}")
print(f"opt memory: {2477426688-840481792}")

Mesh('data': 4), 
NamedSharding(mesh=Mesh('data': 4), 
spec=PartitionSpec()), [('.*', 'replicate')]:
param memory: 840468480
opt memory: 1636944896


### bfloat16

In [None]:
print(f"Mesh('data': 4), \nNamedSharding(mesh=Mesh('data': 4), \nspec=PartitionSpec()),[('.*', 'replicate')]: ")
print(f"model scan = False")
print(f"img dtype_mm = bfloat16")
print(f"param memory: {840226304-13312}")
print(f"opt memory: {2476526592-840226304}")

Mesh('data': 4), 
NamedSharding(mesh=Mesh('data': 4), 
spec=PartitionSpec()),[('.*', 'replicate')]: 
model scan = False
dtype_mm = bfloat16
param memory: 840212992
opt memory: 1636300288


### with FSDP

In [12]:
print(f"Mesh('data': 4), \nNamedSharding(mesh=Mesh('data': 4), \nspec=PartitionSpec()), [('.*', 'fsdp(axis='data', min_size_to_shard_mb=4)')]: ")
print(f"param memory: {416379904-13312}")
print(f"opt memory: {1193653760-416379904}")

Mesh('data': 4), 
NamedSharding(mesh=Mesh('data': 4), 
spec=PartitionSpec()), [('.*', 'fsdp(axis='data', min_size_to_shard_mb=4)')]: 
param memory: 416366592
opt memory: 777273856


In [13]:
print(f"Mesh('data': 4), \nNamedSharding(mesh=Mesh('data': 4), \nspec=PartitionSpec()),[('.*', 'fsdp(axis='data', min_size_to_shard_mb=2)')]: ")
print(f"param memory: {273303040-13312}")
print(f"opt memory: {727310336-273303040}")

Mesh('data': 4), 
NamedSharding(mesh=Mesh('data': 4), 
spec=PartitionSpec()),[('.*', 'fsdp(axis='data', min_size_to_shard_mb=2)')]: 
param memory: 273289728
opt memory: 454007296


In [14]:
print(f"Mesh('data': 4), \nNamedSharding(mesh=Mesh('data': 4), \nspec=PartitionSpec()), [('.*', 'fsdp(axis='data', min_size_to_shard_mb=1)')]: ")
print(f"param memory: {273303040-13312}")
print(f"opt memory: {727310336-273303040}")

Mesh('data': 4), 
NamedSharding(mesh=Mesh('data': 4), 
spec=PartitionSpec()), [('.*', 'fsdp(axis='data', min_size_to_shard_mb=1)')]: 
param memory: 273289728
opt memory: 454007296


### with scan = True

In [15]:
print(f"Mesh('data': 4), \nNamedSharding(mesh=Mesh('data': 4), \nspec=PartitionSpec()),[('.*', 'replicate')]: ")
print(f"model scan = True")
print(f"param memory: {236578816-13312}")
print(f"opt memory: {688075264-236578816}")

Mesh('data': 4), 
NamedSharding(mesh=Mesh('data': 4), 
spec=PartitionSpec()),[('.*', 'replicate')]: 
model scan = True
param memory: 236565504
opt memory: 451496448


### with FSDP & scan=True

In [8]:
print(f"Mesh('data': 4), \nNamedSharding(mesh=Mesh('data': 4), \nspec=PartitionSpec()),[('.*', 'fsdp(axis='data', min_size_to_shard_mb=4)')]: ")
print(f"model scan = True")
print(f"param memory: {245345792-13312}")
print(f"opt memory: {716230144-245345792}")

Mesh('data': 4), 
NamedSharding(mesh=Mesh('data': 4), 
spec=PartitionSpec()),[('.*', 'fsdp(axis='data', min_size_to_shard_mb=4)')]: 
model scan = True
param memory: 245332480
opt memory: 470884352


In [8]:
print(f"Mesh('data': 4), \nNamedSharding(mesh=Mesh('data': 4), \nspec=PartitionSpec()),[('.*', 'fsdp(axis='data', min_size_to_shard_mb=2)')]: ")
print(f"model scan = True")
print(f"param memory: {236578816-13312}")
print(f"opt memory: {688075264-236578816}")

Mesh('data': 4), 
NamedSharding(mesh=Mesh('data': 4), 
spec=PartitionSpec()),[('.*', 'fsdp(axis='data', min_size_to_shard_mb=2)')]: 
model scan = True
param memory: 236565504
opt memory: 451496448


In [8]:
print(f"Mesh('data': 4), \nNamedSharding(mesh=Mesh('data': 4), \nspec=PartitionSpec()),[('.*', 'fsdp(axis='data', min_size_to_shard_mb=1)')]: ")
print(f"model scan = True")
print(f"param memory: {236578816-13312}")
print(f"opt memory: {688075264-236578816}")

Mesh('data': 4), 
NamedSharding(mesh=Mesh('data': 4), 
spec=PartitionSpec()),[('.*', 'fsdp(axis='data', min_size_to_shard_mb=1)')]: 
model scan = True
param memory: 236565504
opt memory: 451496448
