In [None]:
import os
import cv2
import numpy as np
import torch
from torchvision import transforms
from PIL import Image

In [None]:
def load_model():
    model = torch.hub.load('pytorch/vision:v0.10.0', 'deeplabv3_resnet50', pretrained=True)
    model.eval()
    return model

In [None]:
def make_transparent_foreground(pic, mask):
    b, g, r = cv2.split(np.array(pic).astype('uint8'))
    a = np.ones(mask.shape, dtype='uint8') * 255
    alpha_im = cv2.merge([b, g, r, a], 4)
    bg = np.zeros(alpha_im.shape)
    new_mask = np.stack([mask, mask, mask, mask], axis=2)
    foreground = np.where(new_mask, alpha_im, bg).astype(np.uint8)
    return foreground

In [None]:
def remove_background(model, input_file):
    input_image = Image.open(input_file)
    preprocess = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    input_tensor = preprocess(input_image)
    input_batch = input_tensor.unsqueeze(0)

    if torch.cuda.is_available():
        input_batch = input_batch.to('cuda')
        model.to('cuda')

    with torch.no_grad():
        output = model(input_batch)['out'][0]
    output_predictions = output.argmax(0)
    mask = output_predictions.byte().cpu().numpy()
    background = np.zeros(mask.shape)
    bin_mask = np.where(mask, 255, background).astype(np.uint8)
    foreground = make_transparent_foreground(input_image, bin_mask)
    return foreground

In [None]:
def process_images(input_folder, output_folder, model):
    if not os.path.exists(output_folder):
        os.makedirs(output_folder)

    for filename in os.listdir(input_folder):
        if filename.endswith(('.png', '.jpg', '.jpeg')):
            input_path = os.path.join(input_folder, filename)
            output_path = os.path.join(output_folder, filename)

            foreground = remove_background(model, input_path)

            cv2.imwrite(output_path, cv2.cvtColor(foreground, cv2.COLOR_RGBA2BGRA))
            print(f'Processed: {filename}')

In [None]:
!wget https://github.com/garythung/trashnet/raw/master/data/dataset-resized.zip -O dataset-resized.zip

--2025-01-22 08:01:20--  https://github.com/garythung/trashnet/raw/master/data/dataset-resized.zip
Resolving github.com (github.com)... 140.82.116.4
Connecting to github.com (github.com)|140.82.116.4|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/garythung/trashnet/master/data/dataset-resized.zip [following]
--2025-01-22 08:01:21--  https://raw.githubusercontent.com/garythung/trashnet/master/data/dataset-resized.zip
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 42834870 (41M) [application/zip]
Saving to: ‘dataset-resized.zip’


2025-01-22 08:01:22 (274 MB/s) - ‘dataset-resized.zip’ saved [42834870/42834870]



In [None]:
!unzip dataset-resized.zip -d dataset

Archive:  dataset-resized.zip
   creating: dataset/dataset-resized/
  inflating: dataset/dataset-resized/.DS_Store  
   creating: dataset/__MACOSX/
   creating: dataset/__MACOSX/dataset-resized/
  inflating: dataset/__MACOSX/dataset-resized/._.DS_Store  
   creating: dataset/dataset-resized/cardboard/
  inflating: dataset/dataset-resized/cardboard/cardboard1.jpg  
  inflating: dataset/dataset-resized/cardboard/cardboard10.jpg  
  inflating: dataset/dataset-resized/cardboard/cardboard100.jpg  
  inflating: dataset/dataset-resized/cardboard/cardboard101.jpg  
  inflating: dataset/dataset-resized/cardboard/cardboard102.jpg  
  inflating: dataset/dataset-resized/cardboard/cardboard103.jpg  
  inflating: dataset/dataset-resized/cardboard/cardboard104.jpg  
  inflating: dataset/dataset-resized/cardboard/cardboard105.jpg  
  inflating: dataset/dataset-resized/cardboard/cardboard106.jpg  
  inflating: dataset/dataset-resized/cardboard/cardboard107.jpg  
  inflating: dataset/dataset-resized/car

In [None]:
# just to make new folders named train, test data
import os
import shutil
from sklearn.model_selection import train_test_split

# Original dataset path
dataset_dir = "dataset/dataset-resized"
# Directories for train and test sets
train_dir = "dataset/train"
test_dir = "dataset/test"

# Create train and test directories if they don't exist
os.makedirs(train_dir, exist_ok=True)
os.makedirs(test_dir, exist_ok=True)

# Define split ratio
test_size = 0.2  # 20% for testing, adjust as needed

# Iterate through each category folder
for category in os.listdir(dataset_dir):
    category_path = os.path.join(dataset_dir, category)
    if os.path.isdir(category_path):  # Check if it's a folder
        # List all files in the category folder
        files = os.listdir(category_path)
        files = [f for f in files if os.path.isfile(os.path.join(category_path, f))]

        # Split the files into train and test sets
        train_files, test_files = train_test_split(files, test_size=test_size, random_state=42)

        # Create subdirectories in train and test folders
        train_category_dir = os.path.join(train_dir, category)
        test_category_dir = os.path.join(test_dir, category)
        os.makedirs(train_category_dir, exist_ok=True)
        os.makedirs(test_category_dir, exist_ok=True)

        # Move files to train folder
        for file in train_files:
            shutil.copy(os.path.join(category_path, file), os.path.join(train_category_dir, file))

        # Move files to test folder
        for file in test_files:
            shutil.copy(os.path.join(category_path, file), os.path.join(test_category_dir, file))

print(f"Dataset split completed. Train data is in '{train_dir}', Test data is in '{test_dir}'.")


Dataset split completed. Train data is in 'dataset/train', Test data is in 'dataset/test'.


In [None]:
# Define input and output directories
input_folder = '/content/dataset/dataset-resized/cardboard'
output_folder = '/content/dataset-bgremoved/cardboard_bgremoved'

deeplab_model = load_model()
process_images(input_folder, output_folder, deeplab_model)

Downloading: "https://github.com/pytorch/vision/zipball/v0.10.0" to /root/.cache/torch/hub/v0.10.0.zip
Downloading: "https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth" to /root/.cache/torch/hub/checkpoints/deeplabv3_resnet50_coco-cd0a2569.pth
100%|██████████| 161M/161M [00:03<00:00, 43.8MB/s]


Processed: cardboard297.jpg
Processed: cardboard94.jpg
Processed: cardboard282.jpg
Processed: cardboard190.jpg
Processed: cardboard220.jpg
Processed: cardboard116.jpg
Processed: cardboard211.jpg
Processed: cardboard252.jpg
Processed: cardboard280.jpg
Processed: cardboard143.jpg
Processed: cardboard402.jpg
Processed: cardboard56.jpg
Processed: cardboard371.jpg
Processed: cardboard169.jpg
Processed: cardboard164.jpg
Processed: cardboard240.jpg
Processed: cardboard159.jpg
Processed: cardboard257.jpg
Processed: cardboard49.jpg
Processed: cardboard370.jpg
Processed: cardboard18.jpg
Processed: cardboard38.jpg
Processed: cardboard244.jpg
Processed: cardboard204.jpg
Processed: cardboard185.jpg
Processed: cardboard346.jpg
Processed: cardboard200.jpg
Processed: cardboard15.jpg
Processed: cardboard165.jpg
Processed: cardboard209.jpg
Processed: cardboard89.jpg
Processed: cardboard88.jpg
Processed: cardboard103.jpg
Processed: cardboard24.jpg
Processed: cardboard270.jpg
Processed: cardboard20.jpg
Pr

In [12]:
# glass
input_folder1 = '/content/dataset/dataset-resized/glass'
output_folder1 = '/content/dataset-bgremoved/glass_bgremoved'

# metal
input_folder2 = '/content/dataset/dataset-resized/metal'
output_folder2 = '/content/dataset-bgremoved/metal_bgremoved'

# paper
input_folder3 = '/content/dataset/dataset-resized/paper'
output_folder3 = '/content/dataset-bgremoved/paper_bgremoved'

# plastic
input_folder4 = '/content/dataset/dataset-resized/plastic'
output_folder4 = '/content/dataset-bgremoved/plastic_bgremoved'

# trash
input_folder5 = '/content/dataset/dataset-resized/trash'
output_folder5 = '/content/dataset-bgremoved/trash_bgremoved'

deeplab_model = load_model()
process_images(input_folder1, output_folder1, deeplab_model)
process_images(input_folder2, output_folder2, deeplab_model)
process_images(input_folder3, output_folder3, deeplab_model)
process_images(input_folder4, output_folder4, deeplab_model)
process_images(input_folder5, output_folder5, deeplab_model)



Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0


Processed: glass227.jpg
Processed: glass490.jpg
Processed: glass241.jpg
Processed: glass142.jpg
Processed: glass471.jpg
Processed: glass361.jpg
Processed: glass496.jpg
Processed: glass119.jpg
Processed: glass45.jpg
Processed: glass146.jpg
Processed: glass291.jpg
Processed: glass249.jpg
Processed: glass157.jpg
Processed: glass289.jpg
Processed: glass356.jpg
Processed: glass200.jpg
Processed: glass226.jpg
Processed: glass281.jpg
Processed: glass472.jpg
Processed: glass44.jpg
Processed: glass121.jpg
Processed: glass183.jpg
Processed: glass31.jpg
Processed: glass105.jpg
Processed: glass300.jpg
Processed: glass208.jpg
Processed: glass35.jpg
Processed: glass409.jpg
Processed: glass167.jpg
Processed: glass110.jpg
Processed: glass221.jpg
Processed: glass493.jpg
Processed: glass28.jpg
Processed: glass491.jpg
Processed: glass94.jpg
Processed: glass242.jpg
Processed: glass405.jpg
Processed: glass328.jpg
Processed: glass376.jpg
Processed: glass310.jpg
Processed: glass271.jpg
Processed: glass274.jp