In [18]:
import torch
import os
import random
from collections import defaultdict, Counter
import numpy as np

In [3]:
os.getcwd()

'/home/santari/Projects/pred2control/notebooks'

In [4]:
DATA_PATH = "/home/santari/Projects/pred2control/data/pred2control_target.pt"
device = "cuda" if torch.cuda.is_available() else "cpu"
print("device:", device)


device: cuda


In [5]:
payload = torch.load(DATA_PATH) 

In [6]:
print(type(payload))
payload.keys()

<class 'dict'>


dict_keys(['episodes', 'meta', 'fps', 'action_dim'])

In [7]:
payload["meta"][0]

{'episode_id': 0, 'category': 1}

In [8]:
episodes = payload["episodes"]
meta = payload["meta"]

assert len(episodes) == len(meta)
assert all(ep.shape == (300, 6) for ep in episodes), "Expected all episodes to be (300,6)"


print("Loaded:", len(episodes), "episodes")
print("Categories:", Counter([m["category"] for m in meta]))

Loaded: 50 episodes
Categories: Counter({1: 10, 2: 10, 3: 10, 4: 10, 5: 10})


In [15]:
SEED = 67
TEST_PER_CATERGORY = 3
EPS = 1e-6
L = 150 # context length

In [10]:
rng = random.Random(SEED)

cat_to_ids = defaultdict(list)
for i, m in enumerate(meta):
    cat_to_ids[int(m["category"])].append(i)

In [11]:
# check the episodes in this category (1-5)
print(cat_to_ids[5])
print(cat_to_ids.keys())

[40, 41, 42, 43, 44, 45, 46, 47, 48, 49]
dict_keys([1, 2, 3, 4, 5])


In [12]:
train_ids, test_ids = [], []

for cat in cat_to_ids.keys():
    ids = cat_to_ids[cat]
    rng.shuffle(ids)
    test = ids[:TEST_PER_CATERGORY]
    train = ids[TEST_PER_CATERGORY:]
    test_ids.extend(test)
    train_ids.extend(train)

train_ids = sorted(train_ids)
test_ids = sorted(test_ids)


print("Train:", len(train_ids), "Test:", len(test_ids))  # should be 35 / 15
print("Train cats:", Counter([meta[i]["category"] for i in train_ids]))
print("Test cats:", Counter([meta[i]["category"] for i in test_ids]))


Train: 35 Test: 15
Train cats: Counter({1: 7, 2: 7, 3: 7, 4: 7, 5: 7})
Test cats: Counter({1: 3, 2: 3, 3: 3, 4: 3, 5: 3})


### Normalization

In [16]:
def fit_action_normalizer(episodes_list, eps=1e-6):
    # episodes_list: list of (T,6) CPU tensors
    total = torch.zeros(6)
    total2 = torch.zeros(6)
    count = 0

    for ep in episodes_list:
        ep = ep.float()
        total += ep.sum(dim=0)
        total2 += (ep * ep).sum(dim=0)
        count += ep.shape[0]

    mean = total / count
    var = total2 / count - mean * mean
    std = torch.sqrt(torch.clamp(var, min=eps))
    return mean, std

train_eps = [episodes[i] for i in train_ids]
test_eps  = [episodes[i] for i in test_ids]

mean, std = fit_action_normalizer(train_eps, EPS)
print("mean:", mean)
print("std :", std)


mean: tensor([ -3.5770, -45.7268,  32.1744,  74.3475,  -2.3436,   0.5599])
std : tensor([14.9339, 54.8193, 55.3805, 10.9193,  3.7141,  0.1235])


In [17]:
def normalize_episode(ep, mean, std):
    return (ep.float() - mean) / (std + EPS)

train_eps_n = [normalize_episode(ep, mean, std) for ep in train_eps]
test_eps_n  = [normalize_episode(ep, mean, std) for ep in test_eps]

# quick sanity check: train normalized should have ~0 mean, ~1 std
all_train = torch.cat(train_eps_n, dim=0)
print("norm train mean:", all_train.mean(dim=0))
print("norm train std :", all_train.std(dim=0))


norm train mean: tensor([ 4.0690e-08,  2.9337e-08,  1.1626e-07,  9.3006e-07, -1.0536e-08,
         1.3551e-07])
norm train std : tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000])


In [22]:
train_eps_n.shape()

AttributeError: 'list' object has no attribute 'shape'

### Batching

In [None]:
import numpy as np

def sample_batch_stage2(episodes, batch_size, L=150):
    B = batch_size
    X = torch.empty(B, L, 6, dtype=torch.float32)
    Y = torch.empty(B, 6, dtype=torch.float32)

    for b in range(B):
        ep = episodes[np.random.randint(0, len(episodes))]
        t = np.random.randint(L, 300)  # target index
        X[b] = ep[t-L:t]
        Y[b] = ep[t]
    return X, Y

# smoke test
X, Y = sample_batch_stage2(train_eps_n, batch_size=8, L=L)
print(X.shape, Y.shape)  # (8,150,6) (8,6)
