In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms, models
from PIL import ImageFile
import os

ImageFile.LOAD_TRUNCATED_IMAGES = True

data_dir = r"E:\Project\images\dataset-original"

In [5]:
train_transforms = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

val_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

In [None]:
from torchvision import datasets

class CustomImageFolder(datasets.ImageFolder):
    def __getitem__(self, index):
        try:
            return super(CustomImageFolder, self).__getitem__(index)
        except (OSError, IOError) as e:
            print(f"Skipping corrupted image at index {index}: {e}")
            return self.__getitem__((index + 1) % len(self))


In [None]:
full_dataset = CustomImageFolder(root=data_dir, transform=train_transforms)

class_names = full_dataset.classes
print("클래스:", class_names)
num_classes = len(class_names)

train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [train_size, val_size])

val_dataset.dataset.transform = val_transforms

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

클래스: ['cardboard', 'glass', 'metal', 'paper', 'plastic', 'trash']


In [None]:
class GarbageClassifier(nn.Module):
    def __init__(self, num_classes):
        super(GarbageClassifier, self).__init__()
        self.resnet = models.resnet50(pretrained=True)
        num_ftrs = self.resnet.fc.in_features
        self.resnet.fc = nn.Linear(num_ftrs, num_classes)

    def forward(self, x):
        return self.resnet(x)

model = GarbageClassifier(num_classes)



In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:
from tqdm.notebook import tqdm 

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def train_model(model, criterion, optimizer, num_epochs=50):
    model.to(device)
    
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        print(f"Epoch {epoch+1}/{num_epochs}")
        
        for inputs, labels in tqdm(train_loader, desc="Training", leave=False):
            inputs, labels = inputs.to(device), labels.to(device)
            try:
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
                running_loss += loss.item() * inputs.size(0)
            except Exception as e:
                print(f"Error processing batch, skipping: {e}")
                continue
        
        epoch_loss = running_loss / len(train_dataset)
        print(f"Training Loss: {epoch_loss:.4f}")
        
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, labels in tqdm(val_loader, desc="Validating", leave=False):
                inputs, labels = inputs.to(device), labels.to(device)
                try:
                    outputs = model(inputs)
                    _, predicted = torch.max(outputs.data, 1)
                    total += labels.size(0)
                    correct += (predicted == labels).sum().item()
                except Exception as e:
                    print(f"Error processing validation batch, skipping: {e}")
                    continue
        
        val_accuracy = 100 * correct / total
        print(f'Validation Accuracy: {val_accuracy:.2f}%\n')


In [16]:
train_model(model, criterion, optimizer)

Epoch 1/50


Training:   0%|          | 0/222 [00:00<?, ?it/s]

Training Loss: 0.7107


Validating:   0%|          | 0/56 [00:00<?, ?it/s]

Validation Accuracy: 75.30%

Epoch 2/50


Training:   0%|          | 0/222 [00:00<?, ?it/s]

Training Loss: 0.5519


Validating:   0%|          | 0/56 [00:00<?, ?it/s]

Validation Accuracy: 79.35%

Epoch 3/50


Training:   0%|          | 0/222 [00:00<?, ?it/s]

Training Loss: 0.4382


Validating:   0%|          | 0/56 [00:00<?, ?it/s]

Validation Accuracy: 82.95%

Epoch 4/50


Training:   0%|          | 0/222 [00:00<?, ?it/s]

Training Loss: 0.3839


Validating:   0%|          | 0/56 [00:00<?, ?it/s]

Validation Accuracy: 84.19%

Epoch 5/50


Training:   0%|          | 0/222 [00:00<?, ?it/s]

Training Loss: 0.3144


Validating:   0%|          | 0/56 [00:00<?, ?it/s]

Validation Accuracy: 85.09%

Epoch 6/50


Training:   0%|          | 0/222 [00:00<?, ?it/s]

Training Loss: 0.2829


Validating:   0%|          | 0/56 [00:00<?, ?it/s]

Validation Accuracy: 85.59%

Epoch 7/50


Training:   0%|          | 0/222 [00:00<?, ?it/s]

Training Loss: 0.2448


Validating:   0%|          | 0/56 [00:00<?, ?it/s]

Validation Accuracy: 85.76%

Epoch 8/50


Training:   0%|          | 0/222 [00:00<?, ?it/s]

Training Loss: 0.1955


Validating:   0%|          | 0/56 [00:00<?, ?it/s]

Validation Accuracy: 85.99%

Epoch 9/50


Training:   0%|          | 0/222 [00:00<?, ?it/s]

Training Loss: 0.1804


Validating:   0%|          | 0/56 [00:00<?, ?it/s]

Validation Accuracy: 86.89%

Epoch 10/50


Training:   0%|          | 0/222 [00:00<?, ?it/s]

Training Loss: 0.1556


Validating:   0%|          | 0/56 [00:00<?, ?it/s]

Validation Accuracy: 84.07%

Epoch 11/50


Training:   0%|          | 0/222 [00:00<?, ?it/s]

Training Loss: 0.1429


Validating:   0%|          | 0/56 [00:00<?, ?it/s]

Validation Accuracy: 90.21%

Epoch 12/50


Training:   0%|          | 0/222 [00:00<?, ?it/s]

Training Loss: 0.1254


Validating:   0%|          | 0/56 [00:00<?, ?it/s]

Validation Accuracy: 83.23%

Epoch 13/50


Training:   0%|          | 0/222 [00:00<?, ?it/s]

Training Loss: 0.1150


Validating:   0%|          | 0/56 [00:00<?, ?it/s]

Validation Accuracy: 88.24%

Epoch 14/50


Training:   0%|          | 0/222 [00:00<?, ?it/s]

Training Loss: 0.1143


Validating:   0%|          | 0/56 [00:00<?, ?it/s]

Validation Accuracy: 85.59%

Epoch 15/50


Training:   0%|          | 0/222 [00:00<?, ?it/s]

Training Loss: 0.0834


Validating:   0%|          | 0/56 [00:00<?, ?it/s]

Validation Accuracy: 87.45%

Epoch 16/50


Training:   0%|          | 0/222 [00:00<?, ?it/s]

Training Loss: 0.1004


Validating:   0%|          | 0/56 [00:00<?, ?it/s]

Validation Accuracy: 89.98%

Epoch 17/50


Training:   0%|          | 0/222 [00:00<?, ?it/s]

Training Loss: 0.0767


Validating:   0%|          | 0/56 [00:00<?, ?it/s]

Validation Accuracy: 84.30%

Epoch 18/50


Training:   0%|          | 0/222 [00:00<?, ?it/s]

Training Loss: 0.0794


Validating:   0%|          | 0/56 [00:00<?, ?it/s]

Validation Accuracy: 88.69%

Epoch 19/50


Training:   0%|          | 0/222 [00:00<?, ?it/s]

Training Loss: 0.0856


Validating:   0%|          | 0/56 [00:00<?, ?it/s]

Validation Accuracy: 90.43%

Epoch 20/50


Training:   0%|          | 0/222 [00:00<?, ?it/s]

Training Loss: 0.0511


Validating:   0%|          | 0/56 [00:00<?, ?it/s]

Validation Accuracy: 88.58%

Epoch 21/50


Training:   0%|          | 0/222 [00:00<?, ?it/s]

Training Loss: 0.0661


Validating:   0%|          | 0/56 [00:00<?, ?it/s]

Validation Accuracy: 84.69%

Epoch 22/50


Training:   0%|          | 0/222 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [17]:
torch.save(model.state_dict(), 'Model.pth')