In [3]:
import argparse
import os
import sys
import numpy as np
import math
import pickle
import time
import cv2 as cv
import matplotlib
import matplotlib.pyplot as plt
import random
from cv2 import VideoWriter, VideoWriter_fourcc, imread

import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import Dataset, DataLoader
from torchvision import datasets
from torch.autograd import Variable
from torch.cuda.amp import autocast, GradScaler

import torch.nn as nn
import torch.nn.functional as F
import torch
import torchvision

import warnings

### Visual Token Pipeline Demo

In [4]:
b = 16
a = torch.randn(b, 3, 64, 64)

maxpool1 = nn.MaxPool2d(kernel_size=[2, 2], stride=2)
maxpool2 = nn.MaxPool2d(kernel_size=[8, 8], stride=1, padding=0)

conv1 = nn.Conv2d(in_channels=3, out_channels=256, kernel_size=[7, 7], stride=2, padding=3, bias=False)
conv2 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=[7, 7], stride=2, padding=3, bias=False)
conv3 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=[3, 3], stride=1, padding=1)
conv4 = nn.Conv2d(in_channels=512, out_channels=20000, kernel_size=[3, 3], stride=1, padding=1)


inter1 = nn.Linear(20000, 256)
inter2 = nn.Linear(256, 20000)
deconv1 = nn.ConvTranspose2d(in_channels=20000, out_channels=512, kernel_size=[7, 7])

tokenizer = nn.Embedding(num_embeddings=20000, embedding_dim=1024)

start_time=time.time()
print(a.shape)
a = conv1(a)
print(a.shape)
a = conv2(a)
print(a.shape)
a = maxpool1(a)
print(a.shape)
a = conv3(a)
print(a.shape)
a = conv4(a)
print(a.shape)
embedding = torch.argmax(a, dim=1)
print(embedding.shape)
embedding = embedding.view(b, embedding.shape[-1]*embedding.shape[-2])
print(embedding.shape)
embedding = tokenizer(embedding)
print(embedding.shape)
print(time.time()-start_time)

torch.Size([16, 3, 64, 64])
torch.Size([16, 256, 32, 32])
torch.Size([16, 512, 16, 16])
torch.Size([16, 512, 8, 8])
torch.Size([16, 512, 8, 8])
torch.Size([16, 20000, 8, 8])
torch.Size([16, 8, 8])
torch.Size([16, 64])
torch.Size([16, 64, 1024])
0.7681350708007812


### Pretraining Stage 1 Demo

In [5]:
b = 16
a = torch.randn(b, 3, 64, 64)

maxpool1 = nn.MaxPool2d(kernel_size=[2, 2], stride=2)
maxpool2 = nn.MaxPool2d(kernel_size=[8, 8], stride=1, padding=0)
upsample = nn.UpsamplingBilinear2d(scale_factor=2)

conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=[7, 7], stride=2, padding=3, bias=False)
conv2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=[7, 7], stride=2, padding=3, bias=False)
conv3 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=[3, 3], stride=1, padding=1)
conv4 = nn.Conv2d(in_channels=128, out_channels=20000, kernel_size=[3, 3], stride=1, padding=1)

inter1 = nn.Conv2d(in_channels=20000, out_channels=256, kernel_size=[3, 3], stride=2, padding=1, bias=False)
inter2 = nn.ConvTranspose2d(in_channels=256, out_channels=256, kernel_size=[4, 4], padding=0)
inter3 = nn.ConvTranspose2d(in_channels=256, out_channels=256, kernel_size=[4, 4], stride=2, padding=1)

deconv1 = nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=[4, 4], stride=2, padding=1)
deconv2 = nn.ConvTranspose2d(in_channels=512, out_channels=3, kernel_size=[4, 4], stride=2, padding=1)
#deconv3 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=[4, 4], stride=2, padding=1)
#deconv4 = nn.ConvTranspose2d(in_channels=128, out_channels=3, kernel_size=[4, 4], stride=2, padding=1)


start_time=time.time()

print(a.shape)
a = conv1(a)
print(a.shape)
a = conv2(a)
print(a.shape)
a = conv3(a)
print(a.shape)
a = conv4(a)
print(a.shape)

## Inter Layer. will be discarded in real use, for bottlenecking reason only.
a = inter1(a)
print(a.shape)
a = maxpool2(a)
print(a.shape)
a = inter2(a)
print(a.shape)
a = upsample(a)
print(a.shape)
a = inter3(a)
print(a.shape)
a = upsample(a)
print(a.shape)

## Deconv Layers, will be finetuned with the reformer.
## Has input shape of (b, 256, 1024) ==> all the embeddings for one image
## first, reconstruct an image of size 16*16 from the embeddings, it will have 512 channels
a = a.view(b, 16, 16, 1024)
print(a.shape)
a = a.transpose(1, 3).transpose(2, 3)
print(a.shape)
a = deconv1(a)
print(a.shape)
a = deconv2(a)
print(a.shape)

print(time.time()-start_time)

torch.Size([16, 3, 64, 64])
torch.Size([16, 64, 32, 32])
torch.Size([16, 128, 16, 16])
torch.Size([16, 128, 16, 16])
torch.Size([16, 20000, 16, 16])
torch.Size([16, 256, 8, 8])
torch.Size([16, 256, 1, 1])
torch.Size([16, 256, 4, 4])
torch.Size([16, 256, 8, 8])
torch.Size([16, 256, 16, 16])
torch.Size([16, 256, 32, 32])
torch.Size([16, 16, 16, 1024])
torch.Size([16, 1024, 16, 16])
torch.Size([16, 512, 32, 32])
torch.Size([16, 3, 64, 64])
6.295857667922974


### Pretraining Stage 2 Demo (idk if we need a stage 2)

In [6]:
b = 16
a = torch.randn(b, 3, 64, 64)

maxpool1 = nn.MaxPool2d(kernel_size=[2, 2], stride=2)
maxpool2 = nn.MaxPool2d(kernel_size=[8, 8], stride=1, padding=0)

conv1 = nn.Conv2d(in_channels=3, out_channels=256, kernel_size=[7, 7], stride=2, padding=3, bias=False)
conv2 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=[7, 7], stride=2, padding=3, bias=False)
conv3 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=[3, 3], stride=1, padding=1)
conv4 = nn.Conv2d(in_channels=512, out_channels=20000, kernel_size=[3, 3], stride=1, padding=1)


inter1 = nn.Linear(20000, 256)
inter2 = nn.Linear(256, 20000)
deconv1 = nn.ConvTranspose2d(in_channels=20000, out_channels=512, kernel_size=[7, 7], padding=0)

tokenizer = nn.Embedding(num_embeddings=20000, embedding_dim=512)

start_time=time.time()
print(a.shape)
a = conv1(a)
print(a.shape)
a = conv2(a)
print(a.shape)
a = conv3(a)
print(a.shape)
a = conv4(a)
print(a.shape)

embedding = torch.argmax(a, dim=1)
print('embedding', embedding.shape)
embedding = embedding.view(b, embedding.shape[-1]*embedding.shape[-2])
print('embedding', embedding.shape)
embedding = tokenizer(embedding)
print('embedding', embedding.shape)

torch.Size([16, 3, 64, 64])
torch.Size([16, 256, 32, 32])
torch.Size([16, 512, 16, 16])
torch.Size([16, 512, 16, 16])
torch.Size([16, 20000, 16, 16])
embedding torch.Size([16, 16, 16])
embedding torch.Size([16, 256])
embedding torch.Size([16, 256, 512])
