In [55]:
import jax
import collections
import jax.numpy as jnp

from typing import Iterator
from dataclasses import dataclass

class Dataset:
    """A pytorch-like Dataset class."""

    def __len__(self):
        raise NotImplementedError

    def __getitem__(self, index):
        raise NotImplementedError


class ArrayDataset(Dataset):
    """Dataset wrapping numpy arrays."""

    def __init__(
        self, 
        *arrays: jnp.array
    ):
        assert all(arrays[0].shape[0] == arr.shape[0] for arr in arrays), \
            "All arrays must have the same dimension."
        self.arrays = arrays

    def __len__(self):
        return self.arrays[0].shape[0]

    def __getitem__(self, index):
        return tuple(arr[index] for arr in self.arrays)
    

class Text2TextDataset(Dataset):
    """Dataset wrapping numpy arrays."""

    def __init__(
        self, 
        *arrays: jnp.array
    ):
        assert all(arrays[0].shape[0] == arr.shape[0] for arr in arrays), \
            "All arrays must have the same dimension."
        self.arrays = arrays

    def __len__(self):
        return self.arrays[0].shape[0]

    def __getitem__(self, index):
        return tuple(arr[index] for arr in self.arrays)


class CausalLMDataset(Dataset):
    """Dataset wrapping numpy arrays."""

    def __init__(
        self, 
        *arrays: jnp.array
    ):
        assert all(arrays[0].shape[0] == arr.shape[0] for arr in arrays), \
            "All arrays must have the same dimension."
        self.arrays = arrays

    def __len__(self):
        return self.arrays[0].shape[0]

    def __getitem__(self, index):
        return tuple(arr[index] for arr in self.arrays)
    

class ConditionalTextDataset(Dataset):
    """Dataset wrapping numpy arrays."""

    def __init__(
        self, 
        *arrays: jnp.array
    ):
        assert all(arrays[0].shape[0] == arr.shape[0] for arr in arrays), \
            "All arrays must have the same dimension."
        self.arrays = arrays

    def __len__(self):
        return self.arrays[0].shape[0]

    def __getitem__(self, index):
        return tuple(arr[index] for arr in self.arrays)
    

class ImageDataset(Dataset):
    """Dataset wrapping numpy arrays."""

    def __init__(
        self, 
        *arrays: jnp.array
    ):
        assert all(arrays[0].shape[0] == arr.shape[0] for arr in arrays), \
            "All arrays must have the same dimension."
        self.arrays = arrays

    def __len__(self):
        return self.arrays[0].shape[0]

    def __getitem__(self, index):
        return tuple(arr[index] for arr in self.arrays)


class ImageToImageDataset(Dataset):
    """Dataset wrapping numpy arrays."""

    def __init__(
        self, 
        *arrays: jnp.array
    ):
        assert all(arrays[0].shape[0] == arr.shape[0] for arr in arrays), \
            "All arrays must have the same dimension."
        self.arrays = arrays

    def __len__(self):
        return self.arrays[0].shape[0]

    def __getitem__(self, index):
        return tuple(arr[index] for arr in self.arrays)
    

class ConditionalTextDataset(Dataset):
    """Dataset wrapping numpy arrays."""

    def __init__(
        self, 
        *arrays: jnp.array
    ):
        assert all(arrays[0].shape[0] == arr.shape[0] for arr in arrays), \
            "All arrays must have the same dimension."
        self.arrays = arrays

    def __len__(self):
        return self.arrays[0].shape[0]

    def __getitem__(self, index):
        return tuple(arr[index] for arr in self.arrays)
    

class ImageToTextDataset(Dataset):
    """Dataset wrapping numpy arrays."""

    def __init__(
        self, 
        *arrays: jnp.array
    ):
        assert all(arrays[0].shape[0] == arr.shape[0] for arr in arrays), \
            "All arrays must have the same dimension."
        self.arrays = arrays

    def __len__(self):
        return self.arrays[0].shape[0]

    def __getitem__(self, index):
        return tuple(arr[index] for arr in self.arrays)
    

class TextToImageDataset(Dataset):
    """Dataset wrapping numpy arrays."""

    def __init__(
        self, 
        *arrays: jnp.array
    ):
        assert all(arrays[0].shape[0] == arr.shape[0] for arr in arrays), \
            "All arrays must have the same dimension."
        self.arrays = arrays

    def __len__(self):
        return self.arrays[0].shape[0]

    def __getitem__(self, index):
        return tuple(arr[index] for arr in self.arrays)


class DataLoader:
    """Dataloder in Vanilla Jax"""
    def __init__(
        self, 
        dataset: Dataset, 
        batch_size: int = 1,
        shuffle: bool = False,
        drop_last: bool = False,
        **kwargs
    ):
        self.dataset = dataset
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.drop_last = drop_last

        self.keys = PRNGSequence(seed=Config.default().global_seed)
        self.data_len = len(dataset)  # Length of the dataset
        self.indices = jnp.arange(self.data_len) # available indices in the dataset
        self.pose = 0  # record the current position in the dataset
        self._shuffle()

    def _shuffle(self):
        if self.shuffle:
            self.indices = jax.random.permutation(next(self.keys), self.indices)
        
    def _stop_iteration(self):
        self.pose = 0
        self._shuffle()
        raise StopIteration

    def __len__(self):
        if self.drop_last:
            batches = len(self.dataset) // self.batch_size  # get the floor of division
        else:
            batches = -(len(self.dataset) // -self.batch_size)  # get the ceil of division
        return batches

    def __next__(self):
        if self.pose + self.batch_size <= self.data_len:
            batch_indices = self.indices[self.pose: self.pose + self.batch_size]
            batch_data = self.dataset[batch_indices]
            self.pose += self.batch_size
            return batch_data
        elif self.pose < self.data_len and not self.drop_last:
            batch_indices = self.indices[self.pose:]
            batch_data = self.dataset[batch_indices]
            self.pose += self.batch_size
            return batch_data
        else:
            self._stop_iteration()

    def __iter__(self):
        return self
    

@dataclass
class Config:
    """Global configuration for the library"""
    rng_reserve_size: int
    global_seed: int

    @classmethod
    def default(cls):
        return cls(rng_reserve_size=1, global_seed=42)

class PRNGSequence(Iterator[jax.random.PRNGKey]):
    """An Interator of Jax PRNGKey (minimal version of `haiku.PRNGSequence`)."""

    def __init__(self, seed: int):
        self._key = jax.random.PRNGKey(seed)
        self._subkeys = collections.deque()

    def reserve(self, num):
        """Splits additional ``num`` keys for later use."""
        if num > 0:
            new_keys = tuple(jax.random.split(self._key, num + 1))
            self._key = new_keys[0]
            self._subkeys.extend(new_keys[1:])
            
    def __next__(self):
        if not self._subkeys:
            self.reserve(Config.default().rng_reserve_size)
        return self._subkeys.popleft()
    

dataset = ArrayDataset(jnp.ones((1001,256,256)), jnp.ones((1001,256,256)))
dataloader = DataLoader(dataset, batch_size= 10, shuffle= True, drop_last= False)

In [38]:
import jax
import pickle
import jax.numpy as jnp
from flax.training import train_state


class DataParallelTrainer:
    
    def __init__(self, 
                 model, 
                 input_shape,
                 train_step,
                 optax_optimizer,
                 learning_rate,
                 weights_filename):
        
        self.model = model
        self.num_parameters = None
        self.optimizer = optax_optimizer
        self.best_val_loss = float("inf")
        self.weights_filename = weights_filename
        self.num_devices = jax.local_device_count()
        self.train_step = jax.pmap(train_step, axis_name='devices')
        self.state = self.create_train_state(learning_rate, input_shape)
    

    def create_train_state(self, learning_rate, input_shape):
        rng = jax.random.PRNGKey(0)
        params = self.model.init(rng, jnp.ones(input_shape))['params']
        self.num_parameters = sum(param.size for param in jax.tree_util.tree_leaves(params))
        state = train_state.TrainState.create(apply_fn=self.model.apply, 
                                              params=params, 
                                              tx=self.optimizer(learning_rate))
        return jax.device_put_replicated(state, jax.local_devices())
    

    def train(self, train_loader, num_epochs, val_loader=None):
        for epoch in range(num_epochs):
            total_loss = 0.0
            for inputs, targets in train_loader:
                batch_size = inputs.shape[0]
                batch_size_per_device = batch_size // self.num_devices
                inputs = inputs.reshape((self.num_devices, batch_size_per_device, -1))
                targets = targets.reshape((self.num_devices, batch_size_per_device, -1))
                self.state, loss = self.train_step(self.state, {'inputs': inputs, 'targets': targets})
                total_loss += jnp.mean(loss)
            
            mean_loss = total_loss / num_epochs
            print(f'Epoch {epoch+1}, Train Loss: {mean_loss}')

        if val_loader is not None:
            self.validate(val_loader, epoch, num_epochs)
        return 
    

    def validate(self, val_loader, epoch, num_epochs):
        total_loss = 0.0
        for inputs, targets in val_loader:
            batch_size = inputs.shape[0]
            batch_size_per_device = batch_size // self.num_devices
            inputs = inputs.reshape((self.num_devices, batch_size_per_device, -1))
            targets = targets.reshape((self.num_devices, batch_size_per_device, -1))
            _, loss = self.train_step(self.state, {'inputs': inputs, 'targets': targets})
            total_loss += jnp.mean(loss)
        
        mean_loss = total_loss / num_epochs
        print(f'Epoch {epoch+1}, Val Loss: {mean_loss}')
        if mean_loss < self.best_val_loss:
            self.best_val_loss = mean_loss
        print("New best validation score achieved, saving model...")
        self.save_params()
        return 
    
    
    def save_params(self):
        with open(self.weights_filename, 'wb') as f:
            pickle.dump(self.state.params, f)
        return

 
    def load_params(self, filename):
        with open(filename, 'rb') as f:
            params = pickle.load(f)
        return params

In [29]:
import jax
import jax.numpy as jnp
from jax.nn import softmax

def mse_train_step(state, batch):
    def loss_fn(params):
        predictions = state.apply_fn({'params': params}, batch['inputs'])
        loss = jnp.mean(jnp.square(predictions - batch['targets']))
        return loss
    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    return state, loss


def binary_cross_entropy_train_step(state, batch):
    def loss_fn(params):
        logits = state.apply_fn({'params': params}, batch['inputs'])
        loss = -jnp.mean(batch['targets'] * jax.nn.log_sigmoid(logits) +
                         (1 - batch['targets']) * jax.nn.log_sigmoid(-logits))
        return loss
    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    return state, loss


def categorical_cross_entropy_train_step(state, batch):
    def loss_fn(params):
        logits = state.apply_fn({'params': params}, batch['inputs'])
        loss = -jnp.mean(jnp.sum(jax.nn.log_softmax(logits) * batch['targets'], axis=-1))
        return loss
    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    return state, loss


def sparse_categorical_cross_entropy_train_step(state, batch):
    def loss_fn(params):
        logits = state.apply_fn({'params': params}, batch['inputs'])
        loss = jax.nn.sparse_softmax_cross_entropy(logits, batch['targets'])
        return jnp.mean(loss)
    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    return state, loss


def cross_entropy_train_step(state, batch):
    def loss_fn(params):
        logits = state.apply_fn({'params': params}, batch['inputs'])
        loss = -jnp.mean(jnp.sum(softmax(logits) * batch['targets'], axis=-1))
        return loss
    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    return state, loss


def l1_regularization_train_step(state, batch, l1_penalty):
    def loss_fn(params):
        logits = state.apply_fn({'params': params}, batch['inputs'])
        loss = -jnp.mean(jnp.sum(softmax(logits) * batch['targets'], axis=-1))
        l1_loss = l1_penalty * jnp.sum(jnp.abs(params))
        return loss + l1_loss
    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    return state, loss


def l2_regularization_train_step(state, batch, l2_penalty):
    def loss_fn(params):
        logits = state.apply_fn({'params': params}, batch['inputs'])
        loss = -jnp.mean(jnp.sum(softmax(logits) * batch['targets'], axis=-1))
        l2_loss = l2_penalty * jnp.sum(jnp.square(params))
        return loss + l2_loss
    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    return state, loss


def hinge_loss_train_step(state, batch):
    def loss_fn(params):
        logits = state.apply_fn({'params': params}, batch['inputs'])
        targets = 2 * batch['targets'] - 1  # Convert {0, 1} labels to {-1, 1}
        loss = jnp.mean(jnp.maximum(0, 1 - targets * logits))
        return loss
    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    return state, loss


def triplet_loss_train_step(state, batch, margin):
    def loss_fn(params):
        anchor, positive, negative = batch['anchor'], batch['positive'], batch['negative']
        anchor_logits = state.apply_fn({'params': params}, anchor)
        positive_logits = state.apply_fn({'params': params}, positive)
        negative_logits = state.apply_fn({'params': params}, negative)
        
        positive_distance = jnp.sum(jnp.square(anchor_logits - positive_logits), axis=-1)
        negative_distance = jnp.sum(jnp.square(anchor_logits - negative_logits), axis=-1)
        loss = jnp.mean(jnp.maximum(0, margin + positive_distance - negative_distance))
        return loss
    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    return state, loss


def contrastive_loss_train_step(state, batch, margin):
    def loss_fn(params):
        x1, x2, y = batch['x1'], batch['x2'], batch['y']
        logits1 = state.apply_fn({'params': params}, x1)
        logits2 = state.apply_fn({'params': params}, x2)
        
        euclidean_distance = jnp.sqrt(jnp.sum(jnp.square(logits1 - logits2), axis=-1))
        loss = y * jnp.square(euclidean_distance) + (1 - y) * jnp.square(jnp.maximum(margin - euclidean_distance, 0))
        return jnp.mean(loss)
    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    return state, loss


def cosine_similarity_loss_train_step(state, batch):
    def loss_fn(params):
        embeddings1 = state.apply_fn({'params': params}, batch['x1'])
        embeddings2 = state.apply_fn({'params': params}, batch['x2'])
        # Normalize the embeddings to have unit length
        embeddings1 = embeddings1 / jnp.linalg.norm(embeddings1, axis=-1, keepdims=True)
        embeddings2 = embeddings2 / jnp.linalg.norm(embeddings2, axis=-1, keepdims=True)
        # Cosine similarity as dot product of normalized vectors; we subtract from 1 to get the loss
        loss = 1 - jnp.sum(embeddings1 * embeddings2, axis=-1)
        return jnp.mean(loss)
    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    return state, loss


def clip_loss_train_step(state, batch, temperature=0.07):
    def loss_fn(params):
        # Get image and text features from the CLIP model
        image_features, text_features = state.apply_fn({'params': params}, batch['images'], batch['text'])
        
        # Normalize the features
        image_features = image_features / jnp.linalg.norm(image_features, axis=-1, keepdims=True)
        text_features = text_features / jnp.linalg.norm(text_features, axis=-1, keepdims=True)
        
        # Calculate the similarity
        similarity = jnp.dot(image_features, text_features.T) / temperature
        
        # CLIP loss calculation
        image_loss = jax.nn.softmax_cross_entropy(similarity, jnp.arange(similarity.shape[0]))
        text_loss = jax.nn.softmax_cross_entropy(similarity.T, jnp.arange(similarity.shape[0]))
        loss = (image_loss + text_loss) / 2
        return jnp.mean(loss)
        
    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    return state, loss


def nli_train_step(state, batch):
    def loss_fn(params):
        # Obtain the logits for the three classes: entailment, contradiction, neutral
        logits = state.apply_fn({'params': params}, batch['premise'], batch['hypothesis'])
        # Use softmax cross-entropy as the loss function
        loss = jax.nn.softmax_cross_entropy(logits, batch['label'])
        return jnp.mean(loss)
        
    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    return state, loss


def simclr_train_step(state, batch, temperature=0.1):
    def loss_fn(params):
        # Obtain representations for two sets of augmented images
        representations = state.apply_fn({'params': params}, batch['augmented_images'])
        # Split the representations into two halves, one for each augmentation
        h1, h2 = jnp.split(representations, 2, axis=0)
        
        # Normalize the representations
        h1 = h1 / jnp.linalg.norm(h1, axis=-1, keepdims=True)
        h2 = h2 / jnp.linalg.norm(h2, axis=-1, keepdims=True)
        
        # Compute the similarity between all pairs
        similarity_matrix = jnp.matmul(h1, h2.T) / temperature
        
        # The contrastive loss function
        batch_size = h1.shape[0]
        contrastive_labels = jnp.arange(batch_size)
        loss = jax.nn.softmax_cross_entropy(similarity_matrix, contrastive_labels)
        return jnp.mean(loss)
        
    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    return state, loss


def vae_train_step(vae_state, batch, optimizer):
    def vae_loss_fn(params):
        reconstructions, mean, logvar = vae_state.apply_fn({'params': params}, batch['inputs'])
        # Reconstruction loss (e.g., binary cross-entropy or MSE)
        recon_loss = jnp.mean(jnp.square(reconstructions - batch['inputs']))
        # KL divergence loss
        kl_loss = -0.5 * jnp.mean(1 + logvar - jnp.square(mean) - jnp.exp(logvar))
        # Total loss is the sum of reconstruction loss and KL divergence
        return recon_loss + kl_loss

    grad_fn = jax.value_and_grad(vae_loss_fn)
    loss, grads = grad_fn(vae_state.params)
    updates, new_opt_state = optimizer.update(grads, vae_state.opt_state)
    new_vae_state = vae_state.apply_gradients(grads=updates, opt_state=new_opt_state)
    return new_vae_state, loss


def dice_loss_train_step(state, batch):
    def loss_fn(params):
        logits = state.apply_fn({'params': params}, batch['inputs'])
        probs = jax.nn.softmax(logits)
        intersection = jnp.sum(probs * batch['targets'])
        loss = 1 - (2. * intersection + 1.) / (jnp.sum(probs) + jnp.sum(batch['targets']) + 1.)
        return loss
    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    return state, loss


def focal_loss_train_step(state, batch, gamma=2.0, alpha=0.25):
    def loss_fn(params):
        logits = state.apply_fn({'params': params}, batch['inputs'])
        probs = jax.nn.softmax(logits)
        ce_loss = -jnp.sum(softmax(logits) * batch['targets'], axis=-1)
        p_t = jnp.where(batch['targets'] == 1, probs, 1 - probs)
        loss = jnp.mean(-alpha * jnp.power(1 - p_t, gamma) * ce_loss)
        return loss
    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    return state, loss


def kl_divergence_train_step(state, batch):
    def loss_fn(params):
        logits = state.apply_fn({'params': params}, batch['inputs'])
        q_probs = softmax(logits)
        p_probs = batch['targets']
        loss = jnp.sum(q_probs * (jnp.log(q_probs) - jnp.log(p_probs)), axis=-1)
        return jnp.mean(loss)
    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    return state, loss


def iou_loss_train_step(state, batch):
    def loss_fn(params):
        logits = state.apply_fn({'params': params}, batch['inputs'])
        probs = jax.nn.softmax(logits)
        intersection = jnp.sum(probs * batch['targets'], axis=(1,2))
        union = jnp.sum(probs + batch['targets'], axis=(1,2)) - intersection
        iou = intersection / jnp.maximum(union, 1e-6)
        loss = 1 - iou  # IoU loss
        return jnp.mean(loss)
    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    return state, loss


def tversky_loss_train_step(state, batch, alpha=0.5, beta=0.5):
    def loss_fn(params):
        logits = state.apply_fn({'params': params}, batch['inputs'])
        probs = jax.nn.softmax(logits)
        true_pos = jnp.sum(batch['targets'] * probs)
        false_neg = jnp.sum(batch['targets'] * (1 - probs))
        false_pos = jnp.sum((1 - batch['targets']) * probs)
        tversky_index = true_pos / (true_pos + alpha * false_neg + beta * false_pos)
        loss = 1 - tversky_index
        return loss
    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    return state, loss


def gan_train_step(generator_state, discriminator_state, batch, z_dim, rng_key):
    # Sample random noise
    z = jax.random.normal(rng_key, (batch['real_images'].shape[0], z_dim))

    # Discriminator loss
    def discriminator_loss_fn(d_params):
        fake_images = generator_state.apply_fn({'params': generator_state.params}, z)
        real_logits = discriminator_state.apply_fn({'params': d_params}, batch['real_images'])
        fake_logits = discriminator_state.apply_fn({'params': d_params}, fake_images)
        real_loss = jax.nn.sigmoid_cross_entropy_with_logits(real_logits, jnp.ones_like(real_logits))
        fake_loss = jax.nn.sigmoid_cross_entropy_with_logits(fake_logits, jnp.zeros_like(fake_logits))
        return jnp.mean(real_loss + fake_loss)

    # Generator loss
    def generator_loss_fn(g_params):
        fake_images = generator_state.apply_fn({'params': g_params}, z)
        fake_logits = discriminator_state.apply_fn({'params': discriminator_state.params}, fake_images)
        return jnp.mean(jax.nn.sigmoid_cross_entropy_with_logits(fake_logits, jnp.ones_like(fake_logits)))

    # Update discriminator
    d_grad_fn = jax.value_and_grad(discriminator_loss_fn)
    d_loss, d_grads = d_grad_fn(discriminator_state.params)
    discriminator_state = discriminator_state.apply_gradients(grads=d_grads)

    # Update generator
    g_grad_fn = jax.value_and_grad(generator_loss_fn)
    g_loss, g_grads = g_grad_fn(generator_state.params)
    generator_state = generator_state.apply_gradients(grads=g_grads)

    return generator_state, discriminator_state, g_loss, d_loss


def wgan_train_step(generator_state, discriminator_state, batch, z_dim, rng_key):
    # Sample random noise
    z = jax.random.normal(rng_key, (batch['real_images'].shape[0], z_dim))

    # Discriminator (Critic) loss
    def discriminator_loss_fn(d_params):
        fake_images = generator_state.apply_fn({'params': generator_state.params}, z)
        real_logits = discriminator_state.apply_fn({'params': d_params}, batch['real_images'])
        fake_logits = discriminator_state.apply_fn({'params': d_params}, fake_images)
        return jnp.mean(fake_logits - real_logits)

    # Generator loss
    def generator_loss_fn(g_params):
        fake_images = generator_state.apply_fn({'params': g_params}, z)
        fake_logits = discriminator_state.apply_fn({'params': discriminator_state.params}, fake_images)
        return -jnp.mean(fake_logits)

    # Update discriminator (critic)
    d_grad_fn = jax.value_and_grad(discriminator_loss_fn)
    d_loss, d_grads = d_grad_fn(discriminator_state.params)
    discriminator_state = discriminator_state.apply_gradients(grads=d_grads)

    # Update generator
    g_grad_fn = jax.value_and_grad(generator_loss_fn)
    g_loss, g_grads = g_grad_fn(generator_state.params)
    generator_state = generator_state.apply_gradients(grads=g_grads)

    return generator_state, discriminator_state, g_loss, d_loss


def lsgan_train_step(generator_state, discriminator_state, batch, z_dim, rng_key):
    # Sample random noise
    z = jax.random.normal(rng_key, (batch['real_images'].shape[0], z_dim))

    # Discriminator loss
    def discriminator_loss_fn(d_params):
        fake_images = generator_state.apply_fn({'params': generator_state.params}, z)
        real_logits = discriminator_state.apply_fn({'params': d_params}, batch['real_images'])
        fake_logits = discriminator_state.apply_fn({'params': d_params}, fake_images)
        real_loss = jnp.square(real_logits - 1)
        fake_loss = jnp.square(fake_logits)
        return jnp.mean(0.5 * (real_loss + fake_loss))

    # Generator loss
    def generator_loss_fn(g_params):
        fake_images = generator_state.apply_fn({'params': g_params}, z)
        fake_logits = discriminator_state.apply_fn({'params': discriminator_state.params}, fake_images)
        return jnp.mean(0.5 * jnp.square(fake_logits - 1))

    # Update discriminator
    d_grad_fn = jax.value_and_grad(discriminator_loss_fn)
    d_loss, d_grads = d_grad_fn(discriminator_state.params)
    discriminator_state = discriminator_state.apply_gradients(grads=d_grads)

    # Update generator
    g_grad_fn = jax.value_and_grad(generator_loss_fn)
    g_loss, g_grads = g_grad_fn(generator_state.params)
    generator_state = generator_state.apply_gradients(grads=g_grads)

    return generator_state, discriminator_state, g_loss, d_loss


def reinforce_train_step(policy_state, trajectory, optimizer):
    def policy_loss_fn(policy_params):
        # Extract observations and actions from the trajectory
        observations, actions, rewards = trajectory
        # Log probabilities of the actions under the policy
        log_probs = policy_state.apply_fn({'params': policy_params}, observations)
        selected_log_probs = jnp.take_along_axis(log_probs, actions[..., None], axis=-1).squeeze(-1)
        # REINFORCE loss
        loss = -jnp.mean(jnp.sum(selected_log_probs * rewards, axis=-1))
        return loss
    
    grad_fn = jax.value_and_grad(policy_loss_fn)
    loss, grads = grad_fn(policy_state.params)
    updates, new_opt_state = optimizer.update(grads, policy_state.opt_state)
    new_policy_state = policy_state.apply_gradients(grads=updates, opt_state=new_opt_state)
    
    return new_policy_state, loss


def dqn_train_step(q_network_state, batch, optimizer, gamma=0.99):
    def q_loss_fn(q_params):
        # Extract observations, actions, next observations, and rewards from the batch
        observations, actions, next_observations, rewards, dones = batch
        # Compute Q values for current observations
        q_values = q_network_state.apply_fn({'params': q_params}, observations)
        # Select the Q value for the action taken
        q_values = jnp.take_along_axis(q_values, actions[..., None], axis=-1).squeeze(-1)
        # Compute Q values for next observations
        next_q_values = q_network_state.apply_fn({'params': q_params}, next_observations)
        # Take max over next Q values for the TD target
        max_next_q_values = jnp.max(next_q_values, axis=-1)
        # Compute the target Q values
        target_q_values = rewards + gamma * max_next_q_values * (1 - dones)
        # DQN loss
        loss = jnp.mean(jnp.square(q_values - target_q_values))
        return loss
    
    grad_fn = jax.value_and_grad(q_loss_fn)
    loss, grads = grad_fn(q_network_state.params)
    updates, new_opt_state = optimizer.update(grads, q_network_state.opt_state)
    new_q_network_state = q_network_state.apply_gradients(grads=updates, opt_state=new_opt_state)
    
    return new_q_network_state, loss


def ppo_train_step(policy_state, value_state, batch, policy_optimizer, value_optimizer, clip_ratio=0.2):
    observations, actions, advantages, log_probs_old, returns = batch
    
    def policy_loss_fn(policy_params):
        log_probs = policy_state.apply_fn({'params': policy_params}, observations, actions)
        ratio = jnp.exp(log_probs - log_probs_old)
        clipped_advantages = jnp.clip(ratio, 1 - clip_ratio, 1 + clip_ratio) * advantages
        loss = -jnp.mean(jnp.minimum(ratio * advantages, clipped_advantages))
        return loss
    
    def value_loss_fn(value_params):
        values = value_state.apply_fn({'params': value_params}, observations)
        loss = jnp.mean(jnp.square(values - returns))
        return loss
    
    # Update policy
    policy_grad_fn = jax.value_and_grad(policy_loss_fn)
    policy_loss, policy_grads = policy_grad_fn(policy_state.params)
    policy_updates, new_policy_opt_state = policy_optimizer.update(policy_grads, policy_state.opt_state)
    new_policy_state = policy_state.apply_gradients(grads=policy_updates, opt_state=new_policy_opt_state)
    
    # Update value function
    value_grad_fn = jax.value_and_grad(value_loss_fn)
    value_loss, value_grads = value_grad_fn(value_state.params)
    value_updates, new_value_opt_state = value_optimizer.update(value_grads, value_state.opt_state)
    new_value_state = value_state.apply_gradients(grads=value_updates, opt_state=new_value_opt_state)
    return new_policy_state, new_value_state, policy_loss, value_loss

In [39]:
import jax
import optax
import jax.numpy as jnp
from jax.nn import softmax
from flax import linen as nn
from jax.nn.initializers import lecun_normal

class SimpleNN(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=512, kernel_init=lecun_normal())(x)
        x = nn.relu(x)
        x = nn.Dense(features=10, kernel_init=lecun_normal())(x)
        return x
    
def train_step(state, batch):
    def loss_fn(params):
        logits = state.apply_fn({'params': params}, batch['inputs'])
        loss = -jnp.mean(jnp.sum(softmax(logits) * batch['targets'], axis=-1))
        return loss
    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    return state, loss

trainer = DataParallelTrainer(model=SimpleNN(), 
                            input_shape=(batch_size_per_device, input_dim),
                            train_step=train_step,
                            optax_optimizer=optax.adam,
                            learning_rate=learning_rate,
                            weights_filename="params.pkl")

trainer.train(10, train_loader, val_loader)