In [1]:
from torch import nn
import torch

class VGG(nn.Module):
    def __init__(self):
        super().__init__()
        def basic_block(in_channel, out_channel, **kwargs):
            return nn.Sequential(
                nn.Conv2d(in_channel, out_channel, **kwargs, bias=False),
                nn.BatchNorm2d(out_channel),
                nn.ReLU(inplace=True),
                nn.MaxPool2d(kernel_size=2, stride=2)
            )

        self.feature_layer = nn.Sequential(
            basic_block(1, 4),
            basic_block(4, 16),
            basic_block(16, 1),
        )

    def forward(self, x):
        return self.feature_layer(x)

class FastRCNN(nn.Module):
    def __init__(self, num_classes=2):
        super().__init__()

        self.feature_layer = VGG()

        self.roi_pooling = nn.AdaptiveAvgPool2d((7, 7))

        self.roi_feature = nn.Sequential(
            nn.Flatten(),
            nn.Linear(7*7, 256),
            nn.Linear(256, 256),
        )

        self.cls_layer = nn.Sequential(
            nn.Linear(256, num_classes),
            nn.Sigmoid()
        )

        self.loc_layer = nn.Sequential(
            nn.Linear(256, 4)
        )

    def forward(self, x, bboxs):
        """[summary]

        Args:
            x ([type]): [description]
            bboxs ([type]): [img_idx, x_min, y_min, x_max, y_max]
        """
        feature = self.feature_layer(x)
        scale_factor = set((feature / bboxs)[3:])
        feature_bboxs = torch.ceil(bboxs / scale_factor).type(torch.int)
        rois = []
        for img_idx, x_min, y_min, x_max, y_max in feature_bboxs:
            roi = self.roi_pooling(feature[img_idx, :, x_min:x_max, y_min:y_max])
            rois.append(roi)
        rois = torch.stack(rois, dim=0)

        feature_vec = self.roi_feature(rois)

        cls = self.cls_layer(feature_vec)
        bbox = self.loc_layer(feature_vec)

        return cls, bbox

In [None]:
from pytorch_lightning import LightningModule
from torch.nn import CrossEntropyLoss, SmoothL1Loss

class Net(LightningModule):
    def __init__(self, ):
        super().__init__()

        self.backbone = FastRCNN()

        # losses
        self.cls_loss = CrossEntropyLoss()

        self.bbox_loss = SmoothL1Loss()

    def forward(self, x, bboxs):
        cls, bbox = self.backbone(x, bboxs)

    def training_step(self, batch, batch_idx):
        pass

    def test_step(self, batch, batch_idx):
        pass

    def configure_optimizers(self):
        from torch.optim import SGD
        return SGD(self.parameters(), lr=1e-4)