In [None]:
import pathlib

import matplotlib.pyplot as plt
import matplotlib.transforms as mtrans
import numpy as np
import torch
import torch.nn.functional as F
import torchvision
from sklearn import linear_model
from torch import nn
from torch.utils import data
from tqdm import tqdm, trange

from influence_functions import BaseObjective, CGInfluenceModule

# load dataset
from imitation.data import serialize
import os
root_dir  = "/scr/aliang80/changepoint_aug"
dataset_file = "datasets/expert_dataset/assembly-v2_50"
print("load dataset from ", dataset_file)
dataset_file = os.path.join(root_dir, dataset_file)
expert_trajectories = serialize.load(dataset_file)
print("number of expert trajectories: ", len(expert_trajectories))


bc_trainer = bc.BC(
    observation_space=env.observation_space,
    action_space=env.action_space,
    demonstrations=transitions,
    rng=rng,
    policy=policy,
    device=device,
)

# ===========
# Initialize influence module using custom objective
# ===========
class BinClassObjective(BaseObjective):

    def train_outputs(self, model, batch):
        return model(batch[0])

    def train_loss_on_outputs(self, outputs, batch):
        return F.binary_cross_entropy(outputs, batch[1])

    def train_regularization(self, params):
        return L2_WEIGHT * torch.square(params.norm())

    def test_loss(self, model, params, batch):
        outputs = model(batch[0])
        return F.binary_cross_entropy(outputs, batch[1])

module = CGInfluenceModule(
    model=bc_policy,
    objective=BinClassObjective(),
    train_loader=data.DataLoader(train_set, batch_size=32),
    test_loader=data.DataLoader(test_set, batch_size=32),
    device=DEVICE,
    damp=0.001,
    atol=1e-8,
    maxiter=1000,
)

# ===========
# For each test point:
#   1. Get the influence scores for all training points
#   2. Find the most helpful and harmful training points
# The most helpful point is that which, if removed, most increases the loss at the
# test point of interest (as predicted by the influence scores). Conversely, the most harmful
# test point is that which most decreases the test loss if removed.
# ===========

helpful_images = []
harmful_images = []

all_train_idxs = list(range(X_train.shape[0]))
for test_idx in tqdm(test_idxs, desc="Computing Influences"):
    # In practice, this can be further optimized since we are recomputing the
    # loss gradients over the entire training dataset. With enough space, these
    # gradients can be cached instead.
    influences = module.influences(train_idxs=all_train_idxs, test_idxs=[test_idx])

    helpful = captioned_image(clf, embeds, "train", influences.argmax(), influences)
    harmful = captioned_image(clf, embeds, "train", influences.argmin(), influences)

    helpful_images.append(helpful)
    harmful_images.append(harmful)