# 80 · Vision Forgery Detection
        
        Train/evaluate a simple vision classifier for document/image forgery.
        - Point to the Intel image dataset downloaded via KaggleHub.
        - Inspect classes and a few sample images.
        - Fine-tune a small backbone (simple CNN or ResNet) and plot metrics.


In [None]:
from pathlib import Path
import sys, random
import matplotlib.pyplot as plt
from PIL import Image

project_root = Path('..').resolve()
src_path = project_root / 'src'
if str(src_path) not in sys.path:
    sys.path.append(str(src_path))

from uais.vision.train_vision_model import VisionConfig, run_vision_experiment

# Default to Intel image classification dataset pulled via KaggleHub
data_dir = project_root / 'data' / 'raw' / 'vision' / 'datasets' / 'puneet6060' / 'intel-image-classification' / 'versions' / '2' / 'seg_train'
print('Data dir:', data_dir, '| exists:', data_dir.exists())


In [None]:
# Inspect classes and show sample images
if not data_dir.exists():
    raise FileNotFoundError(f'Missing dataset directory: {data_dir}')
class_dirs = sorted([d for d in data_dir.iterdir() if d.is_dir()])
print('Classes:', [d.name for d in class_dirs])
counts = {d.name: len(list(d.glob('*'))) for d in class_dirs}
print('Counts per class:', counts)

fig, axes = plt.subplots(1, min(3, len(class_dirs)), figsize=(9,3))
if len(class_dirs) == 1:
    axes = [axes]
for ax, d in zip(axes, class_dirs):
    imgs = list(d.glob('*.jpg')) + list(d.glob('*.png'))
    if not imgs:
        continue
    img = Image.open(random.choice(imgs)).convert('RGB')
    ax.imshow(img)
    ax.set_title(d.name)
    ax.axis('off')
plt.tight_layout(); plt.show()


In [None]:
# Train/evaluate vision model
config = VisionConfig(
    dataset_dir=data_dir,
    image_size=224,
    batch_size=32,
    epochs=3,
    backbone='resnet18',
)
metrics = run_vision_experiment(config)
print('Validation metrics:', {k: v for k,v in metrics.items() if k != 'history'})

hist = metrics.get('history', {})
if hist:
    fig, ax = plt.subplots(1, 2, figsize=(10,4))
    ax[0].plot(hist.get('loss', []), label='train')
    ax[0].plot(hist.get('val_loss', []), label='val')
    ax[0].set_title('Loss'); ax[0].legend()
    ax[1].plot(hist.get('accuracy', []), label='train')
    ax[1].plot(hist.get('val_accuracy', []), label='val')
    ax[1].set_title('Accuracy'); ax[1].legend()
    plt.tight_layout(); plt.show()
