##### Import and Setup

In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms, datasets
from PIL import Image
from torch.utils.data import DataLoader

##### Common Data Setup

In [4]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

train_data = datasets.ImageFolder(r'C:\Users\chang\OneDrive\文件\ELLIE\Garbage Classification\dataset\train', transform=transform)
val_data = datasets.ImageFolder(r'C:\Users\chang\OneDrive\文件\ELLIE\Garbage Classification\dataset\val', transform=transform)

train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
val_loader = DataLoader(val_data, batch_size=32, shuffle=False)

num_classes = len(train_data.classes)  

##### Training ResNet18

In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

resnet = models.resnet18(pretrained=True)

# Freeze all layers
for param in resnet.parameters():
    param.requires_grad = False

# Replace final FC layer
resnet.fc = nn.Linear(resnet.fc.in_features, num_classes)
resnet = resnet.to(device)

# Train only the final layer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(resnet.fc.parameters(), lr=0.001)

# Training loop
for epoch in range(5):  # train more if needed
    resnet.train()
    total_loss = 0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        outputs = resnet(images)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    
    print(f"Epoch {epoch+1}, Loss: {total_loss / len(train_loader):.4f}")

Epoch 1, Loss: 1.3299
Epoch 2, Loss: 0.8340
Epoch 3, Loss: 0.7129
Epoch 4, Loss: 0.6144
Epoch 5, Loss: 0.5758


##### Save trained resnet

In [7]:
torch.save(resnet.state_dict(), 'resnet_model.pth')

##### Train MobileNetV2

In [8]:
mobilenet = models.mobilenet_v2(pretrained=True)

for param in mobilenet.parameters():
    param.requires_grad = False

mobilenet.classifier[1] = nn.Linear(mobilenet.last_channel, num_classes)
mobilenet = mobilenet.to(device)

optimizer = optim.Adam(mobilenet.classifier[1].parameters(), lr=0.001)

# Training loop 
for epoch in range(5):  # train more if needed
    mobilenet.train()
    total_loss = 0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        outputs = mobilenet(images)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    
    print(f"Epoch {epoch+1}, Loss: {total_loss / len(train_loader):.4f}")

Downloading: "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth" to C:\Users\chang/.cache\torch\hub\checkpoints\mobilenet_v2-b0353104.pth
100%|██████████| 13.6M/13.6M [00:10<00:00, 1.37MB/s]


Epoch 1, Loss: 1.0844
Epoch 2, Loss: 0.6622
Epoch 3, Loss: 0.5794
Epoch 4, Loss: 0.5253
Epoch 5, Loss: 0.5039


##### Save trained mobilenet

In [9]:
torch.save(mobilenet.state_dict(), 'mobilenet_model.pth')

##### Load Models

In [10]:
# Load ResNet18
resnet = models.resnet18(pretrained=True)
for param in resnet.parameters():
    param.requires_grad = False
resnet.fc = nn.Linear(resnet.fc.in_features, num_classes)
resnet.load_state_dict(torch.load("resnet_model.pth"))
resnet.eval()

# Load MobileNetV2
mobilenet = models.mobilenet_v2(pretrained=True)
for param in mobilenet.parameters():
    param.requires_grad = False
mobilenet.classifier[1] = nn.Linear(mobilenet.last_channel, num_classes)
mobilenet.load_state_dict(torch.load("mobilenet_model.pth"))
mobilenet.eval()


MobileNetV2(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU6(inplace=True)
    )
    (1): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (2): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(96, eps=

##### Define Transform and Class Names

In [11]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # match input size for ResNet and MobileNet
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

class_names = ['cardboard', 'glass', 'metal', 'organic', 'paper', 'plastic']

##### Define Prediction Function

In [12]:
def ensemble_predict(image_path):
    image = Image.open(image_path).convert('RGB')
    img_tensor = transform(image).unsqueeze(0)

    with torch.no_grad():
        out1 = resnet(img_tensor)
        out2 = mobilenet(img_tensor)

        # Average the outputs
        avg_output = (out1 + out2) / 2
        _, pred = torch.max(avg_output, 1)

    return class_names[pred.item()]


##### Use the Model

In [13]:
result = ensemble_predict("example.jpg")
print("Predicted class:", result)

Predicted class: cardboard


##### Evaluate the model

In [14]:
resnet.eval()
mobilenet.eval()    

correct = 0
total = 0

with torch.no_grad():
    for images, labels in val_loader:
        images, labels = images.to(device), labels.to(device)

        # Get outputs from both models
        output_resnet = resnet(images)
        output_mobilenet = mobilenet(images)

        # Average the outputs (soft voting)
        avg_output = (output_resnet + output_mobilenet) / 2

        # Get predicted class
        _, predicted = torch.max(avg_output, 1)

        # Update counters
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = 100 * correct / total
print(f'Ensemble Validation Accuracy: {accuracy:.2f}%')


Ensemble Validation Accuracy: 82.68%
