In [None]:
import os
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight

import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import (Conv2D, MaxPooling2D, Flatten, Dense,
                                     Dropout, BatchNormalization)
from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping
from tensorflow.keras.optimizers import Adam

# === 1. Load CSV and Prepare Data ===
data_path = 'cassava_data'
csv_path = os.path.join(data_path, 'train.csv')

df = pd.read_csv(csv_path)
df['label'] = df['label'].astype(str)  # Convert labels to strings

# === 2. Split into Train/Val/Test (70/15/15) ===
train_df, test_df = train_test_split(df, test_size=0.15, stratify=df['label'], random_state=42)
train_df, val_df = train_test_split(train_df, test_size=0.176, stratify=train_df['label'], random_state=42)

# === 3. Class Weights (for balanced training) ===
class_weights = compute_class_weight(class_weight='balanced',
                                     classes=np.unique(train_df['label']),
                                     y=train_df['label'])
class_weights = dict(enumerate(class_weights))

# === 4. Data Augmentation ===
image_dir = os.path.join(data_path, 'train_images')
target_size = (160, 160)

train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=25,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.3,
    horizontal_flip=True,
    fill_mode='nearest'
)

val_datagen = ImageDataGenerator(rescale=1./255)
test_datagen = ImageDataGenerator(rescale=1./255)

# === 5. Data Generators ===
train_gen = train_datagen.flow_from_dataframe(
    train_df,
    image_dir,
    x_col='image_id',
    y_col='label',
    target_size=target_size,
    batch_size=32,
    class_mode='categorical'
)

val_gen = val_datagen.flow_from_dataframe(
    val_df,
    image_dir,
    x_col='image_id',
    y_col='label',
    target_size=target_size,
    batch_size=32,
    class_mode='categorical'
)

test_gen = test_datagen.flow_from_dataframe(
    test_df,
    image_dir,
    x_col='image_id',
    y_col='label',
    target_size=target_size,
    batch_size=32,
    class_mode='categorical',
    shuffle=False
)

# === 6. Model Architecture (Improved CNN) ===
model = Sequential([
    Conv2D(32, (3,3), activation='relu', padding='same', input_shape=(160, 160, 3)),
    BatchNormalization(),
    Conv2D(32, (3,3), activation='relu', padding='same'),
    BatchNormalization(),
    MaxPooling2D(2,2),

    Conv2D(64, (3,3), activation='relu', padding='same'),
    BatchNormalization(),
    Conv2D(64, (3,3), activation='relu', padding='same'),
    BatchNormalization(),
    MaxPooling2D(2,2),

    Conv2D(128, (3,3), activation='relu', padding='same'),
    BatchNormalization(),
    Conv2D(128, (3,3), activation='relu', padding='same'),
    BatchNormalization(),
    MaxPooling2D(2,2),

    Flatten(),
    Dense(256, activation='relu'),
    Dropout(0.5),
    Dense(128, activation='relu'),
    Dropout(0.5),
    Dense(5, activation='softmax')  # 5 disease classes
])

# === 7. Compile Model ===
model.compile(optimizer=Adam(learning_rate=1e-4),
              loss='categorical_crossentropy',
              metrics=['accuracy'])

# === 8. Callbacks ===
callbacks = [
    ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3, verbose=1),
    EarlyStopping(monitor='val_loss', patience=6, restore_best_weights=True, verbose=1)
]

# === 9. Train Model ===
history = model.fit(
    train_gen,
    validation_data=val_gen,
    epochs=15,
    callbacks=callbacks,
    class_weight=class_weights
)
