In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import StratifiedShuffleSplit
from tqdm import tqdm
import numpy as np
import os

In [2]:
data_dir = '../New_Data'
num_classes = 6
batch_size = 32
img_size = 128
epochs = 25
learning_rate = 0.001
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
transform = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

In [4]:
full_dataset = datasets.ImageFolder(root=data_dir, transform=transform)
class_names = full_dataset.classes
targets = np.array(full_dataset.targets)

splitter = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
for train_idx, val_idx in splitter.split(np.zeros(len(targets)), targets):
    train_dataset = Subset(full_dataset, train_idx)
    val_dataset = Subset(full_dataset, val_idx)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

In [5]:
model = models.resnet18(pretrained=True)



In [6]:
for name, param in model.named_parameters():
    if "layer4" in name or "fc" in name:
        param.requires_grad = True
    else:
        param.requires_grad = False

model.fc = nn.Sequential(
    nn.Linear(model.fc.in_features, 128),
    nn.ReLU(),
    nn.Dropout(0.3),
    nn.Linear(128, num_classes)
)

model = model.to(device)

In [7]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=learning_rate)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

In [8]:
best_acc = 0
patience = 5
counter = 0

In [9]:
for epoch in range(epochs):
    model.train()
    train_loss = 0.0
    train_correct = 0
    train_total = 0

    print(f"\nEpoch {epoch+1}/{epochs}")
    loop = tqdm(train_loader, desc="Training", leave=False)
    for images, labels in loop:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        train_correct += (predicted == labels).sum().item()
        train_total += labels.size(0)
        loop.set_postfix(loss=loss.item(), acc=100. * train_correct / train_total)

    train_acc = 100. * train_correct / train_total
    print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")

    model.eval()
    val_loss = 0.0
    val_correct = 0
    val_total = 0
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            val_correct += (predicted == labels).sum().item()
            val_total += labels.size(0)

    val_acc = 100. * val_correct / val_total
    print(f"Validation Loss: {val_loss:.4f} | Validation Acc: {val_acc:.2f}%")

    scheduler.step()

    if val_acc > best_acc:
        best_acc = val_acc
        counter = 0
        torch.save(model.state_dict(), "best_waste_classifier_resnet18.pth")
        print("✅ Model improved. Saved!")
    else:
        counter += 1
        if counter >= patience:
            print("⏹️ Early stopping triggered.")
            break


Epoch 1/25


                                                                                 

Train Loss: 174.3715 | Train Acc: 66.29%




Validation Loss: 36.4176 | Validation Acc: 71.99%
✅ Model improved. Saved!

Epoch 2/25


                                                                                 

Train Loss: 135.0763 | Train Acc: 74.74%




Validation Loss: 29.3635 | Validation Acc: 78.05%
✅ Model improved. Saved!

Epoch 3/25


                                                                                 

Train Loss: 117.9824 | Train Acc: 78.77%




Validation Loss: 27.5105 | Validation Acc: 79.70%
✅ Model improved. Saved!

Epoch 4/25


                                                                                 

Train Loss: 104.4071 | Train Acc: 80.42%




Validation Loss: 28.2822 | Validation Acc: 78.73%

Epoch 5/25


                                                                                 

Train Loss: 88.5298 | Train Acc: 83.90%




Validation Loss: 27.7754 | Validation Acc: 80.52%
✅ Model improved. Saved!

Epoch 6/25


                                                                                 

Train Loss: 83.3309 | Train Acc: 84.81%




Validation Loss: 26.8437 | Validation Acc: 81.00%
✅ Model improved. Saved!

Epoch 7/25


                                                                                  

Train Loss: 77.3984 | Train Acc: 86.14%




Validation Loss: 27.1012 | Validation Acc: 81.28%
✅ Model improved. Saved!

Epoch 8/25


                                                                                  

Train Loss: 53.6370 | Train Acc: 90.29%




Validation Loss: 23.9631 | Validation Acc: 82.52%
✅ Model improved. Saved!

Epoch 9/25


                                                                                  

Train Loss: 44.8445 | Train Acc: 91.43%




Validation Loss: 23.0551 | Validation Acc: 84.38%
✅ Model improved. Saved!

Epoch 10/25


                                                                                  

Train Loss: 43.0721 | Train Acc: 92.18%




Validation Loss: 23.6834 | Validation Acc: 83.69%

Epoch 11/25


                                                                                  

Train Loss: 38.6831 | Train Acc: 92.85%




Validation Loss: 23.7382 | Validation Acc: 83.34%

Epoch 12/25


                                                                                  

Train Loss: 35.9906 | Train Acc: 93.90%




Validation Loss: 24.1947 | Validation Acc: 83.07%

Epoch 13/25


                                                                                  

Train Loss: 33.4306 | Train Acc: 94.23%




Validation Loss: 25.2038 | Validation Acc: 83.69%

Epoch 14/25


                                                                                  

Train Loss: 32.1018 | Train Acc: 94.20%




Validation Loss: 24.8892 | Validation Acc: 84.17%
⏹️ Early stopping triggered.


In [10]:
print(f"\nBest Validation Accuracy: {best_acc:.2f}%")
print("Class indices:", dict(zip(class_names, range(num_classes))))


Best Validation Accuracy: 84.38%
Class indices: {'biodegradable': 0, 'cardboard': 1, 'glass': 2, 'metal': 3, 'paper': 4, 'plastic': 5}
