In [1]:
import jax
from jax import lax,random,numpy as jnp

import flax
from flax.core import freeze, unfreeze
from flax import linen as nn
from flax.training import train_state

# import haiku as hk

import optax


from torchvision.datasets import MNIST
from torch.utils.data import DataLoader

import functools
from typing import Any,Callable,Sequence,Optional

import numpy as np
import matplotlib.pyplot as plt

In [4]:
#Define constants:

T = 1000
beta_start = 1e-4
beta_end = 0.02

beta_schedule = jax.numpy.linspace(beta_start, beta_end, T)

alpha_schedule = 1-beta_schedule
alpha_prefix_product = np.ones(T+1)
for i in range(1,1001):
    alpha_prefix_product[i] = alpha_prefix_product[i-1]*alpha_schedule[i]

print(alpha_prefix_product)

[1.00000000e+00 9.99880075e-01 9.99740243e-01 ... 4.11860819e-05
 4.03623599e-05 3.95551142e-05]


In [None]:
#input is CxHxW latent image, transform into patches of size Nx(P^2*C)

class Patchify(nn.Module):
    p: int

    @nn.compact
    def __call__(self, lat_img):
        B, C, H, W = lat_img.shape
        p = self.p
        assert (H % p == 0 and W % p == 0)
        
        lat_img = lat_img.reshape(B, C, H//p, p, W//p, p)
    
        lat_img = lat_img.transpose(0, 2, 4, 3, 5, 1)
    
        patches = lat_img.reshape(B,-1, p*p*C)
        return patches

# inp = random.normal(random.PRNGKey(23), (3,4,32,32))

# pat = jax.vmap(patchify, in_axes=(0,None))(inp, 4)
# pat = patchify(inp,4)
# print(pat.shape)

# def batch_patchify(batch_lat_img, p):
#     return jax.vmap(patchify, in_axes=(0,None))(batch_lat_img,p)



#TODO: Positional encoding
class EmbedPatch(nn.Module):
    # patchdim: int
    embed_dim: int

    def setup(self):
        self.layer = nn.Dense(self.embed_dim)

    # @nn.compact
    def __call__(self, x_t, t):
        return self.layer(x_t)

#CHATGPT
class sinusoidal_embedding(nn.Module):
    dim: int

    @nn.compact
    def __call__(self, timesteps):
        dim = self.dim
        half_dim = dim // 2
        freqs = jnp.exp(-jnp.arange(half_dim) * (jnp.log(10000.0) / (half_dim - 1)))
        args = timesteps[:, None] * freqs[None]  # [batch, half_dim]
        embedding = jnp.concatenate([jnp.sin(args), jnp.cos(args)], axis=-1)
        return embedding

sin = sinusoidal_embedding(dim = 32)
key = random.PRNGKey(23)
params = sin.init(key, jnp.array([1,2]))
output = sin.apply(params, jnp.array([1,2]))
print(output)
    
# embed = EmbedPatch(embed_dim=32)
# key = random.PRNGKey(23)
# params = embed.init(key, pat)
# output = embed.apply(params, pat)
# print(output.shape)

class MHA(nn.Module):
    num_heads: int
    embed_dim: int

    def setup(self):
        assert self.embed_dim%self.num_heads == 0, "embed_dim not divisible by num_heads"
        self.W = nn.Dense(self.embed_dim)
        self.K = nn.Dense(self.embed_dim)
        self.Q = nn.Dense(self.embed_dim)
        self.W0 = nn.Dense(self.embed_dim)

    def __call__(self, x):
        #Assume x has shape (Batches, Seq_len, embed_dim)
        B,S,_ = x.shape
        q = self.Q(x)
        w = self.W(x)
        k = self.K(x)

        head_dim = self.embed_dim//self.num_heads
        multi_q = q.reshape(B,S,self.num_heads,head_dim).transpose(0,2,1,3)
        multi_w = w.reshape(B,S,self.num_heads,head_dim).transpose(0,2,1,3)
        multi_k = k.reshape(B,S,self.num_heads,head_dim).transpose(0,2,1,3)
        
        attention = jnp.matmul(multi_q, multi_k.transpose(0,1,3,2))/jnp.sqrt(head_dim)
        attention = nn.softmax(attention,-1)
        z = jnp.matmul(attention,multi_w)
        multi_z = self.W0(z.transpose(0,2,1,3).reshape(B,S,self.embed_dim))

        return multi_z

class DiT_block(nn.Module):
    num_heads: int
    embed_dim: int

    def setup(self):
        self.layernorm = nn.LayerNorm()
        self.mha = MHA(num_heads = self.num_heads,embed_dim = self.embed_dim)
        self.ffd = nn.Sequential([
            nn.Dense(self.embed_dim * 4),
            nn.relu,
            nn.Dense(self.embed_dim),
        ])
        self.mlp = nn.Dense(6*self.embed_dim)
        

    def __call__(self, x_t, t_emb):
        #x_t.shape: (B,Seq_len, Seq_size)
        activation = x_t
        alpha1,beta1,gamma1,alpha2,beta2,gamma2 = jnp.split(self.mlp(t_emb), 6, axis=1)
        means = jnp.mean(activation, axis=-1)
        variances = jnp.var(activation,axis=-1)
        res = activation
        # activation = self.layernorm(x_t)
        #scale,shift
        activation = (activation-means[:, :, None])/variances[:, :, None]
        print(gamma1.shape)
        activation = (activation*gamma1[:,None,:])+beta1[:,None,:]

        activation = self.mha(activation)
        activation = activation*alpha1[:,None,:]
        activation = activation + res
        res2 = activation
        # activation = self.layernorm(activation)
        means2 = jnp.mean(activation, axis=-1)
        variances2 = jnp.var(activation,axis=-1)
        #scale,shift
        activation = (activation-means2[:, :, None])/variances2[:,:,None]
        activation = (activation*gamma2[:,None,:])+beta2[:,None,:]
        
        activation = self.ffd(activation)
        #scale
        activation = activation*alpha2[:,None,:]
        activation = activation+res2
        return activation
    

# mha = MHA(4,32)
# key, key1 = random.split(key)
# params = mha.init(key1, output) 

In [None]:
#THIS SECTION IS WRITTEN BY CHATGPT

import torch
from torch.utils.data import DataLoader, ConcatDataset
from torchvision import datasets, transforms

# Transforms (standard ImageNet preprocessing)
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Lambda(lambda t: (t * 2) - 1)
])

data_dir = "/kaggle/input/imagenet100"

# Collect train shards
train_folders = [f"{data_dir}/train.X{i}" for i in range(1, 5)]

train_datasets = [
    datasets.ImageFolder(root=folder, transform=transform) 
    for folder in train_folders
]

# Merge into one dataset
train_dataset = ConcatDataset(train_datasets)

# Validation dataset
val_dataset = datasets.ImageFolder(root=f"{data_dir}/val.X", transform=transform)

# DataLoaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2)
val_loader   = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=2)

# Quick check
images, labels = next(iter(train_loader))
print(f"Train batch images: {images.shape}, labels: {labels.shape}")


In [None]:
print(images[1][0])
patchify = Patchify(p=4)
key = random.PRNGKey(23)
images = jnp.array(images)
params = patchify.init(key, images)
output = patchify.apply(params, images)
print(images.shape)
print(output.shape)



In [None]:
class Model(nn.Module):
    num_heads: int
    embed_dim: int
    p: int
    n: int

    def setup(self):
        self.patchify = Patchify(self.p)
        self.layernorm = nn.LayerNorm()
        self.dit = DiT_block(num_heads = self.num_heads, embed_dim = self.embed_dim)
        self.sin_embed = sinusoidal_embedding(self.embed_dim)
        
    def __call__(self, x_t, t):
        time_embedding = self.sin_embed(t)
        print(time_embedding.shape)
        activation = self.patchify(x_t)  #shape: BxSeq_lenxSeq_size
        for i in range(self.n):
            activation = self.dit(activation, time_embedding)
        activation = self.layernorm(activation)


model = Model(12,48,4,1)
print(jnp.array(labels)[:,None].shape)
key = random.PRNGKey(23)
params = model.init(key, jnp.array(images), jnp.array(labels))
output = model.apply(params, jnp.array(images), jnp.array(labels))
print(output.shape)