In [8]:
import os
import shutil
import pandas as pd
import numpy as np
from PIL import Image
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from transformers import BertTokenizer, BertModel
import pytesseract

In [10]:
pytesseract.pytesseract.tesseract_cmd = r"C:\\Program Files\\Tesseract-OCR\\tesseract.exe"

In [14]:
# Sample a smaller dataset from full data
FULL_CSV = "data/labels.csv"
FULL_IMG_DIR = "data/images"
SAMPLE_CSV = "data/labels_sample.csv"
SAMPLE_IMG_DIR = "data/images_sample"

os.makedirs(SAMPLE_IMG_DIR, exist_ok=True)

original_df = pd.read_csv(FULL_CSV)
sampled_data = []
label_counts = {}
SAMPLE_SIZE = 10  # number of images per class

for row in original_df.itertuples(index=False):
    count = label_counts.get(row.label, 0)
    if count < SAMPLE_SIZE:
        src = os.path.join(FULL_IMG_DIR, row.image_name)
        dst = os.path.join(SAMPLE_IMG_DIR, row.image_name)
        if os.path.exists(src):
            shutil.copy(src, dst)
            sampled_data.append({"image_name": row.image_name, "label": row.label})
            label_counts[row.label] = count + 1

sample_df = pd.DataFrame(sampled_data)
sample_df.to_csv(SAMPLE_CSV, index=False)
print(f"Sampled {len(sample_df)} images across {len(set(sample_df.label))} classes.")


Sampled 160 images across 16 classes.


In [16]:
# Apply OCR and Split
ocr_data = []
for row in sample_df.itertuples(index=False):
    img_path = os.path.join(SAMPLE_IMG_DIR, row.image_name)
    try:
        img = Image.open(img_path).convert("RGB")
        text = pytesseract.image_to_string(img).strip()
        ocr_data.append({"image_name": row.image_name, "ocr_text": text, "label": row.label})
    except Exception as e:
        print(f"Failed OCR on {row.image_name}: {e}")

ocr_df = pd.DataFrame(ocr_data)
ocr_df.dropna(subset=["ocr_text"], inplace=True)  # clean
ocr_df.to_csv("data/all_sample_data.csv", index=False)

train_df, test_df = train_test_split(ocr_df, test_size=0.2, stratify=ocr_df["label"], random_state=42)
train_df.to_csv("data/train_sample.csv", index=False)
test_df.to_csv("data/test_sample.csv", index=False)


In [24]:
# Dataset and Model
import re

# Function to clean OCR text
def clean_text(text):
    text = str(text)
    text = re.sub(r"[^a-zA-Z0-9.,!? ]", " ", text)  # Keep basic punctuation and alphanumerics
    return text.strip()
    
class DocumentDataset(Dataset):
    def __init__(self, dataframe, img_dir, tokenizer, transform):
        self.df = dataframe
        self.img_dir = img_dir
        self.tokenizer = tokenizer
        self.transform = transform

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img = Image.open(os.path.join(self.img_dir, row['image_name'])).convert("RGB")
        img = self.transform(img)

        text = str(row['ocr_text']) if pd.notna(row['ocr_text']) else ""
        tokens = self.tokenizer(
            text,
            padding='max_length',
            truncation=True,
            max_length=512,
            return_tensors='pt'
        )

        return {
            'image': img,
            'input_ids': tokens['input_ids'].squeeze(0),
            'attention_mask': tokens['attention_mask'].squeeze(0),
            'label': torch.tensor(row['label'])
        }

    def __len__(self):
        return len(self.df)

class NdLinear(nn.Module):
    def __init__(self, in_shape, out_features):
        super().__init__()
        self.linear = nn.Linear(np.prod(in_shape), out_features)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        return self.linear(x)

class MultiModalClassifier(nn.Module):
    def __init__(self, ndlinear_dim=256, num_classes=16):
        super().__init__()
        self.image_encoder = models.resnet18(pretrained=True)
        self.image_encoder.fc = nn.Identity()
        self.text_encoder = BertModel.from_pretrained('bert-base-uncased')
        self.ndlinear = NdLinear((2, 768), ndlinear_dim)
        self.classifier = nn.Sequential(
            nn.ReLU(),
            nn.Linear(ndlinear_dim, 64),
            nn.ReLU(),
            nn.Linear(64, num_classes)
        )

    def forward(self, image, input_ids, attention_mask):
        img_feat = self.image_encoder(image)
        txt_feat = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask).pooler_output
        if img_feat.shape[1] != txt_feat.shape[1]:
            img_feat = F.pad(img_feat, (0, txt_feat.shape[1] - img_feat.shape[1]))
        combined = torch.stack([txt_feat, img_feat], dim=1)
        fused = self.ndlinear(combined)
        return self.classifier(fused)

In [26]:
# Train the Model
SAMPLE_IMG_DIR = "data/images_sample"
train_data = pd.read_csv("data/train_sample.csv")
test_data = pd.read_csv("data/test_sample.csv")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
transform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()])

train_dataset = DocumentDataset(train_data, SAMPLE_IMG_DIR, tokenizer, transform)
test_dataset = DocumentDataset(test_data, SAMPLE_IMG_DIR, tokenizer, transform)
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=4)

model = MultiModalClassifier(num_classes=len(sample_df.label.unique())).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)
loss_fn = nn.CrossEntropyLoss()

model.train()
for epoch in range(10):
    total_loss, correct = 0, 0
    for batch in train_loader:
        image = batch['image'].to(device)
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)

        optimizer.zero_grad()
        outputs = model(image, input_ids, attention_mask)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        correct += (outputs.argmax(dim=1) == labels).sum().item()

    print(f"Epoch {epoch+1}: Loss = {total_loss/len(train_loader):.4f}, Accuracy = {correct/len(train_dataset):.4f}")




Epoch 1: Loss = 2.7723, Accuracy = 0.0703
Epoch 2: Loss = 2.7149, Accuracy = 0.1875
Epoch 3: Loss = 2.6640, Accuracy = 0.2578
Epoch 4: Loss = 2.5794, Accuracy = 0.4062
Epoch 5: Loss = 2.4730, Accuracy = 0.5000
Epoch 6: Loss = 2.3705, Accuracy = 0.4922
Epoch 7: Loss = 2.2346, Accuracy = 0.6172
Epoch 8: Loss = 2.1060, Accuracy = 0.6875
Epoch 9: Loss = 1.9486, Accuracy = 0.7656
Epoch 10: Loss = 1.8055, Accuracy = 0.7891


In [28]:
# Evaluate the Model
model.eval()
all_preds, all_labels = [], []
with torch.no_grad():
    for batch in test_loader:
        image = batch['image'].to(device)
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)

        outputs = model(image, input_ids, attention_mask)
        preds = outputs.argmax(dim=1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

print(classification_report(all_labels, all_preds))

              precision    recall  f1-score   support

           0       0.40      1.00      0.57         2
           1       0.50      1.00      0.67         2
           2       1.00      1.00      1.00         2
           3       0.00      0.00      0.00         2
           4       0.00      0.00      0.00         2
           5       1.00      0.50      0.67         2
           6       0.00      0.00      0.00         2
           7       0.00      0.00      0.00         2
           8       0.33      1.00      0.50         2
           9       0.00      0.00      0.00         2
          10       0.00      0.00      0.00         2
          11       0.00      0.00      0.00         2
          12       0.33      0.50      0.40         2
          13       0.33      1.00      0.50         2
          14       0.00      0.00      0.00         2
          15       0.00      0.00      0.00         2

    accuracy                           0.38        32
   macro avg       0.24   

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
