In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import torch.autograd.forward_ad as fwAD

primal = torch.randn(10, 10)
tangent = torch.randn(10, 10)

def fn(x, y):
    return x ** 2 + y ** 2

# All forward AD computation must be performed in the context of
# a ``dual_level`` context. All dual tensors created in such a context
# will have their tangents destroyed upon exit. This is to ensure that
# if the output or intermediate results of this computation are reused
# in a future forward AD computation, their tangents (which are associated
# with this computation) won't be confused with tangents from the later
# computation.
with fwAD.dual_level():
    # To create a dual tensor we associate a tensor, which we call the
    # primal with another tensor of the same size, which we call the tangent.
    # If the layout of the tangent is different from that of the primal,
    # The values of the tangent are copied into a new tensor with the same
    # metadata as the primal. Otherwise, the tangent itself is used as-is.
    #
    # It is also important to note that the dual tensor created by
    # ``make_dual`` is a view of the primal.
    dual_input = fwAD.make_dual(primal, tangent)
    assert fwAD.unpack_dual(dual_input).tangent is tangent

    # To demonstrate the case where the copy of the tangent happens,
    # we pass in a tangent with a layout different from that of the primal
    dual_input_alt = fwAD.make_dual(primal, tangent.T)
    assert fwAD.unpack_dual(dual_input_alt).tangent is not tangent

    # Tensors that do not have an associated tangent are automatically
    # considered to have a zero-filled tangent of the same shape.
    plain_tensor = torch.randn(10, 10)
    dual_output = fn(dual_input, plain_tensor)

    # Unpacking the dual returns a ``namedtuple`` with ``primal`` and ``tangent``
    # as attributes
    jvp = fwAD.unpack_dual(dual_output).tangent

assert fwAD.unpack_dual(dual_output).tangent is None

In [10]:
from tensordict import TensorDict

In [58]:
primal = TensorDict({'a': torch.arange(30).reshape(10, 3).float(), 'b': torch.ones(10, 4, 5)}, batch_size=10)

In [59]:
tangent = TensorDict({'a': torch.ones(10, 3), 'b': torch.ones(10, 4, 5)}, batch_size=10)

In [60]:
def f_dict(x):
    return x['a'] ** 2 + x['b'].sum()

In [61]:
torch.func.jvp(f_dict, (primal,), (tangent,))

(tensor([[ 200.,  201.,  204.],
         [ 209.,  216.,  225.],
         [ 236.,  249.,  264.],
         [ 281.,  300.,  321.],
         [ 344.,  369.,  396.],
         [ 425.,  456.,  489.],
         [ 524.,  561.,  600.],
         [ 641.,  684.,  729.],
         [ 776.,  825.,  876.],
         [ 929.,  984., 1041.]]),
 tensor([[200., 202., 204.],
         [206., 208., 210.],
         [212., 214., 216.],
         [218., 220., 222.],
         [224., 226., 228.],
         [230., 232., 234.],
         [236., 238., 240.],
         [242., 244., 246.],
         [248., 250., 252.],
         [254., 256., 258.]]))

In [62]:
from tensordict import TensorDict
import torch

# 2つのtensordictを作成
tensordict1 = TensorDict({
    "key1": torch.randn(3, 4),
    "key2": torch.randn(3, 4)
}, batch_size=[3])

tensordict2 = TensorDict({
    "key1": torch.randn(3, 4),
    "key2": torch.randn(3, 4)
}, batch_size=[3])

# 2つのtensordictを足す
result = tensordict1 + tensordict2

# 結果を確認
print(result["key1"])
print(result["key2"])

TypeError: unsupported operand type(s) for +: 'TensorDict' and 'TensorDict'

In [64]:
1.0 * tensordict1

TypeError: unsupported operand type(s) for *: 'float' and 'TensorDict'

In [65]:
tensordict1.shape

torch.Size([3])

In [66]:
type(tensordict1)

tensordict._td.TensorDict

In [67]:
tensordict1.apply(lambda x: 2 * x)

TensorDict(
    fields={
        key1: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        key2: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)

In [72]:
tensordict1.apply(lambda x: 2 * x)["key2"]

tensor([[ 1.0484, -1.8561,  1.4272,  1.0533],
        [-1.5890,  0.2027, -0.1566,  1.5146],
        [ 1.2385,  2.0889,  1.0955, -1.2020]])

In [71]:
tensordict1['key2']

tensor([[ 0.5242, -0.9280,  0.7136,  0.5267],
        [-0.7945,  0.1013, -0.0783,  0.7573],
        [ 0.6193,  1.0444,  0.5478, -0.6010]])

In [76]:
def add_tensordicts(tensordict1, tensordict2):
    """
    2つのtensordictを足し合わせる関数
    
    Args:
        tensordict1 (TensorDict): 1つ目のtensordict
        tensordict2 (TensorDict): 2つ目のtensordict
        
    Returns:
        TensorDict: 2つのtensordictを足し合わせた結果
    """
    # 2つのtensordictのキーが同じであることを確認
    assert set(tensordict1.keys()) == set(tensordict2.keys()), "Keys in the two tensordicts must be the same."
    
    # 各keyに対応するtensorを足し合わせる
    result = {k:tensordict1[k] + tensordict2[k] for k in tensordict1.keys()}
    
    # 新しいtensordictを作成して返す
    return TensorDict(result, batch_size=tensordict1.batch_size)

In [84]:
add_tensordicts(tensordict1, tensordict2)["key2"]

tensor([[ 0.3702,  0.1352,  0.4325,  1.0031],
        [-1.2659, -1.4851,  0.7063, -0.8183],
        [ 0.7257,  0.3506,  0.5705, -0.2927]])

In [83]:
tensordict1["key2"] + tensordict2["key2"]

tensor([[ 0.3702,  0.1352,  0.4325,  1.0031],
        [-1.2659, -1.4851,  0.7063, -0.8183],
        [ 0.7257,  0.3506,  0.5705, -0.2927]])

In [86]:
isinstance(tensordict1, torch.Tensor)

False

In [1]:
from nigbms.utils.solver import rademacher_like

: 

In [None]:
import torch
from petsc4py import PETSc
from tensordict import TensorDict

: 