# 02. Baseline CNN Model

## Introduction
This notebook implements a custom Convolutional Neural Network (CNN) from scratch for anomaly detection. 
It is fully self-contained and includes:
1. Data Loading & Preprocessing
2. Model Architecture Definition
3. Training Loop
4. Evaluation

## Setup

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras import layers, models, regularizers
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns

# Set random seed for reproducibility
tf.random.set_seed(42)
np.random.seed(42)

print(f"TensorFlow Version: {tf.__version__}")

## 1. Configuration & Data Loading
We define hyperparameters and load the dataset directly from the `data/raw` directory.

In [None]:
# Hyperparameters
IMG_HEIGHT = 256
IMG_WIDTH = 256
BATCH_SIZE = 32
EPOCHS = 20
LEARNING_RATE = 0.001
DATA_DIR = "../data/raw"

# For this baseline, we will treat 'good' as class 0 and all other defects as class 1 (Binary Classification)
# Or we can do multi-class if we want to classify defect types.
# Let's stick to a simple Binary Classification (Normal vs Anomaly) per category for simplicity in this demo,
# OR a global multi-class classification if we mix all categories.
# Given MVTec structure, usually models are trained PER CATEGORY.
# Here, we will demonstrate training on ONE category (e.g., 'bottle') to keep it runnable.

TARGET_CATEGORY = 'bottle'
TRAIN_DIR = os.path.join(DATA_DIR, TARGET_CATEGORY, 'train')
TEST_DIR = os.path.join(DATA_DIR, TARGET_CATEGORY, 'test')

print(f"Training on category: {TARGET_CATEGORY}")

In [None]:
def load_data(category_path, img_size, batch_size, subset='train'):
    # Custom data loader generator could be used here, but for simplicity we use image_dataset_from_directory
    # Note: MVTec 'train' only has 'good'. 'test' has 'good' and defects.
    # To train a supervised classifier, we need anomalies in training or use outlier detection.
    # Since this is a "Baseline CNN" (Supervised), we technically need defective samples in train.
    # MVTec is designed for Unsupervised/One-Class learning.
    # HOWEVER, for the sake of this exercise (Baseline CNN), we will split the TEST set to get some anomalies for training,
    # or we assume we are doing a multi-class classification of the defects provided in test.
    
    # Let's try to load the TEST set and split it into Train/Val for the purpose of supervised classification demonstration.
    
    print(f"Loading data from {category_path}...")
    ds = tf.keras.utils.image_dataset_from_directory(
        category_path,
        seed=123,
        image_size=img_size,
        batch_size=batch_size,
        label_mode='int' # Categorical labels (good, broken_large, etc.)
    )
    return ds

# Load Test data (which contains all classes) and split it for this supervised demo
full_ds = load_data(TEST_DIR, (IMG_HEIGHT, IMG_WIDTH), BATCH_SIZE)

class_names = full_ds.class_names
print(f"Classes: {class_names}")

# Split into train/val (80/20)
train_size = int(0.8 * len(full_ds))
train_ds = full_ds.take(train_size)
val_ds = full_ds.skip(train_size)

# Performance optimization
AUTOTUNE = tf.data.AUTOTUNE
train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

## 2. Model Architecture
We define a custom CNN with Convolutional blocks, Batch Normalization, and Dropout.

In [None]:
def create_cnn_custom(input_shape, num_classes):
    inputs = layers.Input(shape=input_shape)
    
    # Rescaling [0, 255] -> [0, 1]
    x = layers.Rescaling(1./255)(inputs)
    
    # Block 1
    x = layers.Conv2D(32, 3, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.MaxPooling2D()(x)
    
    # Block 2
    x = layers.Conv2D(64, 3, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.MaxPooling2D()(x)
    
    # Block 3
    x = layers.Conv2D(128, 3, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.MaxPooling2D()(x)
    
    # Dense
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dense(128, kernel_regularizer=regularizers.l2(0.01))(x)
    x = layers.Activation('relu')(x)
    x = layers.Dropout(0.5)(x)
    
    outputs = layers.Dense(num_classes, activation='softmax')(x)
    
    model = models.Model(inputs, outputs, name="cnn_custom")
    return model

model = create_cnn_custom((IMG_HEIGHT, IMG_WIDTH, 3), len(class_names))
model.summary()

## 3. Training
We compile and train the model.

In [None]:
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(),
    metrics=['accuracy']
)

history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=EPOCHS,
    callbacks=[
        tf.keras.callbacks.EarlyStopping(patience=5, restore_best_weights=True),
        tf.keras.callbacks.ReduceLROnPlateau(factor=0.2, patience=3)
    ]
)

## 4. Evaluation
Visualizing loss curves and confusion matrix.

In [None]:
# Plot History
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs_range = range(len(acc))

plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Loss')
plt.show()

In [None]:
# Confusion Matrix
y_true = []
y_pred = []

for images, labels in val_ds:
    preds = model.predict(images, verbose=0)
    y_true.extend(labels.numpy())
    y_pred.extend(np.argmax(preds, axis=1))

cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', xticklabels=class_names, yticklabels=class_names)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
plt.show()