In [None]:
import os

IMAGE_DIR = "TrashBox/TrashBox_train_set"

if os.path.exists(IMAGE_DIR):
    images = os.listdir(IMAGE_DIR)
    print("Total images found:", len(images))
    print("First 5 images:", images[:5])  # Show first 5 images for verification
else:
    print("Dataset directory not found! Make sure TrashBox is cloned properly.")


Total images found: 7
First 5 images: ['paper', 'glass', 'e-waste', 'plastic', 'metal']


In [None]:
for category in os.listdir(IMAGE_DIR):
    category_path = os.path.join(IMAGE_DIR, category)
    if os.path.isdir(category_path):
        images = os.listdir(category_path)
        print(f"Category: {category}, Total Images: {len(images)}, First 3: {images[:3]}")


Category: paper, Total Images: 2156, First 3: ['paper 576.jpg', 'paper 474.jpg', 'paper 1049.jpg']
Category: glass, Total Images: 2022, First 3: ['glass 881.jpg', 'glass 2259.jpg', 'glass 90.jpg']
Category: e-waste, Total Images: 2406, First 3: ['e-waste 44.jpg', 'e-waste 1350.jpg', 'e-waste 1867.jpg']
Category: plastic, Total Images: 2135, First 3: ['plastic 2404.jpg', 'plastic 652.jpg', 'plastic 1961.jpg']
Category: metal, Total Images: 2068, First 3: ['metal 395.jpg', 'metal 1304.jpg', 'metal 1016.jpg']
Category: medical, Total Images: 1565, First 3: ['medical 854.jpg', 'medical 1787.jpg', 'medical 1126.jpg']
Category: cardboard, Total Images: 1930, First 3: ['cardboard 1845.jpg', 'cardboard 1997.jpg', 'cardboard 1223.jpg']


In [None]:
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
import glob

# Define the dataset class
class TrashDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.classes = sorted(os.listdir(root_dir))  # Category names as labels
        self.image_paths = []
        self.labels = []

        # Collect image paths and labels
        for label, category in enumerate(self.classes):
            category_path = os.path.join(root_dir, category)
            if os.path.isdir(category_path):
                for img_file in glob.glob(os.path.join(category_path, "*.jpg")):
                    self.image_paths.append(img_file)
                    self.labels.append(label)

        print(f"Total images loaded: {len(self.image_paths)} across {len(self.classes)} categories.")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label = self.labels[idx]
        image = Image.open(img_path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        return image, label

# Define transformations
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])

# Load dataset
dataset = TrashDataset(IMAGE_DIR, transform=transform)


Total images loaded: 14282 across 7 categories.


In [None]:
import torchvision.transforms as transforms

class TrashDataset(Dataset):
    def __init__(self, image_dir, transform=None):
        self.image_dir = image_dir
        self.categories = sorted(os.listdir(image_dir))
        self.image_paths = []
        self.mask_paths = []
        self.transform = transform

        # Load image paths
        for category in self.categories:
            category_path = os.path.join(image_dir, category)
            for img_name in os.listdir(category_path):
                img_path = os.path.join(category_path, img_name)
                self.image_paths.append(img_path)
                self.mask_paths.append(category)  # Assuming category is the class label

        self.class_to_index = {cls: idx for idx, cls in enumerate(self.categories)}

        # Transformations
        self.image_transform = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor(),
        ])

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        category = self.mask_paths[idx]
        label = self.class_to_index[category]

        # Load image
        image = Image.open(img_path).convert("RGB")
        if self.image_transform:
            image = self.image_transform(image)

        # Create mask tensor (B, H, W)
        mask = torch.tensor(label, dtype=torch.long)  # Ensure correct dtype

        return image, mask


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=7):
        super(UNet, self).__init__()

        # Encoder
        self.conv1 = self.double_conv(in_channels, 64)
        self.conv2 = self.double_conv(64, 128)
        self.conv3 = self.double_conv(128, 256)
        self.conv4 = self.double_conv(256, 512)

        # Bottleneck
        self.bottleneck = self.double_conv(512, 1024)

        # Decoder
        self.upconv4 = self.upconv(1024, 512)
        self.dec4 = self.double_conv(1024, 512)
        self.upconv3 = self.upconv(512, 256)
        self.dec3 = self.double_conv(512, 256)
        self.upconv2 = self.upconv(256, 128)
        self.dec2 = self.double_conv(256, 128)
        self.upconv1 = self.upconv(128, 64)
        self.dec1 = self.double_conv(128, 64)

        # Output layer
        self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1)

    def double_conv(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

    def upconv(self, in_channels, out_channels):
        return nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)

    def forward(self, x):
        # Encoder
        x1 = self.conv1(x)
        x2 = self.conv2(F.max_pool2d(x1, 2))
        x3 = self.conv3(F.max_pool2d(x2, 2))
        x4 = self.conv4(F.max_pool2d(x3, 2))

        # Bottleneck
        x_b = self.bottleneck(F.max_pool2d(x4, 2))

        # Decoder
        x = self.upconv4(x_b)
        x = torch.cat([x, x4], dim=1)
        x = self.dec4(x)

        x = self.upconv3(x)
        x = torch.cat([x, x3], dim=1)
        x = self.dec3(x)

        x = self.upconv2(x)
        x = torch.cat([x, x2], dim=1)
        x = self.dec2(x)

        x = self.upconv1(x)
        x = torch.cat([x, x1], dim=1)
        x = self.dec1(x)

        # Output
        x = self.final_conv(x)
        return x

# Initialize model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet().to(device)
print("U-Net model initialized successfully!")


U-Net model initialized successfully!


In [None]:
import torch.optim as optim

# Define loss function (CrossEntropy for multi-class segmentation)
criterion = nn.CrossEntropyLoss()

# Define optimizer (Adam)
optimizer = optim.Adam(model.parameters(), lr=0.001)

print("Loss function and optimizer set!")


Loss function and optimizer set!


In [None]:
import time

# Training parameters
num_epochs = 10
batch_size = 8

# Move model to device
model.to(device)

# Training loop
print("Starting training...")

start_time = time.time()  # Track training time

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    for batch_idx, (images, masks) in enumerate(dataloader):
        images = images.to(device)  # Shape: (B, C, H, W)
        masks = masks.to(device).long()  # Ensure correct shape: (B, H, W)

        optimizer.zero_grad()
        outputs = model(images)  # Shape: (B, C, H, W)

        # Ensure masks have correct dimensions
        loss = criterion(outputs, masks)  # CrossEntropyLoss expects (B, C, H, W) & (B, H, W)

        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        if batch_idx % 10 == 0:  # Print every 10 batches
            print(f"Epoch [{epoch+1}/{num_epochs}], Step [{batch_idx+1}/{len(dataloader)}], Loss: {loss.item():.4f}")

    avg_loss = running_loss / len(dataloader)
    print(f"Epoch [{epoch+1}/{num_epochs}] Completed. Avg Loss: {avg_loss:.4f}")

end_time = time.time()  # End time
elapsed_time = (end_time - start_time) / 60  # Convert to minutes
print(f"Training completed in {elapsed_time:.2f} minutes.")


Starting training...


RuntimeError: only batches of spatial targets supported (3D tensors) but got targets of size: : [8]

In [None]:
print(masks.shape, masks.dtype)


torch.Size([8]) torch.int64


You cannot use U-Net because:

U-Net is for segmentation, but your dataset has category labels, not segmentation masks

U-Net predicts pixel-wise masks (e.g., where trash is in an image).
Your dataset only has category labels (e.g., "paper," "plastic"), not pixel-by-pixel masks.
CrossEntropyLoss in PyTorch expects (B, C, H, W) & (B, H, W) but your labels are just class IDs (B,).
Your labels are for classification, not segmentation

If you had segmentation masks (where each pixel is labeled as a specific category), U-Net would work.
Since your labels are single category per image, you need a classification model (ResNet, VGG, etc.).

Best Approach: Switch to a CNN-based classification model (ResNet, EfficientNet, etc.).