In [3]:
import dash
from dash import dcc, html, Input, Output, State, dash_table
import plotly.express as px
import pandas as pd
import os
from PIL import Image
import base64
import io
import torch
from torchvision import transforms

# Ensure the uploads directory exists
if not os.path.exists("uploads"):
    os.makedirs("uploads")

# Load the classification model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_path = "trashbox_model.pth"

if os.path.exists(model_path):
    try:
        from torchvision import models
        model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
        num_ftrs = model.fc.in_features
        model.fc = torch.nn.Linear(num_ftrs, 7)
        model.load_state_dict(torch.load(model_path, map_location=device))
        model.to(device)
        model.eval()
        categories = ["Cardboard", "E-Waste", "Glass", "Medical", "Metal", "Paper", "Plastic"]
    except Exception as e:
        print(f"Error loading model: {e}")
        model = None
        categories = []
else:
    model = None
    categories = []

# Waste Type Mapping
waste_type_mapping = {
    "Cardboard": "Recyclable",
    "E-Waste": "Hazardous",
    "Glass": "Recyclable",
    "Medical": "Hazardous",
    "Metal": "Recyclable",
    "Paper": "Recyclable",
    "Plastic": "Recyclable"
}

# Dash App Initialization
app = dash.Dash(__name__)
server = app.server

# Session Data
uploaded_images = []
waste_counts = {category: 0 for category in categories}

def encode_image(image_path):
    with open(image_path, "rb") as img_file:
        return "data:image/png;base64," + base64.b64encode(img_file.read()).decode()

# Layout
app.layout = html.Div([
    html.H1("Smart Waste Management System", style={'textAlign': 'center', 'color': 'blue'}),
    
    # Upload Section
    dcc.Upload(
        id='upload-image',
        children=html.Button("Upload Images", style={'backgroundColor': 'green', 'color': 'white'}),
        multiple=True,
        style={
            'width': '100%',
            'height': '50px',
            'lineHeight': '50px',
            'borderWidth': '1px',
            'borderStyle': 'dashed',
            'borderRadius': '5px',
            'textAlign': 'center',
            'marginBottom': '10px'
        }
    ),
    html.Div(id='upload-message', style={'color': 'green', 'textAlign': 'center'}),
    
    # Image Preview
    html.Div(id='image-preview', style={'display': 'flex', 'flexWrap': 'wrap', 'gap': '10px', 'marginTop': '10px'}),
    
    # Classification Section
    html.Button("Start Classification", id='start-classification', n_clicks=0, style={'backgroundColor': 'blue', 'color': 'white', 'marginTop': '10px'}),
    html.Button("Clear All", id='clear-all', style={'backgroundColor': 'red', 'color': 'white', 'marginLeft': '10px'}),
    
    # Waste Composition Analysis
    dcc.Graph(id='waste-composition-pie', style={'width': '50%', 'height': '400px', 'marginTop': '20px'}),
    
    # Waste Distribution Graph
    dcc.Graph(id='waste-distribution-graph', style={'width': '100%', 'height': '400px'}),
    
    # Classification Results Table
    dash_table.DataTable(
        id='classification-results',
        columns=[
            {'name': 'Item', 'id': 'Item'},
            {'name': 'Filename', 'id': 'Filename'},
            {'name': 'Resolution', 'id': 'Resolution'},
            {'name': 'Category', 'id': 'Category'},
            {'name': 'Confidence (%)', 'id': 'Confidence'},
            {'name': 'Waste Type', 'id': 'Waste Type'}  # New column for Waste Type
        ],
        data=[],
        style_table={'width': '100%', 'marginTop': '10px'}
    ),
    
    # Export Section
    html.Button("Export to Excel", id='export-excel', style={'backgroundColor': 'orange', 'color': 'white', 'marginTop': '10px'}),
    html.Div(id='export-message', style={'color': 'blue', 'textAlign': 'center'})
])

# Callbacks
@app.callback(
    Output('upload-message', 'children'),
    Input('upload-image', 'contents'),
    State('upload-image', 'filename')
)
def upload_images(contents, filenames):
    if contents is not None:
        if not os.path.exists("uploads"):
            os.makedirs("uploads")
        for content, filename in zip(contents, filenames):
            try:
                data = content.split(",")[1]
                img = Image.open(io.BytesIO(base64.b64decode(data)))
                path = os.path.join("uploads", filename)
                img.save(path)
                if path not in uploaded_images:
                    uploaded_images.append(path)
            except Exception as e:
                print(f"Error uploading image {filename}: {e}")
        return "Images uploaded successfully!"
    return ""

@app.callback(
    [Output('classification-results', 'data'), Output('waste-distribution-graph', 'figure')],
    Input('start-classification', 'n_clicks')
)
def start_classification(n_clicks):
    if n_clicks > 0 and uploaded_images:
        global waste_counts
        waste_counts = {category: 0 for category in categories}
        
        table_data = []
        processed_images = set()
        for i, image_path in enumerate(uploaded_images):
            if image_path in processed_images:
                continue
            processed_images.add(image_path)

            # Classify image using model
            try:
                img = Image.open(image_path)
                if model:
                    category, confidence = classify_image(img)
                else:
                    category, confidence = "Unknown", 0.0

                resolution = img.size  # Use actual resolution
                if category in waste_counts:
                    waste_counts[category] += 1
                else:
                    category = 'Unknown'

                # Add Waste Type based on category
                table_data.append({
                    'Item': f'Item {i+1}',
                    'Filename': os.path.basename(image_path),
                    'Resolution': f'{resolution[0]}x{resolution[1]}',
                    'Category': category,
                    'Confidence': f"{confidence:.2f}",
                    'Waste Type': waste_type_mapping.get(category, "Unknown")  # Add Waste Type
                })
            except Exception as e:
                print(f"Error processing image {image_path}: {e}")
        
        fig = px.bar(x=list(waste_counts.keys()), y=list(waste_counts.values()), labels={'x': 'Category', 'y': 'Count'})
        return table_data, fig
    return [], px.bar()

@app.callback(
    [Output('classification-results', 'data', allow_duplicate=True), Output('waste-distribution-graph', 'figure', allow_duplicate=True)],
    Input('clear-all', 'n_clicks'),
    prevent_initial_call=True
)
def clear_all(n_clicks):
    if n_clicks > 0:
        uploaded_images.clear()
        global waste_counts
        waste_counts = {category: 0 for category in categories}
        return [], px.bar()
    return [], px.bar()

@app.callback(
    Output('export-message', 'children'),
    Input('export-excel', 'n_clicks'),
    State('classification-results', 'data')
)
def export_to_excel(n_clicks, data):
    if n_clicks is None:  # Handle case where n_clicks is None
        return ""
    
    if n_clicks > 0 and data:
        try:
            df = pd.DataFrame(data)
            df.to_excel("classification_results.xlsx", index=False)
            return "Data exported to Excel successfully!"
        except Exception as e:
            return f"Error exporting data: {e}"
    return ""

# Callback for Waste Composition Analysis
@app.callback(
    Output('waste-composition-pie', 'figure'),
    Input('classification-results', 'data')
)
def update_waste_composition(data):
    if not data:
        return px.pie(labels=['No Data'], values=[1], title="Waste Composition")
    
    df = pd.DataFrame(data)
    category_counts = df['Category'].value_counts()
    fig = px.pie(
        values=category_counts.values,
        names=category_counts.index,
        title="Waste Composition",
        hole=0.3,  # Add a hole in the middle for a donut chart
        labels={'names': 'Category', 'values': 'Count'}
    )
    fig.update_traces(textposition='inside', textinfo='percent+label')
    fig.update_layout(showlegend=False)
    return fig

# Helper Functions
def preprocess_image(image):
    try:
        transform = transforms.Compose([
            transforms.Resize((224, 224)),  # Resize to the input size expected by ResNet
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        return transform(image).unsqueeze(0)
    except Exception as e:
        print(f"Error preprocessing image: {e}")
        return None

def classify_image(image):
    try:
        input_tensor = preprocess_image(image).to(device)
        with torch.no_grad():
            output = model(input_tensor)
        _, predicted = torch.max(output, 1)
        confidence = torch.nn.functional.softmax(output, dim=1)[0][predicted].item() * 100
        return categories[predicted], confidence
    except Exception as e:
        print(f"Error classifying image: {e}")
        return "Unknown", 0.0

# Run the app in Jupyter Notebook
if __name__ == '__main__':
    app.run_server(mode='inline')

Error preprocessing image: The size of tensor a (4) must match the size of tensor b (3) at non-singleton dimension 0
Error classifying image: 'NoneType' object has no attribute 'to'
