Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

UNETR - Make skip connections optional #9

Merged
merged 7 commits into from
Dec 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions experiments/vision-transformer/unetr/livecell/train_by_parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ def prune_prefix(checkpoint_path):
return updated_model_state


def get_custom_unetr_model(device, model_name, sam_initialization, output_channels, checkpoint_path, freeze_encoder):
def get_custom_unetr_model(
device, model_name, sam_initialization, output_channels, checkpoint_path, freeze_encoder, joint_training
):
if checkpoint_path is not None:
if checkpoint_path.endswith("pt"): # for finetuned models
model_state = prune_prefix(checkpoint_path)
Expand All @@ -37,8 +39,10 @@ def get_custom_unetr_model(device, model_name, sam_initialization, output_channe
out_channels=output_channels,
use_sam_stats=sam_initialization,
final_activation="Sigmoid",
encoder_checkpoint=model_state
encoder_checkpoint=model_state,
use_skip_connection=not joint_training # if joint_training, no skip con. else, use skip con. by default
)

model.to(device)

# if expected, let's freeze the image encoder
Expand Down Expand Up @@ -66,7 +70,7 @@ def main(args):
# get the custom model for the training and inference on livecell dataset
model = get_custom_unetr_model(
device, args.model_name, sam_initialization=args.do_sam_ini, output_channels=3,
checkpoint_path=args.checkpoint, freeze_encoder=args.freeze_encoder
checkpoint_path=args.checkpoint, freeze_encoder=args.freeze_encoder, joint_training=args.joint_training
)

# determining where to save the checkpoints and tensorboard logs
Expand Down Expand Up @@ -123,5 +127,8 @@ def main(args):
parser.add_argument(
"--freeze_encoder", action="store_true", help="Experiments to freeze the encoder."
)
parser.add_argument(
"--joint_training", action="store_true", help="Uses VNETR for training"
)
args = parser.parse_args()
main(args)
41 changes: 28 additions & 13 deletions torch_em/model/unetr.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,13 @@ def __init__(
use_mae_stats: bool = False,
encoder_checkpoint: Optional[Union[str, OrderedDict]] = None,
final_activation: Optional[Union[str, nn.Module]] = None,
use_skip_connection: bool = True
) -> None:
super().__init__()

self.use_sam_stats = use_sam_stats
self.use_mae_stats = use_mae_stats
self.use_skip_connection = use_skip_connection

print(f"Using {encoder} from {backbone.upper()}")
self.encoder = get_vision_transformer(img_size=img_size, backbone=backbone, model=encoder)
Expand Down Expand Up @@ -107,10 +109,12 @@ def __init__(
self.deconv1 = Deconv2DBlock(self.encoder.embed_dim, features_decoder[0])
self.deconv2 = Deconv2DBlock(features_decoder[0], features_decoder[1])
self.deconv3 = Deconv2DBlock(features_decoder[1], features_decoder[2])
self.deconv4 = Deconv2DBlock(features_decoder[2], features_decoder[3])

self.deconv4 = SingleDeconv2DBlock(features_decoder[-1], features_decoder[-1])
self.deconv_out = SingleDeconv2DBlock(features_decoder[-1], features_decoder[-1])

self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1])

self.decoder_head = ConvBlock2d(2*features_decoder[-1], features_decoder[-1])
self.final_activation = self._get_activation(final_activation)

def _get_activation(self, activation):
Expand Down Expand Up @@ -167,26 +171,37 @@ def forward(self, x):
# backbone used for reshaping inputs to the desired "encoder" shape
x = torch.stack([self.preprocess(e) for e in x], dim=0)

z0 = self.z_inputs(x)
use_skip_connection = getattr(self, "use_skip_connection", True)

z12, from_encoder = self.encoder(x)
x = self.base(z12)

from_encoder = from_encoder[::-1]
z9 = self.deconv1(from_encoder[0])
if use_skip_connection:
# TODO: we share the weights in the deconv(s), and should preferably avoid doing that
from_encoder = from_encoder[::-1]
z9 = self.deconv1(from_encoder[0])

z6 = self.deconv1(from_encoder[1])
z6 = self.deconv2(z6)

z6 = self.deconv1(from_encoder[1])
z6 = self.deconv2(z6)
z3 = self.deconv1(from_encoder[2])
z3 = self.deconv2(z3)
z3 = self.deconv3(z3)
anwai98 marked this conversation as resolved.
Show resolved Hide resolved

z3 = self.deconv1(from_encoder[2])
z3 = self.deconv2(z3)
z3 = self.deconv3(z3)
z0 = self.z_inputs(x)

else:
z9 = self.deconv1(z12)
z6 = self.deconv2(z9)
z3 = self.deconv3(z6)
z0 = self.deconv4(z3)

updated_from_encoder = [z9, z6, z3]

x = self.base(z12)
x = self.decoder(x, encoder_inputs=updated_from_encoder)
x = self.deconv4(x)
x = torch.cat([x, z0], dim=1)
x = self.deconv_out(x)

x = torch.cat([x, z0], dim=1)
x = self.decoder_head(x)

x = self.out_conv(x)
Expand Down