# Predict cell cycle phase

In [1]:
import os
import torch
import numpy as np

from cnn_framework.utils.data_managers.default_data_manager import DefaultDataManager
from cnn_framework.utils.data_loader_generators.data_loader_generator import (
    DataLoaderGenerator,
)
from cnn_framework.utils.model_managers.cnn_model_manager import CnnModelManager
from cnn_framework.utils.metrics.abstract_metric import AbstractMetric

from cell_cycle_classification.utils.data_set import FucciClassificationDataSet
from cell_cycle_classification.backbone.fucci_classifier import FucciClassifier

from cell_cycle_classification.utils.model_params import FucciVAEModelParams

  check_for_updates()


### Define parameters

In [2]:
params = FucciVAEModelParams()
params.update()
params.load_classification_model()  # load trained model from HuggingFace

Model time id: 20250917-155627-local
epochs 10 | batch 32 | lr 0.0001 | weight decay 0.05 | dropout 0.0 | c [0] | z [0, 1, 2, 3, 4] | data set size None | latent dim 256 | beta 0.01 | gamma 100.0 | delta 10000.0 | depth 5 | kld loss standard | encoder name resnet18 | latent dim 256 | beta 0.01 | gamma 100.0 | delta 10000.0 | C 50 | depth 5 | kld loss standard | encoder name resnet18


In [3]:
# Set path to data directory
params.data_dir = os.path.join(os.path.abspath(''), "data")
params.test_ratio = 1.0  # use all data for testing
params.data_set_size = 280  # maximum nucleus diameter 

### Load data

In [4]:
params.check_ready()
loader_generator = DataLoaderGenerator(
    params, FucciClassificationDataSet, DefaultDataManager
)
_, _, test_dl = loader_generator.generate_data_loader()

File names correctly loaded.
Splitting file names ...
### Data source ###
No data is loaded for train
No data is loaded for val
test data is loaded from c:\Users\thoma\cell_cycle_classification\notebooks\data - 100% elements
###################
train has 0 images.
val has 0 images.
test has 1 images.
###################


### Load pretrained model

In [5]:
model = FucciClassifier(params)
model.load_state_dict(torch.load(params.model_load_path))

<All keys matched successfully>

### Predict cell cycle phase

In [6]:
predictions = CnnModelManager(model, params, AbstractMetric).predict(test_dl);

Current commit hash: 7fcf4a97da8a68c06036377aef8b827806ddec9a


  A.PadIfNeeded(
  A.Resize(


Model evaluation in progress: 100.0% | Batch #0                                                    

In [7]:
cycle_phases = ['G1', 'S', 'G2/M']
for i, pred in enumerate(predictions):
    print(f"Image {i}: Predicted cycle phase: {cycle_phases[np.argmax(pred)]}")

Image 0: Predicted cycle phase: G1
