Skip to content

Commit

Permalink
Merge branch 'master' of github.com:josegcpa/adell-mri
Browse files Browse the repository at this point in the history
  • Loading branch information
josegcpa committed Apr 29, 2024
2 parents 5742e7f + 216a83d commit f9be14c
Show file tree
Hide file tree
Showing 6 changed files with 181 additions and 51 deletions.
1 change: 1 addition & 0 deletions adell_mri/entrypoints/segmentation/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,7 @@ def train_loader_call(batch_size):
batch_size=network_config["batch_size"],
shuffle=False,
sampler=val_sampler,
drop_last=args.semi_supervised,
num_workers=nw,
collate_fn=collate_fn_train,
)
Expand Down
88 changes: 60 additions & 28 deletions adell_mri/modules/semi_supervised_segmentation/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,19 @@ def derangement(
return xs


def anchors_from_derangement(
X: torch.Tensor, rng: np.random.Generator = None
) -> torch.Tensor:
"""Generate anchors from a derangement of the elements of X."""
if rng is None:
rng = np.random.default_rng()
anchors = []
for idx in derangement(X.shape[0], rng=rng):
anchors.append(X[idx])
anchors = torch.stack(anchors)
return anchors


class AnatomicalContrastiveLoss(torch.nn.Module):
"""
Implementation of the anatomical loss method suggested in "Bootstrapping
Expand Down Expand Up @@ -292,32 +305,6 @@ def __init__(self, temperature: float = 0.1, seed: int = 42):
self.rng = np.random.default_rng(seed)
self.eps = torch.as_tensor(1e-8)

def forward_with_anchors(
self,
X_1: torch.Tensor,
X_2: torch.Tensor,
anchors: torch.Tensor = None,
) -> torch.Tensor:
if anchors is None:
# if no anchors are provided, the anchors are a random derangement
# of the input
anchors = []
for idx in derangement(X_1.shape[0], rng=self.rng):
anchors.append(
X_1[idx] if self.rng.random() < 0.5 else X_2[idx]
)
anchors = torch.stack(anchors)
X_1 = X_1.flatten(start_dim=2)
X_2 = X_2.flatten(start_dim=2)
anchors = anchors.flatten(start_dim=2)
sim_1 = F.cosine_similarity(X_1, anchors, dim=1) / self.temperature
sim_2 = F.cosine_similarity(X_2, anchors, dim=1) / self.temperature
return F.kl_div(
F.softmax(sim_1, dim=1),
F.softmax(sim_2, dim=1),
reduction="none",
)

def forward(
self,
X_1: torch.Tensor,
Expand All @@ -326,12 +313,57 @@ def forward(
) -> torch.Tensor:
# based on LoCo [1]
# [1] https://proceedings.neurips.cc/paper/2020/file/7fa215c9efebb3811a7ef58409907899-Paper.pdf
if anchors is not None:
return self.forward_with_anchors(X_1, X_2, anchors)
X_1 = X_1.flatten(start_dim=2)[None, :, :, :]
X_2 = X_2.flatten(start_dim=2)[:, None, :, :]
sim = F.cosine_similarity(X_1, X_2, dim=2) / self.temperature
loss = -torch.log(
torch.max(F.softmax(sim, dim=1).diagonal().permute(1, 0), self.eps)
).mean(-1)
return loss


class LocalContrastiveLossWithAnchors(torch.nn.Module):
"""
Implements a local contrastive loss function.
"""

def __init__(self, temperature: float = 0.1, seed: int = 42):
super().__init__()
self.temperature = temperature
self.seed = seed
self.rng = np.random.default_rng(seed)
self.eps = torch.as_tensor(1e-8)

def anchors_from_derangement(self, X: torch.Tensor) -> torch.Tensor:
anchors = []
for idx in derangement(X.shape[0], rng=self.rng):
anchors.append(X[idx])
anchors = torch.stack(anchors)
return anchors

def forward(
self,
X: torch.Tensor,
anchors_1: torch.Tensor,
anchors_2: torch.Tensor = None,
) -> torch.Tensor:
anchors_1 = (
anchors_from_derangement(X, self.rng)
if anchors_1 is None
else anchors_1
)
anchors_2 = (
anchors_from_derangement(X, self.rng)
if anchors_2 is None
else anchors_2
)
X = X.flatten(start_dim=2)
anchors_1 = anchors_1.flatten(start_dim=2)
anchors_2 = anchors_2.flatten(start_dim=2)
sim_1 = F.cosine_similarity(X, anchors_1, dim=1) / self.temperature
sim_2 = F.cosine_similarity(X, anchors_2, dim=1) / self.temperature
return F.kl_div(
F.softmax(sim_1, dim=1),
F.softmax(sim_2, dim=1),
reduction="none",
).sum(-1)
105 changes: 87 additions & 18 deletions adell_mri/modules/semi_supervised_segmentation/pl.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
import torch

from ..segmentation.pl import UNetBasePL, update_metrics
from .unet import UNetSemiSL

Expand Down Expand Up @@ -124,7 +124,7 @@ def __init__(
self.all_true = []

self.bn_mult = 0.1
self.ssl_weight = 0.1
self.ssl_weight = 0.01

if (
self.semi_sl_image_key_1 is not None
Expand Down Expand Up @@ -165,41 +165,110 @@ def forward_features_ema_stop_grad(self, **kwargs):
else:
return op(**kwargs)

def calculate_loss_semi_sl(
self, output_1: torch.Tensor, output_2: torch.Tensor
def coherce_batch_size(self, *tensors):
batch_sizes = [
x.shape[0] if x is not None else np.inf for x in tensors
]
min_batch_size = min(batch_sizes)
tensors = [
x[:min_batch_size] if x is not None else None
for x, bs in zip(tensors, batch_sizes)
]
return tensors

def step_semi_sl_loco(
self,
x_1: torch.Tensor,
x_2: torch.Tensor,
x_cond: torch.Tensor,
x_fc: torch.Tensor,
*args,
**kwargs,
):
loss = self.loss_fn_semi_sl(output_1, output_2)
return loss.mean() * self.ssl_weight
features_1 = self.forward_features(
X=x_1,
X_skip_layer=x_cond,
X_feature_conditioning=x_fc,
)
features_2 = self.forward_features_ema_stop_grad(
X=x_2,
X_skip_layer=x_cond,
X_feature_conditioning=x_fc,
apply_linear_transformation=True,
)

return (
self.loss_fn_semi_sl(
features_1, features_2, *args, **kwargs
).mean()
* self.ssl_weight
)

def loss_wrapper_semi_sl(
def step_semi_sl_anchors(
self,
x: torch.Tensor,
x_1: torch.Tensor,
x_2: torch.Tensor,
x_cond: torch.Tensor,
x_fc: torch.Tensor,
*args,
**kwargs,
):
output_1 = self.forward_features(
X=x_1, X_skip_layer=x_cond, X_feature_conditioning=x_fc
x, x_1, x_2, x_cond, x_fc = self.coherce_batch_size(
x, x_1, x_2, x_cond, x_fc
)
with torch.no_grad():
anchor_1 = (
self.forward_features(
X=x_1,
X_skip_layer=x_cond,
X_feature_conditioning=x_fc,
)
if x_1 is not None
else None
)
anchor_2 = (
self.forward_features_ema_stop_grad(
X=x_2,
X_skip_layer=x_cond,
X_feature_conditioning=x_fc,
apply_linear_transformation=True,
)
if x_2 is not None
else None
)
features = self.forward_features(
X=x, X_skip_layer=x_cond, X_feature_conditioning=x_fc
)
output_2 = self.forward_features_ema_stop_grad(
X=x_2, X_skip_layer=x_cond, X_feature_conditioning=x_fc
return (
self.loss_fn_semi_sl(
features, anchor_1, anchor_2, *args, **kwargs
).mean()
* self.ssl_weight
)
return self.calculate_loss_semi_sl(output_1, output_2)

def step_semi_sl(
self,
x: torch.Tensor,
x_1: torch.Tensor,
x_2: torch.Tensor,
x_cond: torch.Tensor,
x_fc: torch.Tensor,
*args,
**kwargs,
):
loss_a = self.loss_wrapper_semi_sl(x_1, x_2, x_cond, x_fc)
# loss_b = self.loss_wrapper_semi_sl(x_2, x_1, x_cond, x_fc)
loss = loss_a # + loss_b
return loss
if x is not None:
return self.step_semi_sl_anchors(
x, x_1, x_2, x_cond, x_fc, *args, **kwargs
)
else:
return self.step_semi_sl_loco(
x_1, x_2, x_cond, x_fc, *args, **kwargs
)

def training_step(self, batch, batch_idx):
# supervised bit
x = None
if self.label_key is not None:
x, x_cond, x_fc, y, y_class = self.unpack_batch(batch)
pred_final, pred_class, loss, class_loss = self.step(
Expand All @@ -216,7 +285,7 @@ def training_step(self, batch, batch_idx):
and self.semi_sl_image_key_2 is not None
):
x_1, x_2, x_cond, x_fc = self.unpack_batch_semi_sl(batch)
self_sl_loss = self.step_semi_sl(x_1, x_2, x_cond, x_fc)
self_sl_loss = self.step_semi_sl(x, x_1, x_2, x_cond, x_fc)
self.log(
"train_self_sl_loss",
self_sl_loss,
Expand Down Expand Up @@ -311,7 +380,7 @@ def validation_step(self, batch, batch_idx):
and self.semi_sl_image_key_2 is not None
):
x_1, x_2, x_cond, x_fc = self.unpack_batch_semi_sl(batch)
self_sl_loss = self.step_semi_sl(x_1, x_2, x_cond, x_fc)
self_sl_loss = self.step_semi_sl(x, x_1, x_2, x_cond, x_fc)
self.log(
"val_self_sl_loss",
self_sl_loss,
Expand Down
28 changes: 26 additions & 2 deletions adell_mri/modules/semi_supervised_segmentation/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,28 @@ class UNetSemiSL(UNet):
`return_features`.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

self.init_linear_transformation()

def init_linear_transformation(self, *args, **kwargs):
"""
Initialise the linear transformation layer for predictions.
"""
if self.spatial_dimensions == 2:
self.linear_transformation = torch.nn.Conv2d(
self.depth[0],
self.depth[0],
kernel_size=1,
)
elif self.spatial_dimensions == 3:
self.linear_transformation = torch.nn.Conv3d(
self.depth[0],
self.depth[0],
kernel_size=1,
)

def forward(
self,
X: torch.Tensor,
Expand Down Expand Up @@ -105,6 +127,7 @@ def forward_features(
X: torch.Tensor,
X_skip_layer: torch.Tensor = None,
X_feature_conditioning: torch.Tensor = None,
apply_linear_transformation: bool = False,
) -> torch.Tensor:
"""Forward pass for this class.
Expand All @@ -131,7 +154,6 @@ def forward_features(
encoding_out.append(curr)
curr = op_ds(curr)

deep_outputs = []
for i in range(len(self.decoding_operations)):
op = self.decoding_operations[i]
link_op = self.link_ops[i]
Expand Down Expand Up @@ -159,6 +181,8 @@ def forward_features(
curr = crop_to_size(curr, sh2)
curr = torch.concat((curr, encoded), dim=1)
curr = op(curr)
deep_outputs.append(curr)

if apply_linear_transformation is True:
curr = self.linear_transformation(curr)

return curr
4 changes: 3 additions & 1 deletion adell_mri/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@ def subsample_dataset(
data_dict = {k: data_dict[k] for k in ss}
else:
s = subsample_size * len(data_dict)
ss = rng.choice(list(data_dict.keys()))
ss = rng.choice(
list(data_dict.keys()), subsample_size, replace=False
)
data_dict = {k: data_dict[k] for k in ss}
return data_dict

Expand Down
6 changes: 4 additions & 2 deletions adell_mri/utils/network_factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@

# semi-supervised segmentation
from ..modules.semi_supervised_segmentation.pl import UNetContrastiveSemiSL
from ..modules.semi_supervised_segmentation.losses import LocalContrastiveLoss
from ..modules.semi_supervised_segmentation.losses import (
LocalContrastiveLossWithAnchors,
)
from ..utils import ExponentialMovingAverage

# self-supervised learning
Expand Down Expand Up @@ -404,7 +406,7 @@ def get_size(*size_list):
semi_sl_image_key_2="semi_sl_image_2",
deep_supervision=deep_supervision,
ema=ema,
loss_fn_semi_sl=LocalContrastiveLoss(seed=seed),
loss_fn_semi_sl=LocalContrastiveLossWithAnchors(seed=seed),
**boilerplate,
**network_config,
)
Expand Down

0 comments on commit f9be14c

Please sign in to comment.