Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add vit isic #221

Merged
merged 14 commits into from Nov 29, 2022
94 changes: 84 additions & 10 deletions examples/fuse_examples/imaging/classification/isic/isic_runner.py
Expand Up @@ -52,6 +52,9 @@
from fuseimg.datasets.isic import ISIC, ISICDataModule
from fuse_examples.imaging.classification.isic.golden_members import FULL_GOLDEN_MEMBERS

import torch.nn as nn
from fuse.dl.models.model_wrapper import ModelWrapSeqToDict
from fuse.dl.models.backbones.backbone_vit import ViT

###########################################################################################################
# Fuse
Expand All @@ -71,9 +74,13 @@
##########################################
# Modality
##########################################

multimodality = True # Set: 'False' to use only imaging, 'True' to use imaging & meta-data

##########################################
# Model Type
##########################################
model_type = "Transformer" # Set: 'Transformer' to use ViT/MMViT, 'CNN' to use InceptionResNet

##########################################
# Output Paths
##########################################
Expand Down Expand Up @@ -126,15 +133,75 @@
# ===============
# Model
# ===============
TRAIN_COMMON_PARAMS["model"] = dict(
dropout_rate=0.5,
layers_description=(256,),
tabular_data_inputs=[("data.input.clinical.all", 19)] if multimodality else None,
tabular_layers_description=(128,) if multimodality else tuple(),
)
if model_type == "CNN":
TRAIN_COMMON_PARAMS["model"] = dict(
dropout_rate=0.5,
layers_description=(256,),
tabular_data_inputs=[("data.input.clinical.all", 19)] if multimodality else None,
tabular_layers_description=(128,) if multimodality else tuple(),
)
elif model_type == "Transformer":
token_dim = 768
TRAIN_COMMON_PARAMS["model"] = dict(
token_dim=token_dim,
projection_kwargs=dict(image_shape=[300, 300], patch_shape=[30, 30], channels=3),
transformer_kwargs=dict(depth=12, heads=12, mlp_dim=token_dim * 4, dim_head=64, dropout=0.0, emb_dropout=0.0),
)


def create_model(
def perform_softmax(logits: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
cls_preds = F.softmax(logits, dim=1)
return logits, cls_preds


class MMViT(ViT):
def __init__(self, token_dim: int, projection_kwargs: dict, transformer_kwargs: dict, multimodality: bool):
super().__init__(token_dim, projection_kwargs, transformer_kwargs)
self.multimodality = multimodality
num_tokens = self.projection_layer.num_tokens
self.token_dim = token_dim
self._head = nn.Linear(token_dim, 8)
if self.multimodality:
# change pos embedding to accept additional token for multimodal
self.transformer.pos_embedding = nn.Parameter(torch.randn(1, num_tokens + 2, token_dim))

# This forward can be Multimodal or just Imaging
def forward(self, img_x: torch.Tensor, clinical_x: torch.Tensor = None) -> torch.Tensor:
img_x = self.projection_layer(img_x)
if self.multimodality:
clinical_x = clinical_x.unsqueeze(1)
clinical_x_zeros = torch.zeros((img_x.shape[0], 1, self.token_dim))
clinical_x_zeros[:, :, :19] = clinical_x
clinical_x = clinical_x_zeros.cuda()
x = torch.cat((img_x, clinical_x), 1)
else:
x = img_x
x = self.transformer(x)
x = self._head(x[:, 0])
return x


def create_transformer_model(
token_dim: int,
projection_kwargs: dict,
transformer_kwargs: dict,
) -> ModelWrapSeqToDict:
torch_model = MMViT(
token_dim=token_dim,
projection_kwargs=projection_kwargs,
transformer_kwargs=transformer_kwargs,
multimodality=multimodality,
)
model = ModelWrapSeqToDict(
model=torch_model,
model_inputs=["data.input.img", "data.input.clinical.all"] if multimodality else ["data.input.img"],
post_forward_processing_function=perform_softmax,
model_outputs=["model.logits.head_0", "model.output.head_0"],
)
return model


def create_cnn_model(
dropout_rate: float,
layers_description: Sequence[int],
tabular_data_inputs: Sequence[Tuple[str, int]],
Expand Down Expand Up @@ -216,7 +283,10 @@ def run_train(paths: dict, train_common_params: dict) -> None:
# ==============================================================================
lgr.info("Model:", {"attrs": "bold"})

model = create_model(**train_common_params["model"])
if model_type == "Transformer":
model = create_transformer_model(**train_common_params["model"])
elif model_type == "CNN":
model = create_cnn_model(**train_common_params["model"])

lgr.info("Model: Done", {"attrs": "bold"})

Expand Down Expand Up @@ -339,7 +409,11 @@ def run_infer(paths: dict, infer_common_params: dict) -> None:
)

# load python lightning module
model = create_model(**infer_common_params["model"])
if model_type == "Transformer":
model = create_transformer_model(**infer_common_params["model"])
elif model_type == "CNN":
model = create_cnn_model(**infer_common_params["model"])

pl_module = LightningModuleDefault.load_from_checkpoint(
checkpoint_file, model_dir=paths["model_dir"], model=model, map_location="cpu", strict=True
)
Expand Down
49 changes: 48 additions & 1 deletion fuse/dl/lightning/pl_funcs.py
Expand Up @@ -20,7 +20,7 @@
Collection of useful functions to implement FuseMedML pytorch lightning based module and train loop
"""
import traceback
from typing import Any, Dict, List, OrderedDict, Sequence, Union
from typing import Any, Dict, List, OrderedDict, Sequence, Union, Mapping, TypeVar
from statistics import mean
from fuse.data.utils.sample import get_sample_id_key
from fuse.utils.data.collate import uncollate
Expand All @@ -36,6 +36,53 @@
from fuse.eval import MetricBase
from fuse.eval.metrics.utils import PerSampleData

# for clearml
from clearml import Task
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add clearml to dependency list

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added to dependency in 31a8f07


TaskInstance = TypeVar("TaskInstance", bound="Task")


def start_clearml_logger(
project_name: Union[str, None],
task_name: Union[str, None],
tags: Union[Sequence[str], None] = None,
reuse_last_task_id: Union[bool, str] = True,
continue_last_task: Union[bool, str, int] = False,
output_uri: Union[str, bool, None] = None,
auto_connect_arg_parser: Union[bool, Mapping[str, bool]] = True,
auto_connect_frameworks: Union[bool, Mapping[str, bool]] = True,
auto_resource_monitoring: bool = True,
auto_connect_streams: Union[bool, Mapping[str, bool]] = True,
deferred_init: bool = False,
) -> TaskInstance:
"""
Just a fuse function to quickly start the clearml logger. It sets up patches to pytorch lightning logging hooks so it doesnt need to be passed to any lightning logger.
For information on all the arguments please see: https://clear.ml/docs/latest/docs/references/sdk/task/ or https://github.com/allegroai/clearml/blob/master/clearml/task.py

General Clearml instructions:
Unless using offline mode, to use clearml, you must first make an account on their website https://app.clear.ml/login?redirect=%2Fsettings%2Fworkspace-configuration.
Then, you must create a ~/clearml.conf file and specify server address as shown here https://clear.ml/docs/latest/docs/configs/clearml_conf/.
Otherwise, offline mode instructions can be found here: https://clear.ml/docs/latest/docs/guides/set_offline/

Example usage:
from dl.lightning.pl_funcs import start_clearml_logger
start_clearml_logger(project_name="my_project_name", task_name="test_01")
"""
task = Task.init(
project_name=project_name,
task_name=task_name,
tags=tags,
reuse_last_task_id=reuse_last_task_id,
continue_last_task=continue_last_task,
output_uri=output_uri,
auto_connect_arg_parser=auto_connect_arg_parser,
auto_connect_frameworks=auto_connect_frameworks,
auto_resource_monitoring=auto_resource_monitoring,
auto_connect_streams=auto_connect_streams,
deferred_init=deferred_init,
)
return task


def model_checkpoint_callbacks(model_dir: str, best_epoch_source: Union[Dict, List[Dict]]) -> List[pl.Callback]:
"""
Expand Down
3 changes: 2 additions & 1 deletion fuse/requirements.txt
Expand Up @@ -28,4 +28,5 @@ hydra-core
omegaconf
nibabel
vit-pytorch
lifelines
lifelines
clearml