In [0]:
import os
import glob
from PIL import Image, ImageDraw
import xml.etree.ElementTree as ET
import torch
import torchvision
from torchvision import transforms
from torch.utils.data import Dataset
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

In [0]:
class TissueDataset(Dataset):

    def __init__(self, img_dir='data/images', annotation_dir='data/annotations', transform=None):
        self.img_dir = img_dir
        self.annotation_dir = annotation_dir
        self.transform = transform

        self.img_names = os.listdir(img_dir)
        self.img_names = [filename for filename in self.img_names if filename.endswith('png')]
        self.img_names.sort()
        self.img_names = [os.path.join(img_dir, img_name) for img_name in self.img_names]

        self.annotation_names = os.listdir(annotation_dir)
        self.annotation_names = [filename for filename in self.annotation_names if filename.endswith('xml')]
        self.annotation_names.sort()
        self.annotation_names = [os.path.join(annotation_dir, ann_name) for ann_name in self.annotation_names]


    def __getitem__(self, idx):
        img_name = self.img_names[idx]
        img = Image.open(img_name)

        annotation_name = self.annotation_names[idx]
        annotation_tree = ET.parse(annotation_name)
        objects = annotation_tree.findall("object")
        bndboxes = []
        labels = []
        
        for object in objects:
            label = int(object.find("name").text)
            bndbox_xml = object.find("bndbox")
                
            xmax = int(bndbox_xml.find('xmax').text) 
            ymax = int(bndbox_xml.find('ymax').text)
            xmin = int(bndbox_xml.find('xmin').text)
            ymin = int(bndbox_xml.find('ymin').text)

            w = xmax - xmin #
            h = ymax - ymin
            x = int(xmin + w / 2)
            y = int(ymin + h / 2)

            x /= img.size[0]
            w /= img.size[0]
            y /= img.size[1]
            h /= img.size[1]

            bndbox = (x, y, w, h)
            
            labels.append([label])
            bndboxes.append([bndbox])
        
        if self.transform:
            img = self.transform(img)

        bndboxes = torch.tensor(bndboxes)
        labels = torch.tensor(labels)
        
        annotations = {}
        annotations['boxes'] = bndboxes
        annotations['labels'] = labels

        return img, annotations


    def __len__(self):
        return len(self.img_names)

In [0]:
def get_transform():
    custom_transforms = []
    custom_transforms.append(torchvision.transforms.ToTensor())
    return torchvision.transforms.Compose(custom_transforms)

In [0]:
tissueDataset = TissueDataset(transform=get_transform())

In [0]:
def collate_fn(batch):
    return tuple(zip(*batch))

In [0]:
data_loader = torch.utils.data.DataLoader(tissueDataset,
                                          batch_size=1,
                                          shuffle=True,
                                          collate_fn=collate_fn)

In [0]:
# select device (whether GPU or CPU)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [0]:
def GetIndividualBox(imgs, annotations):
    for imgs, annotations in data_loader:
        imgs_replicated = []
        new_annotations = []
        new_dict = {}
        for i in range(len(imgs)):
            for j in range(len(annotations[i]['boxes'])):
                new_dict['boxes'] = annotations[i]['boxes'][j]
                new_dict['labels'] = annotations[i]['labels'][j]

                imgs_replicated.append(imgs[i])
                new_annotations.append(new_dict)
    
    return imgs_replicated, new_annotations

In [0]:
def UnpackBndbox(bndbox, img):
    x, y, w, h = tuple(bndbox)
    x *= img.size[0] 
    w *= img.size[0]
    y *= img.size[1]
    h *= img.size[1]
    xmin = x - w / 2
    xmax = x + w / 2
    ymin = y - h / 2
    ymax = y + h / 2
    bndbox = [xmin, ymin, xmax, ymax]
    return bndbox

In [0]:
def Show(batch, pred_bndbox=None):
    img, annotations = batch

    img = transforms.ToPILImage()(img)
    img = transforms.Resize((512, 512))(img)
    draw = ImageDraw.Draw(img)

    for bndbox in annotations['boxes']:
        bndbox = UnpackBndbox(bndbox[0], img)
        draw.rectangle(bndbox)
        if pred_bndbox is not None:
            pred_bndbox = unpack_bndbox(pred_bndbox, img)
            draw.rectangle(pred_bndbox, outline=1000)
    img.show()

In [0]:
Show(tissueDataset[0])

In [16]:
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=False)

Downloading: "https://download.pytorch.org/models/resnet50-19c8e357.pth" to /root/.cache/torch/checkpoints/resnet50-19c8e357.pth


HBox(children=(IntProgress(value=0, max=102502400), HTML(value='')))




In [0]:
in_features = model.roi_heads.box_predictor.cls_score.in_features

In [0]:
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, 3)

In [19]:
model.to(device)

FasterRCNN(
  (transform): GeneralizedRCNNTransform()
  (backbone): BackboneWithFPN(
    (body): IntermediateLayerGetter(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): FrozenBatchNorm2d()
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d()
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d()
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d()
          (relu): ReLU(inplace=True)
          (downsample): Sequential(
            (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (1): FrozenBatchNorm2d()
          )
  

In [0]:
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)

len_dataloader = len(data_loader)

for epoch in range(10):
    model.train()
    i = 0    
    for imgs, annotations in data_loader:
        i += 1
        imgs, annotations = GetIndividualBox(imgs, annotations)
        imgs = list(img.to(device) for img in imgs)
        annotations = [{k: v.to(device) for k, v in t.items()} for t in annotations]
        loss_dict = model(imgs, annotations)
        losses = sum(loss for loss in loss_dict.values())

        optimizer.zero_grad()
        losses.backward()
        optimizer.step()

        print(f'Iteration: {i}/{len_dataloader}, Loss: {losses}')