Skip to content

Commit

Permalink
use dance logger
Browse files Browse the repository at this point in the history
  • Loading branch information
RemyLau committed Mar 8, 2023
1 parent 665ac11 commit 6ce5f0a
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 22 deletions.
22 changes: 11 additions & 11 deletions dance/datasets/multimodality.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import logging
import os
import pickle

Expand All @@ -7,6 +6,7 @@
import scanpy as sc
import torch

from dance import logger
from dance.transforms.preprocess import lsiTransformer
from dance.utils.download import download_file, unzip_file

Expand Down Expand Up @@ -183,7 +183,7 @@ def __init__(self, subtask, data_dir="./data"):

def preprocess(self, kind='feature_selection', selection_threshold=10000):
if kind == 'pca':
logging.info('Preprocessing method not supported.')
logger.info('Preprocessing method not supported.')
return self
elif kind == 'feature_selection':
if self.modalities[0].shape[1] > selection_threshold:
Expand All @@ -193,9 +193,9 @@ def preprocess(self, kind='feature_selection', selection_threshold=10000):
for i in [0, 2]:
self.modalities[i] = self.modalities[i][:, self.modalities[i].var['highly_variable']]
else:
logging.info('Preprocessing method not supported.')
logger.info('Preprocessing method not supported.')
return self
logging.info('Preprocessing done.')
logger.info('Preprocessing done.')
return self


Expand Down Expand Up @@ -294,9 +294,9 @@ def preprocess(self, kind='pca', pkl_path=None, selection_threshold=10000):
self.modalities[i] = self.modalities[i][:, self.modalities[i].var['highly_variable']]
self.modalities[i + 2] = self.modalities[i + 2][:, self.modalities[i + 2].var['highly_variable']]
else:
logging.info('Preprocessing method not supported.')
logger.info('Preprocessing method not supported.')
return self
logging.info('Preprocessing done.')
logger.info('Preprocessing done.')
self.preprocessed = True
return self

Expand Down Expand Up @@ -373,7 +373,7 @@ def preprocess(self, kind='aux', pretrained_folder='.', selection_threshold=1000
# cell types, batch labels, cell cycle
self.nb_cell_types, self.nb_batches, self.nb_phases = pickle.load(f)
self.preprocessed = True
logging.info('Preprocessing done.')
logger.info('Preprocessing done.')
return self

##########################################
Expand Down Expand Up @@ -466,8 +466,8 @@ def preprocess(self, kind='aux', pretrained_folder='.', selection_threshold=1000
'AURKA', 'PSRC1', 'ANLN', 'LBR', 'CKAP5', 'CENPE', 'CTCF', \
'NEK2', 'G2E3', 'GAS2L3', 'CBX5', 'CENPA']

logging.info('Data loading and pca done', mod1_pca.shape, mod2_pca.shape)
logging.info('Start to calculate cell_cycle score. It may roughly take an hour.')
logger.info('Data loading and pca done', mod1_pca.shape, mod2_pca.shape)
logger.info('Start to calculate cell_cycle score. It may roughly take an hour.')

cell_type_labels = self.test_sol.obs['cell_type'].to_numpy() #mod1_obs['cell_type']
batch_ids = mod1_obs['batch']
Expand Down Expand Up @@ -523,10 +523,10 @@ def preprocess(self, kind='aux', pretrained_folder='.', selection_threshold=1000
n_top_genes=selection_threshold)
self.modalities[i] = self.modalities[i][:, self.modalities[i].var['highly_variable']]
else:
logging.info('Preprocessing method not supported.')
logger.info('Preprocessing method not supported.')
return self
self.preprocessed = True
logging.info('Preprocessing done.')
logger.info('Preprocessing done.')
return self

def get_preprocessed_data(self):
Expand Down
4 changes: 2 additions & 2 deletions dance/modules/multi_modality/predict_modality/babel.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
multiomic profiles at single-cell resolution." Proceedings of the National Academy of Sciences 118, no. 15 (2021).
"""
import logging
import math
from typing import Callable, List, Tuple, Union

Expand All @@ -18,6 +17,7 @@
from torch.utils.data import DataLoader

import dance.utils.loss as loss_functions
from dance import logger

REDUCE_LR_ON_PLATEAU_PARAMS = {
"mode": "min",
Expand Down Expand Up @@ -317,7 +317,7 @@ def __init__(
self.final_activations["act1"] = final_activations
else:
raise ValueError(f"Unrecognized type for final_activation: {type(final_activations)}")
logging.info(f"ChromDecoder with {len(self.final_activations)} output activations")
logger.info(f"ChromDecoder with {len(self.final_activations)} output activations")

self.final_decoders = nn.ModuleList() # List[List[Module]]
for n in self.num_outputs:
Expand Down
9 changes: 4 additions & 5 deletions dance/transforms/graph/scmogcn_graph.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import logging
import os
import pickle
from collections import defaultdict
Expand All @@ -9,8 +8,9 @@
import torch
from sklearn.decomposition import TruncatedSVD

from dance import logger
from dance.data.base import Data
from dance.typing import Optional, Tuple, Union
from dance.typing import Union

from ..base import BaseTransform

Expand Down Expand Up @@ -76,9 +76,8 @@ def create_pathway_graph(gex_features: scipy.sparse.spmatrix, gene_names: Union[

pk_path = f'pw_{subtask}_{pathway_weight}.pkl'
if os.path.exists(pk_path):
logging.warning(
'Pathway file exist. Load pickle file by default. Auguments "--pathway_weight" and "--pathway_path" will not take effect.'
)
logger.warning("Pathway file exist. Load pickle file by default. "
"Auguments '--pathway_weight' and '--pathway_path' will not take effect.")
uu, vv, ee = pickle.load(open(pk_path, 'rb'))
else:
# Load Original Pathway File
Expand Down
6 changes: 2 additions & 4 deletions examples/multi_modality/predict_modality/babel.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,13 @@
import mudata
import torch

from dance import logger
from dance.data import Data
from dance.datasets.multimodality import ModalityPredictionDataset
from dance.modules.multi_modality.predict_modality.babel import BabelWrapper
from dance.utils import set_seed

if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)

OPTIMIZER_DICT = {
"adam": torch.optim.Adam,
"rmsprop": torch.optim.RMSprop,
Expand Down Expand Up @@ -52,13 +51,12 @@
os.makedirs(os.path.dirname(args.outdir))

# Specify output log file
logger = logging.getLogger()
fh = logging.FileHandler(f"{args.outdir}/training_{args.subtask}_{args.rnd_seed}.log", "w")
fh.setLevel(logging.INFO)
logger.addHandler(fh)

for arg in vars(args):
logging.info(f"Parameter {arg}: {getattr(args, arg)}")
logger.info(f"Parameter {arg}: {getattr(args, arg)}")

# Construct data object
mod1 = anndata.concat((dataset.modalities[0], dataset.modalities[2]))
Expand Down

0 comments on commit 6ce5f0a

Please sign in to comment.