# Import Modules

In [1]:
import sys
from dataclasses import dataclass
import matplotlib.pyplot as plt

import jax
from jax import (
    Array,
    numpy as jnp,
    random as jrand
)
try:
    from flash_attn_jax import flash_mha
    USE_FLASH_ATT = True
except:
    USE_FLASH_ATT = False
import keras as nn; nn.utils.set_random_seed(42)
nn.mixed_precision.set_dtype_policy("mixed_float16")
import tensorflow as tf; tf.config.set_visible_devices([], 'GPU')
import tensorflow_datasets as tfds; tf.random.set_seed(42)

print("Python Version", sys.version); del sys
print(f"Keras Version {nn.__version__} with {nn.backend.backend()} backend \tJax Version {jax.__version__}")
print("Jax backend device", jax.default_backend())
print("Tensorflow & TFDS version:", tf.__version__, tfds.__version__)

Python Version 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]
Keras Version 3.0.5 with jax backend 	Jax Version 0.4.23
Jax backend device gpu
Tensorflow & TFDS version: 2.16.1 4.9.4+nightly


# Dataset

In [2]:
DATASET_NAME = "places365_small"
train_data, train_info = tfds.load(DATASET_NAME, split='train', with_info=True); len(train_data)
val_data, val_info = tfds.load(DATASET_NAME, split='validation', with_info=True); len(val_data)
test_data, test_info = tfds.load(DATASET_NAME, split='test', with_info=True); len(test_data)

2024-04-15 21:11:38.559175: W external/local_tsl/tsl/platform/cloud/google_auth_provider.cc:184] All attempts to get a Google authentication bearer token failed, returning an empty token. Retrieving token from files failed with "NOT_FOUND: Could not locate the credentials file.". Retrieving token from GCE failed with "FAILED_PRECONDITION: Error executing an HTTP request: libcurl code 6 meaning 'Couldn't resolve host name', error details: Could not resolve host: metadata.google.internal".


[1mDownloading and preparing dataset 29.27 GiB (download: 29.27 GiB, generated: 27.85 GiB, total: 57.13 GiB) to /home/vvy/tensorflow_datasets/places365_small/2.1.0...[0m


Dl Completed...: 0 url [00:00, ? url/s]

Dl Size...: 0 MiB [00:00, ? MiB/s]

NonMatchingChecksumError: Artifact http://data.csail.mit.edu/places/places365/test_256.tar, downloaded to /home/vvy/tensorflow_datasets/downloads/data.csail.mit.edu_places_places3_test_2563V2WJWGkuR51a6abtdCb08ZF5b9Pbsjn4JNCV38EOqE.tar.tmp.d57fef6dd09e42c2bd1af58b4c9845fa/test_256.tar, has wrong checksum:
* Expected: UrlInfo(size=4.41 GiB, checksum='037ee8180369bdde46636341b92900d4bcb8ea000c026a1fd3e0e9827a8702a1', filename='test_256.tar')
* Got: UrlInfo(size=1.01 GiB, checksum='b92f11e21cc30b923b0275ec18f7f7d3a7d1f1fb00c55fe1afb3fb1bb81bc021', filename='test_256.tar')
To debug, see: https://www.tensorflow.org/datasets/overview#fixing_nonmatchingchecksumerror

In [None]:
tfds.show_examples(train_data, train_info)

In [None]:
tfds.as_dataframe(val_data, val_info)

In [None]:
tfds.as_dataframe(test_data, test_info)

# Config

In [None]:
@dataclass
class config:
    # Transformer specifics
    maxlen:int

    d_model:int  
    num_heads:int 
    num_layers:int 
    num_classes:int
    dropout_rate:float

    use_flash_att:int 
    
    # ViT specifics
    patch_size:int 
    H:int
    W:int
    patch_size:int 
    
    # Tranining Args

# ViT Architechture

In [None]:
class PositionalEmbedding:
    """```
    Sinusoidal Fixed Positional Embeddings
    Args:
        maxlen:int
        dim:int
    sinusoidal_embeddings: 
        pos_emb: (1, maxlen, dim)
    ```"""
    def __init__(self, maxlen:int, dim:int):
        p, i = jnp.meshgrid(jnp.arange(float(maxlen)), jnp.arange(dim/2)*2)
        theta = (p/1e4**(i/dim)).T
    
        self.pos_emb = jnp.stack([jnp.sin(theta), jnp.cos(theta)], axis=-1)
        self.pos_emb = self.pos_emb.reshape((maxlen, dim))[None] # (1, maxlen, dim)
    
    def sinusoidal_embeddings(self):
        return self.pos_emb # (1, maxlen, dim)

In [None]:
class Attention(nn.Layer):
    """```
    Multi-head Attention
    Args:
        causal:bool
        config
    Input:
        x: shape(B, N, d_model)
        training: bool
    Output:N
        linear_att_out: shape(B, N, d_model)
    ```"""
    def __init__(
            self,
            causal:bool,
            config:config,
            **kwargs
    ):
        super().__init__(**kwargs)
        assert config.d_model % config.num_heads == 0
        self.flash = config.use_flash_att
        self.causal = causal
        self.num_heads = config.num_heads
        self.dim = config.d_model//config.num_heads
        
        self.wq = nn.layers.Dense(config.d_model, use_bias=False)
        self.wk = nn.layers.Dense(config.d_model, use_bias=False)
        self.wv = nn.layers.Dense(config.d_model, use_bias=False)
        self.dropout = nn.layers.Dropout(config.dropout_rate)

        self.wo = nn.layers.Dense(config.d_model)
        if causal and (not config.use_flash_att): # when causal and not using flash att
            self.causal_mask = jnp.triu(jnp.full(shape=(1, 1, config.maxlen, config.maxlen), fill_value=-jnp.inf), k=1)

    def call(
            self,
            x:Array, # (B, T, d_model)
            training:bool
    ):
        B, T, d_model = x.shape

        # compute q, k, v
        q = self.wq(x) # (B, T, d_model)
        k = self.wk(x) # (B, T, d_model)
        v = self.wv(x) # (B, T, d_model)
        
        # compute attention weights
        if self.flash:
            shape = (B, T, self.num_heads, self.dim)
            q, k, v = q.reshape(shape), k.reshape(shape), v.reshape(shape) # (B, T, h, dim)
            att_out = flash_mha(q, k, v, softmax_scale=None, is_causal=self.causal) # (B, T, h, dim)
        else:
            shape = (B, self.num_heads, T, self.dim)
            q, k, v = q.reshape(shape), k.reshape(shape), v.reshape(shape) # (B, h, T, dim)
            att_wei = (q @ jnp.matrix_transpose(k))/self.dim**0.5 # (B, h, T, T) <= (B, h, T, dim) @ (B, h, T, dim).transpose(2, 3)
            # causal mask
            if self.causal:
                att_wei += self.causal_mask[:, :, :T, :T] # (B, h, T, T)
            att_wei = jax.nn.softmax(att_wei, axis=-1) # (B, h, T, T)
            # apply attention weights to v
            att_out = att_wei @ v # (B, h, T, T) @ (B, h, T, dv) => (B, h, T, dv)

        # combine heads
        att_out = att_out.reshape((B, T, d_model)) # (B, T, h*dim) ==> (B, T, d_model)

        # linear of att_out
        linear_att_out = self.wo(att_out)
        linear_att_out = self.dropout(linear_att_out, training=training) # (B, T, d_model)
        return linear_att_out

In [None]:
class TransformerBlock(nn.Model):
    """```
    TransformerBlock
    Args:
        causal:bool
        config
    Inputs: 
        inputs: shape(B, T, d_model)
    Outputs:
        outputs: shape(B, T, d_model)
    ```"""
    def __init__(
            self, 
            causal:bool, 
            config:config,
            **kwargs
    ):
        super().__init__(**kwargs)
        dff_in = 4*config.d_model
        self.norm1 = nn.layers.LayerNormalization(epsilon=1e-5)
        self.mha = Attention(causal, config)

        self.ffn = nn.Sequential([
            nn.layers.Dense(int(dff_in)),
            nn.layers.Activation(lambda x: nn.activations.gelu(x, approximate=True)),
            nn.layers.Dense(config.d_model),
            nn.layers.Dropout(config.dropout_rate)
        ])
        self.norm2 = nn.layers.LayerNormalization(epsilon=1e-5)
        
    def call(self, x:Array, training:bool):
        z = x + self.mha(self.norm1(x), training=training)
        y = z + self.ffn(self.norm2(z), training=training)
        return y # (B, T, d_model)

In [None]:
class VIT(nn.Model):
    def __init__(self, config:config):
        self.P = config.patch_size
        self.N = config.H*config.W//config.patch_size

        self.pos_embed = PositionalEmbedding(
            maxlen=1+self.N,
            dim=config.d_model
        ).sinusoidal_embeddings() # (1, 1+N, d_model)

        self.proj_flattened_patches = nn.layers.Dense(
            config.d_model
        )
        self.class_emb = self.add_weight(shape=(1, config.d_model))

        self.encoder_layers = [
            TransformerBlock(causal=False, config=config)
            for _ in range(config.num_layers)
        ]
        self.norm = nn.layers.LayerNormalization(epsilon=1e-5)
        self.mlp_head = nn.layers.Dense(config.num_classes)

    def patchify(self, x:Array): # (B, H, W, C)
        return x.reshape((-1, self.N, (self.P**2)*x.shape[-1]))
        
    def call(self, x:Array): # (B, H, W, C)
        x = self.patchify(x) # (B, N, (P**2)*C)
        x = self.proj_flattened_patches(x) # (B, N, d_model)
        x = jnp.concatenate((self.class_emb, x), axis=1) # (B, 1+N, d_model)

        x += self.pos_embed # (B, 1+N, d_model)

        for encoder_layer in self.encoder_layers:
            x = encoder_layer(x) # (B, 1+N, d_model)

        class_rep = x[:, 0, :] # (B, d_model)
        x = self.norm(class_rep) # (B, d_model)
        x = self.mlp_head(x) # (B, num_classes)
        return x