In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
print(torch.cuda.is_available())

class CNN(nn.Module):
    layers: torch.nn.Module
    def __init__(self, conv1c, conv2c, hidden_dims, output_c):
        super(CNN,self).__init__()
        layers = []
        self.view_input_dim=4*4*conv2c
        self.conv1=nn.Conv2d(1,conv1c,5,1)
        self.conv2=nn.Conv2d(conv1c,conv2c,5,1)
        i = self.view_input_dim
        for o in hidden_dims:
            layers += [
                torch.nn.Linear(i, o),
                torch.nn.ReLU(inplace=True),
            ]
            i = o
        layers += [torch.nn.Linear(i, output_c)]
        self.layers = torch.nn.Sequential(*layers)

    def forward(self,x):
        x=F.relu(self.conv1(x))
        x=F.max_pool2d(x,2,2)
        x=F.relu(self.conv2(x))
        x=F.max_pool2d(x,2,2)
        x=x.view(-1,self.view_input_dim)
        x=self.layers(x)
        return F.log_softmax(x,dim=1) 

True


In [4]:
new_layer_n = [6, 7] # 说明新模型第几层是新的, 从0开始.
def get_liner_name(net):
    name_list = []
    for name, g in list(net.named_modules())[1:]:
        if isinstance(g, nn.Linear):
            name_list.append(name)
    return name_list

def get_conv_name(net):
    name_list = []
    for name, g in list(net.named_modules())[1:]:
        if isinstance(g, nn.Conv2d):
            name_list.append(name)
    return name_list

def Surgery_move(new_w, new_b, old_w, old_b, eps=0):
    old_outdim, old_indim = old_w.shape
    new_w[:old_outdim, :old_indim] = old_w
    new_w[:old_outdim, old_indim:] *= eps
    new_b[:old_outdim] = old_b
    return new_w, new_b

def Surgery2(oldnet, newnet, new_layer_n):
    # call after surgery conv
    old_name = get_liner_name(oldnet)
    new_name = get_liner_name(newnet)
    old_p = oldnet.state_dict()
    new_p = newnet.state_dict()
    id_index = 0 # record how much identity
    out = -1 # TODO: input feature dim 
    for i, layer in enumerate(new_name):
        new_w, new_b = new_p[layer+'.weight'], new_p[layer+'.bias']
        if i in new_layer_n: # identity 
            old_w, old_b = torch.eye(out, out), torch.zeros([out])
            id_index += 1
        else:
            name = old_name[i - id_index]
            old_w, old_b = old_p[name+'.weight'], old_p[name+'.bias']
        new_w, new_b = Surgery_move(new_w, new_b, old_w, old_b)
        out = old_b.shape[0]
    return new_p

0 0
25 torch.Size([25])
1 0
50 torch.Size([50])
2 0
100 torch.Size([100])
3 0
100 torch.Size([100])
4 0
100 torch.Size([100])
5 0
100 torch.Size([100])
6 0
100 torch.Size([100])
7 1
100 torch.Size([100])
8 2
2 torch.Size([2])


OrderedDict([('layers.0.weight',
              tensor([[ 8.5299e-03, -2.2137e-01,  4.2435e-01, -3.4890e-01,  1.6520e-01,
                        3.0243e-02,  5.2682e-03, -0.0000e+00, -0.0000e+00,  0.0000e+00,
                        0.0000e+00, -0.0000e+00, -0.0000e+00],
                      [-9.4769e-01,  2.8481e-02,  3.4451e-01, -6.2756e-01,  1.5863e-01,
                       -4.7748e-01, -8.3875e-01,  0.0000e+00, -0.0000e+00,  0.0000e+00,
                       -0.0000e+00, -0.0000e+00,  0.0000e+00],
                      [ 4.3869e-01, -4.8007e-01, -1.1632e-01, -3.6142e-01,  4.2263e-01,
                        1.9024e-01, -1.9510e-02,  0.0000e+00, -0.0000e+00,  0.0000e+00,
                        0.0000e+00,  0.0000e+00, -0.0000e+00],
                      [ 5.6465e-01, -9.9225e-02,  8.4840e-01,  4.3230e-01, -1.4533e-02,
                       -1.0773e-01,  3.3058e-01,  0.0000e+00,  0.0000e+00, -0.0000e+00,
                       -0.0000e+00, -0.0000e+00, -0.0000e+00],
           