In [19]:
# Install required packages if not already installed
!pip install torch torchvision flask pillow numpy nest_asyncio

# Import necessary libraries
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
from flask import Flask, render_template, request, jsonify
import io
import base64
import os
from datetime import datetime
import sys
import nest_asyncio
import threading

# Apply nest_asyncio to allow Flask to run in Jupyter
nest_asyncio.apply()

# ------------------ Flask Setup ------------------
app = Flask(__name__)

# ------------------ Model Setup ------------------
MODEL_PATH = "cnn.pth"

if not os.path.exists(MODEL_PATH):
    print(f"Model file '{MODEL_PATH}' not found!")
    raise FileNotFoundError(f"{MODEL_PATH} is missing")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class CNNModel(nn.Module):
    def __init__(self, num_classes=6): 
        super(CNNModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 28 * 28, 512)
        self.fc2 = nn.Linear(512, num_classes)

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = self.pool(torch.relu(self.conv3(x)))
        x = x.view(-1, 64 * 28 * 28)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

model = CNNModel(num_classes=6)
model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
model.to(device)
model.eval()

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                         std=[0.229, 0.224, 0.225])
])

class_labels = ["Cardboard", "Glass", "Metal", "Paper", "Plastic", "Trash"]
waste_type_mapping = {
    "Cardboard": "Recyclable",
    "Glass": "Hazardous",
    "Metal": "Recyclable",
    "Paper": "Recyclable",
    "Plastic": "Recyclable",
    "Trash": "Non-Recyclable"
}

bins = {
    "Recyclable": {"level": 0, "status": "Normal", "location": (10, 20), "color": "#4CAF50"},
    "Hazardous": {"level": 0, "status": "Normal", "location": (30, 40), "color": "#FF5722"},
    "Non-Recyclable": {"level": 0, "status": "Normal", "location": (50, 60), "color": "#9E9E9E"}
}

classification_history = []

def update_bin_status():
    for bin_type, data in bins.items():
        if data["level"] > 75:
            data["status"] = "Nearly Full"
        else:
            data["status"] = "Normal"

def optimize_route():
    priority = {"Hazardous": 3, "Recyclable": 2, "Non-Recyclable": 1}
    bins_to_collect = [(k, priority[k], v["location"]) for k, v in bins.items() if v["status"] == "Nearly Full"]
    
    if not bins_to_collect:
        return "No bins need collection", 0

    bins_to_collect.sort(key=lambda x: x[1], reverse=True)

    route = ["Start"]
    current = (0, 0)
    total_distance = 0
    
    while bins_to_collect:
        nearest_idx = None
        min_dist = float('inf')
        
        for i, (_, _, loc) in enumerate(bins_to_collect):
            dist = ((current[0] - loc[0])**2 + (current[1] - loc[1])**2)**0.5
            if dist < min_dist:
                min_dist = dist
                nearest_idx = i
        
        bin_type, _, location = bins_to_collect.pop(nearest_idx)
        route.append(f"Bin_{bin_type}")
        current = location
        total_distance += min_dist
    
    route.append("End")
    return " → ".join(route), total_distance

# ------------------ Flask Routes ------------------

@app.route('/')
def index():
    return render_template('index.html')

@app.route('/get_bin_status')
def get_bin_status():
    return jsonify(bins)

@app.route('/get_history')
def get_history():
    return jsonify(classification_history[-10:])

@app.route('/classify', methods=['POST'])
def classify():
    if 'image' not in request.files:
        return jsonify({'error': 'No image provided'}), 400
    
    file = request.files['image']
    img_bytes = file.read()
    
    try:
        img = Image.open(io.BytesIO(img_bytes)).convert('RGB')
    except Exception as e:
        return jsonify({'error': f"Invalid image: {e}"}), 400
    
    img_tensor = transform(img).unsqueeze(0).to(device)
    
    with torch.no_grad():
        outputs = model(img_tensor)
        _, predicted = torch.max(outputs, 1)
        predicted_class = class_labels[predicted.item()]
        waste_type = waste_type_mapping.get(predicted_class, "Non-Recyclable")
    
    bins[waste_type]["level"] += 10
    if bins[waste_type]["level"] > 100:
        bins[waste_type]["level"] = 100
    
    update_bin_status()
    
    classification_history.append({
        'timestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
        'image': base64.b64encode(img_bytes).decode('utf-8'),
        'predicted_class': predicted_class,
        'waste_type': waste_type
    })
    
    route, distance = optimize_route()
    
    return jsonify({
        'predicted_class': predicted_class,
        'waste_type': waste_type,
        'bins': bins,
        'route': route,
        'distance': f"{distance:.2f} units"
    })

@app.route('/optimize_route')
def get_route():
    route, distance = optimize_route()
    return jsonify({
        'route': route,
        'distance': f"{distance:.2f} units"
    })

@app.route('/reset_bins')
def reset_bins():
    for data in bins.values():
        data["level"] = 0
        data["status"] = "Normal"
    return jsonify(bins)


def run_app():
    app.run(host="0.0.0.0", port=5000, debug=True, use_reloader=False)

thread = threading.Thread(target=run_app)
thread.start()

print("Flask app is running! Access it at http://127.0.0.1:5000/")


Flask app is running! Access it at http://127.0.0.1:5000/
 * Serving Flask app '__main__'
 * Debug mode: on


 * Running on all addresses (0.0.0.0)
 * Running on http://127.0.0.1:5000
 * Running on http://10.98.223.191:5000
Press CTRL+C to quit
