In [None]:
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

# Add project root directory to path
sys.path.append('..')

# Import project modules
from models.classifier import DocumentClassifier
from utils.preprocess import preprocess_for_classification


In [None]:
# Create directories if they don't exist
categories = ['invoice', 'resume', 'receipt']

for category in categories:
    os.makedirs(f'../data/train/{category}', exist_ok=True)


In [None]:
# Initialize document classifier
classifier = DocumentClassifier(categories=categories)


In [None]:
# Train model
history = classifier.train(
    train_dir='../data/train',
    epochs=10,
    batch_size=32
)


In [None]:
# Plot training history
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'])
plt.title('Model Accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')

plt.subplot(1, 2, 2)
plt.plot(history.history['loss'])
plt.title('Model Loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')

plt.tight_layout()
plt.show()


In [None]:
# Fine-tune model
fine_tune_history = classifier.fine_tune(
    train_dir='../data/train',
    epochs=5,
    batch_size=16,
    learning_rate=1e-4
)


In [None]:
# Plot fine-tuning history
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(fine_tune_history.history['accuracy'])
plt.title('Model Accuracy (Fine-tuning)')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')

plt.subplot(1, 2, 2)
plt.plot(fine_tune_history.history['loss'])
plt.title('Model Loss (Fine-tuning)')
plt.ylabel('Loss')
plt.xlabel('Epoch')

plt.tight_layout()
plt.show()


In [None]:
# Save model
classifier.save_model('../models/document_classifier.h5')


In [None]:
# Load a sample image for testing
# Replace with path to a sample image
sample_image_path = '../data/examples/sample_invoice.jpg'

# Check if file exists
if os.path.exists(sample_image_path):
    # Load and preprocess image
    image = Image.open(sample_image_path)
    
    # Display image
    plt.figure(figsize=(8, 8))
    plt.imshow(image)
    plt.axis('off')
    plt.title('Sample Document')
    plt.show()
    
    # Predict document type
    doc_type, confidence, all_scores = classifier.predict(image)
    
    print(f"Predicted document type: {doc_type}")
    print(f"Confidence: {confidence:.2f}")
    print("All scores:")
    for category, score in sorted(all_scores.items(), key=lambda x: x[1], reverse=True):
        print(f"  {category}: {score:.2f}")
else:
    print(f"Sample image not found at '{sample_image_path}'. Please add a sample image for testing.")
