# Model Training & Comparison (ANN vs CNN)

This notebook trains and compares simple **ANN** and **CNN** models on:

- **MNIST** (ANN & CNN)
- **CIFAR-10** (CNN)
- **Titanic** (ANN)

You can toggle sections to run the experiments you want.

## Setup

In [None]:
import os, time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score

import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.datasets import mnist, cifar10


## Helper Functions

In [None]:
def plot_history(history, title_prefix=''):
    # Accuracy
    plt.figure()
    plt.plot(history.history.get('accuracy', []), label='train_acc')
    plt.plot(history.history.get('val_accuracy', []), label='val_acc')
    plt.title(f"{title_prefix} Accuracy")
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.show()

    # Loss
    plt.figure()
    plt.plot(history.history.get('loss', []), label='train_loss')
    plt.plot(history.history.get('val_loss', []), label='val_loss')
    plt.title(f"{title_prefix} Loss")
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.show()

def evaluate_model(model, X_test, y_test, is_logits=False):
    preds = model.predict(X_test)
    if preds.ndim > 1 and preds.shape[-1] > 1:
        y_pred = np.argmax(preds, axis=1)
    else:
        y_pred = (preds.flatten() > 0.5).astype(int)
    acc = accuracy_score(y_test, y_pred)
    print('Accuracy:', acc)
    print('Classification Report:')
    print(classification_report(y_test, y_pred))
    print('Confusion Matrix:')
    print(confusion_matrix(y_test, y_pred))
    return acc


## Section A: MNIST — ANN vs CNN

In [None]:
# Load MNIST
(X_train, y_train), (X_test, y_test) = mnist.load_data()

# Normalize
X_train = X_train.astype('float32')/255.0
X_test = X_test.astype('float32')/255.0

# Prepare one-hot labels for training (sparse labels are also fine with sparse_categorical_crossentropy)
num_classes = 10

# -------- ANN --------
ann = models.Sequential([
    layers.Input(shape=(28,28)),
    layers.Flatten(),
    layers.Dense(256, activation='relu'),
    layers.Dropout(0.3),
    layers.Dense(128, activation='relu'),
    layers.Dropout(0.3),
    layers.Dense(num_classes, activation='softmax')
])
ann.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
hist_ann = ann.fit(X_train, y_train, validation_split=0.1, epochs=5, batch_size=128, verbose=1)
plot_history(hist_ann, 'MNIST ANN')
acc_ann = evaluate_model(ann, X_test, y_test)

# -------- CNN --------
X_train_cnn = np.expand_dims(X_train, -1)
X_test_cnn = np.expand_dims(X_test, -1)

cnn = models.Sequential([
    layers.Input(shape=(28,28,1)),
    layers.Conv2D(32, (3,3), activation='relu'),
    layers.MaxPooling2D((2,2)),
    layers.Conv2D(64, (3,3), activation='relu'),
    layers.MaxPooling2D((2,2)),
    layers.Flatten(),
    layers.Dropout(0.3),
    layers.Dense(128, activation='relu'),
    layers.Dense(num_classes, activation='softmax')
])
cnn.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
hist_cnn = cnn.fit(X_train_cnn, y_train, validation_split=0.1, epochs=5, batch_size=128, verbose=1)
plot_history(hist_cnn, 'MNIST CNN')
acc_cnn = evaluate_model(cnn, X_test_cnn, y_test)

print('MNIST Results -> ANN:', acc_ann, '| CNN:', acc_cnn)


## Section B: CIFAR-10 — CNN baseline

In [None]:
# Load CIFAR-10
(X_train_c, y_train_c), (X_test_c, y_test_c) = cifar10.load_data()
y_train_c = y_train_c.flatten()
y_test_c = y_test_c.flatten()

# Normalize
X_train_c = X_train_c.astype('float32')/255.0
X_test_c = X_test_c.astype('float32')/255.0

num_classes = 10
cifar_cnn = models.Sequential([
    layers.Input(shape=(32,32,3)),
    layers.Conv2D(32, (3,3), activation='relu', padding='same'),
    layers.MaxPooling2D((2,2)),
    layers.Conv2D(64, (3,3), activation='relu', padding='same'),
    layers.MaxPooling2D((2,2)),
    layers.Conv2D(128, (3,3), activation='relu', padding='same'),
    layers.MaxPooling2D((2,2)),
    layers.Flatten(),
    layers.Dropout(0.4),
    layers.Dense(256, activation='relu'),
    layers.Dense(num_classes, activation='softmax')
])
cifar_cnn.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
hist_cifar = cifar_cnn.fit(X_train_c, y_train_c, validation_split=0.1, epochs=5, batch_size=128, verbose=1)
plot_history(hist_cifar, 'CIFAR-10 CNN')
acc_cifar = evaluate_model(cifar_cnn, X_test_c, y_test_c)
print('CIFAR-10 Result -> CNN:', acc_cifar)


## Section C: Titanic — ANN baseline

Place `train.csv` at `../datasets/titanic/train.csv`. This is a quick baseline using a small ANN on tabular data.

In [None]:
csv_path = '../datasets/titanic/train.csv'
if os.path.exists(csv_path):
    df = pd.read_csv(csv_path)
    df = df.dropna(subset=['Age','Fare','Sex','Pclass','Survived'])
    df['Sex'] = df['Sex'].map({'male':0, 'female':1})
    X = df[['Pclass','Sex','Age','Fare']].values.astype('float32')
    y = df['Survived'].values.astype('int32')
    
    scaler = StandardScaler()
    X = scaler.fit_transform(X)
    X_train_t, X_test_t, y_train_t, y_test_t = train_test_split(X, y, test_size=0.2, random_state=42)
    
    titanic_ann = models.Sequential([
        layers.Input(shape=(X.shape[1],)),
        layers.Dense(64, activation='relu'),
        layers.Dropout(0.2),
        layers.Dense(32, activation='relu'),
        layers.Dense(1, activation='sigmoid')
    ])
    titanic_ann.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
    hist_titanic = titanic_ann.fit(X_train_t, y_train_t, validation_split=0.1, epochs=10, batch_size=32, verbose=1)
    plot_history(hist_titanic, 'Titanic ANN')
    acc_titanic = evaluate_model(titanic_ann, X_test_t, y_test_t)
    print('Titanic Result -> ANN:', acc_titanic)
else:
    print('Titanic CSV not found at', csv_path)


## Summary Cell — Compare Accuracies (if you ran multiple sections)

In [None]:
summary = {}
try:
    summary['MNIST_ANN'] = float(acc_ann)
    summary['MNIST_CNN'] = float(acc_cnn)
except Exception:
    pass
try:
    summary['CIFAR10_CNN'] = float(acc_cifar)
except Exception:
    pass
try:
    summary['Titanic_ANN'] = float(acc_titanic)
except Exception:
    pass
print('Accuracy Summary:')
for k,v in summary.items():
    print(k, ':', v)

if summary:
    plt.figure()
    plt.bar(list(summary.keys()), list(summary.values()))
    plt.xticks(rotation=45)
    plt.ylabel('Accuracy')
    plt.title('Model Accuracy Comparison')
    plt.tight_layout()
    plt.show()
