Skip to content

Commit

Permalink
Add vit isic (#221)
Browse files Browse the repository at this point in the history
  • Loading branch information
shatz01 committed Nov 29, 2022
1 parent 2148e88 commit 8b12ad8
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 12 deletions.
94 changes: 84 additions & 10 deletions examples/fuse_examples/imaging/classification/isic/isic_runner.py
Expand Up @@ -52,6 +52,9 @@
from fuseimg.datasets.isic import ISIC, ISICDataModule
from fuse_examples.imaging.classification.isic.golden_members import FULL_GOLDEN_MEMBERS

import torch.nn as nn
from fuse.dl.models.model_wrapper import ModelWrapSeqToDict
from fuse.dl.models.backbones.backbone_vit import ViT

###########################################################################################################
# Fuse
Expand All @@ -71,9 +74,13 @@
##########################################
# Modality
##########################################

multimodality = True # Set: 'False' to use only imaging, 'True' to use imaging & meta-data

##########################################
# Model Type
##########################################
model_type = "Transformer" # Set: 'Transformer' to use ViT/MMViT, 'CNN' to use InceptionResNet

##########################################
# Output Paths
##########################################
Expand Down Expand Up @@ -126,15 +133,75 @@
# ===============
# Model
# ===============
TRAIN_COMMON_PARAMS["model"] = dict(
dropout_rate=0.5,
layers_description=(256,),
tabular_data_inputs=[("data.input.clinical.all", 19)] if multimodality else None,
tabular_layers_description=(128,) if multimodality else tuple(),
)
if model_type == "CNN":
TRAIN_COMMON_PARAMS["model"] = dict(
dropout_rate=0.5,
layers_description=(256,),
tabular_data_inputs=[("data.input.clinical.all", 19)] if multimodality else None,
tabular_layers_description=(128,) if multimodality else tuple(),
)
elif model_type == "Transformer":
token_dim = 768
TRAIN_COMMON_PARAMS["model"] = dict(
token_dim=token_dim,
projection_kwargs=dict(image_shape=[300, 300], patch_shape=[30, 30], channels=3),
transformer_kwargs=dict(depth=12, heads=12, mlp_dim=token_dim * 4, dim_head=64, dropout=0.0, emb_dropout=0.0),
)


def create_model(
def perform_softmax(logits: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
cls_preds = F.softmax(logits, dim=1)
return logits, cls_preds


class MMViT(ViT):
def __init__(self, token_dim: int, projection_kwargs: dict, transformer_kwargs: dict, multimodality: bool):
super().__init__(token_dim, projection_kwargs, transformer_kwargs)
self.multimodality = multimodality
num_tokens = self.projection_layer.num_tokens
self.token_dim = token_dim
self._head = nn.Linear(token_dim, 8)
if self.multimodality:
# change pos embedding to accept additional token for multimodal
self.transformer.pos_embedding = nn.Parameter(torch.randn(1, num_tokens + 2, token_dim))

# This forward can be Multimodal or just Imaging
def forward(self, img_x: torch.Tensor, clinical_x: torch.Tensor = None) -> torch.Tensor:
img_x = self.projection_layer(img_x)
if self.multimodality:
clinical_x = clinical_x.unsqueeze(1)
clinical_x_zeros = torch.zeros((img_x.shape[0], 1, self.token_dim))
clinical_x_zeros[:, :, :19] = clinical_x
clinical_x = clinical_x_zeros.cuda()
x = torch.cat((img_x, clinical_x), 1)
else:
x = img_x
x = self.transformer(x)
x = self._head(x[:, 0])
return x


def create_transformer_model(
token_dim: int,
projection_kwargs: dict,
transformer_kwargs: dict,
) -> ModelWrapSeqToDict:
torch_model = MMViT(
token_dim=token_dim,
projection_kwargs=projection_kwargs,
transformer_kwargs=transformer_kwargs,
multimodality=multimodality,
)
model = ModelWrapSeqToDict(
model=torch_model,
model_inputs=["data.input.img", "data.input.clinical.all"] if multimodality else ["data.input.img"],
post_forward_processing_function=perform_softmax,
model_outputs=["model.logits.head_0", "model.output.head_0"],
)
return model


def create_cnn_model(
dropout_rate: float,
layers_description: Sequence[int],
tabular_data_inputs: Sequence[Tuple[str, int]],
Expand Down Expand Up @@ -216,7 +283,10 @@ def run_train(paths: dict, train_common_params: dict) -> None:
# ==============================================================================
lgr.info("Model:", {"attrs": "bold"})

model = create_model(**train_common_params["model"])
if model_type == "Transformer":
model = create_transformer_model(**train_common_params["model"])
elif model_type == "CNN":
model = create_cnn_model(**train_common_params["model"])

lgr.info("Model: Done", {"attrs": "bold"})

Expand Down Expand Up @@ -339,7 +409,11 @@ def run_infer(paths: dict, infer_common_params: dict) -> None:
)

# load python lightning module
model = create_model(**infer_common_params["model"])
if model_type == "Transformer":
model = create_transformer_model(**infer_common_params["model"])
elif model_type == "CNN":
model = create_cnn_model(**infer_common_params["model"])

pl_module = LightningModuleDefault.load_from_checkpoint(
checkpoint_file, model_dir=paths["model_dir"], model=model, map_location="cpu", strict=True
)
Expand Down
49 changes: 48 additions & 1 deletion fuse/dl/lightning/pl_funcs.py
Expand Up @@ -20,7 +20,7 @@
Collection of useful functions to implement FuseMedML pytorch lightning based module and train loop
"""
import traceback
from typing import Any, Dict, List, OrderedDict, Sequence, Union
from typing import Any, Dict, List, OrderedDict, Sequence, Union, Mapping, TypeVar
from statistics import mean
from fuse.data.utils.sample import get_sample_id_key
from fuse.utils.data.collate import uncollate
Expand All @@ -36,6 +36,53 @@
from fuse.eval import MetricBase
from fuse.eval.metrics.utils import PerSampleData

# for clearml
from clearml import Task

TaskInstance = TypeVar("TaskInstance", bound="Task")


def start_clearml_logger(
project_name: Union[str, None],
task_name: Union[str, None],
tags: Union[Sequence[str], None] = None,
reuse_last_task_id: Union[bool, str] = True,
continue_last_task: Union[bool, str, int] = False,
output_uri: Union[str, bool, None] = None,
auto_connect_arg_parser: Union[bool, Mapping[str, bool]] = True,
auto_connect_frameworks: Union[bool, Mapping[str, bool]] = True,
auto_resource_monitoring: bool = True,
auto_connect_streams: Union[bool, Mapping[str, bool]] = True,
deferred_init: bool = False,
) -> TaskInstance:
"""
Just a fuse function to quickly start the clearml logger. It sets up patches to pytorch lightning logging hooks so it doesnt need to be passed to any lightning logger.
For information on all the arguments please see: https://clear.ml/docs/latest/docs/references/sdk/task/ or https://github.com/allegroai/clearml/blob/master/clearml/task.py
General Clearml instructions:
Unless using offline mode, to use clearml, you must first make an account on their website https://app.clear.ml/login?redirect=%2Fsettings%2Fworkspace-configuration.
Then, you must create a ~/clearml.conf file and specify server address as shown here https://clear.ml/docs/latest/docs/configs/clearml_conf/.
Otherwise, offline mode instructions can be found here: https://clear.ml/docs/latest/docs/guides/set_offline/
Example usage:
from dl.lightning.pl_funcs import start_clearml_logger
start_clearml_logger(project_name="my_project_name", task_name="test_01")
"""
task = Task.init(
project_name=project_name,
task_name=task_name,
tags=tags,
reuse_last_task_id=reuse_last_task_id,
continue_last_task=continue_last_task,
output_uri=output_uri,
auto_connect_arg_parser=auto_connect_arg_parser,
auto_connect_frameworks=auto_connect_frameworks,
auto_resource_monitoring=auto_resource_monitoring,
auto_connect_streams=auto_connect_streams,
deferred_init=deferred_init,
)
return task


def model_checkpoint_callbacks(model_dir: str, best_epoch_source: Union[Dict, List[Dict]]) -> List[pl.Callback]:
"""
Expand Down
3 changes: 2 additions & 1 deletion fuse/requirements.txt
Expand Up @@ -28,4 +28,5 @@ hydra-core
omegaconf
nibabel
vit-pytorch
lifelines
lifelines
clearml

0 comments on commit 8b12ad8

Please sign in to comment.