In [1]:
import os
os.chdir('../..')

import pandas as pd
import numpy as np

from pytrial.data.trial_data import TrialDatasetStructured
from pytrial.tasks.site_selection.data import TrialSiteModalities
from pytrial.tasks.site_selection.framm import FRAMM

# Base Data
NUM_SITES = 50
STATIC_DIM = 48
DX_DIM = 100
RX_DIM = 100

trial_data = pd.read_csv('./demo_data/demo_trial_patient_data/data_processed.csv')
TRIAL_DIM = TrialDatasetStructured(trial_data)[0].shape[1]
trial_data['label'] = [np.random.choice(list(range(NUM_SITES)), 10, replace=False).tolist() for _ in range(len(trial_data))]
trial_data['enrollment'] = [[np.random.randint(0, 100) for _ in range(10)] for _ in range(len(trial_data))]


site_data = {}
site_data['x'] = np.random.rand(NUM_SITES, STATIC_DIM)
site_data['dx'] = [[[np.random.choice(list(range(DX_DIM)), np.random.randint(0, 10), replace=False).tolist()] for n_visits in range(np.random.randint(0, 20))] for _ in range(NUM_SITES)]
site_data['rx'] = [[[np.random.choice(list(range(RX_DIM)), np.random.randint(0, 10), replace=False).tolist()] for n_visits in range(np.random.randint(0, 20))] for _ in range(NUM_SITES)]
site_data['hist'] = np.zeros((NUM_SITES, 10, TRIAL_DIM+1))
for i in range(NUM_SITES):
    for t in range(np.random.randint(0, 10)):
        site_data['hist'][i, t, :TRIAL_DIM] = np.random.rand(TRIAL_DIM)
        site_data['hist'][i, t, TRIAL_DIM] = np.random.randint(0, 100)
site_data['eth_label'] = [np.random.rand(4) for _ in range(NUM_SITES)]
site_data['eth_label'] = [np.exp(a)/np.exp(a).sum() for a in site_data['eth_label']]

# Compile into Trial-Site Dataset
trial_site_data = TrialSiteModalities(site_data, trial_data)

# Init Model
model = FRAMM(
    trial_dim=TRIAL_DIM,
    static_dim=STATIC_DIM,
    dx_dim=DX_DIM,
    rx_dim=RX_DIM,
    lstm_dim=32,
    embedding_dim=32,
    num_layers=1,
    hidden_dim=32,
    n_heads=2,
    missing_type='MCAT',
    scoring_type='Fully Connected',
    enrollment_only=False,
    K=5,
    lam=1,
    batch_size=16,
    epochs=2,
    device='cpu'
)




In [2]:
model.fit(trial_site_data)

***** Running training *****
  Num examples = 3871
  Num Epochs = 2
  Total optimization steps = 484


Training Epoch:   0%|          | 0/2 [00:00<?, ?it/s]


######### Train Loss 10 #########
0 7.6102 

Checkpoint saved in ./checkpoints/./10 at 10 steps.





######### Train Loss 20 #########
0 7.6127 

Checkpoint saved in ./checkpoints/./20 at 20 steps.


Iteration:  13%|█▎        | 31/242 [00:04<00:33,  6.27it/s]
Training Epoch:   0%|          | 0/2 [00:04<?, ?it/s]


######### Train Loss 30 #########
0 7.6513 

Checkpoint saved in ./checkpoints/./30 at 30 steps.





KeyboardInterrupt: 

In [4]:
# Make site selection for a test trial
site_selections = model.predict(trial_site_data)
print(site_selections[0])

[0, 8, 2, 6, 4]
