In [1]:
from typing import Any

import numpy as np
import torch
from torchvision import datasets, transforms
from torchvision.datasets import CIFAR10

import jax 
import jax.numpy as jnp
import optax 
from flax import linen as nn
from flax.training import train_state


In [2]:
SEED = 42

num_workers = 4
batch_size = 128

In [3]:
# gpu確認
print(jax.devices())

[StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)]


### 以下flaxによるCIFAR10実装

In [4]:
# random key
key = jax.random.PRNGKey(SEED)

In [5]:
# モデル
# flaxではモデルはモデルの定義のみで実際のパラメータは別で持っている
# -> @nn.compactのせい
# 多分普通にpytorchっぽくも書けるっぽい？
class SimpleCNN(nn.Module):
    @nn.compact
    def __call__(self,x):
        x = nn.Conv(features=32, kernel_size=(3,3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2,2), strides=(2,2))
        x = nn.Conv(features=64, kernel_size=(3,3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2,2), strides=(2,2))
        x = x.reshape((x.shape[0], -1))  # flatten
        x = nn.Dense(features=128)(x)
        x = nn.relu(x)
        x = nn.Dense(features=10)(x)
        return x

# モデルのチェックなど
model = SimpleCNN()
x = jnp.ones((3, 32, 32, 3))
params = model.init(key, jnp.ones([1,32,32,3]))['params']
out = model.apply({'params': params}, x)
print(out.shape)

(3, 10)


In [6]:
# データセット
# torchvisionからCIFAR10を用いる
DATASET_PATH = "./cache"
train_dataset = CIFAR10(root=DATASET_PATH, train=True, download=True)
DATA_MEANS = (train_dataset.data / 255.0).mean(axis=(0,1,2))
DATA_STD = (train_dataset.data / 255.0).std(axis=(0,1,2))
print("Data mean", DATA_MEANS)
print("Data std", DATA_STD)

Files already downloaded and verified
Data mean [0.49139968 0.48215841 0.44653091]
Data std [0.24703223 0.24348513 0.26158784]


In [7]:
# image util functions
def image_to_numpy(img):
    """
    PIL形式の画像をnumpy形式に変換して、正規化する関数
    """
    img = np.array(img, dtype=np.float32)
    img = (img / 255. - DATA_MEANS) / DATA_STD
    return img

# We need to stack the batch elements
def numpy_collate(batch):
    if isinstance(batch[0], np.ndarray):
        return np.stack(batch)
    elif isinstance(batch[0], (tuple,list)):
        transposed = zip(*batch)
        return [numpy_collate(samples) for samples in transposed]
    else:
        return np.array(batch)

In [8]:
test_transform = image_to_numpy
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomResizedCrop((32,32),scale=(0.8,1.0),ratio=(0.9,1.1)),
    image_to_numpy
])
train_dataset = CIFAR10(root=DATASET_PATH, train=True, transform=train_transform, download=True)
val_dataset = CIFAR10(root=DATASET_PATH, train=True, transform=test_transform, download=True)
train_set, _ = torch.utils.data.random_split(train_dataset, [45000, 5000], generator=torch.Generator().manual_seed(SEED))
_, val_set = torch.utils.data.random_split(val_dataset, [45000, 5000], generator=torch.Generator().manual_seed(SEED))

test_set = CIFAR10(root=DATASET_PATH, train=False, transform=test_transform, download=True)

train_loader = torch.utils.data.DataLoader(
    train_set, batch_size=batch_size, shuffle=True, drop_last=True,
    collate_fn=numpy_collate, num_workers=num_workers, persistent_workers=True)
val_loader   = torch.utils.data.DataLoader(
    val_set, batch_size=batch_size, shuffle=False, drop_last=False,
    collate_fn=numpy_collate, num_workers=num_workers, persistent_workers=True)
test_loader  = torch.utils.data.DataLoader(
    test_set, batch_size=batch_size, shuffle=False, drop_last=False,
    collate_fn=numpy_collate, num_workers=num_workers, persistent_workers=True)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [9]:
class TrainState(train_state.TrainState):
    # A simple extension of TrainState to also include batch statistics
    batch_stats: Any

def create_train_state(key, learning_rate, momentum):
    model = SimpleCNN()
    input_shape = (1, 32, 32, 3)
    params = model.init(key, jnp.ones(input_shape))['params']
    tx = optax.sgd(learning_rate, momentum)
    return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)

In [10]:
# モデルの生成
key, init_key = jax.random.split(key, 2)
model_state = create_train_state(init_key, 0.1, 0.9)
#params = model.init(init_key, jnp.ones(input_shape))['params']

In [11]:
def calculate_loss_acc(state,params,batch):
    """
    Loss: CrossEntropyLoss
    """
    data_inputs, labels = batch
    #labels = jnp.array(labels)
    logits = state.apply_fn({"params": params}, data_inputs)

    loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
    acc = (logits.argmax(axis=-1) == labels).mean()

    return loss, acc

batch = next(iter(train_loader))
loss,acc = calculate_loss_acc(model_state, model_state.params, batch)
print("loss", loss)
print("acc", acc)

loss 2.3220553
acc 0.0546875


In [12]:
# train function
@jax.jit  # Jit the function for efficiency
def train_step(state, batch):
    # Gradient function
    grad_fn = jax.value_and_grad(calculate_loss_acc,  # Function to calculate the loss
                                 argnums=1,  # Parameters are second argument of the function
                                 has_aux=True  # Function has additional outputs, here accuracy
                                )
    # Determine gradients for current model, parameters and batch
    (loss, acc), grads = grad_fn(state, state.params, batch)
    # Perform parameter update with gradients and optimizer
    state = state.apply_gradients(grads=grads)
    # Return state and any other value we might want
    return state, loss, acc

# eval function
@jax.jit  # Jit the function for efficiency
def eval_step(state, batch):
    # Determine the accuracy
    _, acc = calculate_loss_acc(state, state.params, batch)
    return acc

# train test
model_state, loss, acc = train_step(model_state, batch)
print("loss", loss)
print("acc", acc)

# eval test
acc = eval_step(model_state, batch)
print("acc", acc)

loss 2.3220487
acc 0.0546875
acc 0.234375


In [13]:
def train_loop(state,train_loader,val_loader,num_epochs=10):
    for epoch in range(num_epochs):
        # train 1 epoch
        train_losses = []
        train_accs = []
        for batch in train_loader:
            state, loss, train_acc = train_step(state, batch)
            train_accs.append(train_acc)
            train_losses.append(loss)
        train_acc = np.mean(train_accs)
        loss = np.mean(train_losses)
        
        # eval
        val_accs = []
        for batch in val_loader:
            val_acc = eval_step(state, batch)
            val_accs.append(val_acc)
        val_acc = np.mean(val_accs)
                
        # log
        print(f"Epoch: {epoch}, Train_Loss: {loss}, Train Acc: {train_acc}, Val Acc: {val_acc}")

train_loop(model_state, train_loader, val_loader, num_epochs=10)

Epoch: 0, Train_Loss: 1.5358785390853882, Train Acc: 0.4463363587856293, Val Acc: 0.5376952886581421
Epoch: 1, Train_Loss: 1.2157222032546997, Train Acc: 0.5678864121437073, Val Acc: 0.613085925579071
Epoch: 2, Train_Loss: 1.0992904901504517, Train Acc: 0.6159188151359558, Val Acc: 0.6373046636581421
Epoch: 3, Train_Loss: 1.0180217027664185, Train Acc: 0.64507657289505, Val Acc: 0.6519531011581421
Epoch: 4, Train_Loss: 0.9694800972938538, Train Acc: 0.6638621687889099, Val Acc: 0.665234386920929
Epoch: 5, Train_Loss: 0.8968766927719116, Train Acc: 0.6904380321502686, Val Acc: 0.669726550579071
Epoch: 6, Train_Loss: 0.8641999363899231, Train Acc: 0.7015447020530701, Val Acc: 0.662890613079071
Epoch: 7, Train_Loss: 0.8321110010147095, Train Acc: 0.7086894512176514, Val Acc: 0.67578125
Epoch: 8, Train_Loss: 0.8040343523025513, Train Acc: 0.7223334908485413, Val Acc: 0.6832031011581421
Epoch: 9, Train_Loss: 0.7632994055747986, Train Acc: 0.7381588220596313, Val Acc: 0.6962890625
Epoch: 10,

### 細々としたモジュールども

In [28]:
# Dropout
# https://flax.readthedocs.io/en/latest/guides/training_techniques/dropout.html

# model
# 別にcompactでもかける
class DropoutMLP(nn.Module):
    input_dim: int = 3
    hidden_dim: int = 30
    output_dim: int = 5
    dropout_prob: float = 0.5
    
    def setup(self):
        self.dense1 = nn.Dense(self.hidden_dim)
        self.dense2 = nn.Dense(self.output_dim)
        self.dropout = nn.Dropout(rate=self.dropout_prob)

    def __call__(self, x, training=False):
        x = self.dense1(x)
        x = nn.relu(x)
        x = self.dropout(x, deterministic=not training)
        x = self.dense2(x)
        return x

In [29]:
# init model
key, params_key = jax.random.split(key)
model = DropoutMLP(dropout_prob=0.3)
input_shape = jnp.empty((1, 3))
variables = model.init(params_key, x, training=False)
params = variables['params']

# forward with dropout
key = jax.random.PRNGKey(0)
key, dropout_key = jax.random.split(key)
x = jnp.ones((3, 3))
y = model.apply({"params": params}, x, training=True, rngs={"dropout": dropout_key})
print(y)

# forward without dropout
key = jax.random.PRNGKey(0)
key, dropout_key = jax.random.split(key)
x = jnp.ones((3, 3))
#y = model.apply({"params": params}, x, training=False, rngs={"dropout": dropout_key})
y = model.apply({"params": params}, x) # trainingをデフォルトでFalseにしておくと、これで良くなる
print(y)

[[ 0.22760291  0.10780797 -0.01367897  0.9634542  -0.0916012 ]
 [ 0.5429692   0.5138054  -0.06723279  0.87348104  0.7162861 ]
 [ 0.6461147   0.00951485  0.43068302  1.9464691  -0.17464459]]
[[0.896132   0.28386933 0.0984112  1.3529255  0.20996922]
 [0.896132   0.28386933 0.0984112  1.3529255  0.20996922]
 [0.896132   0.28386933 0.0984112  1.3529255  0.20996922]]


In [48]:
# flax training 
class TrainState(train_state.TrainState):
    dropout_key: jax.Array

dropout_key = jax.random.PRNGKey(0)
state = TrainState.create(
    apply_fn=model.apply,
    params=params,
    dropout_key=dropout_key,
    tx=optax.adam(1e-3)
)

# train step
def loss_fn(
    state: TrainState,
    params: dict,
    batch: jnp.ndarray,
    training: bool,
    dropout_key: jnp.ndarray,
) -> jnp.ndarray:
    """Computes the loss for a single batch."""
    data, labels = batch
    logits = state.apply_fn({"params": params}, data, training=True, rngs={"dropout": dropout_key})
    loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
    return loss

#@jax.jit
def train_step(state, batch):
    dropout_key = jax.random.fold_in(state.dropout_key, state.step)
    grad_fn = jax.value_and_grad(loss_fn, argnums=1, has_aux=True)
    #(loss, acc), grads = jax.value_and_grad(loss_fn, has_aux=True)(state, state.params, batch)
    loss, grads = grad_fn(state, state.params, batch, True, dropout_key)
    state = state.apply_gradients(grads=grads)
    return state, loss

# test one step
data = jnp.ones((3, 3))
labels = jnp.array([0, 1, 2])
batch = (data, labels)
state, loss = train_step(state, batch)


TypeError: expected function with aux output to return a two-element tuple, but got type <class 'jax._src.interpreters.ad.JVPTracer'> with value Traced<ConcreteArray(1.5074005126953125, dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array(1.5074005, dtype=float32)
  tangent = Traced<ShapedArray(float32[])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7f007e950130>, in_tracers=(Traced<ShapedArray(float32[3]):JaxprTrace(level=1/0)>,), out_tracer_refs=[<weakref at 0x7eff30620b80; to 'JaxprTracer' at 0x7eff3075c0e0>], out_avals=[ShapedArray(float32[])], primitive=pjit, params={'jaxpr': { lambda ; a:f32[3]. let
    b:f32[] = reduce_sum[axes=(0,)] a
    c:f32[] = div b 3.0
  in (c,) }, 'in_shardings': (UnspecifiedValue,), 'out_shardings': (UnspecifiedValue,), 'resource_env': None, 'donated_invars': (False,), 'name': '_mean', 'keep_unused': False, 'inline': True}, effects=set(), source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x7f007e92f8f0>, name_stack=NameStack(stack=(Transform(name='jvp'),))))