In [3]:
from fastai.vision.all import *
from dataloader import BeforeBatchTransform, CreateBatchTransform
from vocab import Vocab
from torch.utils.data import Dataset, DataLoader
from tqdm.notebook import tqdm

In [33]:
from transformers import BeitFeatureExtractor, BeitModel, TrOCRProcessor,\
                        TrOCRConfig, TrOCRForCausalLM, VisionEncoderDecoderModel
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments, default_data_collator
from datasets import load_metric
from transformers import AdamW
import torch

In [6]:
DATA_PATH = Path('/home/bui.hai.minh.hieu/hieubhm_workspace/competition/ocr')

In [7]:
vocab = Vocab()

In [8]:
with open(DATA_PATH / 'train.txt') as f:
    lines = f.readlines()
label_dict = dict()
for line in lines:
    parts = line.split('\t')
    fn = parts[0]
    label = ''.join(parts[1:]).strip()
    label_dict[fn] = label

In [9]:
# remove = []
# for k, v in label_dict.items():
#     if len(v) == 0:
#         remove.append(k)
#         print(k)

In [10]:
class BKAIDataset(Dataset):
    def __init__(self, data_path, processor, vocab):
        self.root_dir = data_path
        self.processor = processor
        self.vocab = vocab
        self.fns = get_image_files(data_path)
    def __len__(self):
        return len(self.fns)

    def __getitem__(self, idx):
        # get file name + text 
        file_name = self.fns[idx]
        text = label_dict[file_name.name]
        # prepare image (resize + normalize)
        image = Image.open(file_name).convert("RGB")
        pixel_values = self.processor(image, return_tensors="pt").pixel_values
        # add labels  by encoding the text
        labels = self.vocab(text).input_ids
        # make sure that PAD tokens are ignored by the loss function
        labels = [label if label != self.vocab.pad else -100 for label in labels]

        encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
        return encoding

In [14]:
# processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
processor = BeitFeatureExtractor.from_pretrained("microsoft/beit-base-patch16-224-pt22k")

train_dataset = BKAIDataset(DATA_PATH / 'train_img',
                            processor, 
                            vocab)
eval_dataset = BKAIDataset(DATA_PATH / 'test_img',
                            processor, 
                            vocab)

In [26]:
cfg = TrOCRConfig(vocab_size=len(vocab))
encoder = BeitModel.from_pretrained("microsoft/beit-base-patch16-224-pt22k")
decoder = TrOCRForCausalLM(TrOCRConfig())
model = VisionEncoderDecoderModel(encoder=encoder, decoder=decoder)

Some weights of the model checkpoint at microsoft/beit-base-patch16-224-pt22k were not used when initializing BeitModel: ['lm_head.bias', 'lm_head.weight', 'layernorm.weight', 'layernorm.bias']
- This IS expected if you are initializing BeitModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BeitModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BeitModel were not initialized from the model checkpoint at microsoft/beit-base-patch16-224-pt22k and are newly initialized: ['beit.pooler.layernorm.bias', 'beit.pooler.layernorm.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [28]:
# set special tokens used for creating the decoder_input_ids from the labels
model.config.decoder_start_token_id = vocab.go
model.config.pad_token_id = vocab.pad
# make sure vocab size is set correctly
model.config.vocab_size = model.config.decoder.vocab_size

# set beam search parameters
model.config.eos_token_id = vocab.eos
model.config.max_length = 64
model.config.early_stopping = True
model.config.no_repeat_ngram_size = 3
model.config.length_penalty = 2.0
model.config.num_beams = 4


In [31]:
# training_args = Seq2SeqTrainingArguments(
#     predict_with_generate=True,
#     evaluation_strategy="epoch",
#     per_device_train_batch_size=8,
#     per_device_eval_batch_size=8,
#     fp16=True, 
#     output_dir="./",
#     logging_steps=2,
#     save_steps=1000,
#     eval_steps=200,
#     dataloader_num_workers=4,
#     label_smoothing_factor=0.1,
#     learning_rate=5e-5,
# )


In [34]:
cer_metric = load_metric("cer")

In [35]:
def compute_metrics(pred):
    labels_ids = pred.label_ids
    pred_ids = pred.predictions

    pred_str = vocab.batch_decode(pred_ids)
    labels_ids[labels_ids == -100] = vocab.pad
    label_str = processor.batch_decode(labels_ids)

    cer = cer_metric.compute(predictions=pred_str, references=label_str)

    return {"cer": cer}

In [37]:
# trainer = Seq2SeqTrainer(
#     model=model,
#     tokenizer=vocab,
#     args=training_args,
#     compute_metrics=compute_metrics,
#     train_dataset=train_dataset,
#     eval_dataset=eval_dataset,
#     data_collator=default_data_collator,
# )


Using amp half precision backend


In [None]:
# trainer.train()

In [None]:
device = device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True)
eval_dataloader = DataLoader(eval_dataset, batch_size=4)

In [None]:
model.to(device);

In [None]:
optimizer = AdamW(model.parameters(), lr=5e-5)

In [None]:
for epoch in range(10):  # tự sửa cho thích hợp nhé 
    model.train()
    train_loss = 0.0
    for batch in tqdm(train_dataloader):
      # get the inputs
        for k,v in batch.items():
            batch[k] = v.to(device)

      # forward + backward + optimize
        outputs = model(**batch)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        train_loss += loss.item()

    print(f"Loss after epoch {epoch}:", train_loss/len(train_dataloader))
    
   # evaluate
    model.eval()
    valid_cer = 0.0
    with torch.no_grad():
        for batch in tqdm(eval_dataloader):
            outputs = model.generate(batch["pixel_values"].to(device))
            cer = compute_cer(pred_ids=outputs, label_ids=batch["labels"])
            valid_cer += cer 

    print("Validation CER:", valid_cer / len(eval_dataloader))

model.save_pretrained(".")