In [10]:
import importlib
from torch.utils.data import DataLoader
import numpy as np
from tqdm.auto import tqdm
import pickle

import hydra
import logging

import time

from prol.process import (
    get_cycle,
    get_torch_dataset,
    get_task_indicies_and_map,
    get_sequence_indices
)

class SetParams:
    def __init__(self, dict) -> None:
        for k, v in dict.items():
            setattr(self, k, v)

In [2]:
# input parameters
params = {
    "method": "proformer",
    "N": 20,                    # time between two task switches                   
    "t": 500,                  # training time
    "T": 5000,                  # future time horizon
    "task": [[0, 1], [2, 3]],   # task specification
    "contextlength": 200,       
    "seed": 1996,              
    "image_size": 28,           
    "device": "cuda:0",             
    "lr": 1e-3,         
    "batchsize": 64,
    "epochs": 500,
    "verbose": True,
    "reps": 100,                 # number of test reps
    "outer_reps": 3
}
args = SetParams(params)


In [3]:
root = '/cis/home/adesilva/ashwin/research/ProL/data'
torch_dataset = get_torch_dataset(root=root, name='mnist')

In [24]:
from prol.models.proformer import SequentialDataset as proSD
from prol.models.smallconv import SequentialDataset as convSD

In [25]:
# get indices for each task
taskInd, maplab = get_task_indicies_and_map(
    tasks=args.task,
    y=torch_dataset.targets.numpy()
)

# get a training sequence
seed = args.seed
train_SeqInd, updated_taskInd = get_sequence_indices(
    N=args.N, 
    total_time_steps=args.t, 
    tasklib=taskInd, 
    seed=seed,
    remove_train_samples=True
)

# form the train dataset
data_kwargs = {
    "dataset": torch_dataset, 
    "seqInd": train_SeqInd, 
    "maplab": maplab
}

In [26]:
pro_dataset = proSD(args, **data_kwargs)
conv_dataset = convSD(args, **data_kwargs)

In [57]:
proDL = DataLoader(pro_dataset, batch_size=16)
convDL = DataLoader(conv_dataset, batch_size=16)

In [58]:
start = time.time()
batch = next(iter(convDL))
end = time.time()
print(f'Time taken : {end - start}')
print(len(batch))

Time taken : 0.0021600723266601562
2


In [59]:
start = time.time()
batch = next(iter(proDL))
end = time.time()
print(f'Time taken : {end - start}')
print(len(batch))

Time taken : 0.01221776008605957
4
