In [1]:
import os
os.environ['PROJECTS_BASE'] = '.'
import numpy as np
import torch

from trajectories.representation_learner.example_dataset import PatientDataset
from trajectories.representation_learner.adapted_model import CNNGRUModel
from trajectories.representation_learner.args import Args, FineTuneArgs

In [2]:
# Load the main args and the FT args
main_args_path = './SampleArgs/testing/args.json'

ft_args_path = './SampleArgs/testing/fine_tune_args.json'

args = Args.from_json_file(main_args_path)
ft_args = FineTuneArgs.from_json_file(ft_args_path)

# Test `ssl_forward`

In [3]:
pt_dataset = PatientDataset(min_seq_len=8, max_seq_len=8, eval_seq_len=8, task='ssl', 
                         signal_seconds=10, signal_mask=0.25, 
                         history_cutout_prob = 0.5, history_cutout_frac= 0.25, 
                         spatial_dropout_rate= 0.0, corrupt_rate=0.25)

In [4]:
model = CNNGRUModel(data_shape=[2, 1, 1250], use_cuda=False,
            hidden_dim=128, num_layers=2,
            bidirectional=False,
            pooling_method='avg',
            expander_fcs = [128,128],
            args=args
        )

In [5]:
batch = {'signals_timeseries1': torch.randn(2,2,1,1250, dtype=torch.float32),
         'signals_timeseries2': torch.randn(2,2,1,1250, dtype=torch.float32),
         'structured_timeseries1' : torch.randn(2,2,4, dtype=torch.float32),
         'structured_timeseries2' : torch.randn(2,2,4, dtype=torch.float32),
         'statics1' : torch.randn(2,6, dtype=torch.float32),
         'statics2' : torch.randn(2,6, dtype=torch.float32),
         'end_idx' : torch.Tensor([2,2]).type(torch.float32)}

In [6]:
res = model.ssl_forward(batch)

In [7]:
res[2]

{'total_loss': 0.8187847137451172,
 'traj_loss': 0.7595401406288147,
 'component_loss': 0.23697832226753235}

# Check normal forward

In [15]:
args.do_simclr = False
model = CNNGRUModel(data_shape=[2, 1, 1250], use_cuda=False,
            hidden_dim=128, num_layers=2,
            bidirectional=False,
            task_weights = {'example_task':1},
            pooling_method='last',
            expander_fcs = [128,128],
            args=args
        )

In [16]:
import numpy as np

In [19]:
batch = {'signals_timeseries': torch.randn(2,2,1,1250, dtype=torch.float32),
         'structured_timeseries' : torch.randn(2,2,4, dtype=torch.float32),
         'statics' : torch.randn(2,6, dtype=torch.float32),
         'example_task': torch.Tensor([0,1]).type(torch.int32),
         'end_idx' : torch.Tensor([1,1]).type(torch.float32)}

In [20]:
res = model.forward(batch)

In [21]:
res

(None,
 tensor([[ 4.2631e-02, -5.7899e-02,  7.4890e-02,  7.7913e-03,  4.9472e-02,
          -1.3801e-01,  7.7029e-03,  1.0021e-01,  2.1431e-01, -2.3640e-02,
           1.4885e-03, -2.3019e-02,  9.3461e-02,  5.8661e-02,  3.4646e-02,
           1.6821e-01,  5.7278e-02,  6.2329e-02, -5.9221e-02, -2.4537e-02,
           1.0785e-01,  1.8473e-01, -1.9742e-02,  6.4092e-02, -7.2477e-03,
           2.4889e-02,  2.3452e-02,  1.3512e-02, -6.8853e-03, -1.6377e-01,
          -1.6342e-02, -2.0358e-01, -1.2599e-01, -1.1685e-01, -1.0839e-01,
           4.9287e-02,  3.0330e-02,  5.9497e-02,  1.0639e-01, -6.9850e-02,
          -1.3475e-01, -2.5958e-02, -8.3263e-02,  5.2883e-03, -5.9462e-03,
           3.0242e-03, -5.4740e-02,  1.0761e-01, -1.1868e-01, -7.8829e-02,
          -8.0273e-02, -1.4055e-01, -6.6015e-02,  1.6010e-03,  1.9331e-01,
          -1.4051e-02, -5.1826e-02,  6.6731e-02,  3.8115e-02, -1.4743e-01,
           2.4530e-01,  7.5396e-02, -1.0014e-02,  1.1970e-01,  1.5352e-01,
          -1.5532e

In [22]:
res[2]

{'example_task': (array([[-0.09427005,  0.11021505],
         [ 0.01032159,  0.09823366]], dtype=float32),
  array([0, 1]),
  array(0.7253821, dtype=float32))}