In [None]:
from glob import glob
from config import CFG

In [None]:
IMG_FILES = glob(CFG.img_path + '/*.jpg')
ANN_FILES = glob(CFG.ann_path + '/*.xml')
len(IMG_FILES), len(ANN_FILES)

In [None]:
from data import build_df

In [None]:
df, classes = build_df(ANN_FILES)
cls2id = {cls_name: i for i, cls_name in enumerate(classes)}
id2cls = {i: cls_name for i, cls_name in enumerate(classes)}

print(len(classes))
df.head()

In [None]:
from data import split_df

In [None]:
train_df, valid_df = split_df(df)
print("Train size: ", train_df['id'].nunique())
print("Valid size: ", valid_df['id'].nunique())

In [None]:
from tokenizer import Tokenizer

In [None]:
tokenizer = Tokenizer(num_classes=len(classes), num_bins=CFG.num_bins, width=CFG.img_size, height=CFG.img_size, max_len=CFG.max_len)
CFG.pad_idx = tokenizer.PAD_code

In [None]:
from data import get_loaders

In [None]:
train_loader, valid_loader = get_loaders(train_df, valid_df, tokenizer, CFG.img_size, CFG.batch_size, CFG.max_len, tokenizer.PAD_code)

In [None]:
from models import Encoder, Decoder, EncoderDecoder

In [None]:
encoder = Encoder(model_name=CFG.model_name, pretrained=True, out_dim=256)
decoder = Decoder(vocab_size=tokenizer.vocab_size,
                  encoder_length=CFG.num_patches, dim=256, num_heads=8, num_layers=6)
model = EncoderDecoder(encoder, decoder)
model.to(CFG.device);

In [None]:
import torch
from torch import nn
from transformers import get_linear_schedule_with_warmup
from train_eval import train_eval

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=CFG.lr, weight_decay=CFG.weight_decay)

num_training_steps = CFG.epochs * (len(train_loader.dataset) // CFG.batch_size)
num_warmup_steps = int(0.05 * num_training_steps)
lr_scheduler = get_linear_schedule_with_warmup(optimizer,
                                               num_training_steps=num_training_steps,
                                               num_warmup_steps=num_warmup_steps)
criterion = nn.CrossEntropyLoss(ignore_index=CFG.pad_idx)

train_eval(model,
           train_loader,
           valid_loader,
           criterion,
           optimizer,
           lr_scheduler=lr_scheduler,
           step='batch',
           logger=None)