## [Introduction] Scoring Mechanism for Document Classification

In this script, we perform document classification to categorize images into "handwritten" and "typed" documents. The scoring is done in multiple stages:

1. **Object Detection using DETR**: The initial step uses the DETR model to identify bounding boxes of labels in the images.

2. **Text Detection using CRAFT**: CRAFT is then used to further refine the bounding areas, focusing on text regions.

3. **Feature Extraction using TrOCR**: Each of the refined bounding areas is processed using TrOCR to extract image features.

4. **Classification**: A pre-trained custom classifier (decoder) takes these features to score the likelihood that the document is handwritten or typed.

Step-by-step explanation of the scoring mechanism:

1. **Initialize Dictionaries**: 
    - `score_sum_dict` keeps track of the cumulative scores for each type (handwritten and typed) for each document. 
    - `score_len_dict` keeps track of the number of bounding boxes considered for each type for each document.

2. **Processing Loop**:
    - For each bounding box in each document, we extract features using TrOCR.
    - These features are then input to the classifier which outputs a sigmoid activated score between 0 and 1. 
    - If the score rounds to 0, it is considered "handwritten," and if it rounds to 1, it is considered "typed."
    - The score is added to the corresponding category (handwritten or typed) in `score_sum_dict`, and the count is incremented in `score_len_dict`.

3. **Averaging Scores**:
    - Once all bounding boxes for all documents are processed, the script calculates the average score for each category (handwritten and typed) for each document.

4. **Final Classification**:
    - The average scores are compared, and the document is classified into the category with the higher average score. The images are then moved to their respective folders ("handwritten" or "typed").

This scoring mechanism allows us to evaluate the likelihood that a document is handwritten or typed based on multiple bounding boxes, providing a more robust classification.

## Import Modules

In [9]:
import os
import shutil
from collections import defaultdict
from tqdm import tqdm
from PIL import Image

import torch
import torch.nn as nn

from craft_text_detector import Craft

from transformers import (TrOCRProcessor, 
                        VisionEncoderDecoderModel)

# add parent directory to path so that we can import our python scripts from all subdirectories
cwd_prefix = "/projectnb/sparkgrp/ml-herbarium-grp/summer2023/kabilanm/ml-herbarium/trocr/evaluation-dataset/handwritten-typed-text-classification/"
import sys
sys.path.append(cwd_prefix)

import detr
from utils.utils import *

## Initialize DETR and CRAFT-Related Directories

In [6]:
detr_inputdir = '/projectnb/sparkgrp/ml-herbarium-grp/ml-herbarium-data/TROCR_Training/goodfiles/'
detr_outputdir = cwd_prefix+'data/Doc_Classification/intermediate_files/'
output_dir_craft = cwd_prefix+'data/Doc_Classification/input/'
cache_dir = cwd_prefix+'data/'

## DETR Inference

In [None]:
# Use the DETR model for inference (adopted from Freddie (https://github.com/freddiev4/comp-vision-scripts/blob/main/object-detection/detr.py))
detr_model = 'spark-ds549/detr-label-detection'
# The DETR model returns the bounding boxes of the lables indentified from the images
label_bboxes = detr.run(detr_inputdir, detr_outputdir, detr_model)

## Initialize CRAFT Model and Get Bounding Boxes

In [5]:
# initialize the CRAFT model
craft = Craft(output_dir = output_dir_craft, 
              export_extra = False, 
              text_threshold = .7, 
              link_threshold = .4, 
              crop_type="poly", 
              low_text = .3, 
              cuda = True)

# CRAFT on images to get bounding boxes
images = []
corrupted_images = []
no_segmentations = []
boxes = {}
count= 0
img_name = []
box = []
file_types = (".jpg", ".jpeg",".png")
    
for filename in tqdm(sorted(os.listdir(detr_outputdir))):
    if filename.endswith(file_types):
        image = detr_outputdir+filename
        try:
            img = Image.open(image) 
            img.verify() # Check that the image is valid
            bounding_areas = craft.detect_text(image)
            if len(bounding_areas['boxes']): #check that a segmentation was found
                images.append(image)
                boxes[image] = bounding_areas['boxes']
                
            else:
                no_segmentations.append(image)
        except (IOError, SyntaxError) as e:
            corrupted_images.append(image)

  polys = np.array(polys)
  polys_as_ratio = np.array(polys_as_ratio)
100%|██████████| 251/251 [04:00<00:00,  1.04it/s]


## Initialize Device

In [10]:
# Move the model to the device (CPU or GPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

## Initialize Processor and Models

In [11]:
# Define model and processor
processor = TrOCRProcessor.from_pretrained('microsoft/trocr-large-stage1', cache_dir=cache_dir)
model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-large-stage1', cache_dir=cache_dir)

# Freeze TrOCR layers
for param in model.parameters():
    param.requires_grad = False

# Define our custom classifier (also decoder)
classifier = nn.Sequential(
    
    nn.Conv2d(1, 16, kernel_size=1, stride=1),
    nn.ReLU(inplace=True),
    nn.MaxPool2d(2, 2),
    nn.Conv2d(16, 32, kernel_size=1, stride=1),
    nn.ReLU(inplace=True),
    nn.MaxPool2d(2, 2),
    nn.Conv2d(32, 32, kernel_size=1, stride=1),
    nn.ReLU(inplace=True),
    nn.MaxPool2d(2, 2),
    nn.Flatten(),
    nn.Linear(32 * (577 // 8) * (1024 // 8), 512),
    nn.ReLU(inplace=True),
    nn.Linear(512, 512),
    nn.ReLU(inplace=True),
    nn.Dropout(0.2),
    nn.Linear(512, 256),
    nn.ReLU(inplace=True),
    nn.Linear(256, 1)
)

Downloading (…)rocessor_config.json:   0%|          | 0.00/228 [00:00<?, ?B/s]

Could not find image processor class in the image processor config or the model config. Loading based on pattern matching with the model's feature extractor configuration.


Downloading (…)okenizer_config.json:   0%|          | 0.00/1.35k [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/957 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/4.24k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/2.43G [00:00<?, ?B/s]

Some weights of VisionEncoderDecoderModel were not initialized from the model checkpoint at microsoft/trocr-large-stage1 and are newly initialized: ['encoder.pooler.dense.weight', 'encoder.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Downloading (…)neration_config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

## Load Pretrained Classifier

In [14]:
classifier = torch.nn.DataParallel(classifier, [0]) # list(range(torch.cuda.device_count()))
classifier.load_state_dict(torch.load(cwd_prefix+"model/TrOCR_L_enc_feature_extraction_w_classifier_retrained.pth"))

<All keys matched successfully>

In [15]:
# Move Models to Device
model.to(device)
classifier.to(device)

DataParallel(
  (module): Sequential(
    (0): Conv2d(1, 16, kernel_size=(1, 1), stride=(1, 1))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(16, 32, kernel_size=(1, 1), stride=(1, 1))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1))
    (7): ReLU(inplace=True)
    (8): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (9): Flatten(start_dim=1, end_dim=-1)
    (10): Linear(in_features=294912, out_features=512, bias=True)
    (11): ReLU(inplace=True)
    (12): Linear(in_features=512, out_features=512, bias=True)
    (13): ReLU(inplace=True)
    (14): Dropout(p=0.2, inplace=False)
    (15): Linear(in_features=512, out_features=256, bias=True)
    (16): ReLU(inplace=True)
    (17): Linear(in_features=256, out_features=1, bias=True)
  )
)

## Initialize Scoring Dictionaries

In [16]:
score_sum_dict = defaultdict(lambda: [0, 0]) # file_name: (hw_confidence, typed_confidence)
score_len_dict = defaultdict(lambda: [0, 0]) # file_name: (hw_count, typed_count)

## Process Each Image and Compute Scores

In [17]:
for dir_ in os.listdir(output_dir_craft):
    for file in os.listdir(os.path.join(output_dir_craft, dir_)):
        
        key = dir_.split("_")[0]
        
        img = Image.open(output_dir_craft+dir_+"/"+file)
        
        pixel_values = processor(images=img, return_tensors="pt").pixel_values.to(device)
        encoder_outputs = model.encoder(pixel_values)
        
        image_representation = encoder_outputs.last_hidden_state

        classifier.eval()
        with torch.no_grad():
            classifier_output = classifier(image_representation.unsqueeze(1))
            
            pred_confidence = torch.sigmoid(classifier_output)
            predicted = torch.round(pred_confidence)
            
            if(predicted == 0):
                score_sum_dict[key][0] += 1-pred_confidence
                score_len_dict[key][0] += 1
            if(predicted == 1):
                score_sum_dict[key][1] += pred_confidence
                score_len_dict[key][1] += 1

In [18]:
score_sum_dict = dict(score_sum_dict)
score_len_dict = dict(score_len_dict)

In [19]:
score_avg_dict = defaultdict(lambda: [0, 0])

## Final Scoring

In [20]:
# aggregating and computing final scores
hw_score, typed_score = 0, 0

for sum_, len_ in zip(score_sum_dict.items(), score_len_dict.items()):
    if(len_[1][0] == 0):
        hw_score = 0
    elif(len_[1][1] == 0):
        typed_score = 0
    else:
        hw_score = sum_[1][0]/len_[1][0]
        typed_score = sum_[1][1]/len_[1][1]
    score_avg_dict[sum_[0]] = [hw_score, typed_score]

In [None]:
score_avg_dict = dict(score_avg_dict)
score_avg_dict

## Classify Files Based on Scores

Here, we copy the images to the respective directories based on the average confidence scores computed for each image.

In [None]:
output_dir = "data/Doc_Classification/output/"

for file_name, avg_scores in score_avg_dict.items():
    # print(detr_inputdir+file_name+".jpg", scores)
    # print(file_name, avg_scores)
    
    source_file = detr_inputdir+file_name+".jpg"
    
    # Copy the file using shutil.copy2 to the corresponding directory
    # based on the average prediction score
    if(avg_scores[0] >= avg_scores[1]):
        # print("handwritten")
        shutil.copy2(source_file, os.path.join(output_dir, "handwritten"))
        
    # add some bias here
    if(avg_scores[0] < avg_scores[1]):
        # print("typed")
        shutil.copy2(source_file, os.path.join(output_dir, "typed"))

## Count Output Files

In [23]:
! ls ../data/Doc_Classification/output/handwritten/ | wc -l
! ls ../data/Doc_Classification/output/typed/ | wc -l

117
125


## Cleanup (Optional) but recommended before every new run

In [23]:
# ! rm -rf ../data/Doc_Classification/output/handwritten/*
# ! rm -rf ../data/Doc_Classification/output/typed/*