In [12]:
import torchvision.transforms as transforms
import os
import json
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torch
import torch.nn as nn
import torchvision.models as models
import random

In [3]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

In [4]:
category_dict = {'glass': [9, 26], 'plastic': [4, 5, 29, 42, 47, 48, 49], 'paper': [33, 35], 'styrofoam': [57]}
class_to_category_idx = {6: 0, 9: 0, 23: 0, 26: 0, 21: 3, 24: 3, 7: 3, 4: 3, 5: 3,  29: 3, 42: 3, 47: 3, 48: 3, 49: 3, 33: 1, 35: 1, 57: 2}
idx_to_category = {0: 'glass', 1: 'paper', 2: 'styrofoam', 3: 'plastic', 4: 'other'}

In [34]:
class TacoTrashDataset(Dataset):
    def __init__(self, data_dir, transform=None, isTrain=True):
        self.data_dir = data_dir
        self.transform = transform
        self.isTrain = isTrain

        with open(data_dir + '/annotations.json') as f:
            self.data_info = json.load(f)
        
        self.images_info = self.data_info['images']
        self.all_annotation_info = self.data_info['annotations']

        self.restrictAnnotations()
        self.splitTrainTest()
    
    def restrictAnnotations(self):
        self.annotation_info  = []
        for annotation in self.all_annotation_info:
            if annotation['category_id'] in [4, 5, 6, 7, 9, 26, 29, 42, 47, 48, 49, 33, 35, 57]:
                self.annotation_info.append(annotation)
    
    def splitTrainTest(self):
        random.shuffle(self.annotation_info)
        self.train_annotation_info = self.annotation_info[:int(0.8*len(self.annotation_info))]
        self.test_annotation_info = self.annotation_info[int(0.8*len(self.annotation_info)):]
        
    
    def __len__(self):
        if self.isTrain:
            return len(self.train_annotation_info)
        else:
            return len(self.test_annotation_info)
    
    def __getitem__(self, idx):
        if self.isTrain:
            annotation = self.train_annotation_info[idx]
        else:            
            annotation = self.test_annotation_info[idx]
        
        img_idx = annotation['image_id']
        img_path = os.path.join(self.data_dir, self.images_info[img_idx]['file_name'])
        if annotation['category_id'] in [9, 26, 29, 42, 47, 48, 49, 33, 35, 57]:
            label = class_to_category_idx[annotation['category_id']]
        else:
            label = 4
        
        image = Image.open(img_path).convert("RGB")
        bbox = annotation['bbox']
        image = image.crop((bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3]))
        labels = torch.tensor(label, dtype=torch.long)
        
        if self.transform:
            image = self.transform(image)

        return image, labels

In [35]:
train_dataset = TacoTrashDataset(
    data_dir='D:/Work/My Projects/LitterDetection/Litter-Voxel51-Hackathon/TACO/data',
    transform=transform,
    isTrain=True
)
test_dataset = TacoTrashDataset(
    data_dir='D:/Work/My Projects/LitterDetection/Litter-Voxel51-Hackathon/TACO/data',
    transform=transform,
    isTrain=False
)


In [36]:
train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=0)
test_dataloader = DataLoader(test_dataset, batch_size=4, shuffle=True, num_workers=0)

In [39]:
model = models.resnet50(pretrained=True)

num_custom_classes = 5
model.fc = nn.Linear(in_features=model.fc.in_features, out_features=num_custom_classes)

for param in model.parameters():
    param.requires_grad = False

for param in model.fc.parameters():
    param.requires_grad = True

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.fc.parameters(), lr=1e-4)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)



In [42]:
num_epochs = 10

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for inputs, labels in train_dataloader:
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * inputs.size(0)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

    epoch_loss = running_loss / len(train_dataloader.dataset)
    epoch_acc = 100. * correct / total
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.2f}%")
    

    model.eval()
    val_loss = 0.0
    val_correct = 0
    val_total = 0
    
    with torch.no_grad():
        for inputs, labels in test_dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            val_loss += loss.item() * inputs.size(0)
            _, predicted = outputs.max(1)
            val_total += labels.size(0)
            val_correct += predicted.eq(labels).sum().item()
    
    val_loss /= len(test_dataloader.dataset)
    val_acc = 100. * val_correct / val_total
    print(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_acc:.2f}%")
    


Epoch 1/10, Loss: 0.5847, Accuracy: 85.60%
Validation Loss: 0.5754, Validation Accuracy: 85.58%
Epoch 2/10, Loss: 0.5562, Accuracy: 85.60%
Validation Loss: 0.5408, Validation Accuracy: 85.58%


KeyboardInterrupt: 

In [47]:
train_dataset.data_info['categories']

[{'supercategory': 'Aluminium foil', 'id': 0, 'name': 'Aluminium foil'},
 {'supercategory': 'Battery', 'id': 1, 'name': 'Battery'},
 {'supercategory': 'Blister pack', 'id': 2, 'name': 'Aluminium blister pack'},
 {'supercategory': 'Blister pack', 'id': 3, 'name': 'Carded blister pack'},
 {'supercategory': 'Bottle', 'id': 4, 'name': 'Other plastic bottle'},
 {'supercategory': 'Bottle', 'id': 5, 'name': 'Clear plastic bottle'},
 {'supercategory': 'Bottle', 'id': 6, 'name': 'Glass bottle'},
 {'supercategory': 'Bottle cap', 'id': 7, 'name': 'Plastic bottle cap'},
 {'supercategory': 'Bottle cap', 'id': 8, 'name': 'Metal bottle cap'},
 {'supercategory': 'Broken glass', 'id': 9, 'name': 'Broken glass'},
 {'supercategory': 'Can', 'id': 10, 'name': 'Food Can'},
 {'supercategory': 'Can', 'id': 11, 'name': 'Aerosol'},
 {'supercategory': 'Can', 'id': 12, 'name': 'Drink can'},
 {'supercategory': 'Carton', 'id': 13, 'name': 'Toilet tube'},
 {'supercategory': 'Carton', 'id': 14, 'name': 'Other carton'

In [49]:
torch.save({
            'epoch': 2,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            }, 'classification_resnet.pt')