In [None]:
import os
import pickle
from PIL import Image
import pytesseract
from transformers import BertTokenizer, BertForSequenceClassification
from torch.utils.data import DataLoader
import torch
import warnings

os.chdir("..") # Changing to parent directory 

In [None]:
DATA = "data/"
MODEL = "bert-base-uncased"

In [None]:
def load_images_from_directory(directory: str) -> Tuple:
    """Load images from a directory.
    
    Args:
        directory: Directory path.
    
    Returns:
        Tuple: Images, image paths, labels and class to label mapping.
    """
    # class_folders = [dir_ for dir_ in os.listdir(directory) if dir_ != ".DS_Store"]
    #TODO: class_folders = ["email", "invoice", "letter", "resume"]
    class_to_label = {class_folder: index for index, class_folder in enumerate(class_folders)}

    images = []
    image_paths = []
    labels = []

    for class_folder in class_folders:
        class_path = os.path.join(directory, class_folder)
        for image_filename in os.listdir(class_path):
            if image_filename == ".DS_Store":# TODO
                continue
            image_path = os.path.join(class_path, image_filename)
            # Open the image and convert it to grayscale if needed
            image = Image.open(image_path).convert("L")  # "L" mode converts to grayscale
            images.append(image)
            image_paths.append(image_path)
            labels.append(class_to_label[class_folder])
    return images, image_paths, labels, class_to_label

In [None]:
def create_train_val_test_split(
        image_paths: List,
        input_ids: torch.Tensor,
        labels: torch.Tensor,
        attention_masks: torch.Tensor, 
        val_size: float = 0.2, 
        test_size: float = 0.2
    ) -> Tuple: 
    """Split the image paths, input_ids, labels, and attention_masks into train, validation, and test sets.
    
    Args:
        image_paths: Image paths.
        input_ids: The input ids.
        labels: The labels.
        attention_masks: The attention masks.
        val_size: The proportion of the dataset to include in the validation split.
        test_size: The proportion of the dataset to include in the test split.
    
    Returns:
        Tuple: The train, val, and test data in the form of (image_paths, input_ids, labels, attention_masks).
    """
    # adjust the validation size (because we split the data into train and test sets first)
    val_size = val_size / (1 - test_size)

    train_val_paths, test_paths, train_val_inputs, test_inputs, train_val_masks, test_masks, train_val_labels, test_labels = train_test_split(
        image_paths, input_ids, attention_masks, labels, test_size=0.2, stratify=labels, random_state=42
    )
    train_paths, val_paths, train_inputs, val_inputs, train_masks, val_masks, train_labels, val_labels = train_test_split(
        train_val_paths, train_val_inputs, train_val_masks, train_val_labels, test_size=val_size, stratify=train_val_labels, random_state=42
    )

    train_data = (train_paths, train_inputs, train_labels, train_masks)
    val_data = (val_paths, val_inputs, val_labels, val_masks)
    test_data = (test_paths, test_inputs, test_labels, test_masks)

    return train_data, val_data, test_data


In [None]:
def plot_label_distribution(
        all_labels: dict, 
        train_labels: dict, 
        val_labels: dict, 
        test_labels: dict
    ) -> None:
    """Plot the label distribution of the dataset.
    
    Args:
        all_labels: Labels of train, validation, and test sets.
        train_labels: Train labels.
        val_labels: Validation labels.
        test_labels: Test labels.
    """
    df_count = pd.DataFrame([all_labels, 
                            {label: list(train_labels.numpy()).count(label) for label in list(train_labels.numpy())}, 
                            {label: list(val_labels.numpy()).count(label) for label in list(val_labels.numpy())}, 
                            {label: list(test_labels.numpy()).count(label) for label in list(test_labels.numpy())}], 
                            index=["Total", "Train", "Validation", "Test"])
    fig = px.bar(df_count, barmode="stack")
    fig.show()

In [None]:
def extract_text_from_images(images: List[np.ndarray]) -> List[str]:
    """Extracts text from images using Tesseract OCR.
    
    Args:
        images: List of images.
    
    Returns:
        List: Extracted texts.
    """
    extracted_texts = []
    for image in images:
        extracted_text = pytesseract.image_to_string(image)
        extracted_text = extracted_text.replace("\n", " ")
        extracted_texts.append(extracted_text)
    return extracted_texts

In [None]:
class CustomDataset(Dataset):
    """Custom dataset for BERT.
    
    Args:
        input_ids: Tokenized input ids.
        attention_masks: Attention masks.
        labels: Labels.
    """
    def __init__(self, input_ids, attention_masks, labels):
        self.input_ids = input_ids
        self.attention_masks = attention_masks
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return {
            "input_ids": self.input_ids[idx],
            "attention_mask": self.attention_masks[idx],
            "labels": self.labels[idx]
        }
    

class BERTSequenceClassifier:
    """Wrapper for BERT sequence classification model.
    
    Args:
        pretrained_model_name: Name of the pretrained model.
        num_labels: Number of labels.
        learning_rate: Learning rate.
    """
    def __init__(
            self,
            pretrained_model_name: str = "bert-base-uncased",
            num_labels: int = 4,
            learning_rate: float = 1e-5
        ):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = BertForSequenceClassification.from_pretrained(pretrained_model_name, num_labels=num_labels)
        self.model.to(self.device)
        self.optimizer = AdamW(self.model.parameters(), lr=learning_rate)

    def train(
            self,
            train_loader: DataLoader,
            val_loader: DataLoader,
            epochs: int = 10
        ):
        """Train model.
        
        Args:
            train_loader: Train data loader.
            val_loader: Validation data loader.
            epochs: Number of epochs.
        """
        writer = SummaryWriter(log_dir=f"./logs/{datetime.now().strftime('%Y%m%d-%H%M%S')}")

        for epoch in range(epochs):
            self.model.train()
            for batch in train_loader:
                self._train_batch(batch)

            val_loss = self.evaluate(val_loader)
            print(f"Epoch {epoch + 1}/{epochs}, Validation Loss: {val_loss / len(val_loader)}")
            writer.add_scalar("Loss/batch", val_loss, epoch)

    def _train_batch(self, batch: dict):
        input_ids = batch["input_ids"].to(self.device)
        attention_mask = batch["attention_mask"].to(self.device)
        labels = batch["labels"].to(self.device)

        self.optimizer.zero_grad()
        outputs = self.model(input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        loss.backward()
        self.optimizer.step()

    def evaluate(self, val_loader: DataLoader):
        """Evaluate model on validation set.
        
        Args:
            val_loader: Validation data loader.
        
        Returns:
            float: Validation loss.
        """
        self.model.eval()
        val_loss = 0
        with torch.no_grad():
            for batch in val_loader:
                input_ids = batch["input_ids"].to(self.device)
                attention_mask = batch["attention_mask"].to(self.device)
                labels = batch["labels"].to(self.device)
                outputs = self.model(input_ids, attention_mask=attention_mask, labels=labels)
                val_loss += outputs.loss.item()
        
        return val_loss

    def save_model(self, path: str = "bert_document_classification_model"):
        """Save model to disk.
        
        Args:
            path: Path to save model.
        """
        self.model.save_pretrained(path)
        print(f"Model saved to {path}")

In [None]:
def predict_document_class(
        model: BertForSequenceClassification,
        tokenizer: BertTokenizer,
        text: str
    ) -> int:
    """Predict document class.

    Args:
        model: BERT model.
        tokenizer: BERT tokenizer.
        text: Text to classify.
    
    Returns:
        int: Predicted class.
    """
    inputs = tokenizer(text, padding=True, truncation=True, return_tensors="pt")
    input_ids = inputs.input_ids
    attention_mask = inputs.attention_mask

    with torch.no_grad():
        outputs = model(input_ids, attention_mask=attention_mask)
        logits = outputs.logits

    predicted_class = torch.argmax(logits, dim=1).item()
    return predicted_class

In [None]:
def evaluate_model(
        model: BertForSequenceClassification, 
        test_loader: DataLoader,
    ) -> Tuple:
    """Evaluate model.
    
    Args:
        model: BERT model.
        tokenizer: BERT tokenizer.
        test_loader: Test dataloader.

    Returns:
        Tuple: Labels and predictions for the test set.
    """
    # specify device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    # Put the model in evaluation mode
    model.eval()

    # Initialize variables to calculate accuracy
    total_correct = 0
    total_samples = 0

    all_labels = []
    all_predictions = []

    with torch.no_grad():
        for batch in test_loader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            # Get model predictions
            outputs = model(input_ids, attention_mask=attention_mask)
            predictions = torch.argmax(outputs.logits, dim=1)

            # Update accuracy metrics
            total_correct += torch.sum(predictions == labels).item()
            total_samples += len(labels)

            # Save labels and predictions
            all_labels.extend(labels.cpu().numpy())
            all_predictions.extend(predictions.cpu().numpy())

    # calculate f1 score (not needed here as we have a balanced dataset)
    # f1 = f1_score(all_labels, all_predictions, average="weighted")

    # Calculate accuracy
    accuracy = total_correct / total_samples
    print(f"Accuracy: {accuracy:.2f}")

    # Confusion matrix
    cm = confusion_matrix(all_labels, all_predictions)
    disp = ConfusionMatrixDisplay(confusion_matrix=cm)
    disp.plot()

    return all_labels, all_predictions

In [None]:
def get_wrongly_classified_docs(paths: List[str], labels: List[int], pred: List[int], class_to_label: dict) -> List[str]:
    """Get wrongly classified images.
    
    Args:
        paths: Image paths.
        labels: True labels.
        pred: Predicted labels.
        class_to_label: Class to label mapping.
    
    Returns:
        List: Paths of wrongly classified images.    
    """
    wrongly_classified_docs = []
    for i, (label, prediction) in enumerate(zip(labels, pred)):
        if label != prediction:
            label_name = [k for k, v in class_to_label.items() if v == label][0]
            pred_name = [k for k, v in class_to_label.items() if v == prediction][0]
            print(f"{paths[i]} is <{label_name}>, but was predicted <{pred_name}>")
            wrongly_classified_docs.append(paths[i])

    return wrongly_classified_docs


In [None]:
images, image_paths, labels, class_to_label = load_images_from_directory(data_directory)
label_count = {label: labels.count(label) for label in labels}
extracted_texts = extract_text_from_images(images)

tokenizer = BertTokenizer.from_pretrained(MODEL)
tokenized_texts = tokenizer(extracted_texts, padding=True, truncation=True, return_tensors="pt")

train_data, val_data, test_data = create_train_val_test_split(
    image_paths,
    tokenized_texts.input_ids, 
    torch.tensor(labels), 
    tokenized_texts.attention_mask, 
    val_size=0.2, 
    test_size=0.2
)
train_paths, train_inputs, train_labels, train_masks = train_data
val_paths, val_inputs, val_labels, val_masks = val_data
test_paths, test_inputs, test_labels, test_masks = test_data

plot_label_distribution(label_count, train_labels, val_labels, test_labels)

In [None]:
batch_size = 16

train_dataset = CustomDataset(train_inputs, train_masks, train_labels)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

val_dataset = CustomDataset(val_inputs, val_masks, val_labels)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

test_dataset = CustomDataset(test_inputs, test_masks, test_labels)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

model = BERTSequenceClassifier(pretrained_model_name=MODEL, num_labels=4)

In [None]:
model.train(train_loader, val_loader, epochs=5)

In [None]:
model.eval()
img = Image.open("")
img_text = pytesseract.image_to_string(img)
predicted_class = predict_document_class(model, tokenizer, img_text)
predicted_label = [k for k, v in class_to_label.items() if v == predicted_class]
print(f"Predicted Label: {predicted_label}")

In [None]:
test_labels, test_pred = evaluate_model(model=model, test_loader=test_loader)
wrongly_classified_docs = get_wrongly_classified_docs(test_paths, test_labels, test_pred, class_to_label)