In [None]:
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout, BatchNormalization

# Fashion MNIST class labels
fashion_labels = [
    "T-shirt/top", "Trouser", "Pullover", "Dress", "Coat",
    "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"
]

# Load Fashion MNIST dataset
print("Loading Fashion MNIST dataset...")
fashion = tf.keras.datasets.fashion_mnist
(x_train, y_train), (x_test, y_test) = fashion.load_data()

# Normalize and reshape
x_train = x_train.reshape(-1, 28, 28, 1) / 255.0
x_test = x_test.reshape(-1, 28, 28, 1) / 255.0

print("Dataset shapes:")
print(f"x_train: {x_train.shape}")
print(f"y_train: {y_train.shape}")
print(f"x_test: {x_test.shape}")
print(f"y_test: {y_test.shape}")

# ============================================================
# MODEL 1: SIMPLE MODEL (< 93% accuracy)
# ============================================================
print("\n" + "="*60)
print("TRAINING SIMPLE MODEL (Expected < 93% accuracy)")
print("="*60)

simple_model = Sequential([
    Conv2D(16, (3, 3), activation='relu', input_shape=(28, 28, 1)),
    MaxPooling2D((2, 2)),
    Flatten(),
    Dense(64, activation='relu'),
    Dense(10, activation='softmax')
])

simple_model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

print("\nSimple Model Architecture:")
simple_model.summary()

history_simple = simple_model.fit(
    x_train, y_train,
    epochs=5,
    batch_size=128,
    validation_split=0.1,
    verbose=1
)

simple_loss, simple_accuracy = simple_model.evaluate(x_test, y_test, verbose=0)
print("\n" + "="*60)
print("SIMPLE MODEL RESULTS")
print("="*60)
print(f"Test Accuracy: {simple_accuracy*100:.2f}%")
print(f"Test Loss: {simple_loss:.4f}")

# ============================================================
# MODEL 2: ADVANCED MODEL (>= 93% accuracy)
# ============================================================
print("\n" + "="*60)
print("TRAINING ADVANCED MODEL (Expected >= 93% accuracy)")
print("="*60)

advanced_model = Sequential([
    Conv2D(32, (3, 3), padding='same', activation='relu', input_shape=(28, 28, 1)),
    BatchNormalization(),
    Conv2D(32, (3, 3), padding='same', activation='relu'),
    BatchNormalization(),
    MaxPooling2D((2, 2)),
    Dropout(0.25),

    Conv2D(64, (3, 3), padding='same', activation='relu'),
    BatchNormalization(),
    Conv2D(64, (3, 3), padding='same', activation='relu'),
    BatchNormalization(),
    MaxPooling2D((2, 2)),
    Dropout(0.25),

    Flatten(),
    Dense(512, activation='relu'),
    BatchNormalization(),
    Dropout(0.5),
    Dense(10, activation='softmax')
])

advanced_model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

print("\nAdvanced Model Architecture:")
advanced_model.summary()

history_advanced = advanced_model.fit(
    x_train, y_train,
    epochs=15,
    batch_size=64,
    validation_split=0.1,
    verbose=1
)

advanced_loss, advanced_accuracy = advanced_model.evaluate(x_test, y_test, verbose=0)
print("\n" + "="*60)
print("ADVANCED MODEL RESULTS")
print("="*60)
print(f"Test Accuracy: {advanced_accuracy*100:.2f}%")
print(f"Test Loss: {advanced_loss:.4f}")

# ============================================================
# COMPARISON
# ============================================================
print("\n" + "="*60)
print("MODEL COMPARISON")
print("="*60)
print(f"Simple Model Accuracy:   {simple_accuracy*100:.2f}%")
print(f"Advanced Model Accuracy: {advanced_accuracy*100:.2f}%")
print(f"Improvement: {(advanced_accuracy - simple_accuracy)*100:.2f}%")

# ============================================================
# PREDICTIONS AND ERROR ANALYSIS (Using Advanced Model)
# ============================================================
print("\n" + "="*60)
print("PREDICTIONS & ERROR ANALYSIS (Advanced Model)")
print("="*60)

# Predict all test samples
predictions = advanced_model.predict(x_test, verbose=0)
predicted_labels = np.argmax(predictions, axis=1)

# Find misclassified images
wrong_indices = np.where(predicted_labels != y_test)[0]
correct_indices = np.where(predicted_labels == y_test)[0]

print(f"Total test images: {len(y_test)}")
print(f"Correct predictions: {len(correct_indices)}")
print(f"Wrong predictions: {len(wrong_indices)}")
print(f"Error rate: {(len(wrong_indices)/len(y_test))*100:.2f}%")

# ============================================================
# VISUALIZE WRONG PREDICTIONS
# ============================================================
if len(wrong_indices) >= 2:
    print("\n" + "="*60)
    print("DISPLAYING 2 WRONG PREDICTION EXAMPLES")
    print("="*60)
    
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))
    
    for i in range(2):
        idx = wrong_indices[i]
        ax = axes[i]
        
        # Display image
        ax.imshow(x_test[idx].reshape(28, 28), cmap='gray')
        ax.set_title(
            f"Predicted: {fashion_labels[predicted_labels[idx]]}\n"
            f"Actual: {fashion_labels[y_test[idx]]}",
            fontsize=11, color='red', fontweight='bold'
        )
        ax.axis('off')
        
        print(f"\nWrong prediction #{i+1}:")
        print(f"  Index: {idx}")
        print(f"  Predicted: {fashion_labels[predicted_labels[idx]]}")
        print(f"  Actual: {fashion_labels[y_test[idx]]}")
    
    plt.tight_layout()
    plt.savefig('fashion_wrong_predictions.png', dpi=150, bbox_inches='tight')
    print("\nWrong predictions saved as 'fashion_wrong_predictions.png'")
    plt.show()

# ============================================================
# VISUALIZE CORRECT PREDICTIONS
# ============================================================
print("\n" + "="*60)
print("DISPLAYING 4 CORRECT PREDICTION EXAMPLES")
print("="*60)

fig, axes = plt.subplots(2, 2, figsize=(10, 10))
axes = axes.flatten()

for i in range(4):
    idx = correct_indices[i]
    ax = axes[i]
    
    ax.imshow(x_test[idx].reshape(28, 28), cmap='gray')
    ax.set_title(
        f"Predicted: {fashion_labels[predicted_labels[idx]]}\n"
        f"Actual: {fashion_labels[y_test[idx]]} âœ“",
        fontsize=11, color='green', fontweight='bold'
    )
    ax.axis('off')

plt.tight_layout()
plt.savefig('fashion_correct_predictions.png', dpi=150, bbox_inches='tight')
print("Correct predictions saved as 'fashion_correct_predictions.png'")
plt.show()

# ============================================================
# TRAINING HISTORY COMPARISON
# ============================================================
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(14, 10))

# Simple model accuracy
ax1.plot(history_simple.history['accuracy'], label='Training', linewidth=2)
ax1.plot(history_simple.history['val_accuracy'], label='Validation', linewidth=2)
ax1.set_xlabel('Epoch', fontsize=11)
ax1.set_ylabel('Accuracy', fontsize=11)
ax1.set_title('Simple Model - Accuracy', fontsize=12, fontweight='bold')
ax1.legend()
ax1.grid(True, alpha=0.3)
ax1.set_ylim([0.7, 1.0])

# Simple model loss
ax2.plot(history_simple.history['loss'], label='Training', linewidth=2)
ax2.plot(history_simple.history['val_loss'], label='Validation', linewidth=2)
ax2.set_xlabel('Epoch', fontsize=11)
ax2.set_ylabel('Loss', fontsize=11)
ax2.set_title('Simple Model - Loss', fontsize=12, fontweight='bold')
ax2.legend()
ax2.grid(True, alpha=0.3)

# Advanced model accuracy
ax3.plot(history_advanced.history['accuracy'], label='Training', linewidth=2, color='green')
ax3.plot(history_advanced.history['val_accuracy'], label='Validation', linewidth=2, color='orange')
ax3.set_xlabel('Epoch', fontsize=11)
ax3.set_ylabel('Accuracy', fontsize=11)
ax3.set_title('Advanced Model - Accuracy', fontsize=12, fontweight='bold')
ax3.legend()
ax3.grid(True, alpha=0.3)
ax3.set_ylim([0.7, 1.0])

# Advanced model loss
ax4.plot(history_advanced.history['loss'], label='Training', linewidth=2, color='green')
ax4.plot(history_advanced.history['val_loss'], label='Validation', linewidth=2, color='orange')
ax4.set_xlabel('Epoch', fontsize=11)
ax4.set_ylabel('Loss', fontsize=11)
ax4.set_title('Advanced Model - Loss', fontsize=12, fontweight='bold')
ax4.legend()
ax4.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('fashion_training_comparison.png', dpi=150, bbox_inches='tight')
plt.show()

# ============================================================
# CONFUSION MATRIX (for Advanced Model)
# ============================================================
from sklearn.metrics import confusion_matrix
import seaborn as sns

cm = confusion_matrix(y_test, predicted_labels)

plt.figure(figsize=(12, 10))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=fashion_labels, 
            yticklabels=fashion_labels,
            cbar_kws={'label': 'Count'})
plt.xlabel('Predicted Label', fontsize=12, fontweight='bold')
plt.ylabel('True Label', fontsize=12, fontweight='bold')
plt.title('Confusion Matrix - Advanced Model', fontsize=14, fontweight='bold')
plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=0)
plt.tight_layout()
plt.savefig('fashion_confusion_matrix.png', dpi=150, bbox_inches='tight')
plt.show()

print("\n" + "="*60)
print("ASSIGNMENT 7 COMPLETED!")
print("="*60)
print("\nGenerated files:")
print("  - fashion_wrong_predictions.png")
print("  - fashion_correct_predictions.png")
print("  - fashion_training_comparison.png")
print("  - fashion_confusion_matrix.png")