# 準備

In [8]:
from glob import glob
import xml.etree.ElementTree as ET
from PIL import Image
import torch
from torchvision import transforms
import argparse
import cv2
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torch.utils.data import TensorDataset
import datetime


def FasterRCNN(num_classes):
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    return model


device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print('Using device:', device)
model = FasterRCNN(num_classes=2).to(device)
dataset_class = ['defect']
colors = ((0, 0, 0), (255, 0, 0))

Using device: cuda


# 学習させる

学習に使いたいデータを入力してください． <br>
Model が未入力の場合は FasteRCNNPredictor を使います．

In [34]:
path_model = ""
path_dataset = "dataset/20211022"

出力先の設定

In [35]:
path_model_o = "model"

以下学習

In [41]:
xml_dir = path_dataset + "/label"
img_dir = path_dataset + "/img"
dt_now = datetime.datetime.now()
model_name = path_model_o + "/" + dt_now.strftime('%Y%m%d%H%M%S') + ".pth"

set_num_epochs = 20
set_batch_size = 1
set_lr = 0.005


def xml2list(xml_path, classes):
    xml = ET.parse(xml_path).getroot()
    boxes, labels = [], []

    for obj in xml.iter('object'):
        label = obj.find('name').text

    if label in classes:
        bndbox = obj.find('bndbox')
        xmin = int(bndbox.find('xmin').text.split(".")[0])
        ymin = int(bndbox.find('ymin').text.split(".")[0])
        xmax = int(bndbox.find('xmax').text.split(".")[0])
        ymax = int(bndbox.find('ymax').text.split(".")[0])
        boxes.append([xmin, ymin, xmax, ymax])
        labels.append(classes.index(label))

    anno = {'bboxes': boxes, 'labels': labels}
    return anno, len(boxes)


class MyDataset(torch.utils.data.Dataset):
    def __init__(self, image_dir, xml_paths, classes):
        super().__init__()
        self.image_dir = image_dir
        self.xml_paths = xml_paths
        self.image_ids = sorted(glob('{}/*'.format(xml_paths)))
        self.classes = classes
        self.transform = transforms.Compose([
                                          transforms.ToTensor()
        ])

    def __getitem__(self, index):
        #image_id = self.image_ids[index].split("/")[-1].split(".")[0]
        image_id = self.image_ids[index].split("\\")[-1].split(".")[0]
        image = Image.open(f"{self.image_dir}/{image_id}.jpg")
        image = self.transform(image)
        image = image[:3, :, :]

        path_xml = f'{self.xml_paths}/{image_id}.xml'
        annotations, obje_num = xml2list(path_xml, self.classes)

        boxes = torch.as_tensor(annotations['bboxes'], dtype=torch.int64)
        labels = torch.as_tensor(annotations['labels'], dtype=torch.int64)
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        area = torch.as_tensor(area, dtype=torch.float32)
        iscrowd = torch.zeros((obje_num,), dtype=torch.int64)

        target = dict()
        target["boxes"] = boxes
        target["labels"] = labels + 1
        target["image_id"] = torch.tensor([index])
        target["area"] = area
        target["iscrowd"] = iscrowd
        return image, target, image_id

    def __len__(self):
        return len(self.image_ids)


def collate_fn(batch):
    return tuple(zip(*batch))


if(path_model == ""):
    model = FasterRCNN(num_classes=2).to(device)
else:
    model.load_state_dict(torch.load(path_model))


dataset = MyDataset(img_dir, xml_dir, dataset_class)
train_dataloader = torch.utils.data.DataLoader(dataset,
                                              batch_size=set_batch_size,
                                              shuffle=True,
                                              collate_fn=collate_fn)


params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=set_lr, momentum=0.9, weight_decay=0.0005)

f = open('logfile.txt', 'a')
f.write("==================================================\n")
f.write(f"Started {dt_now}\n")
f.write(f"\tpath_model = \"{path_model}\"\n")
f.write(f"\tpath_dataset = \"{path_dataset}\"\n")
f.write(f"\toutput model = \"{model_name}\"\n")
f.write(f"\tset_num_epochs = \"{set_num_epochs}\"\n")
f.write(f"\tset_batch_size = \"{set_batch_size}\"\n")
f.write(f"\n")

for epoch in range(set_num_epochs):
    model.train()
    train_loss = 0

    for i, batch in enumerate(train_dataloader):
        images, targets, image_ids = batch
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())
        train_loss += losses.item() * len(images)

        optimizer.zero_grad()
        losses.backward()
        optimizer.step()
        #f.write(f"{i}\t{train_loss}\n")
        if i >= 100:
            break

    
    f.write(f"epoch {epoch+1}\tloss: {train_loss / len(train_dataloader.dataset)}\n")
    torch.save(model.state_dict(), model_name)


f.write("\n")
f.close()

# 判定する

判定に使うモデルと判定したい画像ディレクトリを入力してください．

In [2]:
path_model_i = "model/20211025162458.pth"
path_test = "input"

出力先の設定

In [3]:
path_output = "output"

以下判定

In [None]:
test_dir = path_test
save_path = path_output

max_item_in_picture = 5
threshold = 0.5

model.load_state_dict(torch.load(path_model_i))
model.eval()
test_classes = ['__background__'] + dataset_class
for imgfile in sorted(glob(test_dir + '/*')):
    img = cv2.imread(imgfile)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    image_tensor = torchvision.transforms.functional.to_tensor(img)
    #file_name = imgfile.split("/")[-1].split(".")[0]
    file_name = imgfile.split("\\")[-1].split(".")[0]
    f = open(save_path + f"/label/{file_name}.xml", mode='w')
    f.write(f'<annotation verified="yes">\n')
    f.write(f'\t<folder>Annotation</folder>\n')
    f.write(f'\t<filename>{file_name}.jpg</filename>\n')
    f.write(f'\t<source>\n\t\t<database>Unknown</database>\n\t</source>\n')
    f.write(f'\t<size>\n')
    f.write(f'\t\t<width>{img.shape[1]}</width>\n')
    f.write(f'\t\t<height>{img.shape[0]}</height>\n')
    f.write(f'\t\t<depth>{img.shape[2]}</depth>\n')
    f.write(f'\t</size>\n')
    f.write(f'\t<segmented>0</segmented>\n')

    with torch.no_grad():
        prediction = model([image_tensor.to(device)])

    for i, box in enumerate(prediction[0]['boxes'][:max_item_in_picture]):
        score = prediction[0]['scores'][i].cpu().numpy()
        score = round(float(score), 2)
        cat = prediction[0]['labels'][i].cpu().numpy()
        if(score < threshold):
            break

        txt = '{} {}'.format(test_classes[int(cat)], str(score))
        font = cv2.FONT_HERSHEY_SIMPLEX
        cat_size = cv2.getTextSize(txt, font, 0.5, 2)[0]
        c = colors[int(cat)]
        box = box.cpu().numpy().astype('int')
        cv2.rectangle(img, (box[0], box[1]), (box[2], box[3]), c, 2)
        cv2.rectangle(img, (box[0], box[1] - cat_size[1] - 2), (box[0] + cat_size[0], box[1] - 2), c, -1)
        cv2.putText(img, txt, (box[0], box[1] - 2), font, 0.5, (0, 0, 0), thickness=1, lineType=cv2.LINE_AA)
        f.write(f'\t<object>\n')
        f.write(f'\t<name>defect</name>\n')
        f.write(f'\t<pose>Unspecified</pose>\n')
        f.write(f'\t<truncated>0</truncated>\n')
        f.write(f'\t<difficult>0</difficult>\n')
        f.write(f'\t<bndbox>\n')
        f.write(f'\t\t<xmin>{box[0]}</xmin>\n')
        f.write(f'\t\t<ymin>{box[1]}</ymin>\n')
        f.write(f'\t\t<xmax>{box[2]}</xmax>\n')
        f.write(f'\t\t<ymax>{box[3]}</ymax>\n')
        f.write(f'\t</bndbox>\n')
        f.write(f'\t</object>\n')

    pil_img = Image.fromarray(img)
    pil_img.save(save_path + f"/{file_name}_result.jpg")
    
    f.write(f'</annotation>\n')
    f.close()