In [1]:
import os
os.chdir('../..')

import pytrial
pytrial.manual_seed(42)

import torch
import numpy as np
import pandas as pd

from pytrial.tasks.site_selection.data import TrialSiteSimple
from pytrial.tasks.site_selection import PolicyGradientEntropy

In [2]:
# Build Dummy Dataset
M = 10

siteDf = pd.DataFrame(np.random.randint(0,10,size=(1000, 26)), columns=list('ABCDEFGHIJKLMNOPQRSTUVWXYZ'))
siteDf['demographics'] = [torch.softmax(torch.randn(5), dim=0).numpy() for _ in range(len(siteDf))]

trialDf = pd.DataFrame(np.random.randint(0,10,size=(100, 10)), columns=list('ABCDEFGHIJ'))
trialDf['label'] = [list(np.random.choice(list(range(1000)), M, replace=False)) for _ in range(len(trialDf))]
trialDf['enrollment'] = [[np.random.randint(0,100) for _ in range(M)] for _ in range(len(trialDf))]

trial_site_data = TrialSiteSimple(siteDf, trialDf)



In [3]:
# Initialize Model
model = PolicyGradientEntropy(
    trial_dim=10, 
    site_dim=26, 
    embedding_dim=16, 
    enrollment_only=False,
    K=5,
    lam=1,
    learning_rate=1e-4,
    batch_size=4,
    epochs=3,
    device='cpu'
)

model.fit(trial_site_data)

***** Running training *****
  Num examples = 100
  Num Epochs = 3
  Total optimization steps = 75


Training Epoch:   0%|          | 0/3 [00:00<?, ?it/s]


######### Train Loss 10 #########
0 8.4602 



Iteration: 100%|██████████| 25/25 [00:00<00:00, 95.29it/s]
Training Epoch:  33%|███▎      | 1/3 [00:00<00:00,  3.69it/s]


######### Train Loss 20 #########
0 8.5250 






######### Train Loss 30 #########
0 8.4397 






######### Train Loss 40 #########
0 8.4717 



Iteration: 100%|██████████| 25/25 [00:00<00:00, 95.02it/s]
Training Epoch:  67%|██████▋   | 2/3 [00:00<00:00,  3.68it/s]


######### Train Loss 50 #########
0 8.4162 






######### Train Loss 60 #########
0 8.4310 






######### Train Loss 70 #########
0 8.4120 



Iteration: 100%|██████████| 25/25 [00:00<00:00, 127.57it/s]
Training Epoch: 100%|██████████| 3/3 [00:00<00:00,  4.03it/s]


Training completes.


In [4]:
# Make Site Selections
selections = model.predict(trial_site_data)
print(selections)

[[5, 8, 3, 0, 6], [3, 5, 8, 6, 7], [8, 9, 4, 0, 7], [7, 0, 6, 3, 1], [9, 1, 0, 3, 6], [0, 5, 2, 3, 1], [3, 4, 9, 1, 8], [1, 2, 5, 9, 0], [8, 4, 3, 2, 7], [7, 0, 9, 3, 1], [3, 9, 2, 7, 1], [0, 4, 3, 9, 5], [8, 9, 3, 7, 4], [5, 0, 4, 1, 7], [2, 4, 5, 9, 7], [2, 6, 9, 0, 8], [3, 9, 4, 8, 1], [3, 5, 8, 9, 6], [7, 3, 6, 2, 8], [6, 9, 2, 4, 1], [5, 0, 8, 6, 2], [5, 3, 6, 7, 0], [8, 4, 5, 1, 9], [1, 7, 0, 5, 8], [3, 2, 0, 1, 9], [0, 9, 7, 4, 8], [7, 9, 0, 5, 8], [1, 7, 9, 4, 5], [8, 2, 4, 6, 1], [6, 0, 7, 4, 8], [9, 6, 7, 5, 8], [5, 6, 2, 4, 8], [7, 2, 3, 1, 4], [7, 3, 8, 1, 4], [4, 3, 9, 0, 6], [7, 1, 5, 6, 0], [9, 6, 7, 5, 8], [7, 1, 9, 4, 2], [7, 3, 6, 8, 9], [6, 8, 7, 5, 4], [9, 5, 7, 4, 6], [1, 9, 8, 2, 6], [6, 9, 3, 2, 5], [2, 6, 3, 5, 8], [3, 6, 2, 4, 8], [8, 6, 0, 9, 5], [8, 7, 5, 9, 1], [4, 7, 5, 3, 1], [8, 7, 1, 9, 5], [4, 9, 0, 2, 5], [0, 2, 5, 8, 4], [2, 7, 1, 8, 6], [6, 3, 5, 7, 9], [1, 2, 5, 7, 3], [3, 6, 8, 4, 5], [2, 4, 8, 5, 3], [7, 3, 8, 9, 4], [7, 3, 4, 6, 8], [0, 9, 4, 3, 