<a href="https://colab.research.google.com/github/WilliamAshbee/3d-synth-data/blob/main/vit_alternative_models.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
#https://github.com/lucidrains/vit-pytorch

In [2]:
!pip install vit_pytorch



In [26]:
import torch
from vit_pytorch import ViT

v = ViT(
    image_size = 32,
    patch_size = 8,
    num_classes = 2000,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

img = torch.randn(1, 3, 32, 32)

preds = v(img) # (1, 1000)
print('preds',preds.shape)


preds torch.Size([1, 2000])


In [4]:
# import torch
# from torchvision.models import resnet50

# from vit_pytorch.distill import DistillableViT, DistillWrapper

# teacher = resnet50(pretrained = True)

# v = DistillableViT(
#     image_size = 32,
#     patch_size = 8,
#     num_classes = 2000,
#     dim = 1024,
#     depth = 8,
#     heads = 16,
#     mlp_dim = 2048,
#     dropout = 0.1,
#     emb_dropout = 0.1
# )

# distiller = DistillWrapper(
#     student = v,
#     teacher = teacher,
#     temperature = 3,           # temperature of distillation
#     alpha = 0.5,               # trade between main loss and distillation loss
#     hard = False               # whether to use soft or hard distillation
# )

# img = torch.randn(2, 3, 32, 32)
# labels = torch.randint(0, 2000, (2,))

# loss = distiller(img, labels)
# loss.backward()

# # after lots of training above ...

# pred = v(img) # (2, 1000)
# print('preds',preds.shape)


In [5]:
import torch
from vit_pytorch.deepvit import DeepViT

v = DeepViT(
    image_size = 32,
    patch_size = 8,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

img = torch.randn(1, 3, 32, 32)

preds = v(img) # (1, 1000)
print('preds',preds.shape)


preds torch.Size([1, 1000])


In [6]:
import torch
from vit_pytorch.cait import CaiT

v = CaiT(
    image_size = 32,
    patch_size = 8,
    num_classes = 2000,
    dim = 1024,
    depth = 12,             # depth of transformer for patch to patch attention only
    cls_depth = 2,          # depth of cross attention of CLS tokens to patch
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1,
    layer_dropout = 0.05    # randomly dropout 5% of the layers
)

img = torch.randn(1, 3, 32, 32)

preds = v(img) # (1, 1000)
print('preds',preds.shape)


preds torch.Size([1, 2000])


In [7]:
import torch
from vit_pytorch.t2t import T2TViT

v = T2TViT(
    dim = 512,
    image_size = 32,
    depth = 8,
    heads = 16,
    mlp_dim = 512,
    num_classes = 2000,
    t2t_layers = ((7, 4), (3, 2), (3, 2)) # tuples of the kernel size and stride of each consecutive layers of the initial token to token module
)

img = torch.randn(1, 3, 32, 32)

preds = v(img) # (1, 1000)
print('preds',preds.shape)


preds torch.Size([1, 2000])


In [8]:
# import torch
# from vit_pytorch.cct import CCT

# model = CCT(
#         img_size=32,
#         embedding_dim=384,
#         n_conv_layers=2,
#         kernel_size=7,
#         stride=2,
#         padding=3,
#         pooling_kernel_size=3,
#         pooling_stride=2,
#         pooling_padding=1,
#         num_layers=14,
#         num_heads=6,
#         mlp_radio=3.,
#         num_classes=2000,
#         positional_embedding='learnable', # ['sine', 'learnable', 'none']
#         )


In [9]:
# import torch
# from vit_pytorch.cct import cct_14

# model = cct_14(
#         img_size=32,
#         n_conv_layers=1,
#         kernel_size=7,
#         stride=2,
#         padding=3,
#         pooling_kernel_size=3,
#         pooling_stride=2,
#         pooling_padding=1,
#         num_classes=2000,
#         positional_embedding='learnable', # ['sine', 'learnable', 'none']  
#         )

In [10]:
import torch
from vit_pytorch.cross_vit import CrossViT

v = CrossViT(
    image_size = 32,
    num_classes = 2000,
    depth = 4,               # number of multi-scale encoding blocks
    sm_dim = 192,            # high res dimension
    sm_patch_size = 8,      # high res patch size (should be smaller than lg_patch_size)
    sm_enc_depth = 2,        # high res depth
    sm_enc_heads = 8,        # high res heads
    sm_enc_mlp_dim = 2048,   # high res feedforward dimension
    lg_dim = 384,            # low res dimension
    lg_patch_size = 8,      # low res patch size
    lg_enc_depth = 3,        # low res depth
    lg_enc_heads = 8,        # low res heads
    lg_enc_mlp_dim = 2048,   # low res feedforward dimensions
    cross_attn_depth = 2,    # cross attention rounds
    cross_attn_heads = 8,    # cross attention heads
    dropout = 0.1,
    emb_dropout = 0.1
)

img = torch.randn(1, 3, 32, 32)

pred = v(img) # (1, 1000)
print('preds',preds.shape)


preds torch.Size([1, 2000])


In [11]:
import torch
from vit_pytorch.pit import PiT

v = PiT(
    image_size = 32,
    patch_size = 8,
    dim = 256,
    num_classes = 2000,
    depth = (3, 3, 3),     # list of depths, indicating the number of rounds of each stage before a downsample
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

# forward pass now returns predictions and the attention maps

img = torch.randn(1, 3, 32, 32)

preds = v(img) # (1, 1000)
print('preds',preds.shape)


preds torch.Size([1, 2000])


In [12]:
# import torch
# from vit_pytorch.levit import LeViT

# levit = LeViT(
#     image_size = 32,
#     num_classes = 2000,
#     stages = 3,             # number of stages
#     dim = (32, 64, 128),  # dimensions at each stage
#     depth = 4,              # transformer of depth 4 at each stage
#     heads = (4, 6, 8),      # heads at each stage
#     mlp_mult = 2,
#     dropout = 0.1
# )

# img = torch.randn(1, 3, 32, 32)

# preds = levit(img) # (1, 1000)
# print('preds',preds.shape)


In [13]:
# import torch
# from vit_pytorch.cvt import CvT

# v = CvT(
#     num_classes = 2000,
#     s1_emb_dim = 64,        # stage 1 - dimension
#     s1_emb_kernel = 7,      # stage 1 - conv kernel
#     s1_emb_stride = 4,      # stage 1 - conv stride
#     s1_proj_kernel = 3,     # stage 1 - attention ds-conv kernel size
#     s1_kv_proj_stride = 2,  # stage 1 - attention key / value projection stride
#     s1_heads = 1,           # stage 1 - heads
#     s1_depth = 1,           # stage 1 - depth
#     s1_mlp_mult = 4,        # stage 1 - feedforward expansion factor
#     s2_emb_dim = 192,       # stage 2 - (same as above)
#     s2_emb_kernel = 3,
#     s2_emb_stride = 2,
#     s2_proj_kernel = 3,
#     s2_kv_proj_stride = 2,
#     s2_heads = 3,
#     s2_depth = 2,
#     s2_mlp_mult = 4,
#     s3_emb_dim = 384,       # stage 3 - (same as above)
#     s3_emb_kernel = 3,
#     s3_emb_stride = 2,
#     s3_proj_kernel = 3,
#     s3_kv_proj_stride = 2,
#     s3_heads = 4,
#     s3_depth = 10,
#     s3_mlp_mult = 4,
#     dropout = 0
# )

# img = torch.randn(1, 3, 224, 224)

# pred = v(img) # (1, 1000)
# print('preds',preds.shape)


In [14]:
# import torch
# from vit_pytorch.twins_svt import TwinsSVT

# model = TwinsSVT(
#     num_classes = 2000,       # number of output classes
#     s1_emb_dim = 64,          # stage 1 - patch embedding projected dimension
#     s1_patch_size = 4,        # stage 1 - patch size for patch embedding
#     s1_local_patch_size = 7,  # stage 1 - patch size for local attention
#     s1_global_k = 7,          # stage 1 - global attention key / value reduction factor, defaults to 7 as specified in paper
#     s1_depth = 1,             # stage 1 - number of transformer blocks (local attn -> ff -> global attn -> ff)
#     s2_emb_dim = 128,         # stage 2 (same as above)
#     s2_patch_size = 2,
#     s2_local_patch_size = 7,
#     s2_global_k = 7,
#     s2_depth = 1,
#     s3_emb_dim = 256,         # stage 3 (same as above)
#     s3_patch_size = 2,
#     s3_local_patch_size = 7,
#     s3_global_k = 7,
#     s3_depth = 5,
#     s4_emb_dim = 512,         # stage 4 (same as above)
#     s4_patch_size = 2,
#     s4_local_patch_size = 7,
#     s4_global_k = 7,
#     s4_depth = 4,
#     peg_kernel_size = 3,      # positional encoding generator kernel size
#     dropout = 0.              # dropout
# )

# img = torch.randn(1, 3, 224, 224)

# pred = model(img) # (1, 1000)
# print('preds',preds.shape)


In [15]:
# import torch
# from vit_pytorch.regionvit import RegionViT

# model = RegionViT(
#     dim = (64, 128, 256, 512),      # tuple of size 4, indicating dimension at each stage
#     depth = (2, 2, 8, 2),           # depth of the region to local transformer at each stage
#     window_size = 7,                # window size, which should be either 7 or 14
#     num_classes = 2000,             # number of output lcasses
#     tokenize_local_3_conv = False,  # whether to use a 3 layer convolution to encode the local tokens from the image. the paper uses this for the smaller models, but uses only 1 conv (set to False) for the larger models
#     use_peg = False,                # whether to use positional generating module. they used this for object detection for a boost in performance
# )

# img = torch.randn(1, 3, 224, 224)

# pred = model(img) # (1, 1000)
# print('preds',preds.shape)


In [16]:
import torch
from vit_pytorch.nest import NesT

nest = NesT(
    image_size = 32,
    patch_size = 8,
    dim = 96,
    heads = 9,
    num_hierarchies = 3,        # number of hierarchies
    block_repeats = (8, 4, 1),  # the number of transformer blocks at each heirarchy, starting from the bottom
    num_classes = 2000
)

img = torch.randn(1, 3, 32, 32)

pred = nest(img) # (1, 1000)
print('preds',preds.shape)


preds torch.Size([1, 2000])


In [17]:
# import torch
# from vit_pytorch import ViT
# from vit_pytorch.mpp import MPP

# model = ViT(
#     image_size=32,
#     patch_size=8,
#     num_classes=1000,
#     dim=1024,
#     depth=8,
#     heads=16,
#     mlp_dim=2048,
#     dropout=0.1,
#     emb_dropout=0.1
# )

# mpp_trainer = MPP(
#     transformer=model,
#     patch_size=8,
#     dim=1024,
#     mask_prob=0.15,          # probability of using token in masked prediction task
#     random_patch_prob=0.30,  # probability of randomly replacing a token being used for mpp
#     replace_prob=0.50,       # probability of replacing a token being used for mpp with the mask token
# )

# opt = torch.optim.Adam(mpp_trainer.parameters(), lr=3e-4)

# def sample_unlabelled_images():
#     return torch.FloatTensor(20, 3, 32, 32).uniform_(0., 1.)

# for _ in range(1):
#     images = sample_unlabelled_images()
#     loss = mpp_trainer(images)
#     opt.zero_grad()
#     loss.backward()
#     opt.step()

# # save your improved network
# torch.save(model.state_dict(), './pretrained-net.pt')


In [18]:
# import torch
# from vit_pytorch import ViT, Dino

# model = ViT(
#     image_size = 256,
#     patch_size = 32,
#     num_classes = 1000,
#     dim = 1024,
#     depth = 6,
#     heads = 8,
#     mlp_dim = 2048
# )

# learner = Dino(
#     model,
#     image_size = 256,
#     hidden_layer = 'to_latent',        # hidden layer name or index, from which to extract the embedding
#     projection_hidden_size = 256,      # projector network hidden dimension
#     projection_layers = 4,             # number of layers in projection network
#     num_classes_K = 65336,             # output logits dimensions (referenced as K in paper)
#     student_temp = 0.9,                # student temperature
#     teacher_temp = 0.04,               # teacher temperature, needs to be annealed from 0.04 to 0.07 over 30 epochs
#     local_upper_crop_scale = 0.4,      # upper bound for local crop - 0.4 was recommended in the paper 
#     global_lower_crop_scale = 0.5,     # lower bound for global crop - 0.5 was recommended in the paper
#     moving_average_decay = 0.9,        # moving average of encoder - paper showed anywhere from 0.9 to 0.999 was ok
#     center_moving_average_decay = 0.9, # moving average of teacher centers - paper showed anywhere from 0.9 to 0.999 was ok
# )

# opt = torch.optim.Adam(learner.parameters(), lr = 3e-4)

# def sample_unlabelled_images():
#     return torch.randn(20, 3, 256, 256)

# for _ in range(1):
#     images = sample_unlabelled_images()
#     loss = learner(images)
#     opt.zero_grad()
#     loss.backward()
#     opt.step()
#     learner.update_moving_average() # update moving average of teacher encoder and teacher centers

# # save your improved network
# torch.save(model.state_dict(), './pretrained-net.pt')

In [19]:
import torch
from vit_pytorch.vit import ViT

v = ViT(
    image_size = 32,
    patch_size = 8,
    num_classes = 2000,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

# import Recorder and wrap the ViT

from vit_pytorch.recorder import Recorder
v = Recorder(v)

# forward pass now returns predictions and the attention maps

img = torch.randn(1, 3, 32, 32)
preds, attns = v(img)

print('preds',preds.shape)

# there is one extra patch due to the CLS token

print('attns',attns.shape) # (1, 6, 16, 65, 65) - (batch x layers x heads x patch x patch)

preds torch.Size([1, 2000])
attns torch.Size([1, 6, 16, 17, 17])


In [20]:
####research


In [21]:
pip install nystrom-attention




In [22]:
import torch
from vit_pytorch.efficient import ViT
from nystrom_attention import Nystromformer

efficient_transformer = Nystromformer(
    dim = 512,
    depth = 12,
    heads = 16,
    num_landmarks = 256
)

v = ViT(
    dim = 512,
    image_size = 32,
    patch_size = 8,
    num_classes = 2000,
    transformer = efficient_transformer
)

img = torch.randn(1, 3, 32, 32) # your high resolution picture
preds = v(img) # (1, 1000)
print('preds',preds.shape)


preds torch.Size([1, 2000])


In [23]:
!pip install x-transformers



In [24]:
import torch
from vit_pytorch.efficient import ViT
from x_transformers import Encoder

v = ViT(
    dim = 512,
    image_size = 32,
    patch_size = 8,
    num_classes = 2000,
    transformer = Encoder(
        dim = 512,                  # set to be the same as the wrapper
        depth = 12,
        heads = 16,
        ff_glu = True,              # ex. feed forward GLU variant https://arxiv.org/abs/2002.05202
        residual_attn = True        # ex. residual attention https://arxiv.org/abs/2012.11747
    )
)

img = torch.randn(1, 3, 32, 32)
preds = v(img) # (1, 1000)
print('preds',preds.shape)


preds torch.Size([1, 2000])
