<a href="https://colab.research.google.com/github/NaxAlpha/docnet/blob/master/training/docnet_classifier_training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# @title Setup variables and configuration
import os
from google.colab import auth

ROOT_DIR='gs://nax-temp/docnet/' #@param {type:"string"}
DATA_DIR = os.path.join(ROOT_DIR, 'data')
TEMP_DIR = '/tmp/docnet'

if 'gs://' in ROOT_DIR:
    auth.authenticate_user()
!pip install transformers



---

# Part 1 - Data Setup

In [0]:
# @title Load dataset
# https://github.com/Quicksign/ocrized-text-dataset/blob/master/tobacco3482.sh
!echo "Setting up Tesseract..."
!mkdir -p '$TEMP_DIR'
!apt install -y tesseract-ocr
!pip install pytesseract

!echo "Downloading files from UMIACS server..."
!wget -O $TEMP_DIR/1.zip -nv --show-progress -c http://lampsrv02.umiacs.umd.edu/projdb/edit/userfiles/datasets/Tobacco3482_1.zip
!wget -O $TEMP_DIR/2.zip -nv --show-progress -c http://lampsrv02.umiacs.umd.edu/projdb/edit/userfiles/datasets/Tobacco3482_2.zip

!echo "Decompressing .zip archives..."
!unzip -d '$TEMP_DIR' -q -n '$TEMP_DIR'/*.zip

!echo "Moving content to data directory..."
!rm '$TEMP_DIR'/*.zip
!gsutil -m cp -r '$TEMP_DIR' '$DATA_DIR'

In [0]:
# @title Perform OCR on document images
# https://github.com/Quicksign/ocrized-text-dataset/blob/master/to_text.py
import argparse
import os
import pytesseract

import tensorflow as tf
from builtins import str
from joblib import Parallel, delayed
from PIL import Image
from tqdm import tqdm_notebook as tqdm
from threading import Thread


def to_text(filename, lang="eng", format_="txt", ignore_error=False):
    try:
        base, fn = os.path.split(filename)
        temp_filename = os.path.join(TEMP_DIR, fn)
        
        basename, ext = os.path.splitext(temp_filename)
        target = basename + "." + format_
        _, fn = os.path.split(target)
        new_filename = os.path.join(base, fn)
        # implement caching mechanism
        # Check if file exists - skip it
        if tf.io.gfile.exists(new_filename):
            return

        tf.io.gfile.copy(filename, temp_filename, True)
        im = Image.open(temp_filename)
        if format_ == "txt":
            tess_output = pytesseract.image_to_string(im, lang=lang, config="--psm 3 --oem 1")
        elif format_ == "hocr":
            tess_output = pytesseract.image_to_pdf_or_hocr(
                im, lang=lang, config="--psm 3 --oem 1", extension=format_
            )
        with open(target, "w") as fp:
            fp.write(str(tess_output))
        tf.io.gfile.copy(target, new_filename, True)
    except Exception as e:
        if ignore_error:
            print("Error: {}".format(e))
        else:
            raise e

def worker(files):
    for f in files:
        to_text(f)
        prog.update(1)

filenames = tf.io.gfile.glob(os.path.join(DATA_DIR, '**', '*.tif'))
# Only run on first 100 files
# because it can take very long to run on colab
# ideally would run on 16 core machine
filenames = filenames[:100]  
prog = tqdm(filenames)

th = []
n = 4
m = len(filenames)//n
for i in range(n+1):
    chunk = filenames[i*m:(i+1)*m]
    t = Thread(target=worker, args=(chunk, ))
    t.start()
    th.append(t)
for t in th:
    t.join()

HBox(children=(IntProgress(value=0, max=3482), HTML(value='')))

---

# Part 2 - Data analysis and Pre Processing

In [0]:
!mkdir -p data
!gsutil -m cp -r $DATA_DIR/* data/

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Copying gs://nax-temp/docnet/data/Form/2054647558.tif...
Copying gs://nax-temp/docnet/data/Form/2054911394.tif...
Copying gs://nax-temp/docnet/data/Form/2054911394.txt...
Copying gs://nax-temp/docnet/data/Form/2054944258.tif...
Copying gs://nax-temp/docnet/data/Form/2054944258.txt...
Copying gs://nax-temp/docnet/data/Form/2055400325.tif...
Copying gs://nax-temp/docnet/data/Form/2055400325.txt...
Copying gs://nax-temp/docnet/data/Form/2056288404.tif...
Copying gs://nax-temp/docnet/data/Form/2056288404.txt...
Copying gs://nax-temp/docnet/data/Form/2056598962.tif...
Copying gs://nax-temp/docnet/data/Form/2056598962.txt...
Copying gs://nax-temp/docnet/data/Form/2056599150.tif...
Copying gs://nax-temp/docnet/data/Form/2056599150.txt...
Copying gs://nax-temp/docnet/data/Form/2057065020.tif...
Copying gs://nax-temp/docnet/data/Form/2057065020.txt...
Copying gs://nax-temp/docnet/data/Form/2057332433.tif...
Copying gs://nax-temp/d

In [2]:
# @title Get file names and labels
import tensorflow as tf

image_files = tf.io.gfile.glob(os.path.join('data/**', '*.tif'))
text_files = [os.path.splitext(fn)[0] + '.txt' for fn in image_files]
labels = [fn.split('/')[-2] for fn in image_files]
classes = list(set(labels))
labels = [classes.index(lbl) for lbl in labels]

In [3]:
# @title Visualize class distribution
from collections import Counter

class_dist = Counter(labels)

print('Number of Files:', len(image_files))
print('Number of Classes:', len(classes))

print()
print('Class distribution:')
print()

def print_class_dist(image_files):
    labels = [fn.split('/')[-2] for fn in image_files]
    class_dist = Counter(labels)
    print('{:>12}  {:>12}  {:>13}'.format('-'*12, '-'*12, '-'*12))
    print('{:>12}  {:>12}  {:>13}'.format('Class', 'Count', 'Percentage'))
    print('{:>12}  {:>12}  {:>13}'.format('-'*12, '-'*12, '-'*12))
    for lbl, cnt in class_dist.items():
        perc = round(100*cnt/len(labels), 2)
        print('{:>12}  {:>12}  {:>12}%'.format(lbl, cnt, perc))

print_class_dist(image_files)

Number of Files: 3482
Number of Classes: 10

Class distribution:

------------  ------------   ------------
       Class         Count     Percentage
------------  ------------   ------------
       Email           599          17.2%
  Scientific           261           7.5%
      Resume           120          3.45%
        Note           201          5.77%
        ADVE           230          6.61%
        Memo           620         17.81%
        Form           431         12.38%
      Report           265          7.61%
      Letter           567         16.28%
        News           188           5.4%


In [4]:
# @title Split Data into train/validation
import numpy as np

np.random.seed(10)
np.random.shuffle(image_files)

N = int(len(image_files)*0.85)  # 15% for validation
train, valid = image_files[:N], image_files[N:]

print('Train Distribution:')
print_class_dist(train)

print()
print('Validation Distribution:')
print_class_dist(valid)

Train Distribution:
------------  ------------   ------------
       Class         Count     Percentage
------------  ------------   ------------
        Note           176          5.95%
        Memo           533         18.01%
        ADVE           188          6.35%
       Email           500          16.9%
        News           164          5.54%
      Letter           483         16.32%
        Form           360         12.17%
      Report           224          7.57%
  Scientific           230          7.77%
      Resume           101          3.41%

Validation Distribution:
------------  ------------   ------------
       Class         Count     Percentage
------------  ------------   ------------
        Memo            87         16.63%
      Letter            84         16.06%
      Report            41          7.84%
        Form            71         13.58%
       Email            99         18.93%
        News            24          4.59%
  Scientific            31    

In [0]:
# @title Create Dataset loading classes
import torch
from PIL import Image
import tensorflow as tf
import torchvision.transforms as T
from torch.utils.data import Dataset
from tqdm import tqdm_notebook as tqdm
from threading import Thread


class DocDataset(Dataset):

    def __init__(self, images, tokenizer, image_size=(299, 299), text_max_len=256, transform=None):
        self.images = images
        self.tokenizer = tokenizer
        self.transform = transform
        self.image_size = image_size
        self.text_max_len = text_max_len
        self.cache = dict()
        self._setup()

    def _setup(self):
        self.texts = [os.path.splitext(fn)[0] + '.txt' for fn in image_files]
        labels = [fn.split('/')[-2] for fn in image_files]
        self.classes = list(set(labels))
        self.labels = [self.classes.index(lbl) for lbl in labels]
        self.cache.update(
            img=dict(),
            txt=dict(),
        )
        # self._prog = tqdm(range(len(self)))
        # self._start_caching(16)

    def __len__(self):
        return len(self.images)
    
    def _start_caching(self, N):
        th = []
        for i in range(N):
            t = Thread(target=self._cache_worker, args=(i, N))
            t.start()
            th.append(t)
        for t in th:
            t.join()
    
    def _cache_worker(self, idx, N):
        for _ in range(idx, len(self), N):
            _ = self[_]
            self._prog.update(1)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img = self.images[idx]
        txt = self.texts[idx]
        lbl = self.labels[idx]

        if idx not in self.cache['img']:
            with tf.io.gfile.GFile(img, 'rb') as f:
                img = Image.open(f)
                img = T.Resize(self.image_size)(img)
            self.cache['img'][idx] = img

        img = self.cache['img'][idx]
        img = T.Grayscale(3)(img)
        img = T.ToTensor()(img)
        if self.transform:
            img = self.transform(img)

        if idx not in self.cache['txt']:
            with tf.io.gfile.GFile(txt) as f:
                txt = f.read()
                
            txt = self.tokenizer.encode(txt, max_length=self.text_max_len)
            txt = txt[:self.text_max_len]
            txt = txt + [0] * (self.text_max_len - len(txt))
            self.cache['txt'][idx] = txt

        txt = self.cache['txt'][idx]
        txt = torch.tensor(txt)
        lbl = torch.tensor(lbl)

        return img, txt, lbl

In [0]:
# @title Test dataset and caching
from transformers import *
transform = T.Compose([
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

tok = BertTokenizer.from_pretrained('bert-base-uncased')
data_test = DocDataset(train, tok, transform=transform)

i, t, l = data_test[0]
print(i.shape, t.shape, l)
print(len(data_test[0]))

torch.Size([3, 299, 299]) torch.Size([256]) tensor(4)
3


---

# Modelling


In [0]:
import torch.nn as nn
from transformers import *
from torchvision.models import *

class DocClassifier(nn.Module):
    
    def __init__(self, n_classes):
        super().__init__()
        inception = resnext101_32x8d(True)
        self.vision = nn.Sequential(*list(inception.children())[:-1])  # output 2048
        self.lang = BertModel.from_pretrained('bert-base-uncased')  # output 768
        self.dropout = nn.Dropout(0.3)
        self.classifier = nn.Linear(2816, n_classes)
        
    def forward(self, img, txt):
        f1 = self.vision(img).flatten(1)
        _, f2 = self.lang(txt)
        f2 = f2.flatten(1)
        ff = torch.cat([f1, f2], dim=1)
        ff = self.dropout(ff)
        out = self.classifier(ff)
        return out

In [0]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [0]:
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=len(classes)).to(device)

In [0]:
from torch.utils.data import DataLoader

transform = T.Compose([
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
train_data = DocDataset(train, tokenizer, (224, 224), 256, transform)
valid_data = DocDataset(valid, tokenizer, (224, 224), 256, transform)

In [0]:
train_loader = DataLoader(train_data, batch_size=4, shuffle=True, num_workers=8)
valid_loader = DataLoader(valid_data, batch_size=16, shuffle=False, num_workers=8)

In [0]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.00003)

In [24]:
from sklearn.metrics import classification_report
from tqdm import tqdm_notebook as tqdm


def calc_report():
    model.eval()
    preds = []
    lblsx = []

    for i, data in enumerate(tqdm(valid_loader), 0):
        imgs, txts, lbls = data
        imgs, txts, lbls = map(lambda d: d.to(device), [imgs, txts, lbls])

        with torch.no_grad():
            pred = model(txts)[0]
            # pred = temp2(model.lang(txts)[1].flatten(1))
            # pred = temp(model.vision(imgs).flatten(1))
            pred = pred.argmax(-1)
        
        lblsx += lbls.tolist()
        preds += pred.tolist()
        
    print(classification_report(lblsx, preds))
    model.train()

calc_report()

HBox(children=(IntProgress(value=0, max=33), HTML(value='')))


              precision    recall  f1-score   support

           0       1.00      1.00      1.00        88
           1       0.95      1.00      0.97        53
           2       1.00      0.76      0.86        45
           3       0.93      0.95      0.94        40
           4       0.91      1.00      0.95        31
           5       0.98      0.98      0.98        45
           6       0.98      1.00      0.99        86
           7       1.00      1.00      1.00        19
           8       0.99      0.99      0.99        77
           9       0.90      0.95      0.92        39

    accuracy                           0.97       523
   macro avg       0.96      0.96      0.96       523
weighted avg       0.97      0.97      0.97       523



In [23]:
for epoch in range(5): 

    valid_iter = iter(valid_loader)
    running_loss = 0.0

    temp1 = torch.nn.Linear(2048, len(classes)).to(device)
    temp2 = torch.nn.Linear(768, len(classes)).to(device)

    for i, data in enumerate(train_loader, 0):
        model.train()
        
        imgs, txts, lbls = data
        imgs, txts, lbls = map(lambda d: d.to(device), [imgs, txts, lbls])
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(txts)[0]
        # outputs = temp1(model.vision(imgs).flatten(1))
        # outputs = temp2(model.lang(txts)[1].flatten(1))
        loss = criterion(outputs, lbls)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 20 == 19:
            model.eval()

            imgs, txts, lbls = map(lambda d: d.to(device), next(valid_iter))
            with torch.no_grad():
                outputs = model(txts)[0]
                # outputs = temp(model.vision(imgs).flatten(1))
                # outputs = temp2(model.lang(txts)[1].flatten(1))
                val_loss = criterion(outputs, lbls)

            model.train()
            print('[%d, %5d]  Train Loss: %.3f  Valid Loss: %.3f' % (epoch + 1, i + 1, running_loss / 20, val_loss))
            running_loss = 0

        if i % 500 == 499:
            valid_iter = iter(valid_loader)
    calc_report()



[1,    20]  Train Loss: 2.234  Valid Loss: 2.106
[1,    40]  Train Loss: 2.151  Valid Loss: 2.047
[1,    60]  Train Loss: 2.061  Valid Loss: 2.147
[1,    80]  Train Loss: 2.011  Valid Loss: 1.594
[1,   100]  Train Loss: 1.842  Valid Loss: 1.933
[1,   120]  Train Loss: 1.924  Valid Loss: 1.941
[1,   140]  Train Loss: 1.687  Valid Loss: 1.587
[1,   160]  Train Loss: 1.676  Valid Loss: 1.440
[1,   180]  Train Loss: 1.342  Valid Loss: 1.668
[1,   200]  Train Loss: 1.183  Valid Loss: 0.809
[1,   220]  Train Loss: 1.155  Valid Loss: 0.973
[1,   240]  Train Loss: 1.199  Valid Loss: 1.011
[1,   260]  Train Loss: 0.988  Valid Loss: 0.983
[1,   280]  Train Loss: 1.046  Valid Loss: 0.729
[1,   300]  Train Loss: 1.102  Valid Loss: 1.359
[1,   320]  Train Loss: 1.095  Valid Loss: 0.946
[1,   340]  Train Loss: 1.218  Valid Loss: 1.057
[1,   360]  Train Loss: 1.128  Valid Loss: 0.896
[1,   380]  Train Loss: 1.039  Valid Loss: 0.772
[1,   400]  Train Loss: 0.877  Valid Loss: 0.693
[1,   420]  Train Lo

HBox(children=(IntProgress(value=0, max=33), HTML(value='')))


              precision    recall  f1-score   support

           0       1.00      0.98      0.99        88
           1       0.70      0.89      0.78        53
           2       0.65      0.58      0.61        45
           3       0.81      0.55      0.66        40
           4       0.69      0.71      0.70        31
           5       0.68      0.87      0.76        45
           6       0.81      0.85      0.83        86
           7       0.90      1.00      0.95        19
           8       0.91      0.79      0.85        77
           9       0.81      0.74      0.77        39

    accuracy                           0.81       523
   macro avg       0.80      0.80      0.79       523
weighted avg       0.82      0.81      0.81       523

[2,    20]  Train Loss: 0.662  Valid Loss: 0.670
[2,    40]  Train Loss: 0.693  Valid Loss: 0.452
[2,    60]  Train Loss: 0.618  Valid Loss: 0.728
[2,    80]  Train Loss: 0.719  Valid Loss: 0.236
[2,   100]  Train Loss: 0.699  Valid Loss: 0

HBox(children=(IntProgress(value=0, max=33), HTML(value='')))


              precision    recall  f1-score   support

           0       0.98      0.99      0.98        88
           1       0.83      0.94      0.88        53
           2       1.00      0.62      0.77        45
           3       0.86      0.80      0.83        40
           4       1.00      0.81      0.89        31
           5       0.85      0.87      0.86        45
           6       0.86      0.88      0.87        86
           7       1.00      1.00      1.00        19
           8       0.84      0.99      0.91        77
           9       0.85      0.90      0.88        39

    accuracy                           0.89       523
   macro avg       0.91      0.88      0.89       523
weighted avg       0.90      0.89      0.89       523

[3,    20]  Train Loss: 0.319  Valid Loss: 0.221
[3,    40]  Train Loss: 0.350  Valid Loss: 0.191
[3,    60]  Train Loss: 0.462  Valid Loss: 0.132
[3,    80]  Train Loss: 0.408  Valid Loss: 0.087
[3,   100]  Train Loss: 0.422  Valid Loss: 0

HBox(children=(IntProgress(value=0, max=33), HTML(value='')))


              precision    recall  f1-score   support

           0       0.93      1.00      0.96        88
           1       0.88      0.98      0.93        53
           2       0.93      0.87      0.90        45
           3       0.97      0.85      0.91        40
           4       0.97      1.00      0.98        31
           5       1.00      0.93      0.97        45
           6       0.96      0.95      0.96        86
           7       1.00      1.00      1.00        19
           8       0.93      0.97      0.95        77
           9       0.91      0.77      0.83        39

    accuracy                           0.94       523
   macro avg       0.95      0.93      0.94       523
weighted avg       0.94      0.94      0.94       523

[4,    20]  Train Loss: 0.302  Valid Loss: 0.147
[4,    40]  Train Loss: 0.272  Valid Loss: 0.120
[4,    60]  Train Loss: 0.273  Valid Loss: 0.094
[4,    80]  Train Loss: 0.196  Valid Loss: 0.039
[4,   100]  Train Loss: 0.327  Valid Loss: 0

HBox(children=(IntProgress(value=0, max=33), HTML(value='')))


              precision    recall  f1-score   support

           0       1.00      1.00      1.00        88
           1       0.96      0.98      0.97        53
           2       0.85      0.89      0.87        45
           3       0.86      0.95      0.90        40
           4       1.00      0.87      0.93        31
           5       0.98      0.93      0.95        45
           6       0.91      1.00      0.96        86
           7       1.00      1.00      1.00        19
           8       1.00      0.94      0.97        77
           9       0.94      0.85      0.89        39

    accuracy                           0.95       523
   macro avg       0.95      0.94      0.94       523
weighted avg       0.95      0.95      0.95       523

[5,    20]  Train Loss: 0.128  Valid Loss: 0.221
[5,    40]  Train Loss: 0.133  Valid Loss: 0.075
[5,    60]  Train Loss: 0.100  Valid Loss: 0.054
[5,    80]  Train Loss: 0.096  Valid Loss: 0.018


KeyboardInterrupt: ignored

In [0]:
!mkdir -p doc-class
model.save_pretrained('doc-class')
!zip -r model.zip doc-class
model.zip