In [200]:
import numpy as np
from concepts import concept_instances
from utils import prepare_folders
from dqn import load_model
import torch
import torch.nn as nn
from torchvision.transforms import v2
import matplotlib.pyplot as plt

### autograd for a

In [207]:
x = torch.ones(10, dtype=torch.float64)
y = torch.zeros(3, dtype=torch.float64) # target tensor
w = torch.randn(10, 3, requires_grad=True, dtype=torch.float64) # weights
b = torch.randn(3, requires_grad=True, dtype=torch.float64) # bias
z = x @ w + b
a = torch.sigmoid(z)
a.retain_grad()
print('a: ', a)
loss = nn.functional.mse_loss(a, y)
print('loss: ', loss.item())
loss.backward()
print('a.grad: ', a.grad)
print('b.grad: ', b.grad)
'''
# gradient descent on a
a.data -= 0.1 * a.grad.data
print('a: ', a)
loss = nn.functional.mse_loss(a, y)
print('loss: ', loss.item())
'''

a:  tensor([0.5702, 0.1111, 0.0018], dtype=torch.float64,
       grad_fn=<SigmoidBackward0>)
loss:  0.11249074916824343
a.grad:  tensor([0.3801, 0.0741, 0.0012], dtype=torch.float64)
b.grad:  tensor([9.3160e-02, 7.3130e-03, 2.0852e-06], dtype=torch.float64)


"\n# gradient descent on a\na.data -= 0.1 * a.grad.data\nprint('a: ', a)\nloss = nn.functional.mse_loss(a, y)\nprint('loss: ', loss.item())\n"

### finite differences for b

In [208]:
# estimate b gradients with finite differences
with torch.no_grad():
    h = 1e-10
    b1 = b.clone()
    for i in range(3):
        b1[i] += h
        z1 = x @ w + b1
        a1 = torch.sigmoid(z1)
        loss1 = nn.functional.mse_loss(a1, y)
        print('finite difference: ', (loss1.item() - loss.item()) / h)
        b1[i] -= h

finite difference:  0.09316020177507767
finite difference:  0.007313039063205906
finite difference:  2.0816681711721685e-06


### finite differences for a

In [211]:
# estimate a gradients with finite differences
with torch.no_grad():
    h = 1e-10
    a1 = a.clone()
    for i in range(3):
        a1[i] += h
        loss1 = nn.functional.mse_loss(a1, y)
        print('finite difference: ', (loss1.item() - loss.item()) / h)
        a1[i] -= h

finite difference:  0.380133979849262
finite difference:  0.07405798196913338
finite difference:  0.0011801670751765414


### autograd for cav

In [294]:
cav = torch.tensor([0.0, 0.5, 1.0], dtype=torch.float64)

x = torch.ones(10, dtype=torch.float64)
y = torch.zeros(3, dtype=torch.float64) # target tensor
w = torch.randn(10, 3, requires_grad=True, dtype=torch.float64) # weights
b = torch.randn(3, requires_grad=True, dtype=torch.float64) # bias
z = x @ w + b
a = torch.sigmoid(z)
a.retain_grad()
print('a: ', a)
loss = nn.functional.mse_loss(a, y)
print('loss: ', loss.item())
loss.backward()
print('a.shape: ', a.shape)
print('cav.shape: ', cav.shape)
print('a.grad: ', a.grad)
cav_grad = a.grad @ cav
print('cav_grad: ', cav_grad)
# update a with cav_grad
with torch.no_grad():
    a1 = a.clone()
    a1 -= 0.5 * cav * cav_grad
    print('a1: ', a1)
    print('a diff: ', a1 - a)
    loss1 = nn.functional.mse_loss(a1, y)
    print('loss1: ', loss1.item())

a:  tensor([0.1506, 0.7710, 0.0926], dtype=torch.float64,
       grad_fn=<SigmoidBackward0>)
loss:  0.2085440093811767
a.shape:  torch.Size([3])
cav.shape:  torch.Size([3])
a.grad:  tensor([0.1004, 0.5140, 0.0617], dtype=torch.float64)
cav_grad:  tensor(0.3187, dtype=torch.float64)
a1:  tensor([ 0.1506,  0.6913, -0.0668], dtype=torch.float64)
a diff:  tensor([ 0.0000, -0.0797, -0.1594], dtype=torch.float64)
loss1:  0.16833698155411003


### finite differences for cav

In [293]:
# d L/d cav, when cav is a direction in a?
# estimate cav gradients with finite differences
with torch.no_grad():
    h = 1e-10
    a1 = a.clone()
    perturb = h * cav
    print('perturb: ', perturb)
    a1 += perturb
    loss1 = nn.functional.mse_loss(a1, y)
    print('loss0: ', loss.item())
    print('loss1: ', loss1.item())
    print('finite difference cav: ', (loss1.item() - loss.item()) / h)

perturb:  tensor([0.0000e+00, 5.0000e-11, 1.0000e-10], dtype=torch.float64)
loss0:  0.2980911451220287
loss1:  0.29809114518388335
finite difference ca:  0.6185463252705858
