In [1]:
import torch.nn as nn
import torch

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv, GINConv
from torch.autograd import Variable
import numpy as np



class MetaModule(nn.Module):
    
    @staticmethod
    def to_var(x, requires_grad=True):
        if torch.cuda.is_available():
            x = x.cuda()
        return Variable(x, requires_grad=requires_grad)


    def params(self):
        for name, param in self.named_params(self):
            yield param

    def named_leaves(self):
        return []

    def named_submodules(self):
        return []

    def named_params(self, curr_module=None, memo=None, prefix=''):
        if memo is None:
            memo = set()

        if hasattr(curr_module, 'named_leaves'):
            for name, p in curr_module.named_leaves():
                if p is not None and p not in memo:
                    memo.add(p)
                    yield prefix + ('.' if prefix else '') + name, p
        else:
            for name, p in curr_module._parameters.items():
                if p is not None and p not in memo:
                    memo.add(p)
                    yield prefix + ('.' if prefix else '') + name, p

        for mname, module in curr_module.named_children():
            submodule_prefix = prefix + ('.' if prefix else '') + mname
            for name, p in self.named_params(module, memo, submodule_prefix):
                yield name, p

    def update_params(self, lr_inner, first_order=False, source_params=None, detach=False):
        if source_params is not None:
            for tgt, src in zip(self.named_params(self), source_params):
                name_t, param_t = tgt
                grad = src
                if first_order:
                    grad = self.to_var(grad.detach().data)
                tmp = param_t - lr_inner * grad
                self.set_param(self, name_t, tmp)
        else:

            for name, param in self.named_params(self):
                if not detach:
                    grad = param.grad
                    if first_order:
                        grad = self.to_var(grad.detach().data)
                    tmp = param - lr_inner * grad
                    self.set_param(self, name, tmp)
                else:
                    param = param.detach_()  # https://blog.csdn.net/qq_39709535/article/details/81866686
                    self.set_param(self, name, param)

    def set_param(self, curr_mod, name, param):
        if '.' in name:
            n = name.split('.')
            module_name = n[0]
            rest = '.'.join(n[1:])
            for name, mod in curr_mod.named_children():
                if module_name == name:
                    self.set_param(mod, rest, param)
                    break
        else:
            setattr(curr_mod, name, param)

    def detach_params(self):
        for name, param in self.named_params(self):
            self.set_param(self, name, param.detach())

    def copy(self, other, same_var=False):
        for name, param in other.named_params():
            if not same_var:
                param = self.to_var(param.data.clone(), requires_grad=True)
            self.set_param(name, param)

In [6]:
g = GATConv(3, 1)
n = GCNConv(3, 1)

In [7]:
for n, p in n.named_parameters():
    print(n)
    print(p)

bias
Parameter containing:
tensor([0.], requires_grad=True)
lin.weight
Parameter containing:
tensor([[-0.8139,  0.8214, -0.0269]], requires_grad=True)


In [4]:
for n, p in g.named_parameters():
    print(n)
    print(p)

att_src
Parameter containing:
tensor([[[-0.2089]]], requires_grad=True)
att_dst
Parameter containing:
tensor([[[1.5129]]], requires_grad=True)
bias
Parameter containing:
tensor([0.], requires_grad=True)
lin_src.weight
Parameter containing:
tensor([[ 0.1886,  0.0371, -1.0170]], requires_grad=True)


In [34]:
class MetaGCN(MetaModule):
    def __init__(
        self, 
        in_channel: int,
        out_channel: int
    ) -> None:
        super().__init__()
        ignore = GCNConv(in_channel, out_channel)
        self.register_buffer('weight', self.to_var(ignore.lin.weight, requires_grad=True))
        self.register_buffer('bias', self.to_var(ignore.bias, requires_grad=True))
        print(self.weight.shape[0])
        print(self.weight.shape[1])
        self.gcn = [GCNConv(self.weight.shape[1], self.weight.shape[0], self.weight, self.bias)]

    def forward(self, x, edge_index):
        return self.gcn[0](x, edge_index)

    def named_leaves(self):
        return [('weight', self.weight), ('bias', self.bias)]

In [37]:
model = MetaGCN(6, 2)

2
6


In [30]:
t = GCNConv(3, 1)

In [9]:
class MetaLinear(MetaModule):
    def __init__(
        self,
        num_features: int,
        embedding_dim: int
    ) -> None:
        super().__init__()
        ignore = nn.Linear(num_features, embedding_dim)

        self.register_buffer('weight', self.to_var(ignore.weight, requires_grad=True))
        self.register_buffer('bias', self.to_var(ignore.bias, requires_grad=True))
        
    def forward(self, x):
        return F.linear(x, self.weight, self.bias)

    def named_leaves(self):
        return [('weight', self.weight), ('bias', self.bias)]

In [10]:
l = MetaLinear(4, 1)

In [41]:
x = torch.randint(10, 20, size=(1,4), dtype=torch.float32)
o = l(x)

In [34]:
ignore = nn.Linear(4, 1)

In [43]:
z = ignore(x)