Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/membrain_pick/clustering/mean_shift_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def mean_shift_for_scores(
)
else:
raise ValueError("Unknown method for mean shift clustering.")
print("Found", out_pos.shape[0], "clusters.")
return out_pos, out_p_num


Expand Down
2 changes: 1 addition & 1 deletion src/membrain_pick/dataloading/mesh_partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,7 @@ def load_from_cache(cur_cache_path: str) -> Optional[Dict[str, np.ndarray]]:
The loaded partitioning data if successful, None otherwise.
"""
if os.path.isfile(cur_cache_path):
print(f"Loading partitioning data from {cur_cache_path}")
# print(f"Loading partitioning data from {cur_cache_path}")
return np.load(cur_cache_path)
else:
return None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,6 @@ def mesh_for_tomo_mb_folder(
mb_files = [
os.path.join(mb_folder, f) for f in os.listdir(mb_folder) if f.endswith(".mrc")
]
print(mb_files)

if tomo is None:
tomo = load_tomogram(tomo_file)
Expand Down
1 change: 0 additions & 1 deletion src/membrain_pick/napari_utils/surforama_cli_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,6 @@ def normalize_surface_values(surface_values, value_range=None):
np.percentile(surface_values, cutoff_pct * 100),
np.percentile(surface_values, (1 - cutoff_pct) * 100),
)
print("Normalized value range: ", value_range)
normalized_values = (surface_values - value_range[0]) / (
value_range[1] - value_range[0] + np.finfo(float).eps
)
Expand Down
4 changes: 2 additions & 2 deletions src/membrain_pick/networks/diffusion_net/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,7 @@ def get_operators(verts, faces, k_eig=128, op_cache_dir=None, normals=None, over
# If we're overwriting, or there aren't enough eigenvalues, just delete it; we'll create a new
# entry below more eigenvalues
if overwrite_cache:
print(" overwriting cache by request")
# print(" overwriting cache by request")
os.remove(search_path)
break

Expand Down Expand Up @@ -516,7 +516,7 @@ def read_sp_mat(prefix):
break

except FileNotFoundError:
print(" cache miss -- constructing operators")
# print(" cache miss -- constructing operators")
break

except Exception as E:
Expand Down
22 changes: 21 additions & 1 deletion src/membrain_pick/optimization/diffusion_training_pylit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pytorch_lightning as pl
import torch
from torch.optim import Adam, SGD
import matplotlib.pyplot as plt

from membrain_pick.networks.diffusion_net import DiffusionNet
from membrain_pick.clustering.mean_shift_utils import MeanShiftForwarder
Expand Down Expand Up @@ -30,11 +31,17 @@ def __init__(self,
fixed_time=None,
one_D_conv_first=False,
clamp_diffusion=False,
out_plot_file=None,
visualize_diffusion=False,
visualize_grad_rotations=False,
visualize_grad_features=False):
super().__init__()
self.max_epochs = max_epochs
self.epoch_losses = {
"train": [],
"val": []
}
self.out_plot_file = out_plot_file
# Initialize the DiffusionNet with the given arguments.
self.model = DiffusionNet(C_in=C_in,
C_out=C_out,
Expand Down Expand Up @@ -132,7 +139,7 @@ def training_step(self, batch, batch_idx):
# Log training loss
self.total_train_loss += loss.detach()
self.train_batches += 1
print(f"Training loss: {loss}")
# print(f"Training loss: {loss}")
return loss

def validation_step(self, batch, batch_idx):
Expand All @@ -156,12 +163,25 @@ def validation_step(self, batch, batch_idx):
def on_train_epoch_end(self):
# Log the average training loss
avg_train_loss = self.total_train_loss / self.train_batches
print("Train epoch loss: ", avg_train_loss)
self.epoch_losses["train"].append(avg_train_loss)
self.log('train_loss', avg_train_loss)

def on_validation_epoch_end(self):
# Log the average validation loss
avg_val_loss = self.total_val_loss / self.val_batches
print("Validation epoch loss: ", avg_val_loss)
self.epoch_losses["val"].append(avg_val_loss)
self.log('val_loss', avg_val_loss)
self.plot_losses()


def plot_losses(self):
plt.figure()
plt.plot(self.epoch_losses["train"], label="Train loss")
plt.plot(self.epoch_losses["val"], label="Validation loss")
plt.legend()
plt.savefig(self.out_plot_file)


def unpack_batch(batch):
Expand Down
4 changes: 3 additions & 1 deletion src/membrain_pick/orientation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import os
import numpy as np
import scipy.spatial as spatial
Expand All @@ -14,6 +15,7 @@
from membrain_seg.segmentation.dataloading.data_utils import load_tomogram



def orientation_from_mesh(coordinates, mesh):
"""
Get the orientation of a point cloud from a mesh.
Expand All @@ -39,7 +41,7 @@ def orientation_from_mesh(coordinates, mesh):
distances, vertex_indices = tree.query(coordinates)

if np.any(distances > 200):
print(
logging.warning(
"Warning: Some points are more than 200 units away from the mesh. This might be an error. Check rescaling factors."
)

Expand Down
3 changes: 3 additions & 0 deletions src/membrain_pick/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ def train(
val_path = os.path.join(data_dir, "val")
cache_dir_mb = os.path.join(training_dir, "mesh_cache")
log_dir = os.path.join(training_dir, "logs")
out_plot_file = os.path.join(training_dir, "plots", f"training_curves_{project_name}_{sub_name}.png")
os.makedirs(os.path.join(training_dir, "plots"), exist_ok=True)

# Create the data module
data_module = MemSegDiffusionNetDataModule(
Expand Down Expand Up @@ -85,6 +87,7 @@ def train(
device=device,
one_D_conv_first=one_D_conv_first,
max_epochs=max_epochs,
out_plot_file=out_plot_file,
)

checkpointing_name = project_name + "_" + sub_name
Expand Down