In [1]:
from src.data_processing.dataset import LmdbDataset, Database, Collate
from src.net.modules.encoder import ViTSTR
from torch.utils.data import DataLoader
from src.data_processing.vocabulary import Vocabulary
from torchvision import models
from torch import nn
NUM_WORKERS = 20
MAX_EPOCHS = 30
BATCH_SIZE = 2
DATABASE_TRAIN_PATH = '/mnt/s/CV/data_lmdb_release/training/MJ/MJ_train'
DATABASE_VALID_PATH = '/mnt/s/CV/data_lmdb_release/training/MJ/MJ_valid'
DATABASE_TEST_PATH = '/mnt/s/CV/data_lmdb_release/training/MJ/MJ_test'

In [2]:
vocab = Vocabulary()

train_db = Database(DATABASE_TRAIN_PATH)
valid_db = Database(DATABASE_VALID_PATH)
test_db = Database(DATABASE_TEST_PATH)

dataset_train = LmdbDataset(train_db, vocab)
dataset_valid = LmdbDataset(valid_db, vocab)
dataset_test = LmdbDataset(test_db, vocab)

collate = Collate(pad_idx=vocab.char2idx['<PAD>'])
dataloader_train = DataLoader(dataset_train, batch_size=BATCH_SIZE, collate_fn=collate, shuffle=True, num_workers=NUM_WORKERS)
dataloader_valid = DataLoader(dataset_valid, batch_size=BATCH_SIZE, collate_fn=collate, shuffle=False, num_workers=NUM_WORKERS)
dataloader_test = DataLoader(dataset_test, batch_size=BATCH_SIZE, collate_fn=collate, shuffle=True, num_workers=NUM_WORKERS)

print(f'Vocabulary size: {len(vocab)}')
print(f'Train size: {len(dataset_train)}')
print(f'Val/test size: {len(dataset_valid)}/{len(dataset_test)}')

Vocabulary size: 72
Train size: 7224586
Val/test size: 802731/891924


In [3]:
dataset_test[0]

(tensor([[[-0.2235, -0.1945, -0.1489,  ..., -0.1373, -0.1373, -0.1373],
          [-0.2235, -0.1945, -0.1489,  ..., -0.1373, -0.1373, -0.1373],
          [-0.2235, -0.1945, -0.1489,  ..., -0.1373, -0.1373, -0.1373],
          ...,
          [-0.1373, -0.1373, -0.1373,  ..., -0.1373, -0.1373, -0.1373],
          [-0.1373, -0.1373, -0.1373,  ..., -0.1373, -0.1373, -0.1373],
          [-0.1373, -0.1373, -0.1373,  ..., -0.1373, -0.1373, -0.1373]]]),
 tensor([ 1, 31, 24, 21, 26, 23, 21, 26, 19,  2]))

In [4]:
item = next(iter(dataloader_train))
item[0].shape, item[1].shape

(torch.Size([2, 1, 224, 224]), torch.Size([2, 11]))

In [5]:
item[1][0]

tensor([ 1, 34, 21, 31, 21, 27, 26, 21, 26, 19,  2])

In [10]:
model = ViTSTR(in_chans=1, embed_dim=384, num_classes=70)
# model.reset_classifier(num_classes=len(vocab))

In [11]:
model

ViTSTR(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(1, 384, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (patch_drop): Identity()
  (norm_pre): Identity()
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=384, out_features=1152, bias=True)
        (q_norm): Identity()
        (k_norm): Identity()
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=384, out_features=384, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): Identity()
      (drop_path1): Identity()
      (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=384, out_features=1536, bias=True)
        (act): GELU(approximate='none')
        (drop1): Dropout(p=0.0, inplace=False)
        (norm): Identity()
        (

In [10]:
model(item[0]).shape

torch.Size([2, 25, 72])