In [1]:
%cd ../src/

/mnt/c/Users/Jacob/Desktop/prosjektoppgave/tcav_atari/src


In [2]:
import warnings

import matplotlib.pyplot as plt
import numpy as np
import torch
from sklearn import linear_model
from sklearn.exceptions import ConvergenceWarning
from sklearn.metrics import accuracy_score, r2_score
from sklearn.model_selection import KFold, train_test_split

from concepts import concept_instances
from train_model import load_model
from utils import load_data, prepare_folders

In [3]:
warnings.filterwarnings('ignore', category=ConvergenceWarning)

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
q_network = load_model("../runs/20230927-233906/models/model_9999999.pt").to(device)

In [5]:
data = load_data()
for concept in concept_instances.values():
    concept.prepare_data(data, max_size=3000)

In [6]:
def calculate_r2(train_acts, train_values, test_acts, test_values):
    reg = linear_model.LassoCV(max_iter=50, cv=5, n_alphas=5)
    reg.fit(train_acts, train_values)
    pred = reg.predict(test_acts)
    score = r2_score(test_values, pred)
    return reg, score

In [7]:
def calculate_accuracy(train_acts, train_values, test_acts, test_values):
    reg = linear_model.LogisticRegressionCV(max_iter=100, cv=10, Cs=10)
    reg.fit(train_acts, train_values)
    pred = reg.predict(test_acts)
    score = accuracy_score(test_values, pred)
    return reg, 2*score-1

In [8]:
concept = concept_instances["ball left paddle (b)"]
_, train_acts_dict = q_network(torch.tensor(concept.obs_train).to(device), return_acts=True)
test_q_values, test_acts_dict = q_network(torch.tensor(concept.obs_test).to(device), return_acts=True)

layer = 5
train_acts = train_acts_dict[str(layer)].cpu().detach().numpy()
test_acts = test_acts_dict[str(layer)].cpu().detach().numpy()
train_acts = train_acts.reshape(len(train_acts), -1)
test_acts = test_acts.reshape(len(test_acts), -1)

reg, score = calculate_accuracy(train_acts, concept.values_train, test_acts, concept.values_test)
print(score)

0.8366666666666667


In [9]:
if reg.intercept_ > 0: # TODO: positive then cav points to positive class, otherwise it points away?
    cav = reg.coef_[0]
else:
    cav = -reg.coef_[0]

In [10]:
# pertubate a tiny bit of cav and see how q values change
test_acts_changed = torch.tensor(test_acts + (0.0001 * cav), dtype=torch.float32).to(device)
test_acts_changed = test_acts_changed.reshape(test_acts_dict[str(layer)].shape)
# forward activations from given layer
test_q_values_changed = q_network.network[layer + 1:](test_acts_changed)

In [11]:
print(test_q_values[0])
print(test_q_values_changed[0])
print(test_q_values[0] - test_q_values_changed[0])
q_values_diff = test_q_values_changed - test_q_values

tensor([4.3660, 4.3731, 4.3353, 4.3615], device='cuda:0',
       grad_fn=<SelectBackward0>)
tensor([4.3660, 4.3730, 4.3352, 4.3615], device='cuda:0',
       grad_fn=<SelectBackward0>)
tensor([4.9114e-05, 5.6744e-05, 1.1015e-04, 1.6689e-05], device='cuda:0',
       grad_fn=<SubBackward0>)


In [12]:
# how often does max increase?
max_diff = test_q_values_changed.max(dim=1)[0] - test_q_values.max(dim=1)[0]
improvements = sum(max_diff > 0) / len(max_diff)
print(improvements.item())

0.40833333134651184


In [13]:
# how often does the q value for each action increase?
improvement_counter = {0: 0, 1: 0, 2: 0, 3: 0}
for i in range(len(q_values_diff)):
    improvement = q_values_diff[i] > 0
    for j in range(4):
        if improvement[j]:
            improvement_counter[j] += 1
        
actions = ['None', 'Fire', 'Right', 'Left']
for i in range(4):
    print(f"{actions[i]}: {improvement_counter[i] / len(q_values_diff)}")

None: 0.415
Fire: 0.41833333333333333
Right: 0.155
Left: 0.5833333333333334


### Do same with other concept (todo refactor)

In [14]:
concept = concept_instances["ball right paddle (b)"]
_, train_acts_dict = q_network(torch.tensor(concept.obs_train).to(device), return_acts=True)
test_q_values, test_acts_dict = q_network(torch.tensor(concept.obs_test).to(device), return_acts=True)

layer = 5
train_acts = train_acts_dict[str(layer)].cpu().detach().numpy()
test_acts = test_acts_dict[str(layer)].cpu().detach().numpy()
train_acts = train_acts.reshape(len(train_acts), -1)
test_acts = test_acts.reshape(len(test_acts), -1)

reg, score = calculate_accuracy(train_acts, concept.values_train, test_acts, concept.values_test)
print(score)

if reg.intercept_ > 0: # TODO: positive then cav points to positive class, otherwise it points away?
    cav = reg.coef_[0]
else:
    cav = -reg.coef_[0]

# pertubate a tiny bit of cav and see how q values change
test_acts_changed = torch.tensor(test_acts + (0.0001 * cav), dtype=torch.float32).to(device)
test_acts_changed = test_acts_changed.reshape(test_acts_dict[str(layer)].shape)
# forward activations from given layer
test_q_values_changed = q_network.network[layer + 1:](test_acts_changed)

print(test_q_values[0])
print(test_q_values_changed[0])
print(test_q_values[0] - test_q_values_changed[0])
q_values_diff = test_q_values_changed - test_q_values

# how often does max increase?
max_diff = test_q_values_changed.max(dim=1)[0] - test_q_values.max(dim=1)[0]
improvements = sum(max_diff > 0) / len(max_diff)
print(improvements.item())

# how often does the q value for each action increase?
improvement_counter = {0: 0, 1: 0, 2: 0, 3: 0}
for i in range(len(q_values_diff)):
    improvement = q_values_diff[i] > 0
    for j in range(4):
        if improvement[j]:
            improvement_counter[j] += 1
        
actions = ['None', 'Fire', 'Right', 'Left']
for i in range(4):
    print(f"{actions[i]}: {improvement_counter[i] / len(q_values_diff)}")

0.8266666666666667
tensor([5.1732, 5.1733, 5.1466, 5.1700], device='cuda:0',
       grad_fn=<SelectBackward0>)
tensor([5.1731, 5.1732, 5.1467, 5.1699], device='cuda:0',
       grad_fn=<SelectBackward0>)
tensor([ 7.7724e-05,  8.7261e-05, -3.4809e-05,  1.6594e-04], device='cuda:0',
       grad_fn=<SubBackward0>)
0.4883333444595337
None: 0.48
Fire: 0.4766666666666667
Right: 0.7533333333333333
Left: 0.37166666666666665
