# Trash Classification

This notebook is detecated to find out if a given trash (via its image) is recyclable or not. We'll try to adapt the previous model to the torch library


## Importing modules

In [32]:
import pandas as pd
import numpy as np
import os

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torch.nn as nn
import torch.nn.functional as F

import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
from tqdm import tqdm

## Loading Data

### Defining a function to retrieve data (downloaded from Kaggle)

In [None]:
class TrashDataset(Dataset):
    def __init__(self, dir_path, transform=None):
        self.images = []
        self.labels = []
        self.transform = transform
        class_map = {'R': 0, 'O': 1}

        for label_folder in ['R', 'O']:
            folder_path = os.path.join(dir_path, label_folder)
            for file in os.listdir(folder_path):
                self.images.append(os.path.join(folder_path, file))
                self.labels.append(class_map[label_folder])

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = Image.open(self.images[idx]).convert('RGB')
        if self.transform:
            image = self.transform(image)
        label = self.labels[idx]
        return image, label

In [29]:
transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
])

### CNN Model (equivalent to Sequential)

In [33]:
class TrashCNN(nn.Module):
    def __init__(self):
        super(TrashCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3)
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
        self.pool2 = nn.MaxPool2d(2)
        self.fc1 = nn.Linear(64 * 14 * 14, 128)
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(128, 1)  # Pour binaire : 1 neurone

    def forward(self, x):
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        x = x.view(-1, 64 * 14 * 14)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = torch.sigmoid(self.fc2(x))
        return x
