In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.transforms as transforms
from copy import deepcopy
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
import random
import numpy as np
import math

In [42]:
dataset = torchvision.datasets.MNIST(
            root="data",
            train=True,
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5], std=[0.5]),
                transforms.Lambda(lambda x : torch.flatten(x))
            ]),
            download=True
        )

In [3]:
def get_task_indicies_and_map(tasks, y):
    tasklib = {}
    for i, task in enumerate(tasks):
        tasklib[i] = []
        for lab in task:
            tasklib[i].extend(
                np.where(y == lab)[0].tolist()
            )
    mapdict = {}
    for task in tasks:
        for i, lab in enumerate(task):
            mapdict[lab] = i
    maplab = lambda lab : mapdict[lab]
    return tasklib, maplab

In [4]:
def get_unit_cycle(N):
    return [1] * N + [0] * N

In [45]:
def get_sequence_indices(N, total_time_steps, tasklib, seed=1996):
    unit = get_unit_cycle(N)
    pattern = np.array((unit * math.ceil(total_time_steps/(len(unit))))[:total_time_steps]).astype("bool")
    seqInd = np.zeros((total_time_steps,)).astype('int')
    np.random.seed(seed)
    seqInd[pattern] = np.random.choice(tasklib[0], sum(pattern), replace=False)
    seqInd[~pattern] = np.random.choice(tasklib[1], sum(~pattern), replace=False)
    return seqInd

In [33]:
class SequentialDataset(Dataset):
    def __init__(self, dataset, seqInd, maplab, contextlength=200):
        self.dataset = dataset
        self.contextlength = contextlength
        self.t = len(seqInd)
        self.time = torch.arange(self.t).float()
        self.seqInd = seqInd
        self.maplab = maplab

        self.transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5], std=[0.5]),
                transforms.Lambda(lambda x : x.view(x.shape[0], x.shape[-1]**2))
            ])

    def __len__(self):
        return len(self.seqInd)

    def __getitem__(self, idx):
        r = np.random.randint(0, len(self.seqInd)-2*self.contextlength) # select the end of the subsequence
        s = np.random.randint(r+self.contextlength, r+2*self.contextlength)  # select a 'future' time beyond the subsequence

        id = list(range(r, r+self.contextlength)) + [s]
        dataid = self.seqInd[id]

        data = torch.cat(
            [self.dataset.__getitem__(i)[0][None, :] for i in dataid], axis=0
        )
        labels = torch.Tensor(
            [self.dataset.__getitem__(i)[1] for i in dataid]
        ).long().apply_(self.maplab)
        time = self.time[id]
        
        target = labels[-1].clone()
        labels[-1] = np.random.binomial(1, 0.5)

        return data, time, labels, target

In [34]:
total_time_steps = 1000
N = 10
tasks = [[0, 1], [2, 3]]

In [35]:
tasklib, maplab = get_task_indicies_and_map(tasks)
seqInd = get_sequence_indices(N, total_time_steps, tasklib)
seqDataset = SequentialDataset(dataset, seqInd, maplab)
loader = DataLoader(seqDataset)

In [36]:
data, time, labels, target = next(iter(loader))

In [37]:
data.shape

torch.Size([1, 201, 784])

In [38]:
time

tensor([[225., 226., 227., 228., 229., 230., 231., 232., 233., 234., 235., 236.,
         237., 238., 239., 240., 241., 242., 243., 244., 245., 246., 247., 248.,
         249., 250., 251., 252., 253., 254., 255., 256., 257., 258., 259., 260.,
         261., 262., 263., 264., 265., 266., 267., 268., 269., 270., 271., 272.,
         273., 274., 275., 276., 277., 278., 279., 280., 281., 282., 283., 284.,
         285., 286., 287., 288., 289., 290., 291., 292., 293., 294., 295., 296.,
         297., 298., 299., 300., 301., 302., 303., 304., 305., 306., 307., 308.,
         309., 310., 311., 312., 313., 314., 315., 316., 317., 318., 319., 320.,
         321., 322., 323., 324., 325., 326., 327., 328., 329., 330., 331., 332.,
         333., 334., 335., 336., 337., 338., 339., 340., 341., 342., 343., 344.,
         345., 346., 347., 348., 349., 350., 351., 352., 353., 354., 355., 356.,
         357., 358., 359., 360., 361., 362., 363., 364., 365., 366., 367., 368.,
         369., 370., 371., 3

In [39]:
data

tensor([[[-1., -1., -1.,  ..., -1., -1., -1.],
         [-1., -1., -1.,  ..., -1., -1., -1.],
         [-1., -1., -1.,  ..., -1., -1., -1.],
         ...,
         [-1., -1., -1.,  ..., -1., -1., -1.],
         [-1., -1., -1.,  ..., -1., -1., -1.],
         [-1., -1., -1.,  ..., -1., -1., -1.]]])

In [40]:
labels

tensor([[1, 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0,
         1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 0,
         1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0,
         0, 1, 0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0,
         0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1,
         1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 1,
         1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 0, 1,
         0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 0,
         0, 1, 0, 1, 0, 1, 1, 0, 0]])

In [51]:
seqInds = [
        list(get_sequence_indices(10, 100, tasklib, seed=1996 + 1000*rep))
        for rep in range(10)
    ]

In [52]:
seqInds

[[54087,
  8472,
  42625,
  43527,
  17960,
  46797,
  46936,
  14794,
  52444,
  22752,
  42971,
  48781,
  53488,
  34710,
  22386,
  25645,
  40416,
  31513,
  40447,
  17046,
  34168,
  47501,
  5694,
  23492,
  17381,
  52995,
  42170,
  58234,
  8894,
  31377,
  28709,
  43964,
  9685,
  45342,
  53583,
  5762,
  29071,
  46092,
  58483,
  25688,
  49964,
  8255,
  24893,
  2218,
  9204,
  46975,
  30486,
  56668,
  35262,
  45539,
  34573,
  3553,
  16556,
  31379,
  12260,
  51909,
  49034,
  49321,
  25442,
  31868,
  11112,
  45721,
  55963,
  51190,
  27202,
  37405,
  15977,
  39685,
  31348,
  4013,
  24790,
  26509,
  58200,
  45781,
  45216,
  53283,
  24610,
  46286,
  1689,
  15093,
  32165,
  18093,
  44127,
  40974,
  825,
  56349,
  47735,
  15156,
  58322,
  30238,
  6269,
  44312,
  22083,
  12204,
  13325,
  46974,
  1129,
  15533,
  13332,
  34171],
 [40042,
  28131,
  52745,
  34638,
  24397,
  43567,
  36780,
  53173,
  8124,
  29405,
  4241,
  27518,
  32545,

In [50]:
list(range(-5, 0))

[-5, -4, -3, -2, -1]