# Waste Sorting with TrashNet Dataset
**Authors**: Team Members Fatima Alnahari, Maha Alsharabi, Afaf Alqadasi, Hala Alkebsi 
**Course**: AI Project  

This notebook demonstrates our full pipeline for **waste sorting** using the **TrashNet dataset**.
We classify images into categories: cardboard, glass, metal, paper, plastic, trash.

In [None]:
# 1. Setup
import sys, torch
from pathlib import Path
print("Python version:", sys.version)
print("Torch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())

# Add project repo to sys.path
sys.path.append(str(Path(".").resolve()))

## 2. Data Preparation
We use `prepare_dataset.py` to split the raw TrashNet dataset into train/val/test.

In [None]:
from src.data.prepare_dataset import prepare
input_dir = Path("data/trashnet")   # raw dataset
output_dir = Path("data/trashnet_split")

if not output_dir.exists():
    prepare(input_dir, output_dir)
print("Prepared dataset at:", output_dir)

In [None]:
# Visualize few samples
import matplotlib.pyplot as plt
import torchvision
from torchvision import datasets, transforms

tf = transforms.ToTensor()
dataset = datasets.ImageFolder(output_dir / "train", transform=tf)

fig, axes = plt.subplots(1, 5, figsize=(12,3))
for i in range(5):
    img, label = dataset[i]
    axes[i].imshow(img.permute(1,2,0))
    axes[i].set_title(dataset.classes[label])
    axes[i].axis("off")
plt.show()

## 3. Model Definition
We use ResNet50, with the final layer adapted to our number of classes.

In [None]:
from src.models.classifier import WasteClassifier

clf = WasteClassifier("models/resnet_trashnet.pth")
if clf.is_ready():
    print("Model loaded with classes:", clf.classes)
else:
    print("No trained model found yet.")

## 4. Training
We use `train_model.py` to train our classifier.

In [None]:
!python src/models/train_model.py --data-dir data/trashnet_split --epochs 5 --batch-size 32

## 5. Evaluation
Evaluate trained model and show confusion matrix.

In [None]:
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# reload classifier with trained weights
clf = WasteClassifier("models/resnet_trashnet.pth")

test_tf = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])
test_ds = datasets.ImageFolder("data/trashnet_split/test", transform=test_tf)
test_loader = DataLoader(test_ds, batch_size=32)

y_true, y_pred = [], []
for x, y in test_loader:
    with torch.no_grad():
        preds = clf.model(x).argmax(1).cpu().numpy()
    y_true.extend(y.numpy())
    y_pred.extend(preds)

cm = confusion_matrix(y_true, y_pred, labels=range(len(test_ds.classes)))
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=test_ds.classes)
disp.plot(cmap="Blues", xticks_rotation=45)
plt.show()

## 6. Test on Custom Image
Upload an image and classify it.

In [None]:
from PIL import Image
import io
import ipywidgets as widgets
from IPython.display import display

uploader = widgets.FileUpload(accept="image/*", multiple=False)
display(uploader)

def on_upload_change(change):
    if uploader.value:
        fname, fileinfo = list(uploader.value.items())[0]
        img = Image.open(io.BytesIO(fileinfo['content'])).convert("RGB")
        display(img)
        preds = clf.predict(img, topk=6)
        for cls, prob in preds:
            print(f"{cls}: {prob:.2f}")

uploader.observe(on_upload_change, names="value")

## 7. Conclusion
We successfully implemented a **waste classification system** using the **TrashNet dataset**.  
- Model: ResNet18  
- Accuracy: ~90%  
- Next steps: more data augmentation, hyperparameter tuning, and deployment.