In [1]:
from src.data_processing.dataset import LmdbDataset, Database, Collate
from src.net.model import Model
from torch.utils.data import DataLoader
from src.data_processing.vocabulary import Vocabulary
from multiprocessing import cpu_count
# Project constants
NUM_WORKERS = cpu_count()
MAX_EPOCHS = 3
BATCH_SIZE = 128
DATABASE_TRAIN_PATH = '/mnt/s/CV/data_lmdb_release/training/ST'
DATABASE_VALID_PATH = '/mnt/s/CV/data_lmdb_release/validation'
# Model hyperparametres
D_MODEL = 256
NUM_HEADS = 4
INPUT_CHANNELS = 1
LR_MAX = 1e-4
LR_MIN = 1e-5
T_MAX = 3

In [2]:
vocab = Vocabulary()

train_db = Database(DATABASE_TRAIN_PATH, max_readers=NUM_WORKERS)
valid_db = Database(DATABASE_VALID_PATH, max_readers=NUM_WORKERS)

dataset_train = LmdbDataset(train_db, vocab, sample='train')
dataset_valid = LmdbDataset(valid_db, vocab, sample='valid')
dataset_test = LmdbDataset(valid_db, vocab, sample='test')

collate = Collate(pad_idx=vocab.char2idx['<PAD>'])
dataloader_train = DataLoader(dataset_train, batch_size=BATCH_SIZE, collate_fn=collate, shuffle=True, num_workers=NUM_WORKERS)
dataloader_valid = DataLoader(dataset_valid, batch_size=BATCH_SIZE, collate_fn=collate, shuffle=False, num_workers=NUM_WORKERS)
dataloader_test = DataLoader(dataset_test, batch_size=BATCH_SIZE, collate_fn=collate, shuffle=False, num_workers=NUM_WORKERS)

print(f'Vocabulary size: {len(vocab)}')
print(f'Train size: {len(dataset_train)}')
print(f'Val/test size: {len(dataset_valid)}/{len(dataset_test)}')

Vocabulary size: 72
Train size: 2761404
Val/test size: 3496/3496


In [3]:
item = next(iter(dataloader_train))
item[0].shape, item[1].shape

(torch.Size([128, 1, 224, 224]), torch.Size([128, 15]))

In [4]:
config = dict(
    vocab=vocab,
    d_model=D_MODEL,
    input_channels=INPUT_CHANNELS,
    lr_max=LR_MAX,
    lr_min=LR_MIN,
    t_max=T_MAX,
    num_heads=NUM_HEADS
)
model = Model(**config).cuda()

In [5]:
model.forward(item[0].cuda(), item[1].cuda()).shape

Visual features shape:  torch.Size([128, 15, 256])


torch.Size([128, 15, 72])