In [None]:
from transformers import LayoutLMv3FeatureExtractor, LayoutLMv3TokenizerFast, LayoutLMv3Processor, LayoutLMv3ForSequenceClassification
from tqdm import tqdm
import torch
from torch import nn
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from torchmetrics import Accuracy
from torch.utils.data import Dataset, DataLoader
from PIL import Image, ImageDraw, ImageFont
import numpy as np
from sklearn.model_selection import train_test_split
import easyocr
from pathlib import Path
import matplotlib.pyplot as plt
import json
%matplotlib inline

DATASET LINK - https://www.kaggle.com/datasets/patrickaudriaz/tobacco3482jpg?select=Tobacco3482-jpg 

BETTER TO RUN THIS CODE IN GPU

In [None]:
image=Image.open('files/dataset165/email/doc_000042.png')
image

In [None]:
reader=easyocr.Reader(['en'])

In [None]:
result=reader.readtext('D:\\real_world_projects\\Document_classifi\\Dataset\\Email\\80909413.jpg')
result

In [None]:

bbox=[i[0] for i in result]
word=[j[1] for j in result]
print(len(bbox))
print(len(word))

In [None]:
def create_bbox(bbox):
    left, top = bbox[0]
    right, bottom = bbox[2]
    return [int(left), int(top), int(right), int(bottom)]

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20,20))
left_image = Image.open('D:\\real_world_projects\\Document_classifi\\Dataset\\Email\\80909413.jpg').convert("RGB")
right_image = Image.new("RGB", left_image.size, (255, 255, 255))

# Draw on images
left_draw = ImageDraw.Draw(left_image)
right_draw = ImageDraw.Draw(right_image)

font = ImageFont.truetype("arial", 30)

for i, (bbox, word, confidence) in enumerate(result):
    left, top, right, bottom = create_bbox(bbox)
    left_draw.rectangle([left, top, right, bottom], outline="blue", width=2)
    left_draw.text((right + 5, top), text=str(i + 1),font=font, fill="red")
    right_draw.text((left, top), text=word,font=font, fill="black")

# Display images
ax1.imshow(left_image)
ax1.set_title('Original_Image')
ax2.set_title('Extracted_text')
ax2.imshow(right_image)
ax1.axis('off')
ax2.axis('off')
plt.show()

In [None]:
imagepaths=list(Path('Dataset').glob('*/*.jpg'))
print("Toatal images : ",len(imagepaths))
imag=Image.open(imagepaths[1]).convert('RGB')
imag

In [None]:
for imagepath in tqdm(imagepaths[:5]):
    ocr_result=reader.readtext(str(imagepath),batch_size=16)

    ocr_data=[]
    for bbox , word , confidence in ocr_result:
        ocr_data.append({'word':word ,
                         'bbox':create_bbox(bbox)})
    print(ocr_data,end='\n\n')
    '''with imagepath.with_suffix('.json').open('w') as f:
        json.dump(ocr_data,f)'''

In [None]:
feature_Extractor = LayoutLMv3FeatureExtractor(apply_ocr=False)
tokenizer = LayoutLMv3TokenizerFast.from_pretrained("microsoft/layoutlmv3-base")
processor = LayoutLMv3Processor(feature_Extractor , tokenizer)

In [None]:
def scale_bounding_box(box, width_scale, height_scale):
    return [
        int(box[0] * width_scale),
        int(box[1] * height_scale),
        int(box[2] * width_scale),
        int(box[3] * height_scale)
    ]

In [None]:
image1=Image.open(imagepath).convert('RGB')
width , height= image1.size
width_scale=1000/width
height_scale=1000/height
image1

In [None]:

words=[]
boxes=[]
for row in ocr_data:
    words.append(row['word'])
    boxes.append(scale_bounding_box(row['bbox'],width_scale,height_scale))
    

In [None]:
encoding=processor(
    image1,
    words,
    boxes=boxes,
    max_length=512,
    padding='max_length',
    truncation=True,
    return_tensors='pt'
)
encoding.keys()

In [None]:
encoding

In [None]:
tokens=tokenizer.convert_ids_to_tokens(encoding['input_ids'][0])
tokens

In [None]:
tokenizer.convert_tokens_to_string(tokens)

In [None]:
model = LayoutLMv3ForSequenceClassification.from_pretrained("microsoft/layoutlmv3-base")

In [None]:
model.config

In [None]:
output=model(**encoding)

In [None]:
output.logits

In [None]:
classes=[p.name for p in list(Path('Dataset').glob('*'))]
classes

In [None]:
class Document_classifi(Dataset):
    def __init__(self,img_paths,processor):
        self.img_paths=img_paths
        self.processor=processor    

    def __len__(self):
        return len(self.img_paths)
    
    def __getitem__(self,item):
        img_path=self.img_paths[item]

        image=Image.open(img_path).convert('RGB')
        width , height=image.size
        width_scale=1000/width
        height_scale=1000/height
        
        json_path=img_path.with_suffix('.json')

        with json_path.open('r') as f:
            ocr_result=json.load(f)

        
        words=[]
        boxes=[]
        for row in ocr_result:
            words.append(row['word'])
            boxes.append(scale_bounding_box(row['bbox'],width_scale,height_scale))

        encoding=processor(
        image,
        words,
        boxes=boxes,
        max_length=512,
        padding='max_length',
        truncation=True,
        return_tensors='pt')

        label=classes.index(img_path.parent.name)

        return dict(
            input_ids=encoding['input_ids'].flatten(),
            attention_mask=encoding['attention_mask'].flatten(),
            bbox=encoding['bbox'].flatten(end_dim=1),
            pixel_values=encoding['pixel_values'].flatten(end_dim=1),
            labels=torch.tensor(label,dtype=torch.int)
        )
        
    

In [None]:
train_img,test_img=train_test_split(imagepaths,test_size=0.2)
len(train_img) , len(test_img)

In [None]:
train_dataset=Document_classifi(train_img,processor)
test_dataset=Document_classifi(test_img,processor)

In [None]:
for item in train_dataset:
    print(item['bbox'].shape)
    print(item['labels'].shape)
    break

In [None]:
train_data_loader=DataLoader(train_dataset,batch_size=4,shuffle=True,num_workers=2)
test_data_loader=DataLoader(test_dataset,batch_size=4,shuffle=False,num_workers=2)

In [None]:
class ModelModul(nn.Module):
    def __init__(self, n_classes: int):
        super().__init__()
        self.model = LayoutLMv3ForSequenceClassification.from_pretrained("microsoft/layoutlmv3-base", num_labels=n_classes)

    def forward(self, input_ids, attention_mask, bbox, pixel_values, labels=None):
        return self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            bbox=bbox,
            pixel_values=pixel_values,
            labels=labels
        )

# Training model
def train_model(model, train_loader, val_loader, loss_fn, optimizer, num_epochs, n_classes, checkpoint_path):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)

    train_accuracy = Accuracy(task="multiclass", num_classes=n_classes).to(device)
    val_accuracy = Accuracy(task="multiclass", num_classes=n_classes).to(device)
    
    best_val_acc = 0.0

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        train_accuracy.reset()

        for batch in tqdm(train_loader):
            inputs = {
                'input_ids': batch['input_ids'].to(device),
                'attention_mask': batch['attention_mask'].to(device),
                'bbox': batch['bbox'].to(device),
                'pixel_values': batch['pixel_values'].to(device),
                'labels': batch['labels'].to(device)
            }

            optimizer.zero_grad()

            outputs = model(
                inputs['input_ids'], 
                inputs['attention_mask'], 
                inputs['bbox'], 
                inputs['pixel_values'], 
                inputs['labels']
            )
            
            loss = outputs.loss
            logits = outputs.logits
            preds = torch.argmax(logits, dim=1)
            
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            train_accuracy.update(preds, inputs['labels'])

        epoch_loss = total_loss / len(train_loader)
        epoch_train_acc = train_accuracy.compute()

        # Validation step
        model.eval()
        val_accuracy.reset()
        with torch.no_grad():
            for batch in val_loader:
                inputs = {
                    'input_ids': batch['input_ids'].to(device),
                    'attention_mask': batch['attention_mask'].to(device),
                    'bbox': batch['bbox'].to(device),
                    'pixel_values': batch['pixel_values'].to(device),
                    'labels': batch['labels'].to(device)
                }

                outputs = model(
                    inputs['input_ids'], 
                    inputs['attention_mask'], 
                    inputs['bbox'], 
                    inputs['pixel_values']
                )
                
                logits = outputs.logits
                preds = torch.argmax(logits, dim=1)

                val_accuracy.update(preds, inputs['labels'])

        epoch_val_acc = val_accuracy.compute()


        if epoch_val_acc > best_val_acc:
            best_val_acc = epoch_val_acc
            torch.save(model.state_dict(), checkpoint_path)
            print(f'Checkpoint saved at epoch {epoch} with validation accuracy: {epoch_val_acc:.4f}')

        print(f'Epoch {epoch}/{num_epochs - 1}, Loss: {epoch_loss:.4f}, Train Accuracy: {epoch_train_acc:.4f}, Validation Accuracy: {epoch_val_acc:.4f}')

    print('Training complete')
    print(f'Best validation accuracy: {best_val_acc:.4f}')

n_classes = len(classes)
model = ModelModul(n_classes=n_classes)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.00001)
checkpoint_path = "best_model.pth"

train_model(model, train_data_loader, test_data_loader, loss_fn, optimizer, num_epochs=5, n_classes=n_classes, checkpoint_path=checkpoint_path)



In [None]:
class ModelModule(pl.LightningModule):
    def __init__(self, n_classes: int):
        super().__init__()
        self.model = LayoutLMv3ForSequenceClassification.from_pretrained(
            "microsoft/layoutlmv3-base",
            num_labels=n_classes
        )
        self.train_accuracy = Accuracy(task="multiclass", num_classes=n_classes)
        self.val_accuracy = Accuracy(task="multiclass", num_classes=n_classes)

    def forward(self, input_ids, attention_mask, bbox, pixel_values, labels=None):
        return self.model(
            input_ids,
            attention_mask=attention_mask,
            bbox=bbox,
            pixel_values=pixel_values,
            labels=labels
        )

    def training_step(self, batch, batch_idx):
        labels = batch["labels"]
        outputs = self(
            batch["input_ids"],
            batch["attention_mask"],
            batch["bbox"],
            batch["pixel_values"],
            labels
        )
        loss = outputs.loss

        print(f"Processing batch {batch_idx}")

        self.log("train_loss", loss)
        self.train_accuracy(outputs.logits, labels)
        self.log("train_acc", self.train_accuracy, on_step=True, on_epoch=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        labels = batch["labels"]
        outputs = self(
            batch["input_ids"],
            batch["attention_mask"],
            batch["bbox"],
            batch["pixel_values"],
            labels
        )
        loss = outputs.loss
        
        print(f"Processing batch {batch_idx}")

        self.log("val_loss", loss)
        self.val_accuracy(outputs.logits, labels)
        self.log("val_acc", self.val_accuracy, on_step=True, on_epoch=True)
        return loss
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(), lr=0.00001)


In [None]:
model_module=ModelModule(len(classes))

In [None]:
%load_ext tensorboard
%tensorboard --logdir lightning_logs

In [None]:
model_chechpoint = ModelCheckpoint(
    filename="{epoch}-{step}-{val_loss:.4f}",
    save_last=True,
    save_top_k=2,
    monitor='val_loss',
    mode='min'
)

In [None]:
trainer = pl.Trainer(
    max_epochs=5,
    callbacks=[
        model_chechpoint
    ])

In [None]:
trainer.fit(model_module , train_data_loader , test_data_loader)

In [None]:

def predict_document_class(im_path , model , processor ):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)

    image=Image.open(im_path).convert('RGB')
    width , height=image.size
    width_scale=1000/width
    height_scale=1000/height
    
    json_path=im_path.with_suffix('.json')
    with json_path.open('r') as f:
        ocr_result=json.load(f)

    
    words=[]
    boxes=[]
    for row in ocr_result:
        words.append(row['word'])
        boxes.append(scale_bounding_box(row['bbox'],width_scale,height_scale))

    encoding=processor(
    image,
    words,
    boxes=boxes,
    max_length=512,
    padding='max_length',
    truncation=True,
    return_tensors='pt')

    with torch.inference_mode():
        output=model(
            input_ids=encoding['input_ids'].to(device),
            attention_mask=encoding['attention_mask'].to(device),
            bbox=encoding['bbox'].to(device),
            pixel_values=encoding['pixel_values'].to(device)
        )

    predict_class=output.logits.argmax()
    return model.config.id2label[predict_class.item()]

    

In [None]:
#Load the best model
model.load_state_dict(torch.load("best_model.pth"))

im_path=''
print('Actual Label: ',im_path.parent.name)
print('Predicted Label: ',predict_document_class(im_path,model,processor))