Skip to content

NZ99/transformer_in_transformer_flax

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

12 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Transformer in Transformer in JAX/Flax

This repository implements Transformer in Transformer, pixel level attention paired with patch level attention for image classification. It is heavily inspired by both lucidrains's Pytorch implementation and Google Brain's Vision Transformer repo.

AI Coffee Break with Letita

Install

$ pip install transformer-in-transformer-flax

Usage

from jax import random
from jax import numpy as jnp
from transformer_in_transformer_flax import TransformerInTransformer, TNTConfig

#example configuration for TNT-B
config = TNTConfig(
    num_classes = 1000,
    depth = 12,
    image_size = 224,
    patch_size = 16,
    transformed_patch_size = 4,
    inner_dim = 40,
    inner_heads = 4,
    inner_dim_head = 64,
    inner_r = 4,
    outer_dim = 640,
    outer_heads = 10,
    outer_dim_head = 64,
    outer_r = 4
)

rng = random.PRNGKey(seed=0)
model = TransformerInTransformer(config=config)
params = model.init(rng, jnp.ones((1, 224, 224, 3), dtype=config.dtype))
img = random.uniform(rng, (2, 224, 224, 3))
logits = model.apply(params, img) # (2, 1000)

Citations

@misc{han2021transformer,
    title   = {Transformer in Transformer}, 
    author  = {Kai Han and An Xiao and Enhua Wu and Jianyuan Guo and Chunjing Xu and Yunhe Wang},
    year    = {2021},
    eprint  = {2103.00112},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Languages