In [21]:
import torch
import numpy as np
import os
from PIL import Image
import xml.etree.ElementTree as ET
import random
import utils
import transforms as T
from torch.utils.data import random_split
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
import gc
from engine import train_one_epoch, evaluate

In [2]:
# Trabajé con el programa LabelIMG para hacer las cajas, y las clases están ordenadas de esta forma:
classes = ["dog", "person", "cat", "tv", "car", "meatballs", "marinara sauce", "tomato soup",
           "chicken noodle soup", "french onion soup", "chicken breast", "ribs", "pulled pork", "hamburger",
           "cavity"]

len(classes)

15

In [3]:
# Todos los label de las imágenes están en XML
def extract_boxes_from_xml(file_path):
    tree = ET.parse(file_path)
    root = tree.getroot()
    
    boxes = [] # cajas para los objetos que hayan
    for obj in root.iter('object'):
        if obj[0].text == 'person': # sólo estamos considerando si el objeto es persona
            xmin = int(obj[4][0].text)
            ymin = int(obj[4][1].text)
            xmax = int(obj[4][2].text)
            ymax = int(obj[4][3].text)
            boxes.append([xmin, ymin, xmax, ymax])
            
    return boxes

boxes = extract_boxes_from_xml('dataset/human-label/510.xml')
print(f'found {len(boxes)} objects.')
for bbox in boxes:
    print(bbox)

found 1 objects.
[294, 87, 865, 374]


In [4]:
directory_boxes = list(sorted(os.listdir(os.path.join("dataset", "human-label"))))
directory_images = list(sorted(os.listdir(os.path.join("dataset", "human"))))

len(directory_boxes) == len(directory_images)

True

In [5]:
class CameraDataset(torch.utils.data.Dataset):
    def __init__(self, root, transforms):
        self.root = root
        self.transforms = transforms
        
        self.imgs = list(sorted(os.listdir(os.path.join(root, "human"))))
        self.boxes = list(sorted(os.listdir(os.path.join(root, "human-label"))))
        
    def __getitem__(self, idx):
        img_path = os.path.join(self.root, "human", self.imgs[idx])
        boxes_path = os.path.join(self.root, "human-label", self.boxes[idx])
        img = Image.open(img_path).convert("RGB") # pasamos de RGBA (4 Canales en PNG) a RGB (3 canales)
        
        boxes = extract_boxes_from_xml(f'dataset/human-label/{self.boxes[idx]}')
        num_objs = len(boxes)
        # convertimos las cajas obtenidas a formato de tensor
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        # creamos un tensor que indica el label númerico. En este caso, solamente serán 1s (clase persona)
        labels = torch.ones((num_objs), dtype=torch.int64)
        
        image_id = torch.tensor([idx])
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        
        #iscrowd = torch.zeros((num_objs,), dtype=torch.int64)
        
        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["image_id"] = image_id
        target["area"] = area
        #target["iscrowd"] = iscrowd
        
        if self.transforms is not None:
            img, target = self.transforms(img, target)
            
        return img, target
    
    def __len__(self):
        return len(self.imgs)

In [6]:
def get_model(num_classes):
    # cargamos el modelo, en este caso, se ocupa el mismo modelo que en el proyecto de título
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights="DEFAULT")

    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    
    return model

In [7]:
def get_transform(train):
    transforms = []
    transforms.append(T.PILToTensor())
    transforms.append(T.ConvertImageDtype(torch.float))
    if train:
        transforms.append(T.RandomHorizontalFlip(0.5))
    return T.Compose(transforms)

In [8]:
def split_dataset(dataset, train):
    assert 0 <= train <= 1, "'train' should be between 0 and 1."

    dataset_size = len(dataset)
    train_size = int(train * dataset_size)
    test_size = dataset_size - train_size

    train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

    return train_dataset, test_dataset

In [19]:
def train_model():
    device = torch.device('cpu')

    # solamente consideramos bounding box persona
    num_classes = 1

    dataset = CameraDataset('dataset', get_transform(train=True))
    dataset_test = CameraDataset('dataset', get_transform(train=False))

    # 85% para entrenamiento
    # 15% para test
    dataset, dataset_test = split_dataset(dataset, train=0.85)

    # define training and validation data loaders
    data_loader = torch.utils.data.DataLoader(
        dataset, batch_size=32, shuffle=True, #num_workers=4,
        collate_fn=utils.collate_fn)

    data_loader_test = torch.utils.data.DataLoader(
        dataset_test, batch_size=10, shuffle=False, #num_workers=4,
        collate_fn=utils.collate_fn)

    # get the model using our helper function
    model = get_model(num_classes)
    
    # move model to the right device
    model.to(device)

    # construct an optimizer
    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.SGD(params, lr=0.005,
                                momentum=0.9, weight_decay=0.0005)
    # and a learning rate scheduler
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                   step_size=3,
                                                   gamma=0.1)

    # let's train it for 10 epochs
    num_epochs = 10
    
    for epoch in range(num_epochs):
        # train for one epoch, printing every 10 iterations
        train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=10)
        # update the learning rate
        lr_scheduler.step()
        # evaluate on the test dataset
        evaluate(model, data_loader_test, device=device)

    print("Training process over.")

In [26]:
train_model()

KeyboardInterrupt: 

In [27]:
gc.collect() # llamamos al garbage collector para liberar memoria, en caso que esta no sea liberada

3817