In [1]:
import warnings
warnings.filterwarnings('ignore')

In [2]:
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

# Load Model
**Load pretrained model**

In [3]:
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
model.roi_heads.box_predictor

FastRCNNPredictor(
  (cls_score): Linear(in_features=1024, out_features=91, bias=True)
  (bbox_pred): Linear(in_features=1024, out_features=364, bias=True)
)

**Change Last Layer**

In [4]:
num_classes = 11
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
model.roi_heads.box_predictor

FastRCNNPredictor(
  (cls_score): Linear(in_features=1024, out_features=11, bias=True)
  (bbox_pred): Linear(in_features=1024, out_features=44, bias=True)
)

# Load Dataset
**Create dataset wrapper**

In [5]:
import os.path
import pathlib
from xml.etree import ElementTree

import torch
import torchvision.transforms
from PIL import Image
from torch.utils.data import Dataset


class ATZDataset(Dataset):
    """
    The data_dir directory must contain:
        JPEGImages
        Annotations
        train.txt
        test.txt
        val.txt
    to provide Active Terahertz Imaging Dataset for Concealed Object Detection(https://github.com/LingLIx/THz_Dataset)
    """
    ATZ_CLASSES = ["CP", "MD", "GA", "KK", "SS",
                   "KC", "WB", "LW", "CK", "CL", "UNKNOWN", "UN"]
    IGNORE_CLASSES = ["HUMAN", ]
    
    def __init__(self, root: str, split: str):
        assert split in ("train", "test", "val")
        self.data_dir = pathlib.Path(root)
        self.image_dir = self.data_dir / "JPEGImages"
        self.anno_dir = self.data_dir / "Annotations"
        self.collection_file = self.data_dir / ("%s.txt" % split)
        assert os.path.isdir(self.data_dir)
        assert os.path.isdir(self.image_dir)
        assert os.path.isdir(self.anno_dir)
        assert os.path.isfile(self.collection_file)
        # load filenames form split files
        with open(self.collection_file) as fp:
            self.filenames = list(map(lambda x: x.strip(), fp.readlines()))
    
    def read_vocxml_content(self, xml_file: str):
        tree = ElementTree.parse(xml_file)
        root = tree.getroot()
        OOI = list(filter(lambda _box: _box.find("name").text not in self.IGNORE_CLASSES, root.iter('object')))
        N = len(OOI)
        boxes = torch.zeros((N, 4), dtype=torch.float32)
        labels = torch.zeros((N,), dtype=torch.int64)
        for idx, box in enumerate(OOI):
            class_ = box.find("name").text
            ymin = int(box.find("bndbox/ymin").text)
            xmin = int(box.find("bndbox/xmin").text)
            ymax = int(box.find("bndbox/ymax").text)
            xmax = int(box.find("bndbox/xmax").text)
            boxes[idx, :] = (xmin, ymin, xmax, ymax)
            labels[idx] = self.ATZ_CLASSES.index(class_)
        return dict(boxes=boxes, labels=labels)
    
    def __len__(self):
        return len(self.filenames)
    
    def __getitem__(self, idx):
        filename = self.filenames[idx]
        print(filename)
        image = torchvision.transforms.ToTensor()(Image.open(self.image_dir / ("%s.jpg" % filename)).convert("RGB"))
        targets = self.read_vocxml_content(self.anno_dir / ("%s.xml" % filename))
        return image, targets


**Create dataloader function by split**

In [6]:
from torch.utils.data import DataLoader
def get_dataloader(split):
    """
    Returns dataloader for given split
    """
    batch_size = 32
    drop_last = True
    num_workers = 4
    data_dir = "data/THZ_dataset/THZ_dataset_det_VOC"
    dataset = ATZDataset(root=data_dir, split=split)
    dataloader = DataLoader(dataset=dataset, batch_size=batch_size,
                            drop_last=drop_last,
                            shuffle=True, num_workers=num_workers)
    return dataloader

# Train Loop
**initialize optimiser**

In [7]:
import torch
import torchvision
from torch import optim
from torch.utils.data import DataLoader
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor



def get_dataloader(split):
    batch_size = 32
    drop_last = True
    num_workers = 4
    data_dir = "data/THZ_dataset/THZ_dataset_det_VOC"
    dataset = ATZDataset(root=data_dir, split=split)
    dataloader = DataLoader(dataset=dataset, batch_size=batch_size,
                            drop_last=drop_last,
                            shuffle=True, num_workers=num_workers)
    return dataloader


def get_net(num_classes):
    """
    Returns:
    https://pytorch.org/vision/stable/models/generated/torchvision.models.detection.fasterrcnn_resnet50_fpn.html
    #torchvision.models.detection.fasterrcnn_resnet50_fpn
    """
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    return model


def get_optimizers(net, lr=0.001):
    optimizer = optim.RMSprop(net.parameters(), lr=lr)
    return optimizer


def training(net, optimizer, device, max_epoch=5, start_epoch=0):
    train_dataloader = get_dataloader('train')
    # device = torch.device('cpu')
    
    net.train()
    net.to(device)
    for epoch_idx in range(start_epoch, max_epoch, 1):
        for batch_idx, (batch_image, batch_targets) in enumerate(train_dataloader):
            print("\rEpoch [%d/%d] Batch [%d/%d] "%(epoch_idx,max_epoch,), end="\b")
            # transfer data to GPU
            batch_image.to(device)
            batch_targets.to(device)
            # Forward pass with loss
            optimizer.zero_grad()
            losses = net(batch_image, batch_targets)
            # sum classification loss + bbox regression loss
            loss = sum(v for v in losses.values())
            # set gradient to zero
            # Calculate gradient
            loss.backword()
            # update parameters
            optimizer.step()
        print("Loss:", loss.detach().cup())


def testing(net, optimizer):
    test_dataloader = get_dataloader('test')
    ...


def validation(net):
    val_dataloader = get_dataloader('val')
    ...




In [None]:
def main():
    max_epoch = 5
    start_epoch = 0
    num_classes = 11
    lr = 0.001
    device = torch.device('cpu')
    net = get_net(num_classes=num_classes)
    optimizer = get_optimizers(net, lr=lr)
    training(net, optimizer,device, max_epoch=5, start_epoch=0)
main()