In [None]:
import torch
import wandb
from sklearn.model_selection import train_test_split
from torchmetrics import CharErrorRate
from tqdm import tqdm

from dataset import CapchaDataset
from model import CRNN
from train import eval_epoch, train_epoch
from utils import Decoder, get_dataloader

In [None]:
config = {'n_epochs': 10}

wandb.init(project='ocr', config=config)

In [None]:
device = torch.device('cpu')
if torch.cuda.is_available():
    device = torch.device('cuda')

In [None]:
train_set = CapchaDataset((3, 5), samples=10000)
eval_set = CapchaDataset((3, 5), samples=1000)

train_loader = get_dataloader(
    dataset=train_set,
    batch_size=64,
    mode='train',
    num_workers=8,
)

eval_loader = get_dataloader(
    dataset=eval_set,
    batch_size=64,
    mode='eval',
    num_workers=8,
)

In [None]:
model = CRNN(n_classes=train_set.num_classes).to(device)

criterion = torch.nn.CTCLoss(blank=train_set.blank_label)

optimizer = torch.optim.AdamW(model.parameters())

decoder = Decoder(
    labels=['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '-'],
    blank_idx=train_set.blank_label,
)

cer = CharErrorRate()

In [None]:
best_eval_score = 1

checkpoint_filepath = 'best_model.pth'

for epoch in tqdm(range(config['n_epochs'])):
    train_loss, y_pred, y_true = train_epoch(
        dataloader=train_loader,
        model=model,
        criterion=criterion,
        optimizer=optimizer,
        decoder=decoder,
        device=device,
    )
    
    train_score = cer(y_pred, y_true)
    
    eval_loss, y_pred, y_true = eval_epoch(
        dataloader=eval_loader,
        model=model,
        criterion=criterion,
        decoder=decoder,
        device=device,
    )
    
    eval_score = cer(y_pred, y_true)
    
    wandb.log(
        {
            'Loss (train)': train_loss,
            'Loss (eval)': eval_loss,
            'CER (train)': train_score.item(),
            'CER (eval)': eval_score.item(),
        },
    )
    
    if eval_score < best_eval_score:
        best_eval_score = eval_score
        torch.save(model.state_dict(), checkpoint_filepath)

In [None]:
wandb.finish()