In [2]:
import jax
import numpy as np
import flax.linen as nn
import jax.numpy as jnp
from jax import pmap
from functools import partial

In [3]:
class ResNetBlock(nn.Module):
    features: int = 16

    @nn.compact
    def __call__(self, x):
        residual = x

        x = nn.Conv(self.features, (7, 7), (2, 2))(x)
        x = nn.BatchNorm(use_running_average=False)(x)
        x = nn.relu(x)
        x = nn.Conv(self.features, (7, 7), (2, 2))(x)
        x = nn.BatchNorm(use_running_average=False)(x)
        x = nn.relu(x)
        if residual.shape != x.shape:
            residual = nn.Conv(self.features, (14, 14), (4, 4))(residual)
            residual = nn.BatchNorm(use_running_average=False)(residual)
        x += residual
        return nn.relu(x)

class ResNet(nn.Module):
    # stages: int = 4

    @nn.compact
    def __call__(self, x):
        x = nn.Conv(64, (7, 7), (2, 2), 14)(x)
        x = nn.max_pool(x, (7, 7), (2, 2), 'SAME')
        for i in range(4):
            x = ResNetBlock(64 * (i+1))(x)
            x = ResNetBlock(64 * (i+1))(x)
            x = ResNetBlock(64 * (i+1))(x)
        x = nn.avg_pool(x, (7, 7), (2, 2), 'SAME')
        x = nn.Dense(10)(x)
        return x
        

In [4]:
@pmap
def apply_model(x):
    r = ResNet()
    vari = r.init(jax.random.PRNGKey(0), jnp.ones((1, 28, 28, 1)))
    y = r.apply(vari, jnp.ones((10, 28, 28, 1)), mutable=['batch_stats'])
    return y


apply_model(jnp.ones((1, 28, 28, 1)))

2024-02-29 17:31:32.244770: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:439] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR
2024-02-29 17:31:32.244858: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:443] Memory usage: 7471104 bytes free, 85051572224 bytes total.


XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.

In [1]:
import jax
jax.device_count()

5

In [1]:
import os
# os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,4'      # 3 is just GPU for display.

import jax
from functools import partial
from jax.sharding import Mesh, NamedSharding, PartitionSpec
import numpy as np
import jax.numpy as jnp

def perceptron(x, w, b):
    # x의 차원: AB
    # w의 차원: BC
    # b의 차원: C

    # out의 차원: AC
    out = jnp.dot(x, w)

    # b를 AC 로 브로드캐스트
    # jax.lax.broadcast(array, dims)는 배열의 앞에 dims만큼의 차원을 추가함
    # 예시: jax.lax.broadcast(np.ones((2, 4)), [8, 16]) 
    #  -> (8, 16, 2, 4) 형태로 브로드캐스트
    b = jax.lax.broadcast(b, [x.shape[0]])

    out += b

    return out


mesh = Mesh(np.array(jax.devices()).reshape(1, 4), ("x", "y"))

x = np.ones((48, 24))
w = np.ones((24, 48))
b = np.ones((48,))

y = jax.jit(
    perceptron, 
    in_shardings=(
        NamedSharding(mesh, PartitionSpec("x", None)), 
        NamedSharding(mesh, PartitionSpec(None, "y")),
        NamedSharding(mesh, PartitionSpec())
    ),
    out_shardings=NamedSharding(mesh, PartitionSpec())
)(x, w, b)

def perceptron_with_sharding_inspection(x, w, b):
    jax.debug.inspect_array_sharding(x, callback=partial(print, "input x sharding:"))
    jax.debug.inspect_array_sharding(w, callback=partial(print, "input w sharding:"))
    jax.debug.inspect_array_sharding(b, callback=partial(print, "input b sharding:"))

    out = jnp.dot(x, w)
    jax.debug.inspect_array_sharding(out, callback=partial(print, "output of dot sharding:"))

    b = jax.lax.broadcast(b, [x.shape[0]])
    jax.debug.inspect_array_sharding(b, callback=partial(print, "broadcasted b sharding:"))

    out += b
    jax.debug.inspect_array_sharding(out, callback=partial(print, "output sharding:"))
    return out

jax.jit(
    perceptron_with_sharding_inspection,
    in_shardings=(
        NamedSharding(mesh, PartitionSpec("x", None)), 
        NamedSharding(mesh, PartitionSpec(None, "y")),
        NamedSharding(mesh, PartitionSpec())
    ),
    out_shardings=NamedSharding(mesh, PartitionSpec())
)(x, w, b);

input x sharding: GSPMDSharding({replicated})
input w sharding: GSPMDSharding({devices=[1,4]0,1,2,3})
input b sharding: GSPMDSharding({replicated})
output of dot sharding: GSPMDSharding({devices=[1,4]0,1,2,3})
broadcasted b sharding: GSPMDSharding({replicated})
output sharding: GSPMDSharding({replicated})


In [2]:
jax.devices()

[gpu(id=0), gpu(id=1), gpu(id=2), gpu(id=3), gpu(id=4)]