In [1]:
import jax
from jax import lax,random,numpy as jnp

import flax
from flax.core import freeze, unfreeze
from flax import linen as nn
from flax.training import train_state

# import haiku as hk

import optax


from torchvision.datasets import MNIST
from torch.utils.data import DataLoader

import functools
from typing import Any,Callable,Sequence,Optional

import numpy as np
import matplotlib.pyplot as plt

In [None]:
#input is CxHxW latent image, transform into patches of size Nx(P^2*C)

def patchify(lat_img, p):
    C, H, W = lat_img.shape
    assert (H % p == 0 and W % p == 0)
    
    lat_img = lat_img.reshape(C, H//p, p, W//p, p)

    lat_img = lat_img.transpose(1, 3, 2, 4, 0)

    patches = lat_img.reshape(-1, p*p*C)
    return patches

inp = random.normal(random.PRNGKey(23), (3,4,32,32))

pat = jax.vmap(patchify, in_axes=(0,None))(inp, 4)
# pat = patchify(inp,4)
# print(pat.shape)

def batch_patchify(batch_lat_img, p):
    return jax.vmap(patchify, in_axes=(0,None))(batch_lat_img,p)

#TODO: Positional encoding
class EmbedPatch(nn.Module):
    # patchdim: int
    embed_dim: int

    def setup(self):
        self.layer = nn.Dense(self.embed_dim)

    # @nn.compact
    def __call__(self, x):
        return self.layer(x)
    
embed = EmbedPatch(embed_dim=32)
key = random.PRNGKey(23)
params = embed.init(key, pat)
output = embed.apply(params, pat)
print(output.shape)

class MHA(nn.Module):
    num_heads: int
    embed_dim: int

    def setup(self):
        assert self.embed_dim%self.num_heads == 0, "embed_dim not divisible by num_heads"
        self.W = nn.Dense(self.embed_dim)
        self.K = nn.Dense(self.embed_dim)
        self.Q = nn.Dense(self.embed_dim)
        self.W0 = nn.Dense(self.embed_dim)

    def __call__(self, x):
        #Assume x has shape (Batches, Seq_len, embed_dim)
        B,S,_ = x.shape
        q = self.Q(x)
        w = self.W(x)
        # jnp.variance(x)
        k = self.K(x)

        head_dim = self.embed_dim//self.num_heads
        multi_q = q.reshape(B,S,self.num_heads,head_dim).transpose(0,2,1,3)
        multi_w = w.reshape(B,S,self.num_heads,head_dim).transpose(0,2,1,3)
        multi_k = k.reshape(B,S,self.num_heads,head_dim).transpose(0,2,1,3)
        
        attention = jnp.matmul(multi_q, multi_k.transpose(0,1,3,2))/jnp.sqrt(head_dim)
        attention = nn.softmax(attention,-1)
        z = jnp.matmul(attention,multi_w)
        multi_z = self.W0(z.transpose(0,2,1,3).reshape(B,S,self.embed_dim))

        return multi_z
    

mha = MHA(4,32)
key, key1 = random.split(key)
params = mha.init(key1, output) 
print(output)
output = mha.apply(params, output)
print(output.shape)



(3, 64, 32)


AttributeError: module 'jax.numpy' has no attribute 'variance'

In [None]:
#Test with MNIST

def custom_collate(batch):
    transposed_data = list(zip(*batch))
    # print((transposed_data))

    imgs = np.array(transposed_data[0])
    imgs = imgs.reshape(imgs.shape[0],1,imgs.shape[1],imgs.shape[2])
    labels = np.array(transposed_data[1])

    # print(len(imgs))

    return imgs, labels


train_dataset = MNIST(root='./train_mnist',train=True, download=True,transform=lambda x:(np.array(x, dtype=np.float32)))
test_dataset = MNIST(root='./test_mnist',train=False, download=True,transform=lambda x: np.ravel(np.array(x, dtype=np.float32)))
# print(type(train_dataset))
# print((train_dataset[0][0].shape))

BATCH_SIZE = 128

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=custom_collate, drop_last=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=custom_collate, drop_last=True)

batch_data = (next(iter(train_loader)))
# batch_data = next(iter(train_loader))
# batch_data = next(iter(train_loader))
print((batch_data[0].shape))

(128, 1, 28, 28)


In [None]:
class Model(nn.Module):
    embed_dim: int
    num_heads: int

    def setup(self):
        self.embedder = EmbedPatch(self.embed_dim)
        self.mha = MHA(self.num_heads,self.embed_dim)
        self.linear = nn.Dense(10)
        self.mlp = nn.Dense(self.embed_dim)
        self.layernorm = nn.LayerNorm()
        # self.relu = nn.relu()

    def __call__(self,x):
        activation = x
        activation = batch_patchify(activation, 7)
        activation = self.embedder(activation)
        activation = self.layernorm(activation)
        activation = self.mha(activation)
        activation = self.layernorm(activation)
        activation = self.mlp(activation)
        activation = nn.relu(activation)
        activation = self.mha(activation)
        # activation = self.layernorm(activation)
        activation = activation.reshape(activation.shape[0], -1)
        activation = self.linear(activation)
        return activation

In [None]:
from jax import grad, value_and_grad


NUM_EPOCHS = 10
model = Model(64,4)
dummy = random.normal(key, (1,1,28,28))
params = model.init(key, dummy)

def cross_entropy_loss(params,imgs, labels):
    # logits: (batch_size, num_classes)
    # labels: (batch_size,) with class indices
    # patches = patchify(imgs, 7)
    # embed = EmbedPatch(embed_dim=28)
    # key = random.PRNGKey(23)
    # params = embed.init(key, patches)
    # output = embed.apply(params, patches)
    logits = model.apply(params,imgs)
    # print(logits.shape)
    log_probs = jax.nn.log_softmax(logits)  # (batch_size, num_classes)
    one_hot_labels = jax.nn.one_hot(labels, num_classes=logits.shape[-1])
    loss = -jnp.sum(one_hot_labels * log_probs, axis=-1)  # (batch_size,)
    return loss.mean()

def loss(params, imgs, labels):
    output = model.apply(params, imgs)
    log_probs = jax.nn.log_softmax(output)
    one_hot_labels = jax.nn.one_hot(labels, num_classes=10)
    return -jnp.mean(log_probs*one_hot_labels)

def update(params, imgs, gt_labels, lr=0.5):
    l, grads = value_and_grad(loss)(params,imgs,gt_labels)
    return l, jax.tree.map(lambda p, g: p - lr*g, params, grads)

for epoch in range((NUM_EPOCHS)):

    for cnt, (imgs, labels) in enumerate(train_loader):
        # gt_labels = jax.nn.one_hot(labels,len(MNIST.classes))
        # print(imgs.shape)
        l, params = update(params, imgs, labels)

        # if cnt % 50 == 0:
        print(l)
    break



0.2491689
0.2794252
0.24887462
0.23497777
0.22202174
0.21747606
0.21461372
0.21179025
0.21887426
0.20959555
0.21154885
0.20637286
0.20592622
0.21477
0.21296501
0.21111849
0.2066931
0.1921023
0.19493876
0.18349828
0.20488504
0.18718891
0.18495299
0.20155847
0.22873612
0.22798553
0.20073119
0.18396893
0.17839997
0.17764737
0.16039579
0.17898393
0.18765707
0.19703157
0.17930885
0.16762269
0.17732841
0.19357608
0.20906101
0.1821375
0.16943093
0.17747033
0.16382714
0.17245486
0.18947083
0.18907173
0.21488492
0.20388447
0.1827116
0.16281569
0.17374888
0.15885599
0.1670336
0.1507827
0.14224575
0.14994673
0.15892018
0.15349555
0.14422388
0.15266816
0.17002921
0.17464897
0.18210593
0.1579622
0.14776433
0.14705795
0.13789193
0.1412973
0.16482149
0.20081027
0.18203034
0.19205974
0.16150193
0.14283237
0.143156
0.13999854
0.13022624
0.1294763
0.15009367
0.13433163
0.13568218
0.13955028
0.1428421
0.12009948
0.10765703
0.1302225
0.113560036
0.12559192
0.14373104
0.16420786
0.13334249
0.1374217
0.1349