In [1]:
%%capture
!pip install torchinfo

In [4]:
# Data handling
import pandas as pd
import numpy as np

# Data visualization
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
import cv2

# Preprocessing
from sklearn.model_selection import train_test_split
from collections import Counter

# Torch
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
from torchinfo import summary
from torchvision.models import efficientnet_b7, EfficientNet_B7_Weights

# Metrics
from sklearn.metrics import accuracy_score, confusion_matrix

# os
import os

# Path
from pathlib import Path

# tqdm
from tqdm.auto import tqdm

# random
import random

# typing
from typing import Dict,List

# warnings
import warnings
warnings.filterwarnings("ignore")

In [17]:
if torch.backends.mps.is_available():
    mps_device = torch.device("mps")
    print('success')
else:
    print ("MPS device not found.")

success


In [16]:
# Total Images
IMAGE_PATH = Path("/Users/maniksinghsarmaal/Downloads/TrashType_Image_Dataset")

IMAGE_PATH_LIST = list(IMAGE_PATH.glob("*/*.jpg"))

print(f"Total Images = {len(IMAGE_PATH_LIST)}")

Total Images = 2527


In [18]:
# Total Images per class
classes = os.listdir(IMAGE_PATH)
classes = sorted(classes)

print("**" * 30)
print(" " * 17, "Total Images per class")
print("**" * 30)
for c in classes:
    imgs_class = len(list(Path(os.path.join(IMAGE_PATH, c)).glob("*.jpg")))
    print(f"* {c} => {imgs_class} images")

************************************************************
                  Total Images per class
************************************************************
* .DS_Store => 0 images
* cardboard => 403 images
* glass => 501 images
* metal => 410 images
* paper => 594 images
* plastic => 482 images
* trash => 137 images


In [19]:
path = [None] * len(IMAGE_PATH_LIST)
label = [None] * len(IMAGE_PATH_LIST)

for i,img_path in enumerate(IMAGE_PATH_LIST):
    path[i] = img_path
    label[i] = img_path.parent.stem
    
df_path_label = pd.DataFrame({"path":path, 
                              "label":label})

df_path_label.head()

Unnamed: 0,path,label
0,/Users/maniksinghsarmaal/Downloads/TrashType_I...,paper
1,/Users/maniksinghsarmaal/Downloads/TrashType_I...,paper
2,/Users/maniksinghsarmaal/Downloads/TrashType_I...,paper
3,/Users/maniksinghsarmaal/Downloads/TrashType_I...,paper
4,/Users/maniksinghsarmaal/Downloads/TrashType_I...,paper


In [20]:
# We define the random seed for reproducibility.
SEED = 123

df_train, df_rest = train_test_split(df_path_label, 
                                     test_size = 0.3, 
                                     random_state = SEED, 
                                     shuffle = True, 
                                     stratify = df_path_label["label"])

df_valid, df_test = train_test_split(df_rest, 
                                     test_size = 0.5, 
                                     random_state = SEED, 
                                     shuffle = True, 
                                     stratify = df_rest["label"])

In [21]:
Counter(df_train["label"])

Counter({'paper': 416,
         'glass': 350,
         'plastic': 337,
         'metal': 287,
         'cardboard': 282,
         'trash': 96})

In [22]:
Counter(df_valid["label"])

Counter({'paper': 89,
         'glass': 75,
         'plastic': 72,
         'metal': 61,
         'cardboard': 61,
         'trash': 21})

In [23]:
Counter(df_test["label"])

Counter({'paper': 89,
         'glass': 76,
         'plastic': 73,
         'metal': 62,
         'cardboard': 60,
         'trash': 20})

In [24]:
label_map = dict(zip(classes, range(0, len(classes))))
label_map

{'.DS_Store': 0,
 'cardboard': 1,
 'glass': 2,
 'metal': 3,
 'paper': 4,
 'plastic': 5,
 'trash': 6}

In [25]:
df_train["label"] = df_train["label"].map(label_map)
df_valid["label"] = df_valid["label"].map(label_map)

In [26]:
df_train = df_train.reset_index(drop = True)
df_valid = df_valid.reset_index(drop = True)

In [27]:
weights = EfficientNet_B7_Weights.DEFAULT
auto_transforms = weights.transforms()
auto_transforms

ImageClassification(
    crop_size=[600]
    resize_size=[600]
    mean=[0.485, 0.456, 0.406]
    std=[0.229, 0.224, 0.225]
    interpolation=InterpolationMode.BICUBIC
)

In [28]:
class CustomDataset:
    def __init__(self, df:pd.DataFrame, transforms):
        self.df = df
        self.transforms = transforms
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        image_path = self.df.iloc[idx, 0]
        image = Image.open(image_path).convert("RGB")
        image = self.transforms(image)
        label = self.df.iloc[idx, 1]
        
        return image,label

In [29]:
train_dataset = CustomDataset(df_train, auto_transforms)
valid_dataset = CustomDataset(df_valid, auto_transforms)

In [30]:
BATCH_SIZE = 32
NUM_WORKERS = os.cpu_count()

train_dataloader = DataLoader(dataset = train_dataset, 
                              batch_size = BATCH_SIZE, 
                              shuffle = True, 
                              num_workers = NUM_WORKERS)

valid_dataloader = DataLoader(dataset = valid_dataset, 
                              batch_size = BATCH_SIZE, 
                              shuffle = True, 
                              num_workers = NUM_WORKERS)

In [32]:
# Let's visualize a batch.
batch_images, batch_labels = next(iter(train_dataloader))

batch_images.shape, batch_labels.shape

Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/Users/maniksinghsarmaal/mambaforge3/envs/sbin/lib/python3.10/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/Users/maniksinghsarmaal/mambaforge3/envs/sbin/lib/python3.10/multiprocessing/spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)
AttributeError: Can't get attribute 'CustomDataset' on <module '__main__' (built-in)>
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/Users/maniksinghsarmaal/mambaforge3/envs/sbin/lib/python3.10/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/Users/maniksinghsarmaal/mambaforge3/envs/sbin/lib/python3.10/multiprocessing/spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)
AttributeError: Can't get attribute 'CustomDataset' on <module '__main__' (built-in)>
Traceback (most recent call la

RuntimeError: DataLoader worker (pid(s) 54765) exited unexpectedly

Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/Users/maniksinghsarmaal/mambaforge3/envs/sbin/lib/python3.10/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/Users/maniksinghsarmaal/mambaforge3/envs/sbin/lib/python3.10/multiprocessing/spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)
AttributeError: Can't get attribute 'CustomDataset' on <module '__main__' (built-in)>
