Pretrain-Model and save it 

In [1]:
%cd ..
import numpy as np
import pandas as pd
from relbench.datasets import get_dataset
from relbench.base import Table
from tqdm import tqdm
from typing import Any,Dict

import torch
import pickle
import os
from torch import Tensor
from torch_frame import stype
from torch_frame.config import TextEmbedderConfig
from torch_frame.data import Dataset
from torch_frame.data.stats import StatType
from torch_geometric.data import HeteroData
from torch_geometric.typing import NodeType
from torch_geometric.utils import sort_edge_index

/home/lingze/embedding_fusion


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import time

In [3]:
dataset = get_dataset(name = "rel-trial", download = True)
db = dataset.get_db()
cache_path = "data/rel-trial-tensor-frame"

Loading Database object from /home/lingze/.cache/relbench/rel-trial/db...
Done in 7.81 seconds.


In [4]:
# [NOTE]: the dataset has been materialized

# get infer_type in cache
type_path = os.path.join(cache_path,"col_type_dict.pkl")
col_type_dict = pickle.load(open(type_path, "rb"))
len(col_type_dict)

# add "compress_text" in each table in case 
for table_name, table in db.table_dict.items():
    table.df["text_compress"] = np.nan

In [5]:
from typing import List, Optional
from torch_frame.config.text_embedder import TextEmbedderConfig
from sentence_transformers import SentenceTransformer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class GloveTextEmbedding:
    def __init__(self, device: Optional[torch.device
                                       ] = None):
        self.model = SentenceTransformer(
            # "all-MiniLM-L12-v2",
            "sentence-transformers/average_word_embeddings_glove.6B.300d",
            device=device,
        )

    def __call__(self, sentences: List[str]) -> Tensor:
        return torch.from_numpy(self.model.encode(sentences))

text_embedder_cfg = TextEmbedderConfig(
    text_embedder=GloveTextEmbedding(device=device), batch_size=512
)

def remove_pkey_fkey(col_to_stype: Dict[str, Any], table:Table) -> dict:
    r"""Remove pkey, fkey columns since they will not be used as input feature."""
    if table.pkey_col is not None:
        if table.pkey_col in col_to_stype:
            col_to_stype.pop(table.pkey_col)
    for fkey in table.fkey_col_to_pkey_table.keys():
        if fkey in col_to_stype:
            col_to_stype.pop(fkey)

def to_unix_time(ser: pd.Series) -> np.ndarray:
    r"""Converts a :class:`pandas.Timestamp` series to UNIX timestamp (in seconds)."""
    assert ser.dtype in [np.dtype("datetime64[s]"), np.dtype("datetime64[ns]")]
    unix_time = ser.astype("int64").values
    if ser.dtype == np.dtype("datetime64[ns]"):
        unix_time //= 10**9
    return unix_time

  return self.fget.__get__(instance, owner)()


In [6]:
# build graph

# start build graph
cache_dir = "./data/rel-trial-tensor-frame"
if cache_dir is not None:
    os.makedirs(cache_dir, exist_ok=True)
data = HeteroData()
col_stats_dict = {}
for table_name, table in db.table_dict.items():
    df = table.df
    # (important for foreignKey value) Ensure the pkey is consecutive
    if table.pkey_col is not None:
        assert (df[table.pkey_col].values == np.arange(len(df))).all()
    
    col_to_stype = col_type_dict[table_name]
    
    # remove pkey, fkey
    remove_pkey_fkey(col_to_stype, table)
    
    if len(col_to_stype) == 0:
        # for example, relationship table which only contains pkey and fkey
        raise KeyError(f"{table_name} has no column to build graph")
    
    path = (
            None if cache_dir is None else os.path.join(cache_dir, f"{table_name}.pt")
    )
    
    print(f"-----> Materialize {table_name} Tensor Frame")
    dataset = Dataset(
        df = df,
        col_to_stype=col_to_stype,
        col_to_text_embedder_cfg=text_embedder_cfg,
    ).materialize(path=path)
    
    data[table_name].tf = dataset.tensor_frame
    col_stats_dict[table_name] = dataset.col_stats
    
    # Add time attribute
    if table.time_col is not None:
        data[table_name].time = torch.from_numpy(
            to_unix_time(df[table.time_col])
        )
    
    # Add edges normal edges
    for fkey_col_name, pkey_table_name in table.fkey_col_to_pkey_table.items():
        pkey_index = df[fkey_col_name]
        # Filter out dangling foreign keys
        mask = ~pkey_index.isna()
        fkey_index = torch.arange(len(pkey_index))
        
        # filter dangling foreign keys:
        pkey_index = torch.from_numpy(pkey_index[mask].astype(int).values)
        fkey_index = fkey_index[torch.from_numpy(mask.values)]
        
        # fkey -> pkey edges
        edge_index = torch.stack([fkey_index, pkey_index], dim=0)
        edge_type = (table_name, f"f2p_{fkey_col_name}", pkey_table_name)
        data[edge_type].edge_index = sort_edge_index(edge_index)

        # pkey -> fkey edges.
        # "rev_" is added so that PyG loader recognizes the reverse edges
        edge_index = torch.stack([pkey_index, fkey_index], dim=0)
        edge_type = (pkey_table_name, f"rev_f2p_{fkey_col_name}", table_name)
        data[edge_type].edge_index = sort_edge_index(edge_index)
    
data.validate()

-----> Materialize interventions Tensor Frame
-----> Materialize interventions_studies Tensor Frame
-----> Materialize facilities_studies Tensor Frame
-----> Materialize sponsors Tensor Frame
-----> Materialize eligibilities Tensor Frame
-----> Materialize reported_event_totals Tensor Frame
-----> Materialize designs Tensor Frame
-----> Materialize conditions_studies Tensor Frame
-----> Materialize drop_withdrawals Tensor Frame
-----> Materialize studies Tensor Frame
-----> Materialize outcome_analyses Tensor Frame
-----> Materialize sponsors_studies Tensor Frame
-----> Materialize outcomes Tensor Frame
-----> Materialize conditions Tensor Frame
-----> Materialize facilities Tensor Frame


True

In [7]:
from relbench.tasks import get_task
from relbench.modeling.graph import get_node_train_table_input
from torch_geometric.loader import NeighborLoader
from relbench.base import BaseTask
from model.base import CompositeModel, FeatureEncodingPart, NodeRepresentationPart
from relbench.modeling.nn import HeteroTemporalEncoder
# start to fine-train on the task a
from torch.nn import BCEWithLogitsLoss
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
import math
import copy


In [8]:
# add the additional edges
from utils.util import load_np_dict
edge_dict = load_np_dict("./edges/rel-trial-edges.npz")

In [9]:
for edge_name, edge_np in edge_dict.items():
    src_table, dst_table = edge_name.split('-')[0], edge_name.split('-')[1]
    edge_index = torch.from_numpy(edge_np.astype(int)).t()
    # [2, edge_num]
    edge_type = (src_table, f"appendix", dst_table)
    data[edge_type].edge_index = sort_edge_index(edge_index)
data.validate()

True

In [10]:
# read the pre-extracted sample
from utils.util import load_np_dict
sample_dict = load_np_dict("./samples/rel-trail-samples.npz")
sample_dict.keys()

dict_keys(['interventions', 'sponsors', 'eligibilities', 'designs', 'studies', 'conditions'])

In [11]:
from relbench.base import Database
def neighborsample_batch(
    db: Database,
    entity_table: str,
    node_idxs: np.ndarray,
    num_neighbors: List[int] = [128,128],
):
    # node_idxs: [n]
    nodes = (entity_table, torch.from_numpy(node_idxs))
    n = node_idxs.shape[0]
    input_time = torch.from_numpy(
        to_unix_time(pd.Series([db.max_timestamp] * n)))

    if db.table_dict[entity_table].time_col:
        time_col = db.table_dict[entity_table].time_col
        time_values = db.table_dict[entity_table].df[time_col].loc[node_idxs.tolist(
        )]
        input_time = torch.from_numpy(to_unix_time(time_values))

    loader = NeighborLoader(
        data,
        num_neighbors=num_neighbors,
        input_nodes=nodes,
        time_attr = "time",
        input_time=input_time,
        batch_size=n,
        temporal_strategy="uniform",
        shuffle=False,
        disjoint=True,
        num_workers=0,
        persistent_workers=False,
    )
    return next(iter(loader))
    

In [12]:
# construct bottom model
channels = 128
args = {
    "channels": channels,
    "num_layers": 2,
    "dropout_prob": 0.2,
}

temporal_encoder = HeteroTemporalEncoder(
    node_types=[
                node_type for node_type in data.node_types if "time" in data[node_type]
            ],
    channels=args["channels"],
)

feat_encoder = FeatureEncodingPart(
    data=data,
    node_to_col_stats=col_stats_dict,
    channels=args["channels"],
)

node_encoder = NodeRepresentationPart(
    data=data,
    channels=args["channels"],
    num_layers=1,
    normalization="layer_norm",
    dropout_prob=0.2
)

net = CompositeModel(
    data=data,
    channels=args["channels"],
    out_channels=1,
    dropout=0.2,
    aggr="mean",
    norm="batch_norm",
    num_layer=2,
    feature_encoder=feat_encoder,
    node_encoder=node_encoder,
    temporal_encoder=temporal_encoder
)

In [13]:
from model.utils import InfoNCE
lr = 5e-4
negative_sample_pool_size = 512
temprature = 0.01
net.reset_parameters()
optimizer = torch.optim.Adam(net.parameters(), lr = lr)
epoches = 20
early_restart_steps = 20
batch_size = 256
max_steps_in_epoch = 30
loss_fn = InfoNCE(temperature=temprature, negative_mode='paired')


In [14]:
net.to(device)
for epoch in range(1, epoches + 1):
    net.train()
    print("*"*30 + f"<Epoch: {epoch:02d}>" + "*"*30)
    for sample_table, sample_np in sample_dict.items():
        loss_accum = count_accum = 0
        shuffle_sample_np = sample_np[np.random.permutation(len(sample_np))]
        anchor_nodes_np = shuffle_sample_np[:, 0]
        positive_pool_np = shuffle_sample_np[:, 1:]
        # choose the positive samples
        n = sample_np.shape[0]
        negative_num = 20
        m = len(db.table_dict[sample_table].df)
        now = time.time()
        cnt = 0
        for batch_idx in tqdm(range(0, n, batch_size), leave=False):
            cnt += 1
            if cnt > max_steps_in_epoch:
                break
            anchor_nodes = anchor_nodes_np[batch_idx:batch_idx+batch_size]
            positive_pool_batch_np = positive_pool_np[batch_idx:batch_idx+batch_size]
            positive_nodes = []
            # random select the positive samples
            for row in positive_pool_batch_np:
                valid = row[row != -1]
                random_choice = np.random.choice(valid, 1)[0]
                positive_nodes.append(random_choice)

            positive_nodes = np.array(positive_nodes)
            B = positive_nodes.size
            # random select the negative sample, negative ratio is 1:20
            # for one batch, we still extract batch_size negative samples
            # for each positive-negative pair, we extract 20 from this 256 batch as negative samples
            excluded = set(positive_nodes.tolist()).union(
                set(anchor_nodes.tolist()))
            negative_candidates = list(set(range(m)) - excluded)
            # print(negative_candidates)
            
            sample_size = min(negative_sample_pool_size, len(negative_candidates))
            
            # if sample_size < positive_nodes.size:
            #     # special case, for those number of positive pairs is too small, 
            #     # we employ the pure random to select the negative samples
            #     negative_candidates = list(range(m))
            #     sample_size = positive_nodes.size
                
            #     print("==> Candidate:" + str(negative_candidates))
            
            negative_nodes = np.random.choice(
                negative_candidates, size=sample_size, replace=True)
            # [batch_size]
            # print(negative_nodes.shape)
            # print(B)
            # neighbor hood loader
            anchor_nodes_batch = neighborsample_batch(
                db, sample_table, anchor_nodes)
            positive_nodes_batch = neighborsample_batch(
                db, sample_table, positive_nodes)
            negative_nodes_batch = neighborsample_batch(
                db, sample_table, negative_nodes)

            optimizer.zero_grad()

            anchor_nodes_batch, positive_nodes_batch, negative_nodes_batch = \
                anchor_nodes_batch.to(device), positive_nodes_batch.to(
                    device), negative_nodes_batch.to(device)

            anchor_nodes_embedding = net.get_node_embedding(
                anchor_nodes_batch, sample_table)[sample_table][:B]
            positive_nodes_embedding = net.get_node_embedding(
                positive_nodes_batch, sample_table)[sample_table][:B]
            negative_nodes_embedding = net.get_node_embedding(
                negative_nodes_batch, sample_table)[sample_table][:sample_size]

            # negative_nodes_embedding = net.get_node_embedding(negative_nodes_batch, sample_table)[sample_table][:B]
            # [B, D]
            
            negative_indices = torch.stack([torch.randperm(sample_size)[
                                           :negative_num] for _ in range(B)]).to(device)
            negative_nodes_embedding = negative_nodes_embedding[negative_indices]
            # [B, negative_num, D]

            loss = loss_fn(anchor_nodes_embedding,
                           positive_nodes_embedding, negative_nodes_embedding)
            loss.backward()
            optimizer.step()
            loss_accum += loss.detach().item()
            count_accum += B
        end = time.time()
        mins, secs = divmod(end - now, 60)
        train_loss = loss_accum / count_accum
        
        print(f"====> In {sample_table}, Train loss: {train_loss}, Cost Time {mins:.0f}m {secs:.0f}s")

******************************<Epoch: 01>******************************


  0%|          | 0/4 [00:00<?, ?it/s]

                                             

====> In interventions, Train loss: 0.05045639544611813, Cost Time 0m 10s


                                               

====> In sponsors, Train loss: 0.02468927749699593, Cost Time 0m 32s


                                               

====> In eligibilities, Train loss: 0.02806696227543424, Cost Time 0m 19s


                                               

====> In designs, Train loss: 0.029616755287119205, Cost Time 0m 23s


                                               

====> In studies, Train loss: 0.02251334054899049, Cost Time 0m 49s


                                             

====> In conditions, Train loss: 0.04090505604380548, Cost Time 0m 14s
******************************<Epoch: 02>******************************


                                             

====> In interventions, Train loss: 0.029344714430114667, Cost Time 0m 9s


                                               

====> In sponsors, Train loss: 0.014865026082078072, Cost Time 0m 32s


                                               

====> In eligibilities, Train loss: 0.01692450080038505, Cost Time 0m 19s


                                               

====> In designs, Train loss: 0.015704039839290133, Cost Time 0m 23s


                                               

====> In studies, Train loss: 0.012862883962494014, Cost Time 0m 47s


                                             

====> In conditions, Train loss: 0.0264259816317305, Cost Time 0m 13s
******************************<Epoch: 03>******************************


                                             

====> In interventions, Train loss: 0.025905380709501505, Cost Time 0m 9s


                                               

====> In sponsors, Train loss: 0.01264045451390205, Cost Time 0m 31s


                                               

====> In eligibilities, Train loss: 0.012294835519809791, Cost Time 0m 19s


                                               

====> In designs, Train loss: 0.01212904733299409, Cost Time 0m 23s


                                               

====> In studies, Train loss: 0.011233200007329492, Cost Time 0m 47s


                                             

====> In conditions, Train loss: 0.02092016651779069, Cost Time 0m 13s
******************************<Epoch: 04>******************************


                                             

====> In interventions, Train loss: 0.023005327396804828, Cost Time 0m 9s


                                               

====> In sponsors, Train loss: 0.011235979049717454, Cost Time 0m 32s


                                               

====> In eligibilities, Train loss: 0.010371227141892944, Cost Time 0m 19s


                                               

====> In designs, Train loss: 0.010798420362024498, Cost Time 0m 22s


                                               

====> In studies, Train loss: 0.010263980039034715, Cost Time 0m 47s


                                             

====> In conditions, Train loss: 0.018994133503644442, Cost Time 0m 13s
******************************<Epoch: 05>******************************


                                             

====> In interventions, Train loss: 0.021309310932474488, Cost Time 0m 9s


                                               

====> In sponsors, Train loss: 0.009161323172861039, Cost Time 0m 31s


                                               

====> In eligibilities, Train loss: 0.009223417985013383, Cost Time 0m 19s


                                               

====> In designs, Train loss: 0.009741271022182182, Cost Time 0m 22s


                                               

====> In studies, Train loss: 0.009385878391332416, Cost Time 0m 47s


                                             

====> In conditions, Train loss: 0.01813650131225586, Cost Time 0m 13s
******************************<Epoch: 06>******************************


                                             

====> In interventions, Train loss: 0.02021520777003132, Cost Time 0m 9s


                                               

====> In sponsors, Train loss: 0.007934855271692146, Cost Time 0m 30s


                                               

====> In eligibilities, Train loss: 0.008594246525115538, Cost Time 0m 19s


                                               

====> In designs, Train loss: 0.00899614735737743, Cost Time 0m 22s


                                               

====> In studies, Train loss: 0.008703126446276651, Cost Time 0m 47s


                                             

====> In conditions, Train loss: 0.016428684068698896, Cost Time 0m 13s
******************************<Epoch: 07>******************************


                                             

====> In interventions, Train loss: 0.01892111898225932, Cost Time 0m 9s


                                               

====> In sponsors, Train loss: 0.007134444667626492, Cost Time 0m 30s


                                               

====> In eligibilities, Train loss: 0.007489149133061272, Cost Time 0m 19s


                                               

====> In designs, Train loss: 0.008331292987669874, Cost Time 0m 23s


                                               

====> In studies, Train loss: 0.008108276923959646, Cost Time 0m 47s


                                             

====> In conditions, Train loss: 0.015789002616007203, Cost Time 0m 13s
******************************<Epoch: 08>******************************


                                             

====> In interventions, Train loss: 0.017658483876055047, Cost Time 0m 9s


                                               

====> In sponsors, Train loss: 0.006551490042912422, Cost Time 0m 31s


                                               

====> In eligibilities, Train loss: 0.006817821181476092, Cost Time 0m 19s


                                               

====> In designs, Train loss: 0.007899787802024176, Cost Time 0m 23s


                                               

====> In studies, Train loss: 0.0075996934941766035, Cost Time 0m 47s


                                             

====> In conditions, Train loss: 0.01457648629679691, Cost Time 0m 13s
******************************<Epoch: 09>******************************


                                             

====> In interventions, Train loss: 0.01726695149190241, Cost Time 0m 9s


                                               

====> In sponsors, Train loss: 0.006017883165060476, Cost Time 0m 31s


                                               

====> In eligibilities, Train loss: 0.006244399135058206, Cost Time 0m 19s


                                               

====> In designs, Train loss: 0.00750828509362752, Cost Time 0m 23s


                                               

====> In studies, Train loss: 0.007081647453206701, Cost Time 0m 47s


                                             

====> In conditions, Train loss: 0.013888390455179898, Cost Time 0m 13s
******************************<Epoch: 10>******************************


                                             

====> In interventions, Train loss: 0.016390660514976958, Cost Time 0m 9s


                                               

====> In sponsors, Train loss: 0.005262361783657142, Cost Time 0m 31s


                                               

====> In eligibilities, Train loss: 0.005980827443717427, Cost Time 0m 19s


                                               

====> In designs, Train loss: 0.007208401044743173, Cost Time 0m 22s


                                               

====> In studies, Train loss: 0.00672068911189751, Cost Time 0m 47s


                                             

====> In conditions, Train loss: 0.012827332153790175, Cost Time 0m 13s
******************************<Epoch: 11>******************************


                                             

====> In interventions, Train loss: 0.015516727670478094, Cost Time 0m 9s


                                               

====> In sponsors, Train loss: 0.00508109764784257, Cost Time 0m 31s


                                               

====> In eligibilities, Train loss: 0.005665268538133154, Cost Time 0m 19s


                                               

====> In designs, Train loss: 0.006704626507407067, Cost Time 0m 23s


                                               

====> In studies, Train loss: 0.006306446178903372, Cost Time 0m 47s


                                             

====> In conditions, Train loss: 0.012213451665947306, Cost Time 0m 13s
******************************<Epoch: 12>******************************


                                             

====> In interventions, Train loss: 0.015099709642887722, Cost Time 0m 9s


                                               

====> In sponsors, Train loss: 0.004721940788504196, Cost Time 0m 31s


                                               

====> In eligibilities, Train loss: 0.005411288002225493, Cost Time 0m 19s


                                               

====> In designs, Train loss: 0.006310499314493781, Cost Time 0m 23s


                                               

====> In studies, Train loss: 0.006154087319051651, Cost Time 0m 47s


                                             

====> In conditions, Train loss: 0.011936753781783388, Cost Time 0m 13s
******************************<Epoch: 13>******************************


                                             

====> In interventions, Train loss: 0.015169899830218343, Cost Time 0m 9s


                                               

====> In sponsors, Train loss: 0.00427487061629738, Cost Time 0m 31s


                                               

====> In eligibilities, Train loss: 0.004910741595081037, Cost Time 0m 19s


                                               

====> In designs, Train loss: 0.0060931667385485345, Cost Time 0m 23s


                                               

====> In studies, Train loss: 0.005563497959250409, Cost Time 0m 48s


                                             

====> In conditions, Train loss: 0.01151901708372379, Cost Time 0m 13s
******************************<Epoch: 14>******************************


                                             

====> In interventions, Train loss: 0.014438226232080326, Cost Time 0m 9s


                                               

====> In sponsors, Train loss: 0.004192763055901065, Cost Time 0m 31s


                                               

====> In eligibilities, Train loss: 0.004792862654402283, Cost Time 0m 19s


                                               

====> In designs, Train loss: 0.005618148242067171, Cost Time 0m 23s


                                               

====> In studies, Train loss: 0.005147363165522517, Cost Time 0m 47s


                                             

====> In conditions, Train loss: 0.010729101825623075, Cost Time 0m 13s
******************************<Epoch: 15>******************************


                                             

====> In interventions, Train loss: 0.014053800660439944, Cost Time 0m 8s


                                               

====> In sponsors, Train loss: 0.0039127479949498285, Cost Time 0m 30s


                                               

====> In eligibilities, Train loss: 0.004388229715828913, Cost Time 0m 19s


                                               

====> In designs, Train loss: 0.00536778977253293, Cost Time 0m 22s


                                               

====> In studies, Train loss: 0.004788191537560766, Cost Time 0m 46s


                                             

====> In conditions, Train loss: 0.009751023468005831, Cost Time 0m 13s
******************************<Epoch: 16>******************************


                                             

====> In interventions, Train loss: 0.013207587383905239, Cost Time 0m 8s


                                               

====> In sponsors, Train loss: 0.0037482009056193887, Cost Time 0m 30s


                                               

====> In eligibilities, Train loss: 0.004340870631167038, Cost Time 0m 19s


                                               

====> In designs, Train loss: 0.005177228362768288, Cost Time 0m 22s


                                               

====> In studies, Train loss: 0.0045786685851706005, Cost Time 0m 46s


                                             

====> In conditions, Train loss: 0.009507620215324185, Cost Time 0m 13s
******************************<Epoch: 17>******************************


                                             

====> In interventions, Train loss: 0.013852127019453231, Cost Time 0m 8s


                                               

====> In sponsors, Train loss: 0.003595490539696905, Cost Time 0m 30s


                                               

====> In eligibilities, Train loss: 0.0040577502673718405, Cost Time 0m 18s


                                               

====> In designs, Train loss: 0.004855453047976398, Cost Time 0m 22s


                                               

====> In studies, Train loss: 0.004353768911882975, Cost Time 0m 46s


                                             

====> In conditions, Train loss: 0.009361742183371081, Cost Time 0m 13s
******************************<Epoch: 18>******************************


                                             

====> In interventions, Train loss: 0.012399127553045522, Cost Time 0m 9s


                                               

====> In sponsors, Train loss: 0.003520415435280067, Cost Time 0m 30s


                                               

====> In eligibilities, Train loss: 0.003820111812812073, Cost Time 0m 19s


                                               

====> In designs, Train loss: 0.0047017354093142, Cost Time 0m 23s


                                               

====> In studies, Train loss: 0.004329971131107542, Cost Time 0m 46s


                                             

====> In conditions, Train loss: 0.009638367569198785, Cost Time 0m 13s
******************************<Epoch: 19>******************************


                                             

====> In interventions, Train loss: 0.01178445082910482, Cost Time 0m 8s


                                               

====> In sponsors, Train loss: 0.003441160112876447, Cost Time 0m 31s


                                               

====> In eligibilities, Train loss: 0.003695749174108088, Cost Time 0m 19s


                                               

====> In designs, Train loss: 0.004637607492056469, Cost Time 0m 22s


                                               

====> In studies, Train loss: 0.004132830258876625, Cost Time 0m 46s


                                             

====> In conditions, Train loss: 0.00853195485929235, Cost Time 0m 13s
******************************<Epoch: 20>******************************


                                             

====> In interventions, Train loss: 0.011515272314999882, Cost Time 0m 9s


                                               

====> In sponsors, Train loss: 0.0033631375119022038, Cost Time 0m 30s


                                               

====> In eligibilities, Train loss: 0.00349579856176541, Cost Time 0m 18s


                                               

====> In designs, Train loss: 0.004446183815098449, Cost Time 0m 22s


                                               

====> In studies, Train loss: 0.00410617131651349, Cost Time 0m 47s


                                             

====> In conditions, Train loss: 0.009337407979164975, Cost Time 0m 13s




In [15]:
# pre-trained state
# record
pre_trained_state = copy.deepcopy(net.state_dict())

In [16]:
import torch
import json
torch.save(pre_trained_state, "./static/rel-trial-pre-trained-channel128-ep40.pth")