Skip to content

Commit

Permalink
groupings assays by assay type #369 (tests WIP)
Browse files Browse the repository at this point in the history
  • Loading branch information
Zigur committed Nov 11, 2020
1 parent 1c946bf commit e2cada1
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 31 deletions.
51 changes: 31 additions & 20 deletions isatools/create/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2053,13 +2053,12 @@ def _generate_sources(self, ontology_source_references):
src_map[s_arm.name] = list(srcs)
return src_map

def _generate_samples(self, sources_map, sampling_protocol, performer, split_assays_by_sample_type):
def _generate_samples_and_assays(self, sources_map, sampling_protocol, performer):
"""
Private method to be used in 'generate_isa_study'.
:param sources_map: dict - the output of '_generate_sources'
:param sampling_protocol: isatools.model.Protocol
:param performer
:param split_assays_by_sample_type: bool
:param performer: str
:return:
"""
factors = set()
Expand All @@ -2069,6 +2068,16 @@ def _generate_samples(self, sources_map, sampling_protocol, performer, split_ass
process_sequence = []
assays = []
protocols = set()
unique_assay_types = {
assay_graph for arm in self.study_arms
for sample_assay_plan in arm.arm_map.values() if sample_assay_plan is not None
for assay_graph in sample_assay_plan.assay_plan if assay_graph is not None
}
samples_grouped_by_assay_graph = {
assay_graph: [] for assay_graph in unique_assay_types
}

# generate samples
for arm in self.study_arms:
for cell, sample_assay_plan in arm.arm_map.items():
if not sample_assay_plan:
Expand Down Expand Up @@ -2110,22 +2119,24 @@ def _generate_samples(self, sources_map, sampling_protocol, performer, split_ass
process_sequence.append(process)
for sample_node in sample_assay_plan.sample_plan:
samples.extend(sample_batches[sample_node])

for assay_graph in sample_assay_plan.assay_plan:
protocols.update({node for node in assay_graph.nodes if isinstance(node, Protocol)})
if split_assays_by_sample_type is True:
for sample_node in sorted(sample_assay_plan.sample_plan, key=lambda st: st.id):
if assay_graph in sample_assay_plan.sample_to_assay_map[sample_node]:
assays.append(
self._generate_assay(assay_graph, sample_batches[sample_node], cell.name)
)
else:
sample_batch = []
for sample_node in sample_assay_plan.sample_plan:
if assay_graph in sample_assay_plan.sample_to_assay_map[sample_node]:
sample_batch.extend(sample_batches[sample_node])
assays.append(
self._generate_assay(assay_graph, sample_batch, cell.name)
)
for sample_node in sample_assay_plan.sample_plan:
if assay_graph in sample_assay_plan.sample_to_assay_map[sample_node]:
try:
samples_grouped_by_assay_graph[assay_graph] += sample_batches[sample_node]
except AttributeError:
log.error('Assay graph is: {}'.format(assay_graph))
problematic_sample_group = samples_grouped_by_assay_graph[assay_graph]
log.error('Sample bach for assay graph is: {}'.format(
problematic_sample_group
))

# generate assays
for assay_graph in unique_assay_types:
protocols.update({node for node in assay_graph.nodes if isinstance(node, Protocol)})
assays.append(self._generate_assay(assay_graph, samples_grouped_by_assay_graph[assay_graph]))

return factors, protocols, samples, assays, process_sequence, ontology_sources

@staticmethod
Expand Down Expand Up @@ -2261,8 +2272,8 @@ def generate_isa_study(self, split_assays_by_sample_type=False):
study.sources = [source for sources in sources_map.values() for source in sources]
study.factors, protocols, study.samples, study.assays, study.process_sequence, \
study.ontology_source_references = \
self._generate_samples(
sources_map, study.protocols[0], study_config['performers'][0]['name'], split_assays_by_sample_type
self._generate_samples_and_assays(
sources_map, study.protocols[0], study_config['performers'][0]['name']
)
for protocol in protocols:
study.add_protocol(protocol)
Expand Down
2 changes: 1 addition & 1 deletion isatools/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,7 +690,7 @@ class OntologyAnnotation(Commentable):
"""

def __init__(self, term='', term_source=None, term_accession='',
comments=None, id_=str(uuid.uuid4())) :
comments=None, id_=str(uuid.uuid4())):
super().__init__(comments)

self.__term = term
Expand Down
22 changes: 12 additions & 10 deletions tests/test_create_models_study_design.py
Original file line number Diff line number Diff line change
Expand Up @@ -1749,30 +1749,32 @@ def test_generate_isa_study_single_arm_single_cell_elements(self):
self.assertEqual(len(source.characteristics), 1)
self.assertEqual(source.characteristics[0], DEFAULT_SOURCE_TYPE)

expected_num_of_samples_per_plan = reduce(lambda acc_value, sample_node: acc_value+sample_node.size,
self.nmr_sample_assay_plan.sample_plan, 0) * single_arm.group_size
expected_num_of_samples = expected_num_of_samples_per_plan * len([
expected_num_of_samples = reduce(
lambda acc_value, sample_node: acc_value + sample_node.size,
self.nmr_sample_assay_plan.sample_plan, 0
) * single_arm.group_size * len([
a_plan for a_plan in single_arm.arm_map.values() if a_plan is not None
])
print('Expected number of samples is: {0}'.format(expected_num_of_samples))
log.debug('Expected number of samples is: {0}'.format(expected_num_of_samples))
self.assertEqual(len(study.samples), expected_num_of_samples)
self.assertEqual(len(study.assays), 2)
self.assertEqual(len(study.assays), 1)
treatment_assay = next(iter(study.assays))
self.assertIsInstance(treatment_assay, Assay)
# self.assertEqual(len(treatment_assay.samples), expected_num_of_samples_per_plan)
# self.assertEqual(len(treatment_assay.samples), expected_num_of_samples)
self.assertEqual(treatment_assay.measurement_type, nmr_assay_dict['measurement_type'])
self.assertEqual(treatment_assay.technology_type, nmr_assay_dict['technology_type'])
# pdb.set_trace()
extraction_processes = [process for process in treatment_assay.process_sequence
if process.executes_protocol.name == 'extraction']
nmr_processes = [process for process in treatment_assay.process_sequence
if process.executes_protocol.name == 'nmr spectroscopy']
self.assertEqual(len(extraction_processes), expected_num_of_samples_per_plan)
self.assertEqual(len(nmr_processes), 8 * nmr_assay_dict['nmr spectroscopy']['#replicates']
* expected_num_of_samples_per_plan)
self.assertEqual(len(extraction_processes), expected_num_of_samples)
self.assertEqual(
len(nmr_processes),
8 * nmr_assay_dict['nmr spectroscopy']['#replicates'] * expected_num_of_samples)
self.assertEqual(
len(treatment_assay.process_sequence),
(8 * nmr_assay_dict['nmr spectroscopy']['#replicates'] + 1) * expected_num_of_samples_per_plan
(8 * nmr_assay_dict['nmr spectroscopy']['#replicates'] + 1) * expected_num_of_samples
)
for ix, process in enumerate(extraction_processes):
self.assertEqual(process.inputs, [study.samples[ix]])
Expand Down

0 comments on commit e2cada1

Please sign in to comment.