In [60]:
import torch
from PIL import Image, ImageFont, ImageDraw, ImageEnhance
import numpy as np
import ipywidgets as widgets
from IPython.display import display
import functools

Загружаем датасет

In [None]:
# download the Penn-Fudan dataset
!wget https://www.cis.upenn.edu/~jshi/ped_html/PennFudanPed.zip .
# extract it in the current folder
!unzip PennFudanPed.zip

Установка torchvision

In [None]:
%%shell

# Download TorchVision repo to use some files from
# references/detection
git clone https://github.com/pytorch/vision.git
cd vision
git checkout v0.8.2

cp references/detection/utils.py ../
cp references/detection/transforms.py ../
cp references/detection/coco_eval.py ../
cp references/detection/engine.py ../
cp references/detection/coco_utils.py ../

Загружаем модель

In [61]:
the_model = torch.load('model_s_vesami.pth')
the_model.eval()

MaskRCNN(
  (transform): GeneralizedRCNNTransform(
      Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
      Resize(min_size=(800,), max_size=1333, mode='bilinear')
  )
  (backbone): BackboneWithFPN(
    (body): IntermediateLayerGetter(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): FrozenBatchNorm2d(64, eps=0.0)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(64, eps=0.0)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(64, eps=0.0)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(256, eps=0.0)
          (relu): ReLU(in

Класс для работы с датасетом

In [62]:
class PennFudanDataset(torch.utils.data.Dataset):
    def __init__(self, root, transforms):
        self.root = root
        self.transforms = transforms
        # load all image files, sorting them to
        # ensure that they are aligned
        self.imgs = list(sorted(os.listdir(os.path.join(root, "PNGImages"))))
        self.masks = list(sorted(os.listdir(os.path.join(root, "PedMasks"))))

    def __getitem__(self, idx):
        # load images and masks
        img_path = os.path.join(self.root, "PNGImages", self.imgs[idx])
        mask_path = os.path.join(self.root, "PedMasks", self.masks[idx])
        img = Image.open(img_path).convert("RGB")
        # note that we haven't converted the mask to RGB,
        # because each color corresponds to a different instance
        # with 0 being background
        mask = Image.open(mask_path)
        # convert the PIL Image into a numpy array
        mask = np.array(mask)
        # instances are encoded as different colors
        obj_ids = np.unique(mask)
        # first id is the background, so remove it
        obj_ids = obj_ids[1:]

        # split the color-encoded mask into a set
        # of binary masks
        masks = mask == obj_ids[:, None, None]

        # get bounding box coordinates for each mask
        num_objs = len(obj_ids)
        boxes = []
        for i in range(num_objs):
            pos = np.where(masks[i])
            xmin = np.min(pos[1])
            xmax = np.max(pos[1])
            ymin = np.min(pos[0])
            ymax = np.max(pos[0])
            boxes.append([xmin, ymin, xmax, ymax])

        # convert everything into a torch.Tensor
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        # there is only one class
        labels = torch.ones((num_objs,), dtype=torch.int64)
        masks = torch.as_tensor(masks, dtype=torch.uint8)

        image_id = torch.tensor([idx])
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        # suppose all instances are not crowd
        iscrowd = torch.zeros((num_objs,), dtype=torch.int64)

        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["masks"] = masks
        target["image_id"] = image_id
        target["area"] = area
        target["iscrowd"] = iscrowd

        if self.transforms is not None:
            img, target = self.transforms(img, target)

        return img, target

    def __len__(self):
        return len(self.imgs)

In [63]:
from engine import train_one_epoch, evaluate
import utils
import transforms as T
import os


def get_transform(train):
    transforms = []
    # converts the image, a PIL image, into a PyTorch Tensor
    transforms.append(T.ToTensor())
    if train:
        # during training, randomly flip the training images
        # and ground-truth for data augmentation
        transforms.append(T.RandomHorizontalFlip(0.5))
    return T.Compose(transforms)

In [64]:
dataset_test = PennFudanDataset('PennFudanPed', get_transform(train=False))
img, _ = dataset_test[1]

test_img = Image.fromarray(img.mul(255).permute(1, 2, 0).byte().numpy())


show_img_button = widgets.Button(description="Show input image")
output1 = widgets.Output()
display(show_img_button, output1)

@output1.capture(clear_output=True, wait=True)
def on_show_img_button_clicked(b):
    with output1:
        # display(widgets.Image(value=buff.getvalue()))
        # print(img)
        test_img.save('test_img.jpg')
        file = open('test_img.jpg', "rb")
        image = file.read()
        display(
            widgets.Image(
                value=image,
                format='png',
                width=300,
                height=400,
            )
        )

show_img_button.on_click(on_show_img_button_clicked)

Button(description='Show input image', style=ButtonStyle())

Output()

In [65]:
button = widgets.Button(description="Predict")
output2 = widgets.Output()
display(button, output2)

@output2.capture(clear_output=True, wait=True)
def on_button_clicked(b):
    with output2:
        img, _ = dataset_test[1]
        with torch.no_grad():
            prediction = the_model([img])
        img = Image.fromarray(img.mul(255).permute(1, 2, 0).byte().numpy())
        draw = ImageDraw.Draw(img)
        draw.rectangle(((prediction[0]['boxes'][0][0], prediction[0]['boxes'][0][1]), (prediction[0]['boxes'][0][2], prediction[0]['boxes'][0][3])))
        img.save('prediction_test.jpg')
        file = open('prediction_test.jpg', "rb")
        image = file.read()
        display(
            widgets.Image(
                value=image,
                format='png',
                width=300,
                height=400,
            )
        )
        

button.on_click(on_button_clicked)

Button(description='Predict', style=ButtonStyle())

Output()