In [1]:
import torch.nn as nn

num_channels = 3
hidden_size = 768
patch_size = 16
conv = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)

In [110]:
import torch

fake_image = torch.rand((1, 3, 224, 168))

print(conv(fake_image).reshape(1, -1, 768).shape)

torch.Size([1, 140, 768])


In [111]:
from transformers import AutoFeatureExtractor, AutoModelForImageClassification

model = AutoModelForImageClassification.from_pretrained("google/vit-base-patch16-224")

In [112]:
from transformers.models.vit.modeling_vit import ViTPatchEmbeddings

embs = ViTPatchEmbeddings(model.config)

In [113]:
embs(fake_image).shape

ValueError: Input image size (224*168) doesn't match model (224*224).

In [158]:
# How to compute the positional embeddings

# BS, hidden dim, x patches, y patches
patches = conv(fake_image)
# print(patches.shape)

# To project to the right hidden embedding dim
hidden_dim = 768
x_projection = torch.nn.Linear(1, hidden_dim)
y_projection = torch.nn.Linear(1, hidden_dim)

batch_size, x_size, y_size = patches.shape[0], patches.shape[2], patches.shape[3]

# x_embeddings
patches_x_embeddings = torch.arange(x_size).view(batch_size, 1, -1) / x_size
patches_x_embeddings = x_projection(patches_x_embeddings.T).view(batch_size, x_size, hidden_dim)
patches_x_embeddings = patches_x_embeddings.expand(y_size, batch_size, x_size, hidden_dim)
patches_x_embeddings = patches_x_embeddings.reshape(batch_size, -1, hidden_dim)

# y_embeddings
patches_y_embeddings = torch.arange(y_size).view(batch_size, 1, -1) / y_size
patches_y_embeddings = y_projection(patches_y_embeddings.T).view(batch_size, y_size, hidden_dim)
patches_y_embeddings = patches_y_embeddings.expand(x_size, batch_size, y_size, hidden_dim)
patches_y_embeddings = patches_y_embeddings.reshape(batch_size, -1, hidden_dim)

patches_positional_embeddings = patches_x_embeddings + patches_y_embeddings

In [165]:
# Retaining the aspect ratio for a given effective resolution

img1 = torch.rand((1, 3, 1024, 512))
img2 = torch.rand((1, 3, 256, 272))

patches_1 = conv(img1)
patches_2 = conv(img2)
print(f'num_patches={patches_1.shape[-1]*patches_1.shape[-2]}')
print(f'num_patches={patches_2.shape[-1]*patches_2.shape[-2]}')

num_patches=2048
num_patches=272


In [35]:
1024 / 16

64.0

In [38]:
img1 = torch.rand((1, 3, 1024, 512))
.shape

torch.Size([1, 768, 2048])

In [45]:
# Retaining the aspect ratio for a given effective resolution
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import math
import cv2 

num_channels = 3
hidden_size = 768
patch_size = 16
conv = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)

effective_resolution = 512

for x_size, y_size in [(1024, 512), (1024, 1024), (256, 256), (1121, 456)]:
    imarray = torch.rand(x_size, y_size,3) * 255
    # im = Image.fromarray(imarray.astype('uint8'))

    aspect_ratio = x_size / y_size

    new_y = np.sqrt(effective_resolution**2/aspect_ratio)
    new_x = new_y * aspect_ratio
    
    
    rounded_y = math.floor(new_y)
    rounded_x = math.floor(new_x)
    # resized_img = im.resize(())
    res = cv2.resize(imarray, dsize=(rounded_y, rounded_x), interpolation=cv2.INTER_CUBIC)
    print(rounded_y, rounded_x)
    print(rounded_x / rounded_y, aspect_ratio)
    print(effective_resolution**2, rounded_x*rounded_y)
    img = torch.tensor(res, dtype=torch.float).view(3, rounded_x, rounded_y).unsqueeze(0)

    patches = conv(img)
    print(f'num_patches={patches.shape[-1]*patches.shape[-2]}')
    
    # Then we need to pad to the max

error: OpenCV(4.5.4) :-1: error: (-5:Bad argument) in function 'resize'
> Overload resolution failed:
>  - src is not a numpy array, neither a scalar
>  - Expected Ptr<cv::UMat> for argument 'src'


In [31]:
patches = patches.view(1, hidden_size, -1)
seq_length = patches.shape[-1]
max_patches = int((effective_resolution**2)/(patch_size**2))
torch.nn.functional.pad(patches, (0,max_patches-seq_length)).shape

torch.Size([1, 768, 1024])

In [46]:
imarray.numpy()

array([[[ 83.43414  ,   4.775781 , 217.92708  ],
        [ 66.16865  ,  37.061966 ,  31.308632 ],
        [ 25.140652 , 208.80331  , 107.35911  ],
        ...,
        [161.07523  , 149.33807  ,  30.532988 ],
        [150.54028  , 213.85788  ,  67.94908  ],
        [240.86948  ,   8.646132 ,  15.622634 ]],

       [[169.03717  , 221.36116  , 165.24219  ],
        [147.3549   , 162.00345  ,  56.119705 ],
        [242.51111  , 123.3208   , 119.07459  ],
        ...,
        [109.28696  ,  72.45307  ,  21.083914 ],
        [183.39742  ,  15.4071245, 145.85396  ],
        [226.925    ,  51.414204 , 104.07337  ]],

       [[ 41.385983 , 230.80757  , 224.42665  ],
        [162.5485   , 186.38916  ,  86.466805 ],
        [ 34.31891  , 153.07722  ,  56.078804 ],
        ...,
        [185.42879  , 202.80441  , 192.44173  ],
        [ 90.75602  , 245.34912  , 226.29155  ],
        [139.18805  ,  78.44048  ,  29.52753  ]],

       ...,

       [[248.85818  , 125.35989  ,  37.49262  ],
        [25

In [21]:
724/16

45.25

In [55]:
import math

image_height, image_width = 1024, 1024

# This will set the target resolution for resizing
max_patches = 2048
patch_height=patch_width=16

scale = math.sqrt(max_patches * (patch_height / image_height) * (patch_width / image_width))
num_feasible_rows = max(min(math.floor(scale * image_height / patch_height), max_patches), 1)
num_feasible_cols = max(min(math.floor(scale * image_width / patch_width), max_patches), 1)
resized_height = max(num_feasible_rows * patch_height, 1)
resized_width = max(num_feasible_cols * patch_width, 1)

print(resized_height, resized_width)

720 720


In [44]:
patch_height

16

In [45]:
target_resolution

524288

In [51]:
1024*512

524288

In [None]:
256**2

1024

In [3]:
from transformers import XLMRobertaTokenizer

tokenizer = XLMRobertaTokenizer.from_pretrained(
            "hyunwoongko/asian-bart-ecjk"
        )

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'MBartTokenizer'. 
The class this function is called from is 'XLMRobertaTokenizer'.


In [6]:
tokenizer.decode(tokenizer('A1')['input_ids'])

'<s> A<unk></s>'