# Step 17

## 17.1 Memory Management

Dealing with neural network, we often face the situation where it takes extremely long execution time (even with GPU) because the massive data is not well managed. 

### How does Python(CPython) manage the memory?

1. Reference Counting
2. Garbage Collection

## 17.2 Reference Counting

1. Simple Structure, fast
2. All objects are generated with 0 reference counts
3. When another object use it as a reference, the count increments. 
4. When the object is stop being assigned, the count decreases. When hitting 0, python intepreter retrievess the object.
5. When the object is not needed, the object is deleted immediately.


In [1]:
class obj:
    pass
def f(x):
    print(x)

a=obj() #assigning the variable: Reference Count 1
f(a) #passing to a function: Reference inside the function 2
#Completed function: Reference Count 1
a=None #stop assigning: Reference Count 0

<__main__.obj object at 0x107ff3010>


In [None]:
a=obj()
b=obj()
c=obj()

a.b=b
b.c=c

a=b=c=None

When a=b=c=None is executed, reference count changes like the following image.  
<img src="image/step17.jpeg"><br><br>

## 17.3 Circular Reference

In [2]:
a=obj()
b=obj()
c=obj()

a.b=b
b.c=c
c.a=a

a=b=c=None

<img src="image/step17_1.jpeg">

In the above situation, user cannot approach neither one of the objects.<br>
However, the reference count is not 0, so using the method of reference counting won't delete the memory. 

This is why we need Garbage Collection (GC).

## 17.4 weakref Module

__Weak reference__ is assigning the reference while it is not incrementing the reference count. 

In [16]:
import weakref
import numpy as np

a=np.array([1,2,3])
b=weakref.ref(a)
print(b)

<weakref at 0x1090c2340; to 'numpy.ndarray' at 0x107f5b2d0>


In [17]:
b()

array([1, 2, 3])

In [19]:
a=None
b

<weakref at 0x1090c2340; to 'numpy.ndarray' at 0x107f5b2d0>

<img src="image/step17_2.png">

Appeared dead in the terminal

### Applying in the DeZero

By using __weakref__ we can solve the circular reference problem.

In [22]:

import numpy as np
import weakref

class Variable:
    def __init__(self, data):
        if data is not None:
            if not isinstance(data, np.ndarray):
                raise TypeError('{} is not supported'.format(type(data)))

        self.data = data
        self.grad = None
        self.creator = None
        self.generation=0 #number of generations

    def set_creator(self, func): #everytime the set_creator function is called, +1 to the num. of gen.
        self.creator = func
        self.generation= func.generation+1 #parent_generation+1

    def backward(self):
        if self.grad is None:
            self.grad = np.ones_like(self.data)

        funcs = []
        seen_set=set()
        
        def add_func(f):
            if f not in seen_set:
                funcs.append(f)
                seen_set.add(f)
                funcs.sort(key=lambda x: x.generation)
        add_func(self.creator)
        
        while funcs:
            f = funcs.pop()
            gys = [output().grad for output in f.outputs] #prev: gys = [output.grad for output in f.outputs] 
            gxs = f.backward(*gys) 
            if not isinstance(gxs, tuple): 
                gxs=(gxs, )
            for x, gx in zip(f.inputs, gxs): 
                if x.grad is None:
                    x.grad=gx
                else:
                    x.grad=x.grad+gx 
            
                if x.creator is not None:
                    add_func(x.creator)
    def cleargrad(self):
        self.grad=None

def as_array(x):
    if np.isscalar(x):
        return np.array(x)
    return x

class Function:
    def __call__(self, *inputs):
        xs = [x.data for x in inputs]
        ys = self.forward(*xs)
        if not isinstance(ys, tuple):
            ys = (ys,)
        outputs = [Variable(as_array(y)) for y in ys]
        
        self.generation=max([x.generation for x in inputs]) 

        
        for output in outputs:
            output.set_creator(self)
        self.inputs = inputs
        self.outputs = outputs
        self.outputs=[weakref.ref(output) for output in outputs] #added
        return outputs if len(outputs) > 1 else outputs[0]

    def forward(self, xs):
        raise NotImplementedError()

    def backward(self, gys):
        raise NotImplementedError()
        
class Add(Function):
    def forward(self, x0, x1):
        y=x0+x1
        return y
    def backward(self, gy):
        return gy, gy

def add(x0, x1):
    return Add()(x0, x1)    
    
class Square(Function):
    def forward(self, x):
        y=x**2
        return y
    def backward(self, gy):
        x=self.inputs[0].data # prev: x=self.input.data
        gx=2*x*gy
        return gx
    

def square(x):
    return Square()(x)

## 17.5 Check

In [24]:
for i in range(10):
    x=Variable(np.random.randn(10000)) # big data
    y=square(square(square(x))) # Complex computation 