In [497]:
import torch
from torch.autograd import Variable, Function
import inspect
import random
import copy

# import functools
# TODO: Use functools.wrap to get original function/method dir attributes

class EncryptedAdd(Function):
    
    @staticmethod
    def forward(ctx, a, b):
        return a+b
        # compute a + b on encrypted data - they are regular PyTorch tensors
        
    @staticmethod
    def backward(ctx, grad_out):
        grad_out = VariableProxy(grad_out.data)
        return grad_out.var,grad_out.var
        # not grad_out operators are overloaded
        
class EncryptedMult(Function):
    
    @staticmethod
    def forward(ctx, a, b):
        ctx.save_for_backward(a,b)       
        return a*b
        # compute a * b on encrypted data - they are regular PyTorch tensors
        
    @staticmethod
    def backward(ctx, grad_out):
        a,b = ctx.saved_tensors
        grad_out = grad_out
        return Variable(grad_out.data*b),Variable(grad_out.data*a)
        # not grad_out operators are overloaded

class VariableProxy(object):
    
    def __init__(self, var, requires_grad=True):
        self.var = Variable(var,requires_grad=requires_grad)

    def __add__(self, other):
        return (EncryptedAdd.apply(self.var, other.var))
    
    def __mul__(self,other):
        return (EncryptedMult.apply(self.var, other.var))
    
    def grad(self):
        return self.var.grad
    
x = VariableProxy(torch.FloatTensor([1,1,1]),requires_grad=True)
y = VariableProxy(torch.FloatTensor([2,3,4]),requires_grad=True)

z = x * y

z.backward(torch.FloatTensor([1]))

x.grad()

Variable containing:
 2
 3
 4
[torch.FloatTensor of size 3]

In [494]:
z

Variable containing:
 2
 3
 4
[torch.FloatTensor of size 3]

Variable containing:
 2
 3
 4
[torch.FloatTensor of size 3]