## Imports

In [16]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
import numpy as np
import os
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns
import importlib

# Custom modules
import data_preprocessing
importlib.reload(data_preprocessing)

import sys
sys.path.append('../src')
from data_preprocessing import create_data_generators, calculate_class_weights
from config import Config


print("=== IMPORTS FINISHED ===")

=== IMPORTS FINISHED ===


## Data Preparation

In [5]:
train_datagen, val_datagen = create_data_generators()

# Create data generators
train_generator = train_datagen.flow_from_directory(
    '../data/processed/train',
    target_size=(Config.IMG_HEIGHT, Config.IMG_WIDTH),
    batch_size=Config.BATCH_SIZE,
    class_mode='categorical',
    shuffle=True
)

val_generator = val_datagen.flow_from_directory(
    '../data/processed/val', 
    target_size=(Config.IMG_HEIGHT, Config.IMG_WIDTH),
    batch_size=Config.BATCH_SIZE,
    class_mode='categorical',
    shuffle=False
)

test_generator = val_datagen.flow_from_directory(
    '../data/processed/test',
    target_size=(Config.IMG_HEIGHT, Config.IMG_WIDTH), 
    batch_size=Config.BATCH_SIZE,
    class_mode='categorical',
    shuffle=False
)

cat_classes = train_generator.class_indices

print(f"Classes: {cat_classes}")
print(f"Training samples: {train_generator.samples}")
print(f"Validation samples: {val_generator.samples}")
print(f"Test samples: {test_generator.samples}")

Found 380 images belonging to 3 classes.
Found 83 images belonging to 3 classes.
Found 83 images belonging to 3 classes.
Classes: {'both': 0, 'karamela': 1, 'lacta': 2}
Training samples: 380
Validation samples: 83
Test samples: 83


## Class Weights

In [17]:
class_weights = calculate_class_weights('../data/raw', cat_classes)
print("Using class weights:", class_weights)

Class weights for handling imbalance: {0: 2.1411764705882352, 1: 0.9479166666666666, 2: 0.6765799256505576}
Using class weights: {0: 2.1411764705882352, 1: 0.9479166666666666, 2: 0.6765799256505576}


## Baseline Model

In [18]:
from model_utils import create_baseline_model, compile_model

print("=== BASELINE MODEL ===")
baseline_model = create_baseline_model()
baseline_model = compile_model(baseline_model, learning_rate=1e-3)

# Display model architecture
baseline_model.summary()

=== BASELINE MODEL ===


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


### Train

In [None]:
# Cell 5: Train Baseline Model
print("=== TRAINING BASELINE ===")
history = baseline_model.fit(
    train_generator,
    epochs=15,
    validation_data=val_generator,
    class_weight=class_weights,  # Using our calculated weights
    verbose=1
)

=== TRAINING BASELINE ===
Epoch 1/15
[1m12/12[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m49s[0m 4s/step - accuracy: 0.3184 - loss: 1.3264 - val_accuracy: 0.3253 - val_loss: 1.1212
Epoch 2/15
[1m12/12[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m40s[0m 3s/step - accuracy: 0.3684 - loss: 1.1013 - val_accuracy: 0.5904 - val_loss: 1.0946
Epoch 3/15


### Evaluate

In [None]:
# Cell 6: Evaluate Baseline
print("=== BASELINE EVALUATION ===")
# Plot training history
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'], label='Training Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.title('Baseline Model Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Baseline Model Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.tight_layout()
plt.show()

### Prediction and Metrics

In [None]:
# Cell 7: Baseline Predictions and Metrics
# Get true labels and predictions
val_generator.reset()
y_true = val_generator.classes
y_pred = baseline_model.predict(val_generator)
y_pred_classes = np.argmax(y_pred, axis=1)

# Classification report
class_names = list(train_generator.class_indices.keys())
print("\n=== BASELINE CLASSIFICATION REPORT ===")
print(classification_report(y_true, y_pred_classes, target_names=class_names))

# Confusion matrix
cm = confusion_matrix(y_true, y_pred_classes)
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=class_names, yticklabels=class_names)
plt.title('Baseline Model Confusion Matrix')
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.show()