In [1]:
from biological_fuzzy_logic_networks.DREAM.DREAMBioFuzzNet import (
    DREAMBioFuzzNet,
    DREAMBioMixNet,
)
from biological_fuzzy_logic_networks.DREAM_analysis.utils import create_bfz
from biological_fuzzy_logic_networks.Synthetic_experiments.generate_gates_data import generate_gate_datasets

import torch
import pickle
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
gates_dir = "/dccstor/ipc1/CAR/BFN/Model/Gates/"
pkn = "/dccstor/ipc1/CAR/BFN/LiverDREAM_PKN.sif"
true_network_suffix="model_for_simulation.pt"
trained_network_suffix="model_for_simulation.pt"
optimized_network_suffix="model_for_prediction.pt"
n_gates = 15
repeat = 0

In [3]:
teacher = create_bfz(pkn, "DREAMBioFuzzNet")


In [4]:
teacher.fuzzy_nodes

['_and_1',
 '_and_2',
 '_and_3',
 '_and_4',
 '_and_5',
 '_and_6',
 '_and_7',
 '_and_8',
 '_and_9',
 '_and_10',
 '_and_11',
 'or1',
 'or2',
 'or3',
 'or4',
 'or5',
 'or6',
 'or7',
 'or8',
 'or9',
 'or10',
 'or11']

In [5]:
chosen_gates_idx = np.random.choice(len(teacher.fuzzy_nodes), n_gates, replace=False)
chosen_gates_idx

array([ 7,  3,  1,  0, 12,  5,  8,  9,  6, 17, 10, 11, 18,  4, 19])

In [6]:
torch.manual_seed(1356)

<torch._C.Generator at 0x14dc3c1366d0>

In [7]:
(train_input_df,
 train_true_df,
 test_input_df,
 test_true_df,
 teacher_network,
 changed_gates) = generate_gate_datasets(pkn, 
                                         train_size=100, 
                                         test_size= 10, 
                                         chosen_gates_idx=chosen_gates_idx)

In [8]:
for n in teacher_network.nodes():
    print(teacher_network.nodes()[n]["node_type"]==teacher.nodes()[n]["node_type"])

True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
False
True
True
True
False
False
True
True
False
False
False
True
True
True
True
True
False
True
True
True
True


In [9]:
model_state_dict, gate_dict = teacher_network.get_checkpoint(save_gates=True)

In [10]:
teacher.load_from_checkpoint(model_state_dict, gate_dict)

In [11]:
for n in teacher_network.nodes():
    print(teacher_network.nodes()[n]["node_type"]==teacher.nodes()[n]["node_type"])

True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True


In [12]:
teacher

<biological_fuzzy_logic_networks.DREAM.DREAMBioFuzzNet.DREAMBioFuzzNet at 0x14db310ba7c0>

# Load from experiment checkpoint

In [13]:
changed_gates = pickle.load(open(f"{gates_dir}{n_gates}_gates_{repeat}_repeat_changed_gates.p", "rb"))

In [14]:
ckpts = [torch.load(f"{gates_dir}{n_gates}_gates_{repeat}_repeat_{suffix}") for suffix in [optimized_network_suffix, true_network_suffix ]
    ]

In [15]:
teacher = create_bfz(pkn, "DREAMBioFuzzNet")

teacher_network = create_bfz(pkn, "DREAMBioFuzzNet")
teacher_network.load_from_checkpoint(ckpts[1]["model_state_dict"], ckpts[1]["model_gate_dict"])

student_network = create_bfz(pkn, "DREAMBioMixNet")
student_network.load_from_checkpoint(ckpts[0]["model_state_dict"], ckpts[0]["model_gate_dict"])


In [16]:
student_network.fuzzy_nodes

['_and_2', '_and_3', '_and_5', '_and_8', '_and_11', 'or1', 'or6']

In [17]:
student_network.mixed_gates


['_and_1',
 '_and_4',
 '_and_6',
 '_and_7',
 '_and_9',
 '_and_10',
 'or2',
 'or3',
 'or4',
 'or5',
 'or7',
 'or8',
 'or9',
 'or10',
 'or11']

In [18]:
changed_gates

['_and_1',
 '_and_4',
 '_and_6',
 '_and_7',
 '_and_9',
 '_and_10',
 'or2',
 'or3',
 'or4',
 'or5',
 'or7',
 'or8',
 'or9',
 'or10',
 'or11']

In [19]:
for n in teacher_network.nodes():
    node_type = teacher_network.nodes()[n]["node_type"]
    if node_type.startswith("logic"):
        if n in changed_gates:
            print(n, node_type)

_and_1 logic_gate_AND
_and_4 logic_gate_OR
_and_6 logic_gate_AND
_and_7 logic_gate_OR
_and_9 logic_gate_OR
_and_10 logic_gate_AND
or2 logic_gate_OR
or3 logic_gate_OR
or4 logic_gate_AND
or5 logic_gate_AND
or7 logic_gate_AND
or8 logic_gate_OR
or9 logic_gate_AND
or10 logic_gate_AND
or11 logic_gate_AND


In [20]:
for n in student_network.nodes():
    node_type = student_network.nodes()[n]["node_type"]
    if node_type.startswith("logic"):
        if n in changed_gates:
            print(n, node_type)

_and_1 logic_gate_MIXED
_and_4 logic_gate_MIXED
_and_6 logic_gate_MIXED
_and_7 logic_gate_MIXED
_and_9 logic_gate_MIXED
_and_10 logic_gate_MIXED
or2 logic_gate_MIXED
or3 logic_gate_MIXED
or4 logic_gate_MIXED
or5 logic_gate_MIXED
or7 logic_gate_MIXED
or8 logic_gate_MIXED
or9 logic_gate_MIXED
or10 logic_gate_MIXED
or11 logic_gate_MIXED
