In [1]:
from data_reader import Vocabulary, HWDBDatasetHelper, LMDBReader

# your path to data
train_path = r'/DATA/ichuviliaeva/ocr_data/train.lmdb'
test_path = r'/DATA/ichuviliaeva/ocr_data/test.lmdb'
gt_path = './gt.txt'

In [2]:
import cv2
import numpy as np
import torchvision
import wandb

import torch
from torch.utils.data import Dataset, DataLoader
from torch import nn
from torchvision import transforms

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
train_reader = LMDBReader(train_path)
train_reader.open()
train_helper = HWDBDatasetHelper(train_reader)

In [4]:
train_helper, val_helper = train_helper.train_val_split()

In [5]:
train_helper.size(), val_helper.size()

(2578433, 644609)

In [6]:
torch.manual_seed(42)

<torch._C.Generator at 0x7fb06602f370>

In [7]:
from task2pack.utils.train import show_train_plots, train_with_trainable_loss
from task2pack.utils.data import HWDBDataset 

from task2pack.models.resnet import ResNet12GrayscaleFeatPytorch
from task2pack.models.loss import CenterLoss

In [8]:
"""
model = ResNet34GrayscaleFeat(train_helper.vocabulary.num_classes())
"""
model = ResNet12GrayscaleFeatPytorch(train_helper.vocabulary.num_classes())

In [9]:
model_name = 'ResNet12GrayscaleFeatPytorch'

train_transfroms = nn.Sequential(
    transforms.Resize((128, 128))
)

val_transfroms = nn.Sequential(
    transforms.Resize((128, 128))
)

train_dataloader_config = {
    'batch_size': 512,
    'shuffle': True,
    'drop_last': True,
    'num_workers': 8,
}

test_dataloader_config = {
    'batch_size': 2048,
    'shuffle': False,
    'num_workers': 8,
}

training_config = {
    'lr': 1e-3,
    'epochs': 40,
    'milestones': [40, 50, 75],
    'gamma': 0.7,
    'weight_criterion': 0.3,
    'lr_criterion': 0.5,
}

device = 'cuda:2'
criterion = torch.nn.CrossEntropyLoss()
centerloss = CenterLoss(num_classes=train_helper.vocabulary.num_classes(), feat_dim=512)

wandb.init(
    project='ocr task 2',
    name='{} {} epochs with lr={} no augment'.format(model_name, training_config['epochs'], training_config['lr']),
    config={
        'train_dataloader_config': train_dataloader_config,
        'test_dataloader_config': test_dataloader_config,
        'training_config': training_config,
        'train_transforms': train_transfroms,
        'val_transforms': val_transfroms,

        "architecture": model_name,
        "dataset": "CASIA Offline Chinese Handwriting",
        "criterion": "Cross Entropy Loss + Centerloss",
        "optimizer": "Adam + SGD(Centerloss)",
    }
)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mvashchilkoav[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [10]:
train_dataset = HWDBDataset(train_helper, transforms=train_transfroms)
val_dataset = HWDBDataset(val_helper, transforms=val_transfroms)

In [None]:
train_losses, test_losses, train_centerloss, test_centerloss, trained_model = train_with_trainable_loss(
    train_dataset=train_dataset,
    test_dataset=val_dataset,
    model=model, 
    criterion=criterion,
    trainable_criterion=centerloss,
    train_dataloader_kwargs=train_dataloader_config,
    test_dataloader_kwargs=test_dataloader_config,
    training_kwargs=training_config,
    device=device,
    wandb_instance=wandb,
    eval_every=2,
)

	addmm_(Number beta, Number alpha, Tensor mat1, Tensor mat2)
Consider using one of the following signatures instead:
	addmm_(Tensor mat1, Tensor mat2, *, Number beta, Number alpha) (Triggered internally at ../torch/csrc/utils/python_arg_parser.cpp:1420.)
  distmat.addmm_(1, -2, x, self.centers.t())


Initial val: [regular_loss: 0.004350887003099979, trainable_loss: 0.0767397590527928, accuracy: 5.7399136530827216e-05]
Epoch 1:
Train loss: [regular: 0.00746226896521204 trainable: 0.15184186068554675]
Epoch 2:
Train loss: [regular: 0.0007994218671546974 trainable: 0.03814035727705972]
Val : [regular_loss: 0.00023147866785392293, trainable_loss: 0.005389864329863562, accuracy: 0.8907337626375058]
Epoch 3:
Train loss: [regular: 0.000551695716671635 trainable: 0.013713069831213493]
Epoch 4:
Train loss: [regular: 0.00046165847345572644 trainable: 0.006213338089851933]
Val : [regular_loss: 0.00015091135317931196, trainable_loss: 0.0011335323260755431, accuracy: 0.9229501915114434]
Epoch 5:
Train loss: [regular: 0.00040214950120102505 trainable: 0.0032547932806320052]
Epoch 6:
Train loss: [regular: 0.00035551926593102795 trainable: 0.0018976553016466107]
Val : [regular_loss: 0.00012135864119783903, trainable_loss: 0.00038819672168297504, accuracy: 0.9366918550625263]
Epoch 7:
Train loss: [

In [None]:
show_train_plots(train_losses, test_losses, 'ResNet34Grayscale CrossEntropy')
show_train_plots(train_centerloss, test_centerloss, 'ResNet34Grayscale Centerloss')

In [None]:
from course_ocr_t2.evaluate import evaluate
from tqdm import tqdm

test_path = r'/DATA/ichuviliaeva/ocr_data/test.lmdb'
pred_path = './pred.txt'

test_reader = LMDBReader(test_path)
test_reader.open()
test_helper = HWDBDatasetHelper(test_reader, prefix='Test')

test_transforms = nn.Sequential(
    transforms.Resize((128, 128)),
)

test_dataset = HWDBDataset(test_helper, transforms=test_transforms)
test_loader = DataLoader(test_dataset, batch_size=2048, shuffle=False, num_workers=8)

preds = []
trained_model.eval()
with torch.no_grad():
    for X, _ in tqdm(test_loader):
        logits, _ = trained_model(X.to(torch.float32).to(device))
        classes = torch.argmax(logits, dim=1).cpu().numpy()
        preds.extend(classes)
    
with open(pred_path, 'w') as f_pred:
    for idx, pred in enumerate(preds):
        name = test_helper.namelist[idx]
        cls = train_helper.vocabulary.class_by_index(pred)
        print(name, cls, file=f_pred)
        
test_accuracy = evaluate('./gt.txt', './pred.txt')
wandb.run.summary['test_accuracy'] = test_accuracy
wandb.run.summary['test_transforms'] = test_transforms

torch.save(trained_model.state_dict(), './model.pth')
wandb.save('./model.pth')
wandb.save('./pred.txt')

wandb.finish()