In [1]:
from models.components.feature_extractors import ImageFeatureExtractor
from models.components.fusors import NoActionFusor
from models.components.task_performers import ObjectDetectionWithMaskParameters, ObjectDetectionWithMaskPerformer, HeatmapGeneratorParameters, HeatmapGenerator
from models.frameworks import ExtractFusePerform
from models.backbones import get_normal_backbone
from models.setup import ModelSetup
from data.constants import DEFAULT_REFLACX_LABEL_COLS


In [2]:
setup = ModelSetup()
backbone = get_normal_backbone(setup)
image_extractor = ImageFeatureExtractor(backbone)
fusor = NoActionFusor()


obj_params = ObjectDetectionWithMaskParameters()
obj_performer = ObjectDetectionWithMaskPerformer(
    obj_params,
    image_extractor.backbone.out_channels,
    len(DEFAULT_REFLACX_LABEL_COLS) + 1
)

Using pretrained backbone. mobilenet_v3




In [3]:
# image size is 512

In [4]:
backbone.out_channels

64

In [5]:
fix_params = HeatmapGeneratorParameters(input_channel=backbone.out_channels, decoder_channels=[64, 64, 64, 64, 64])
fix_performer = HeatmapGenerator(
    params= fix_params,
)

In [6]:
framework = ExtractFusePerform(
    feature_extractors={"image": image_extractor},
    fusor=fusor,
    task_performers={"object-detection": obj_performer, "fixation-generation": fix_performer },
)