In [1]:
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import torch
import torchvision
from torchvision import transforms
from torch import optim, nn
from torch.optim import lr_scheduler
from torch.utils.data import Dataset
from PIL import Image
import os

train_path = 'garbage/garbage classification/Garbage classification'
valid_path = 'garbage/garbage classification/Garbage classification'

class MyDataset(Dataset):
    def __init__(self, txt_path, img_dir, transform=None):
        with open(txt_path, 'r') as file:
            lines = file.readlines()

        self.class_names = ['glass', 'paper', 'cardboard', 'plastic', 'metal', 'trash']
        self.img_list = [os.path.join(img_dir, ''.join(filter(str.isalpha, line.split('.')[0])), line.split()[0]) for line in lines]
        self.label_list = [self.class_names[int(line.split()[1]) - 1] for line in lines]  # convert indices to class names
        self.transform = transform

    def __len__(self):
        return len(self.img_list)

    def __getitem__(self, idx):
        img_path = self.img_list[idx]
        label = self.label_list[idx]
        image = Image.open(img_path).convert('RGB')
        
        if self.transform is not None:
            image = self.transform(image)
        
        return image, label

def class_to_index(class_name):
    class_names = ['glass', 'paper', 'cardboard', 'plastic', 'metal', 'trash']
    return class_names.index(class_name)

train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(20),
    transforms.ToTensor(),
    transforms.RandomApply([
        transforms.RandomRotation(20),
        transforms.RandomResizedCrop(224),
        transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
    ], p=0.5),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

valid_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

train_dataset = MyDataset('garbage/one-indexed-files-notrash_train.txt', 'garbage/garbage classification/Garbage classification', transform=train_transform)
valid_dataset = MyDataset('garbage/one-indexed-files-notrash_val.txt', 'garbage/garbage classification/Garbage classification', transform=valid_transform)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=16, shuffle=False)

model = torchvision.models.resnet50(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 6)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model)
  
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

num_epochs = 50

for epoch in range(num_epochs):
    model.train()
    for inputs, labels in train_loader:
        inputs = inputs.to(device)
        labels = torch.tensor([class_to_index(label) for label in labels]).to(device)
        
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in valid_loader:
            inputs = inputs.to(device)
            labels = torch.tensor([class_to_index(label) for label in labels]).to(device)
            
            # Forward pass
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print('Accuracy of the network on the validation images: %d %%' % (
        100 * correct / total))


torch.save(model.state_dict(), 'model.pth')



Let's use 2 GPUs!




Accuracy of the network on the validation images: 77 %
Accuracy of the network on the validation images: 84 %
Accuracy of the network on the validation images: 86 %
Accuracy of the network on the validation images: 88 %
Accuracy of the network on the validation images: 89 %
Accuracy of the network on the validation images: 89 %
Accuracy of the network on the validation images: 89 %
Accuracy of the network on the validation images: 91 %
Accuracy of the network on the validation images: 91 %
Accuracy of the network on the validation images: 91 %
Accuracy of the network on the validation images: 92 %
Accuracy of the network on the validation images: 93 %
Accuracy of the network on the validation images: 92 %
Accuracy of the network on the validation images: 93 %
Accuracy of the network on the validation images: 92 %
Accuracy of the network on the validation images: 93 %
Accuracy of the network on the validation images: 92 %
Accuracy of the network on the validation images: 90 %
Accuracy o