In [None]:
import sys

sys.path.append("../")

# Torch

### buffer

In [None]:
import torch

mdl = torch.nn.Linear(1, 1)
mdl.register_buffer("free_node", torch.rand((2, 1)))

In [None]:
for buf in mdl.buffers():
    print(id(buf))

In [None]:
mdl.bias

## nograd

In [None]:
import torch


@torch.no_grad()
def foo(x, y):
    x.grad = torch.ones_like(x)
    return x + bar(x, y)


def bar(x, y):
    return x + 2 * y


x, y = torch.rand((2, 1), requires_grad=True), torch.rand((2, 1))
z = foo(x, y)

In [None]:
@torch.no_grad()
def baz(x: torch.tensor):
    return x.clone().detach().requires_grad_(True)

In [None]:
baz(torch.tensor([1.0]))

In [None]:
x = torch.rand((2, 1))

In [None]:
x.grad = torch.ones_like(x)

## pytest

In [None]:
import ipytest

ipytest.autoconfig()

## bias update

In [None]:
import torch

batch_size = 3
nudge_n = torch.rand((batch_size, 5))
free_n = torch.rand((batch_size, 5))
(nudge_n - free_n) * (nudge_n + free_n - 2 * torch.ones_like(free_n))  # (n-1)

In [None]:
res = (nudge_n - free_n) * (nudge_n + free_n - 2 * torch.ones_like(free_n))  # (n-1)
res = res.mean(dim=0)

In [None]:
res

In [None]:
mean_nudge_n = torch.mean(nudge_n, dim=0)
mean_free_n = torch.mean(free_n, dim=0)
(mean_nudge_n - mean_free_n) * (
    mean_nudge_n + mean_free_n - 2 * torch.ones_like(mean_free_n)
)  # (n-1)

In [None]:
torch.inner((nudge_n - free_n), (nudge_n + free_n - 2 * torch.ones_like(free_n)))  # (n-1)

## etc

In [None]:
import torch.nn as nn

lyr1 = nn.Linear(1, 2, bias=False)

In [None]:
hasattr(lyr1, "bias")

In [None]:
import matplotlib.pyplot as plt

# plot 3d logsumexp
import torch

x = torch.linspace(-1, 1, 10)
y = torch.linspace(-1, 1, 10)
xy = torch.meshgrid(x, y)
logsumexp = torch.logsumexp(xy, dim=0)
plt.plot3d(xy, logsumexp)

In [None]:
y = torch.rand((7))

In [None]:
import torch
import torch.nn as nn

mdl = nn.Linear(1, 1, bias=True)

In [None]:
mdl.bias

In [None]:
list(mdl.named_parameters())

In [None]:
# create empty list
W = torch.rand((40, 5000))
b = torch.randint(1, 5, (5000,))
print(W, b)

In [None]:
# %%timeit
W / b

In [None]:
D = torch.diag_embed(b).float()

In [None]:
%%timeit
W @ D

In [None]:
from src.utils import eqprop_util

In [None]:
B = 64
In = Out = 1000
a = torch.rand((B, In))
b = torch.rand((B, Out))

In [None]:
%%timeit
eqprop_util.deltaV(b, a)

In [None]:
%run -c python -m foo.py cProfile

In [None]:
import functools

import torch

from src.utils import eqprop_util

partial_fn = functools.partial(eqprop_util.rectifier_p3_i, Is=1e-6)
scrfn = torch.jit.script(functools.wraps(partial_fn), example_inputs=(torch.rand(2, 3),))

In [None]:
%%timeit
scrfn(torch.rand(128, 3000))

In [None]:
%%timeit
eqprop_util.rectifier_p3_i(torch.rand(128, 3000))

In [None]:
a = []
not a