Skip to content

This issue was moved to a discussion.

You can continue the conversation there. Go to discussion →

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conception of arbitrary gradients #44

Closed
adtzlr opened this issue Jan 26, 2023 · 0 comments
Closed

Conception of arbitrary gradients #44

adtzlr opened this issue Jan 26, 2023 · 0 comments
Labels
question Further information is requested

Comments

@adtzlr
Copy link
Owner

adtzlr commented Jan 26, 2023

This is a first conception of taking arbitrary gradients:

import numpy as np

class Tensor:
    """A dual-tensor with `value` and `dual`."""
    
    def __init__(self, value, dual=None):
        if dual is None:
            dual = np.zeros_like(value)
        self.value = value
        self.dual = dual
        self.size = self.value.size
        self.shape = self.value.shape
    
    def __mul__(self, other):
        other = asdual(other)
        return Tensor(
            self.value * other.value,
            self.value * other.dual + self.dual * other.value
        )
    
    def __rmul__(self, other):
        other = asdual(other)
        return Tensor(
            other.value * self.value,
            other.value * self.dual + other.dual * self.value
        )
    
    def __matmul__(self, other):
        other = asdual(other)
        return matmul(self, other)
    
    def __rmatmul__(self, other):
        other = asdual(other)
        return matmul(other, self)

    def __add__(self, other):
        other = asdual(other)
        return Tensor(self.value + other.value, self.dual + other.dual)
    
    def T(self):
        return transpose(self)
    
    def reshape(self, newshape):
        return reshape(self, newshape)
    
    __array_ufunc__ = None

def reshape(x, newshape):
    rshape = lambda x, newshape: reshape(x, newshape) if isinstance(x, Tensor) else np.reshape(x, newshape)
    return Tensor(rshape(x.value, newshape), rshape(x.dual, newshape))

def matmul(x, y):
    mmul = lambda x, y: matmul(x, y) if (isinstance(x, Tensor) or isinstance(y, Tensor)) else np.einsum("ik...,kj...->ij...", x, y)
    x = asdual(x)
    y = asdual(y)
    return Tensor(mmul(x.value, y.value), mmul(x.value, y.dual) + mmul(x.dual, y.value))

def trace(x):
    tr = lambda x: trace(x) if isinstance(x, Tensor) else np.trace(x)
    return Tensor(tr(x.value), tr(x.dual))

def transpose(x):
    T = lambda x: transpose(x) if isinstance(x, Tensor) else np.einsum("ij...->ji...", x)
    return Tensor(T(x.value), T(x.dual))

def fsum(x, axis):
    fs = lambda x, axis: fsum(x, axis) if isinstance(x, Tensor) else np.sum(x, axis=axis)
    return Tensor(fs(x.value, axis), fs(x.dual, axis))

def asarray(x):
    asarr = lambda x: asarray(x) if isinstance(x[0], Tensor) else np.asarray(x)
    return Tensor(
        asarr([xi.value for xi in x]), 
        asarr([xi.dual for xi in x])
    )

def asdual(x):
    return x if isinstance(x, Tensor) else Tensor(x)

def duals(x, elementwise_axes):
    shape = x.shape
    reps = np.product(np.array(shape)[elementwise_axes])
    for axis in elementwise_axes:
        shape = np.delete(shape, axis)
    size = np.product(shape)
    return np.tile(np.eye(size), reps).reshape(size, *x.shape)

def reduce(fun, axis=-1):
    "Reduce the returned function result by a sum over a given axis."
    def apply(*args, **kwargs):
        return fsum(fun(*args, **kwargs), axis=axis)
    return apply

def grad(fun, elementwise_axes=(-1,)):
    "Return a function which evaluates the gradient of a function."
    
    def evaluate(x, *args, **kwargs):
        "Evaluate the gradient of a function."
        
        res = [fun(Tensor(x, dx), *args, **kwargs) for dx in duals(x, elementwise_axes)]
        
        shape = res[0].value.shape
        for axis in elementwise_axes:
            shape = np.delete(shape, axis)
            
        return reshape(asarray(res), (*shape, *x.shape))
    
    return evaluate

def W(F):
    "A function."
    C = transpose(F) @ F
    return trace(C)

F = np.eye(3) + np.arange(9).reshape(3, 3) / 10
FF = np.tile(F.reshape(3, 3, -1), 1000)

dWdF = grad(W)(FF)
d2WdF2 = grad(grad(W))(FF)
@adtzlr adtzlr added the question Further information is requested label Jan 26, 2023
@adtzlr adtzlr closed this as completed Jan 26, 2023
Repository owner locked and limited conversation to collaborators Jan 26, 2023
@adtzlr adtzlr converted this issue into discussion #45 Jan 26, 2023

This issue was moved to a discussion.

You can continue the conversation there. Go to discussion →

Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

1 participant