In [17]:
import importlib
import numpy as np
from tqdm.auto import tqdm
import pickle

import hydra
import logging

from prol.process import (
    get_cycle,
    get_torch_dataset,
    get_task_indicies_and_map,
    get_sequence_indices
)
from prol.utils import get_dataloader

In [18]:
class SetParams:
    def __init__(self, dict) -> None:
        for k, v in dict.items():
            setattr(self, k, v)

def get_modules(name):
    try: 
        module1 = importlib.import_module(f"prol.models.{name}")
        module2 = importlib.import_module(f"prol.datahandlers.{name}_handle")
    except ImportError:
        print(f"Module {name} not found")
    return module1, module2

In [20]:
# input parameters
params = {
    "dataset": "mnist",
    "method": "timecnn",       # select from {proformer, cnn, mlp}
    "N": 20,                     # time between two task switches                   
    "t": 500,                  # training time
    "T": 5000,                   # future time horizon
    "task": [[0, 1], [2, 3]],    # task specification
    "contextlength": 200,       
    "seed": 1996,              
    "image_size": 64,           
    "device": "cuda:0",             
    "lr": 1e-3,         
    "batchsize": 64,
    "epochs": 1000,
    "verbose": True,
    "reps": 100,                 # number of test reps
    "outer_reps": 3
}
args = SetParams(params)

In [22]:
# get source dataset
root = '/cis/home/adesilva/ashwin/research/ProL/data'
torch_dataset = get_torch_dataset(root, name=args.dataset)

In [24]:
# get indices for each task
taskInd, maplab = get_task_indicies_and_map(
    tasks=args.task,
    y=torch_dataset.targets.numpy()
)

In [25]:
train_SeqInd, updated_taskInd = get_sequence_indices(
            N=args.N, 
            total_time_steps=args.t, 
            tasklib=taskInd, 
            seed=1996,
            remove_train_samples=True
        )

In [26]:
# sample a bunch of test sequences
test_seqInds = [
    get_sequence_indices(args.N, args.T, updated_taskInd, seed=1996)
    for inner_rep in range(args.reps)
]

In [27]:
method, datahandler = get_modules(args.method)

In [28]:
data_kwargs = {
            "dataset": torch_dataset, 
            "seqInd": train_SeqInd, 
            "maplab": maplab
        }
train_dataset = datahandler.VisionSequentialDataset(args, **data_kwargs)

In [30]:
train_dataset.__getitem__(1)

(tensor([[[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
           -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
           -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
           -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
          [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
           -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
           -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
           -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
          [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
           -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
           -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
           -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
          [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
           -1.0000, -1.0000, -1.000

In [32]:
test_kwargs = {
            "dataset": torch_dataset, 
            "train_seqInd": train_SeqInd, 
            "test_seqInd": test_seqInds[0], 
            "maplab": maplab
            }
test_dataset = datahandler.VisionSequentialTestDataset(args, **test_kwargs)

In [36]:
test_dataset.__getitem__(10)

(tensor([[[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
           -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
           -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
           -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
          [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
           -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
           -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
           -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
          [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
           -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
           -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
           -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
          [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
           -1.0000, -1.0000, -1.000