# Differential for Meta Learning

## 0. Settings

Let we train a model $f(\cdot;w)$ with data $x$, a loss function $\mathcal{L}$ and a learning rate $\alpha$. Then we can derive following equations about the weight of model will be $w_T$, the weight after $T$-th iterations.

$$w_T = w_{T-1} - \alpha \nabla_{w_{T-1}} \mathcal{L} (f(x;w_{T-1}))$$


$$\frac{\partial \mathcal{L}}{\partial x} = \frac{\partial}{\partial x} \mathcal{L} (f(x;w_{T}))$$

In a normal situation, we consider $w_T$ is a independent variable for $x$. PyTorch also adopts this framework through nn.Module or nn.Parameters. However, in a meta learning framwork, we want to know a gradient of gradient. That is, in the second eqaution, **you have to note that $w_T$ is also a function of $x$** if you do not detach them. In formula, $\frac{\partial}{\partial x} \mathcal{L} (f(x;w_{T}(x)))$.

Now, consider following situation.

$f(x;w)=(wx)^2$, $\mathcal{L}(f(x;w))=2-(wx)^2$, $\alpha=0.1$, $w_0=1$, $x=1$.

Then, $\frac{\partial \mathcal{L}}{\partial w} = -2wx^2$.

Assume we train this function with two iterations.

(1) 1st iteration :

$w_1 = w_0 + 0.1\cdot 2 w_0 \cdot x^2 = w_0 + 0.2w_0x^2 = 1.2$

(2) 2nd iteration :

$w_2 = w_1 + 0.1\cdot 2 w_1 \cdot x^2$

$= (w_0 + 0.2w_0x^2) + 0.2(w_0 + 0.2w_0x^2)x^2$

$= w_0 + 0.4w_0x^2 + 0.04w_0x^4 = 1.44$

Then, the gradient of $x$ for each iteration would be as follows :

(1) 1st iteration :

$$\frac{\partial}{\partial x} \mathcal{L} (f(x;w_{0})) = \frac{\partial}{\partial x} \{ 2 - (w_0x)^2\} = -2w_0^2x = -2$$

(2) 2nd iteration :

$$\frac{\partial}{\partial x} \mathcal{L} (f(x;w_{1})) = \frac{\partial}{\partial x} \{ 2 - (w_1x)^2\}$$
$$= - \frac{\partial}{\partial x} (\{w_0 + 0.2w_0x^2\}x)^2$$
$$= - \frac{\partial}{\partial x} (w_0x + 0.2w_0x^3)^2$$
$$= - 2(w_0x + 0.2w_0x^3)(w_0 + 0.6w_0x^2)$$
$$= -2\cdot1.2\cdot1.6 = -3.84$$

Again, note that $\frac{\partial}{\partial x} \mathcal{L} (f(x;w_{1})) \neq -2w_1^2x = -2.88$.

## 1. Codes

From now on, we will see the difference between two frameworks.

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import collections

## 1.1. Normal Situation

In [2]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.fc = nn.Linear(1, 1, bias=False)
        self.fc.weight.data.uniform_(1,1)
            
    def forward(self, x):
        x = self.fc(x)
        return x**2

In [3]:
model = Net()

In [4]:
x = torch.tensor([[1.]])
Y = torch.tensor([[2.]])

In [5]:
x.requires_grad = True

In [6]:
for j in range(2) :
    print("[",j, "th iteration]")
    
    pre = model(x)
    cost = nn.L1Loss()(pre, Y)
    
    grad = torch.autograd.grad(cost, x, retain_graph=True)[0]
    print("Grad:", grad)
    
    optimizer = optim.SGD(model.parameters(), lr=0.1)
    optimizer.zero_grad()
    cost.backward(retain_graph=True)
    optimizer.step()
        
    print("Weight:",model.fc.weight)
    print()

[ 0 th iteration]
Grad: tensor([[-2.]])
Weight: Parameter containing:
tensor([[1.2000]], requires_grad=True)

[ 1 th iteration]
Grad: tensor([[-2.8800]])
Weight: Parameter containing:
tensor([[1.4400]], requires_grad=True)



Thus, we can see that PyTorch considers $w_T$ as a independent variable for $x$. Then, how about we save the params in each iteration?

In [7]:
def get_module(model, names):
    if len(names) == 0 :
        return model
    name = names[0]
    del names[0]
    return get_module(getattr(model, name), names)

In [8]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.fc = nn.Linear(1, 1, bias=False)
        self.fc.weight.data.uniform_(1,1)
            
    def forward(self, x):
        x = self.fc(x)
        return x**2

In [9]:
model = Net()

In [10]:
x = torch.tensor([[1.]])
Y = torch.tensor([[2.]])

In [11]:
x.requires_grad = True

In [12]:
current_param = collections.OrderedDict(model.named_parameters())
params = []
params.append(current_param.copy())

for j in range(2) :
    print("[",j, "th iteration]")
    
    # RuntimeError: One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior.
    # Load_state_dict will only change the "value" of params.
    # model.load_state_dict(params[j])
    
    for key in params[j].keys() :
        a = key.split(".")
        setattr(get_module(model, a[:-1]), a[-1], params[j][key])
    
    pre = model(x)
    cost = nn.L1Loss()(pre, Y)
    
    grad = torch.autograd.grad(cost, x, retain_graph=True)[0]
    print("Grad:", grad)
    
    # If you use below lines, it changes the original params
    # so that the gradient calculation will be inaccurate.
    
    # optimizer = optim.SGD(params[j].values(), lr=0.1)
    # optimizer.zero_grad()
    # cost.backward(retain_graph=True)
    # optimizer.step()

    grads = torch.autograd.grad(cost, params[j].values(), 
                                retain_graph=True, create_graph=True)

    params.append(params[j].copy())
    
    for i, key in enumerate(current_param.keys()):
        params[j+1][key] = nn.Parameter(params[j+1][key] - 0.1*grads[i])

    print("Weight:",params[j+1])
    print()

[ 0 th iteration]
Grad: tensor([[-2.]])
Weight: OrderedDict([('fc.weight', Parameter containing:
tensor([[1.2000]], requires_grad=True))])

[ 1 th iteration]
Grad: tensor([[-2.8800]])
Weight: OrderedDict([('fc.weight', Parameter containing:
tensor([[1.4400]], requires_grad=True))])



Same results. The problem is "nn.Parameter". If you use nn.Module, which is a very obvious choice for normal cituation, the Module will assign a nn.Paramter for weights. One of **the characteristic of nn.Parameter is that changes are NOT tracked**. - Of course, because of this, we can use torch models in a normal situation without deep worries - Thus, with nn.Module or nn.Parameter, we can't get accurate gradients of $x$. Then, what can we use? **The answer is the functional F instead of nn.Module.**

In [13]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
            
    def forward(self, x, w):
        x = F.linear(x, w, bias=None)
        return x**2

In [14]:
model = Net()

In [15]:
x = torch.tensor([[1.]])
Y = torch.tensor([[2.]])

In [16]:
x.requires_grad = True

In [17]:
w = torch.tensor([[1.]])
w.requires_grad = True

for j in range(2) :
    print("[",j, "th iteration]")
    
    pre = model(x, w)
    cost = nn.L1Loss()(pre, Y)
    
    grad = torch.autograd.grad(cost, x, retain_graph=True)[0]
    print("Grad:", grad)
    
    grads = torch.autograd.grad(cost, w, 
                                retain_graph=True, create_graph=True)

    w = w - 0.1*grads[i]
    
    print("Weight:", w)
    print()

[ 0 th iteration]
Grad: tensor([[-2.]])
Weight: tensor([[1.2000]], grad_fn=<SubBackward0>)

[ 1 th iteration]
Grad: tensor([[-3.8400]])
Weight: tensor([[1.4400]], grad_fn=<SubBackward0>)



**Right results for a meta learning framework!**. However, the problem is changing all the nn.Modules into functional is quite annoying. In this case, we can use the package [Higher](https://github.com/facebookresearch/higher). Thanks facebookresearch!

In [18]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.fc = nn.Linear(1, 1, bias=False)
        self.fc.weight.data.uniform_(1,1)
            
    def forward(self, x):
        x = self.fc(x)
        return x**2

In [19]:
model = Net()
optimizer = optim.SGD(model.parameters(), lr=0.1)

In [20]:
x = torch.tensor([[1.]])
Y = torch.tensor([[2.]])

In [21]:
x.requires_grad = True

In [22]:
import higher

In [23]:
with higher.innerloop_ctx(model, optimizer) as (fmodel, diffopt):
    for j in range(2):
        print("[",j, "th iteration]")
        pre = fmodel(x)
        cost = nn.L1Loss()(pre, Y)
        
        grad = torch.autograd.grad(cost, x, retain_graph=True)[0]
        print("Grad:", grad)
        
        diffopt.step(cost)
        print("Weight:", fmodel.fc.weight)
        print()

[ 0 th iteration]
Grad: tensor([[-2.]])
Weight: tensor([[1.2000]], grad_fn=<AddBackward0>)

[ 1 th iteration]
Grad: tensor([[-3.8400]])
Weight: tensor([[1.4400]], grad_fn=<AddBackward0>)



In [24]:
# Warning! the model is not changed.
model.fc.weight

Parameter containing:
tensor([[1.]], requires_grad=True)