In [2]:
import torch
import torch.nn.functional as F
from torch.autograd.functional import vjp
import sys

In [3]:
'''
Testing designs for Module class:

Important requirements:
1)class should support architectures of wierd shapes and even circular ones. It shouldnt be restricted to sequential
    layers.
2) each module has 2 important functions:
    a) forward pass
    b) backward pass
        - gradients of outputs wrt inputs
        - gradients of outputs wrt parameters 
3) should support any number of inputs to a module?
'''

In [None]:
class Node():
    '''
    The idea behind creating this class is to keep track of
    parents of lone tensor between modules
    Will prune once final backward is implemented
    
    Each alone node should have only 1 parent
    Each module can have multiple parent nodes
    Each node/tensor is called only one time(mybe used by muliple child) ..may result in huge model sizes coz it 
    stops reuse of tensor
    '''
    def __init__(self,o):
        self.parent = None
        self.child = []
        self.o = o
        self.outgoing_gradient = []
        self.output_order = None
        self.pass_number = None
        
    def append_child(self,n):
        self.child.append(n)
        
    def backward(self,gradient):
        # if gradients from all children have arrived then sum them and call parents backward based 
        # on ouput ordering from the parent
        # dont need self object of the child
        assert self.o.size() == gradient.size()
        self.outgoing_gradient.append(gradient)
        
        if len(self.child) == len(self.outgoing_gradient):
            v = torch.stack(self.outgoing_gradient,dim =0).sum(dim = 0)
            self.parent.backward(v,self.output_order,self.pass_number)

    def __iter__(self):
        # if implemented need to change loop in Module.__call__
        raise NotImplementedError

class Module():
    '''
    Two usage constraints till now:
    1) keeping *args in forward def
    2) using class.__call__ instead of forward where we actually define forward pass
    3) only inputs and targets can be defined as tensors. Every other transformation
        even a simple (learnable)matrix mul has to be done using module.
    4) all outputs should later be used. Eg if module ouputs a,b,c; all three should later be used in other modules
    '''
    def __init__(self):
        self.pass = 0 # keeps track of how many times the modules is passed through
#         self.node = Node(self)
        self.parents = []
        
        # for each pass
        self.inputs = [] # do I need this? Yes for gradients
        self.outputs = []
#         self.trainable_params = [] #clone of params at that point
        self.gradients_from_output = [] #will have to mantain sequence here
        self.output_nodes = []
        self.gradients_for_trainable_params = [] #sum of this list will be the gradients for trainable params
        
        '''
        # Save cloned values of all tensors used in forward()???? Do I need this?? check bptt.
        # I dont need saved tensors in forward as long as I am using single loss function.
        '''
        
    def __call__(self,*input):
        inputs_for_forward = []
        parents_ = []
        if hasattr(input,'__iter__') and not torch.is_tensor(input):
            # multiple inputs
            for i in input:
                # Make sure i is a tensor or a node
                if isinstance(i,Node) or torch.is_tensor(i):
                    parents_.append(i)
                    if not torch.is_tensor(i):
                        i.append_child(self)
                        self.inputs_for_forward.append(i.o)
                    else:
                        self.inputs_for_forward.append(i)
                else:
                    print(" error : inputs should only be tensors or instances of class Node")
                    sys.exit(1)
                    #TODO : make new exception
        else:
            #single input... Not needed?? input will always come as a list or tensor
            if isinstance(input,Node) or torch.is_tensor(input):
                parents_.append(input)
                if not torch.is_tensor(input):
                    input.append_child(self)
                    self.inputs_for_forward.append(input.o)
                else: 
                    self.inputs_for_forward.append(input)
            else:
                print(" error : inputs should only be tensors or instances of class Node")
                sys.exit(1)
        
        outputs_= self.forward(*inputs_for_forward) # a simple trick to unlist a list 
        
        output_node = []
        
        #Outputs_should alway be a single or multiple tensor
        try:
            if len(outputs_) and not torch.is_tensor(outputs_):
                for j,i in enumerate(outputs_):
                    assert torch.is_tensor(i)
                    c = Node(i)
                    c.parent = self
                    c.output_order =j
                    c.pass_number = self.pass
                    output_node.append(c)
            else:
                assert torch.is_tensor(outputs_)
                c = Node(outputs_)
                c.parent = self
                c.output_order = 0
                c.pass_number = self.pass
                output_node.append(c)
        except TypeError:
            print(" Only lists or tuples of tensors allowed as output of forward()")
        
        self.inputs.append(inputs_for_forward)
        self.outputs.append(outputs_)
        self.output_nodes.append(output_node)
        self.gradients_from_output.append([None]*len(output_node))
        self.parents.append(parents_)
        self.pass += 1
        
    def forward(self,input,*args): # will have to pass by reference
        '''
        while implementing I have in child classes
        I have to keep the func def like this
        forward(input_1,input_2..input_n,*args)
        where
        input_1...input_n are the number of inputs expected
        *args for self trainable tensors that need gradients
        '''
        raise NotImplementedError
        
    def get_trainable_params(self):
        # I dont have to worry about pass by assignment coz vjp just requires param values
        # once I get gradients I can just update the params in the same order with gradient
        trainable_params = []
        for i in vars(self):
            if torch.is_tensor(self.__getattribute__(i)):
                if self.__getattribute__(i).requires_grad == True:
                   trainable_params.append(self.__getattribute__(i))
        return trainable_params
    
    
    def update_parameters(self,gradients):
        # use getattribute again for updating 
        # same order of iteration over dicts in python 3.6+
        
        # update of params occurs when u have gradients from all the passes.
        # you sum those gradients and update parameter with it
        raise NotImplementedError
    
    def prepare_gradients_for_trainable_params(self, gradients):
        # Do I need pass_no info here..not exactly.cud be good to check
        # when the length of the list becomes equal to self.pass_no then update the variable
        self.gradients_for_trainable_params.append(gradients)
        if len(self.gradients_for_trainable_params)==self.pass:
            self.update_parameters()
    
    def make_tuple_for_vjp(self):
        pass
    
    def backward(self,v,output_order,pass_no):
        '''
        Assumption all output nodes are later used and are involved in gradients
        ouput nodes do only one backward with a prticular pass no
        TODO: update method for Nodes with no child
        '''
        
        self.gradients_from_output[pass_no][output_order] = v
        '''
        check if gradients from all child of the pass no are here then do backwards for its parents and send back
        gradients with respect to inputs and save gradients wrt to params:
        From the modular approach u can consider delta(i+1) as a sum of gradients from previous layer 
        So I am thinking that we can send gradients in steps instead of sending them as one coz that will 
        prevent circular architectures and same modules having different inputs at different times
        '''
        # checking gradients from all child present
        if not self.gradients_from_output[pass_no].__contains__(None):
            #calculate gradient wrt to input and trainable params
            trainable_params = self.get_trainable_params()
            output_, gradients = vjp(self.forward,(*self.inputs[pass_no],*trainable_params),*self.gradients_from_output[pass_no])
            
            gradients_for_inputs = gradients[:len(self.inputs[pass_no])]
            gradients_for_params = gradients[len(self.inputs[pass_no]):]

            # call backward on parent nodes..check if parent is a tensor
            # len of gradients of input is same as the number of parents for that pass
            assert len(gradients_for_inputs)== len(self.parents[pass_no])
            for i,j in zip(self.parents[pass_no],gradients_for_inputs):
                if not torch.is_tensor(i):
                    # not passing gradients to input variable [Remember assumption only input variables are plain 
                    # tensors rest all intermediary tensors are nodes]
                    i.backward(j)
            
        

In [5]:
# How to assign parents to intermediate tensors between modules

tensor([1.1508])

In [15]:
class A:
    def __init__(self):
        _x = torch.randn([2,3],dtype= torch.float)
        _W = torch.randn([3,1],dtype= torch.float)
        _b = torch.randn([2,1], dtype = torch.float)
        _t = torch.randn([2,1], dtype= torch.float)
        _r = torch.randn([2,1], dtype= torch.float)
        
    def forward(self):
        _z = _x.matmul(_W) + _b
        _y = F.relu(_z)
        return F.mse_loss(y,t)
    
    def backward():
        print(vjp(forward,self)
    

parents
child
a
b
c
[tensor([-1.0955], requires_grad=True)]


In [17]:
def test():
    return 10,20,30

a = test()
print(a)

(10, 20, 30)


In [16]:
class A():
    def __init__(self):
        self.no_of_inputs = 2
    
    def test(self,*args):
        print(args)
#         return c+e # this fails coz e wasnt sent in params
        return c + b # but this works!! this a nice little trick
        
    def test2(self,*args):
        self.test3(*args)
    
    def test3(self,a , b ,*args):
        print(*args)
        print(a+b)
    
    def __iter__(self):
        return self

# class B(A):
#     def test()

a = A()
c = 10
b = 5
print(a.test(c,b,10,20))

x = A()
x.test2(15,20,30)

m = [1,2,3,4]
x = ( lambda : i for i in m)
print(list(x))

isinstance(x,A)

(10, 5, 10, 20)
15
30
35
[<function <genexpr>.<lambda> at 0x7faf6a6a1b70>, <function <genexpr>.<lambda> at 0x7faf6a6a1ae8>, <function <genexpr>.<lambda> at 0x7faf6a6a1c80>, <function <genexpr>.<lambda> at 0x7faf6a5f0048>]


TypeError: isinstance() arg 2 must be a type or tuple of types

In [13]:
def f():
    return 10,20,30
h = f()
print(h)
r = 10
print(len(r))

(10, 20, 30)


TypeError: object of type 'int' has no len()

In [17]:
import torch
class Base:
    def __init__(self):
        self.trainable_params ={}

class A(Base):
    def __init__(self,a,b):
        super().__init__()
        self.a = a
        self.b = b
        
    def __setattr__(self,k,v):
        if torch.is_tensor(v) and v.requires_grad == True:
            self.trainable_params[k]=v
        super().__setattr__(k,v)
        
    def forward(self,input,**kwargs):
        print(kwargs)
        print(locals())
        
    
    def backward(self):
        self.forward(torch.randn(1),**self.trainable_params)
        
    
a =A(torch.randn(1,requires_grad=True),torch.randn(1,requires_grad=True))
        

In [18]:
a.backward()

{'a': tensor([0.6448], requires_grad=True), 'b': tensor([0.7502], requires_grad=True)}
{'self': <__main__.A object at 0x7fd78ce823d0>, 'input': tensor([0.4845]), 'kwargs': {'a': tensor([0.6448], requires_grad=True), 'b': tensor([0.7502], requires_grad=True)}}
