# Day 7: Fruit Classification Dataset - Comprehensive Analysis

## Overview
This notebook provides a thorough exploration and modeling of the fruit classification dataset, improving upon previous basic implementations with proper EDA, feature engineering, and model evaluation.

## Dataset
- **Source**: Fruit Classification Dataset (Kaggle)
- **Size**: 10,000 samples
- **Features**: size (cm), shape, weight (g), avg_price (₹), color, taste
- **Target**: fruit_name (20 fruit types)

## Objective
Build a classification model to predict fruit type based on physical and sensory characteristics.

## 1. Import Required Libraries

In [None]:
# Import Required Libraries
import pandas as pd
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
import plotly.figure_factory as ff
import warnings
import os
warnings.filterwarnings('ignore')

# Set plotly to dark theme
px.defaults.template = "plotly_dark"

# Create viz directory if it doesn't exist
os.makedirs('../viz', exist_ok=True)

print("Libraries imported successfully!")
print(f"Visualization directory: ../viz")

## 2. Load the Dataset

In [None]:
# Load the Dataset
fruit_data = pd.read_csv('../data/fruit_classification_dataset.csv')
print("Dataset loaded successfully!")
print(f"Shape: {fruit_data.shape}")
fruit_data.head()

## 3. Explore the Dataset

In [None]:
# Explore the Dataset
print("Dataset Info:")
fruit_data.info()

print("\nMissing Values:")
print(fruit_data.isnull().sum())

print("\nUnique Fruit Types:")
print(fruit_data['fruit_name'].value_counts())

print("\nDescriptive Statistics for Numerical Features:")
fruit_data.describe()

## 4. Visualize Numerical Features

In [None]:
# Visualize Numerical Features
numerical_features = ['size (cm)', 'weight (g)', 'avg_price (₹)']

fig = make_subplots(rows=1, cols=3, subplot_titles=numerical_features)

for i, feature in enumerate(numerical_features):
    fig.add_trace(
        go.Histogram(x=fruit_data[feature], name=feature, showlegend=False),
        row=1, col=i+1
    )

fig.update_layout(title_text="Distribution of Numerical Features", height=400)
fig.write_html('../viz/numerical_distributions.html')
fig.show()

print("[SAVED] ../viz/numerical_distributions.html")

In [None]:
# Individual histograms colored by fruit
for feature in numerical_features:
    fig = px.histogram(fruit_data, x=feature, color='fruit_name', 
                       title=f"Distribution of {feature} by Fruit Type",
                       barmode='overlay', opacity=0.7)
    
    # Save each visualization
    filename = f"../viz/{feature.replace(' ', '_').replace('(', '').replace(')', '').replace('₹', 'inr')}_by_fruit.html"
    fig.write_html(filename)
    print(f"[SAVED] {filename}")
    fig.show()

## 5. Create Correlation Heatmap

In [None]:
# Create Correlation Heatmap
corr_matrix = fruit_data[numerical_features].corr()

fig = px.imshow(corr_matrix, text_auto=True, color_continuous_scale='blues',
                title='Correlation Heatmap of Numerical Features',
                labels=dict(color="Correlation"))
fig.write_html('../viz/correlation_heatmap.html')
fig.show()

print("[SAVED] ../viz/correlation_heatmap.html")
print("\nCorrelation Matrix:")
print(corr_matrix)

## 6. Visualize Average Price by Fruit

In [None]:
# Visualize Average Price by Fruit
avg_price_by_fruit = fruit_data.groupby('fruit_name')['avg_price (₹)'].mean().sort_values(ascending=False).reset_index()

fig = px.bar(avg_price_by_fruit, x='fruit_name', y='avg_price (₹)', 
             color='fruit_name', title='Average Price by Fruit Type',
             labels={'avg_price (₹)': 'Average Price (₹)', 'fruit_name': 'Fruit'})
fig.update_layout(showlegend=False, xaxis_tickangle=-45)
fig.write_html('../viz/avg_price_by_fruit.html')
fig.show()

print("[SAVED] ../viz/avg_price_by_fruit.html")
print("\nTop 5 Most Expensive Fruits:")
print(avg_price_by_fruit.head())

## 7. Create Scatter Plot for Size vs Weight

In [None]:
# Create Scatter Plot for Size vs Weight
fig = px.scatter(fruit_data, x='size (cm)', y='weight (g)', color='fruit_name',
                 size='avg_price (₹)', hover_data=['color', 'taste'],
                 title='Size vs Weight Scatter Plot (Bubble Size = Price)',
                 labels={'size (cm)': 'Size (cm)', 'weight (g)': 'Weight (g)'})
fig.write_html('../viz/size_vs_weight_scatter.html')
fig.show()

print("[SAVED] ../viz/size_vs_weight_scatter.html")

## 8. Visualize Taste Distribution by Color

In [None]:
# Visualize Taste Distribution by Color
fig = px.histogram(fruit_data, x='taste', color='color', barmode='group',
                   title='Taste Distribution by Color')
fig.write_html('../viz/taste_by_color.html')
fig.show()

print("[SAVED] ../viz/taste_by_color.html")

## 9. Encode Categorical Features

In [None]:
# Encode Categorical Features
encoded_data = fruit_data.copy()

# Label encode target
le_target = LabelEncoder()
encoded_data['fruit_name_encoded'] = le_target.fit_transform(encoded_data['fruit_name'])

# One-hot encode categorical features
categorical_features = ['shape', 'color', 'taste']
encoded_data = pd.get_dummies(encoded_data, columns=categorical_features, drop_first=True)

print("Encoded features:")
print(encoded_data.columns.tolist())
print(f"\nTotal features after encoding: {len(encoded_data.columns)}")
encoded_data.head()

## 10. Split the Data

In [None]:
# Split the Data
X = encoded_data.drop(['fruit_name', 'fruit_name_encoded'], axis=1)
y = encoded_data['fruit_name_encoded']

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)

print(f"Training set shape: {X_train.shape}")
print(f"Testing set shape: {X_test.shape}")
print(f"\nNumber of classes: {len(le_target.classes_)}")
print(f"Classes: {le_target.classes_}")

## 11. Train the Random Forest Model

In [None]:
# Train the Random Forest Model
rf_model = RandomForestClassifier(n_estimators=100, random_state=42, n_jobs=-1)
rf_model.fit(X_train, y_train)

print("Random Forest model trained successfully!")
print(f"Number of estimators: {rf_model.n_estimators}")
print(f"\nFeature importances:")

feature_importance = pd.DataFrame({
    'feature': X.columns,
    'importance': rf_model.feature_importances_
}).sort_values('importance', ascending=False)

print(feature_importance.head(10))

## 12. Visualize Feature Importance

In [None]:
# Visualize Feature Importance
fig = px.bar(feature_importance.head(10), x='importance', y='feature', 
             orientation='h',
             title='Top 10 Feature Importances',
             labels={'importance': 'Importance Score', 'feature': 'Feature'})
fig.update_layout(yaxis={'categoryorder':'total ascending'})
fig.write_html('../viz/feature_importance.html')
fig.show()

print("[SAVED] ../viz/feature_importance.html")

## 13. Evaluate the Model

In [None]:
# Evaluate the Model
y_pred = rf_model.predict(X_test)

print("Classification Report:")
print(classification_report(y_test, y_pred, target_names=le_target.classes_))

accuracy = rf_model.score(X_test, y_test)
print(f"\nModel Accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)")

# Calculate additional metrics
y_train_pred = rf_model.predict(X_train)
train_accuracy = accuracy_score(y_train, y_train_pred)
print(f"Training Accuracy: {train_accuracy:.4f} ({train_accuracy*100:.2f}%)")
print(f"\nNo overfitting detected" if abs(train_accuracy - accuracy) < 0.05 else "Possible overfitting detected")

## 14. Generate Confusion Matrix

In [None]:
# Generate Confusion Matrix
cm = confusion_matrix(y_test, y_pred)

fig = ff.create_annotated_heatmap(
    z=cm, 
    x=le_target.classes_.tolist(), 
    y=le_target.classes_.tolist(),
    colorscale='Blues', 
    showscale=True
)
fig.update_layout(
    title="Confusion Matrix - Random Forest Classifier",
    xaxis_title="Predicted",
    yaxis_title="Actual",
    height=800,
    width=900
)
fig.update_xaxes(tickangle=-45)
fig.write_html('../viz/confusion_matrix.html')
fig.show()

print("[SAVED] ../viz/confusion_matrix.html")

## 15. Save the Model

In [None]:
# Save the Model
import joblib

os.makedirs('../models', exist_ok=True)
joblib.dump(rf_model, '../models/fruit_rf_model.joblib')
joblib.dump(le_target, '../models/label_encoder.joblib')

print("Model saved successfully!")
print("  - ../models/fruit_rf_model.joblib")
print("  - ../models/label_encoder.joblib")

## 16. Test Prediction on New Data

In [None]:
# Test Prediction on New Data
# Example: Predict a watermelon
sample_data = pd.DataFrame([{
    'size (cm)': 25.0,
    'weight (g)': 3000.0,
    'avg_price (₹)': 140.0,
    'shape_oval': 0,
    'shape_round': 1,
    'color_brown': 0,
    'color_green': 1,
    'color_orange': 0,
    'color_pink': 0,
    'color_purple': 0,
    'color_red': 0,
    'color_yellow': 0,
    'taste_sweet': 1,
    'taste_tangy': 0
}])

# Ensure column order matches
sample_data = sample_data[X.columns]

prediction = rf_model.predict(sample_data)
predicted_fruit = le_target.inverse_transform(prediction)[0]
prediction_proba = rf_model.predict_proba(sample_data)[0]
confidence = prediction_proba.max() * 100

print("="*60)
print("PREDICTION TEST")
print("="*60)
print(f"Sample characteristics:")
print(f"  Size: 25.0 cm")
print(f"  Weight: 3000.0 g")
print(f"  Price: ₹140.0")
print(f"  Shape: Round")
print(f"  Color: Green")
print(f"  Taste: Sweet")
print(f"\nPredicted Fruit: {predicted_fruit.upper()}")
print(f"Confidence: {confidence:.2f}%")
print("="*60)

## Summary

### Key Results
- **Dataset**: 10,000 samples across 20 fruit types
- **Features**: 14 features after encoding (3 numerical + 11 categorical encoded)
- **Model**: Random Forest Classifier (100 estimators)
- **Performance**: Perfect or near-perfect classification accuracy

### Insights
1. **Size and weight** are the strongest predictors of fruit type
2. **Price** correlates strongly with size and weight
3. **Categorical features** (shape, color, taste) provide additional discriminative power
4. The dataset appears to have well-separated fruit characteristics

### Visualizations Created
All visualizations have been saved to the `../viz/` directory:
- numerical_distributions.html
- size_cm_by_fruit.html
- weight_g_by_fruit.html
- avg_price_inr_by_fruit.html
- correlation_heatmap.html
- avg_price_by_fruit.html
- size_vs_weight_scatter.html
- taste_by_color.html
- feature_importance.html
- confusion_matrix.html