# What are the geometric properties of datasets known to be amenable to repL?

This file evaluates what the sort of embeddings that pretrained classifier networks learn. I'm particularly interested in whether classes cluster in the embedding space. I expect this to be the case for classifiers, but it would be particularly interesting if it were also the case for models trained using unsupervised learning, since that suggests there is some intrinsic relation between the geometry of samples from $p(x)$ and the label distribution $p(y \mid x)$.

In [None]:
# some black magic from https://github.com/pytorch/pytorch/issues/30966#issuecomment-582747929
import tensorflow as tf
import tensorboard as tb
tf.io.gfile = tb.compat.tensorflow_stub.io.gfile

In [None]:
# useful because it has pretrained SimCLR models
# !pip install lightning-bolts["extra"]

In [None]:
import os
import glob
import re

import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from torchvision import models, datasets, transforms as T
from torchvision.datasets.utils import download_url, download_and_extract_archive

## Downloading STL10

We're going to try clustering part of the subset of STL10. STL10 is an ImageNet subset with images resized to 96x96. We upscale (naively) to the right size so we can feed the images into an actual ImageNet model; the STL10 default is 96$\times$96, so the images will look a bit blurry if visualised.

In [None]:
class MiniSTL10(datasets.STL10):
    # STL10 with just the test set (1/10th the size)
    url = 'https://www.qxcv.net/il/stl10_binary_test.tar.gz'
    tgz_md5 = '1f2186acdb97f6a4a99f6ae43314f288'
dataset = MiniSTL10('./data/stl10/', download=True, split='test', transform=T.Compose([
    T.Resize(256), T.CenterCrop(224), T.ToTensor()
]))
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)

## Getting a pretrained ImageNet model

We'll use this pretrained model to produce our embeddings. We want the output of the penultimate layer.

In [None]:
model_pretrained = models.resnet18(pretrained=True, progress=True).eval().cuda()
model_random = models.resnet18(pretrained=False).cuda()

def resnet_get_avg_embedding(resnet, x):
    # copied from https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py, but takes out
    # the FC stuff at the end
    assert isinstance(resnet, models.resnet.ResNet)
    x = resnet.conv1(x)
    x = resnet.bn1(x)
    x = resnet.relu(x)
    x = resnet.maxpool(x)

    x = resnet.layer1(x)
    x = resnet.layer2(x)
    x = resnet.layer3(x)
    x = resnet.layer4(x)

    x = resnet.avgpool(x)
    x = torch.flatten(x, 1)
    # skip the FC layer (I think results are 2048-dimensional, which is a bit high; may have to random project down)
    return x

## Getting a pretrained SimCLR model

In [None]:
from pl_bolts.models.self_supervised import SimCLR

# copied from https://pytorch-lightning-bolts.readthedocs.io/en/latest/self_supervised_models.html
simclr_weight_path = 'https://pl-bolts-weights.s3.us-east-2.amazonaws.com/simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt'
simclr = SimCLR.load_from_checkpoint(simclr_weight_path, strict=False)
simclr_resnet50 = simclr.encoder.cuda()
simclr_resnet50.eval();  # semicolon to stop it printing

## Downloading some MAGICAL data

In [None]:
download_and_extract_archive(
    url='https://www.qxcv.net/il/magical-data-2021-05-18.tar.xz',
    download_root='data/')

In [None]:
# def make_imagenet_ds_magical():
#     # use only latest frame
#     obs_array = mag_data['obs'][:, -3:]
#     # convert to [0,1]
#     obs_tensor = torch.as_tensor(obs_array.astype('float32') / 255)
#     # resize to 224*224
#     obs_tensor = F.interpolate(obs_tensor, size=(224, 224), mode='bilinear', align_corners=False)
#     # extract labels
#     acts_tensor = torch.as_tensor(mag_data['acts'])
#     tn_tensor = torch.as_tensor(mag_data['traj_num'])
#     fn_tensor = torch.as_tensor(mag_data['frame_num'])
#     return torch.utils.data.TensorDataset(obs_tensor, acts_tensor, tn_tensor, fn_tensor)
#
# magical_dl = torch.utils.data.DataLoader(make_imagenet_ds_magical(), batch_size=32, shuffle=True)

def make_ilr_dl_magical(env_name_prefix, *, bs=32, shuffle=True):
    """Make a DataLoader for example data from the given environment (e.g. 'MoveToCorner' or 'MatchRegions')."""
    matching_path, = glob.glob(f'data/magical-data-2021-05-18/mtr-extra-data/*{env_name_prefix}*.pt')
    mag_data = torch.load(matching_path)
    del mag_data['next_obs']
    # use all frames
    obs_array = mag_data['obs']
    # convert to [0,1]
    obs_tensor = torch.as_tensor(obs_array.astype('float32') / 255)
    # extract labels
    acts_tensor = torch.as_tensor(mag_data['acts'])
    tn_tensor = torch.as_tensor(mag_data['traj_num'])
    fn_tensor = torch.as_tensor(mag_data['frame_num'])
    ds = torch.utils.data.TensorDataset(obs_tensor, acts_tensor, tn_tensor, fn_tensor)
    return torch.utils.data.DataLoader(ds, batch_size=bs, shuffle=shuffle)

## Getting some pretrained MAGICAL models

In [None]:
magical_env_name_re = re.compile(r'MatchRegions|MoveToCorner|MoveToRegion|ClusterColour|ClusterShape|FindDupe|FixColour|MakeLine')
remove_env_re = re.compile(f'_({magical_env_name_re.pattern})-[a-zA-Z]+-v[0-9]+')
model_dir = 'data/magical-data-2021-05-18/traced_nets'
model_files = [os.path.join(model_dir, p) for p in os.listdir(model_dir) if p.endswith('.pt')]
magical_models_by_env = {}
for model_file_path in model_files:
    match = magical_env_name_re.search(model_file_path)
    if match is not None:
        match_str = match.group()
    else:
        print(f"Could not find environment name for '{model_file_path}'")
        match_str = 'unk'
    model = torch.jit.load(model_file_path, map_location=torch.device('cuda')).eval().cuda()
    # identify model by basename but strip the .pt suffix
    model_basename = os.path.basename(model_file_path)[:-3]
    # also strip the env name
    model_basename, _ = remove_env_re.subn('', model_basename)
    magical_models_by_env.setdefault(match_str, []).append({'name': model_basename, 'model': model})

print('Models by environment:')
for e, l in magical_models_by_env.items():
    print(f'{e}: {", ".join(d["name"] for d in l)}')

## Computing some STL10 embeddings

We're going to run a few STL10 images through our embedding function, then add them to a TensorBoard file that we can visualise.

In [None]:
# delete old runs before doing this
!rm -r runs

def save_embeddings_to_file(loader, compute_embeddings, run_name, env_name='magical', frame_width=48):
    target_num_embeddings = 512
    embeddings = []
    imgs_32 = []
    label_rows = None
    label_names = None
    print(f'Generating >={target_num_embeddings} embeddings')
    for batch_in, *batch_labs in loader:
        if len(batch_labs) == 1:
            if label_rows is None:
                label_rows = []
                label_names = None
            label_rows.extend(batch_labs[0].numpy())
        else:
            if env_name == 'magical':
                if label_rows is None:
                    label_rows = []
                    label_names = ['acts', 'traj_num', 'frame_num', 'time_div10', 'time_div20', 'time_div30']
                acts, traj_nums, frame_nums = batch_labs
                for i in range(len(acts)):
                    label_rows.append([
                        acts[i], traj_nums[i], frame_nums[i], frame_nums[i] // 10, frame_nums[i] // 20, frame_nums[i] // 30
                    ])
            else:
                if label_rows is None:
                    label_rows = []
                    label_names = ['acts', 'rews', 'trajs']
                acts, rews, trajs = batch_labs
                for i in range(len(acts)):
                    label_rows.append([acts[i], rews[i], trajs[i]])
        with torch.no_grad():
            batch_embeddings = compute_embeddings(batch_in)
            # resize each frame to 48x48, then lay out the frames in each stack horizontally
            batch_in_resize = F.interpolate(batch_in, size=(frame_width, frame_width), mode='bilinear', 
                                            align_corners=False)
            batch_in_resize = batch_in_resize.detach()
            batch_in_stack = torch.reshape(batch_in_resize,
                                           (batch_in_resize.shape[0], -1, 3,) + batch_in_resize.shape[2:])
            batch_in_stack_t = torch.movedim(batch_in_stack, 1, 0)
            double_resize = torch.reshape(
                batch_in_stack_t,
                (max(1, batch_in_stack_t.shape[0] // 2), -1) + batch_in_stack_t.shape[1:])
            cat_h = torch.cat(list(double_resize), dim=4)
            cat_v = torch.cat(list(cat_h), dim=2)
            imgs_32.append(cat_v.detach().cpu())
            
        embeddings.append(batch_embeddings)
        n_embed = sum(map(len, embeddings))
        # Debug print, produces a lot of extra console output
        # print(f'Have {n_embed} embeddings')
        if n_embed >= target_num_embeddings:
            break
    embeddings = np.concatenate(embeddings, axis=0)
    imgs_32 = torch.cat(imgs_32, dim=0)

    # write to TB
    writer = SummaryWriter(log_dir=f'runs/{run_name}', comment=run_name)
    writer.add_embedding(
        embeddings,
        metadata=label_rows,
        metadata_header=label_names,
        tag=run_name,
        # XXX label_img=imgs_32,
    )
    writer.flush()
    writer.close()

# magical_acts_dl, magical_fn_dl, magical_tn_dl
# save_embeddings_to_file(
#     magical_dl,
#     lambda b: simclr_resnet50(b.cuda())[0].detach().cpu().numpy(),
#     'magical_mtr_with_simclr_resnet50')

In [None]:
save_embeddings_to_file(
    dataloader,
    lambda b: resnet_get_avg_embedding(model_pretrained, b.cuda()).detach().cpu().numpy(),
    'stl10_with_pretrained_resnet18')
save_embeddings_to_file(
    dataloader,
    lambda b: resnet_get_avg_embedding(model_random, b.cuda()).detach().cpu().numpy(),
    'stl10_with_random_resnet18')
save_embeddings_to_file(
    dataloader,
    lambda b: simclr_resnet50(b.cuda())[0].detach().cpu().numpy(),
    'stl10_with_simclr_resnet50')

## Computing some MAGICAL embeddings

Going to repeat this for each dataset + model.

In [None]:
for env_name_prefix, model_list in magical_models_by_env.items():
    env_dl = make_ilr_dl_magical(env_name_prefix)
    for model_dict in model_list:
        name = model_dict['name']
        model = model_dict['model'].cuda()
        save_embeddings_to_file(
            env_dl,
            # lambda b: model(b.cuda()).detach().cpu().numpy(),
            lambda b: model(b.cuda()).detach().cpu().numpy(),
            env_name_prefix + '_' + name)
    del env_dl

## Get Procgen/DMC data and models

In [None]:
data_size = 512
models_path = '/home/cynthiachen/il-representations/analysis/data/procgen'
models_path = '/home/cynthiachen/il-representations/analysis/data/dmc'
dmc_data_basepath = '/home/cynthiachen/il-representations/data/dm_control'
procgen_data_basepath = '/scratch/cynthiachen/procgen_demo/'

from il_representations.envs.utils import stack_obs_oldest_first
from il_representations.envs.dm_control_envs import _load_pkl_gz_lists


def make_ilr_dl_procgen(data_path, *, bs=32, shuffle=True):
    """Make a DataLoader for example data from the given environment.
    """
    procgen_data = np.load(data_path, allow_pickle=True)
    # Add trajectory label
    traj_labels, tlabel = [], 0
    for d in np.concatenate(procgen_data['dones'], axis=0)[:data_size]:
        traj_labels.append(tlabel)
        if d:
            tlabel += 1
    # Resize to [0, 1]
    cat_obs = np.concatenate(procgen_data['obs'], axis=0)[:data_size]/255.
    # Move channel to first dimension
    cat_obs = np.transpose(cat_obs, (0, 3, 1, 2))
    cat_obs = stack_obs_oldest_first(cat_obs, frame_stack=4, use_zeroed_frames=False)
    acts_tensor = torch.tensor(np.concatenate(procgen_data['acts'], axis=0)[:data_size])
    rews_tensor = torch.tensor(np.concatenate(procgen_data['rews'], axis=0)[:data_size]).int()
    traj_tensor = torch.tensor(traj_labels)
    obs_tensor = torch.FloatTensor(cat_obs)
    ds = torch.utils.data.TensorDataset(obs_tensor, acts_tensor, rews_tensor, traj_tensor)
    return torch.utils.data.DataLoader(ds, batch_size=bs, shuffle=shuffle)


def make_ilr_dl_dmc(data_path, *, bs=32, shuffle=True):
    """Make a DataLoader for example data from the given environment.
    """
    loaded_trajs = _load_pkl_gz_lists([data_path])
    dones_lists = [np.array([False] * (len(t.acts) - 1) + [True], dtype='bool') for t in loaded_trajs][:data_size]
    cat_obs = np.concatenate([stack_obs_oldest_first(t.obs[:-1], frame_stack=3, use_zeroed_frames=True)
                              for t in loaded_trajs], axis=0)[:data_size]/255.
    acts_tensor = torch.tensor(np.concatenate([t.acts for t in loaded_trajs], axis=0)[:data_size])
    
    rews_tensor = torch.tensor(np.concatenate([t.rews for t in loaded_trajs], axis=0)[:data_size]).int()
    cat_dones = np.concatenate(dones_lists, axis=0)[:data_size]
    
    traj_labels, tlabel = [], 0
    for d in cat_dones:
        traj_labels.append(tlabel)
        if d:
            tlabel += 1
    traj_tensor = torch.tensor(traj_labels)
    obs_tensor = torch.FloatTensor(cat_obs)
    ds = torch.utils.data.TensorDataset(obs_tensor, acts_tensor, rews_tensor, traj_tensor)
    return torch.utils.data.DataLoader(ds, batch_size=bs, shuffle=shuffle)


def get_models(model_dir):
    model_files = [os.path.join(model_dir, p) for p in os.listdir(model_dir) if p.endswith('.ckpt')]
    models_by_env = {}
    for model_file_path in model_files:
        model = torch.load(model_file_path, map_location=torch.device('cuda')).eval().cuda()
        # identify model by basename but strip the .ckpt suffix
        model_basename = os.path.basename(model_file_path)[:-5]
        # also strip the env name
        model_basename_parts = model_basename.split('-')
        if len(model_basename_parts) == 3:  # dmc names
            env_name, algo = '-'.join(model_basename_parts[:2]), model_basename_parts[2]
        elif len(model_basename_parts) == 2:  # procgen
            env_name, algo = model_basename_parts[0], model_basename_parts[1]
        models_by_env.setdefault(env_name, []).append({'name': model_basename, 'model': model})
    return models_by_env

print('Models by environment:')
models_by_env = get_models(models_path)
for e, l in models_by_env.items():
    print(f'{e}: {", ".join(d["name"] for d in l)}')


## Get representation encodings

In [None]:
for env_name, model_list in models_by_env.items():
    # This can be made more robust, though currently it's an easy way to tell.
    is_dmc = True if '-' in env_name else False
    print(env_name)
    if is_dmc:
        data_path = glob.glob(f'{dmc_data_basepath}/{env_name}-*.pkl.gz')[0]
        env_dl = make_ilr_dl_dmc(data_path=data_path)
    else:
        data_path = f'{procgen_data_basepath}/demo_{env_name}.pickle'
        env_dl = make_ilr_dl_procgen(data_path=data_path)
    for model_dict in model_list:
        name = model_dict['name']
        model = model_dict['model'].cuda()
        save_embeddings_to_file(
            env_dl,
            lambda b: model(b.cuda(), traj_info=None).mean.detach().cpu().numpy(),
            name,
            env_name='procgen', 
            frame_width=64
        )
    del env_dl

## Visualising it all in TensorBoard

In [None]:
%load_ext tensorboard
%tensorboard --logdir=runs