In [1]:
from data.transforms import get_tensorise_h_flip_transform
from data.constants import DEFAULT_REFLACX_LABEL_COLS
from data.paths import MIMIC_EYE_PATH
from data.datasets import ReflacxObjectDetectionWithFixations
from torch.utils.data import DataLoader
from data.datasets import collate_fn
from data.load import seed_worker, get_dataloader_g

dataset_params_dict = {
    "MIMIC_EYE_PATH": MIMIC_EYE_PATH,
    # "with_clinical": model_setup.use_clinical,
    "bbox_to_mask": True,
    "labels_cols": DEFAULT_REFLACX_LABEL_COLS,
}

train_dataset = ReflacxObjectDetectionWithFixations(
        **dataset_params_dict, split_str="train", transforms=get_tensorise_h_flip_transform(train=False), 
)

train_dataloader =  DataLoader(
    train_dataset,
    batch_size=4,
    shuffle=True,
    collate_fn=collate_fn,
    worker_init_fn=seed_worker,
    generator=get_dataloader_g(0),
)



In [2]:
from models.components.feature_extractors import ImageFeatureExtractor
from models.components.fusors import NoActionFusor
from models.components.task_performers import ObjectDetectionWithMaskParameters, ObjectDetectionWithMaskPerformer, HeatmapGeneratorParameters, HeatmapGenerator
from models.frameworks import ExtractFusePerform
from models.backbones import get_normal_backbone
from models.setup import ModelSetup
from data.constants import DEFAULT_REFLACX_LABEL_COLS


In [3]:
from utils.init import reproducibility, clean_memory_get_device

device = clean_memory_get_device()
reproducibility()

This notebook will running on device: [CPU]


In [4]:
setup = ModelSetup()
backbone = get_normal_backbone(setup)
image_extractor = ImageFeatureExtractor(backbone)
fusor = NoActionFusor()


obj_params = ObjectDetectionWithMaskParameters()
obj_performer = ObjectDetectionWithMaskPerformer(
    obj_params,
    image_extractor.backbone.out_channels,
    len(DEFAULT_REFLACX_LABEL_COLS) + 1
)

fix_params = HeatmapGeneratorParameters(input_channel=backbone.out_channels, decoder_channels=[64, 64, 64, 64, 64, 64, 64, 64, 64, 1]) # the output should be just one channel.
fix_performer = HeatmapGenerator(
    params= fix_params,
)

model = ExtractFusePerform(
    feature_extractors={"image": image_extractor},
    fusor=fusor,
    task_performers={"object-detection": obj_performer, "fixation-generation": fix_performer },
)

Using pretrained backbone. mobilenet_v3




In [5]:
data = next(iter(train_dataloader))
input, targets = train_dataset.prepare_input_from_data(data, device)


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  boxes_df[k] = ellipse_df[
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  boxes_df[k] = ellipse_df[
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  boxes_df[k] = ellipse_df[
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See

In [6]:
model.train()
outputs = model(input, targets=targets)

fixations are being resized.
size should be [512, 512]
fixations are resized into: torch.Size([1, 512, 512])
fixations are being resized.
size should be [512, 512]
fixations are resized into: torch.Size([1, 512, 512])
fixations are being resized.
size should be [512, 512]
fixations are resized into: torch.Size([1, 512, 512])
fixations are being resized.
size should be [512, 512]
fixations are resized into: torch.Size([1, 512, 512])
after preparation, the size is torch.Size([1, 512, 512])
before task performers, the size is torch.Size([1, 512, 512])




: 

: 

In [None]:
model.task_performers['fixation-generation'].targets[0]['fixations'].shape

torch.Size([1, 2539, 3050])

In [None]:
model.task_performers['fixation-generation'].targets[0]['masks'].shape

torch.Size([1, 2539, 3050])

In [None]:
model.task_performers['fixation-generation'].targets[1]['fixations'].shape

torch.Size([1, 3056, 2544])

In [None]:
model.task_performers['fixation-generation'].targets[1]['masks'].shape

torch.Size([0, 3056, 2544])

In [None]:
for i in range(len(model.task_performers['fixation-generation'].targets)):
    print(model.task_performers['fixation-generation'].targets[i]['fixations'].shape)
    print(model.task_performers['fixation-generation'].targets[i]['masks'].shape)
    print("="*20)


torch.Size([1, 2539, 3050])
torch.Size([1, 2539, 3050])
torch.Size([1, 3056, 2544])
torch.Size([0, 3056, 2544])
torch.Size([1, 2544, 3056])
torch.Size([6, 2544, 3056])
torch.Size([1, 2544, 3040])
torch.Size([0, 2544, 3040])


In [None]:
targets.fixations

[{'fixations': tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           ...,
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.]]]),
  'fixation_path': '/Users/jrhs/Desktop/mimic-eye/patient_14718365/REFLACX/main_data/P300R336037/fixations.csv',
  'masks': tensor([[[0, 0, 0,  ..., 0, 0, 0],
           [0, 0, 0,  ..., 0, 0, 0],
           [0, 0, 0,  ..., 0, 0, 0],
           ...,
           [0, 0, 0,  ..., 0, 0, 0],
           [0, 0, 0,  ..., 0, 0, 0],
           [0, 0, 0,  ..., 0, 0, 0]]], dtype=torch.uint8),
  'image_path': '/Users/jrhs/Desktop/mimic-eye/patient_14718365/CXR-JPG/s59269152/5024f775-51ab5259-a943877e-ebaa6afd-2ed9fe6d.jpg',
  'dicom_id': '5024f775-51ab5259-a943877e-ebaa6afd-2ed9fe6d',
  'iscrowd': tensor([0]),
  'area': tensor([1724940.], dtype=torch.float64),
  'image_id': tensor([1101]),
  'labels': tensor([2]),
 

In [None]:
targets[0]['fixations'].shape

torch.Size([1, 2539, 3050])

In [None]:
torch.stack()