# Transfer Learning

- [What is transfer learning](#What-is-transfer-learning)
- [What we need to do in transfer learning](#What-we-need-to-do-in-transfer-learning)
- [Same ResNet different input shape](#lets-transfer-a-resnet)
- [Transfer ResNet to UNet](#resnet-to-resunet)
- [Transfer UNet to ResNet](#transfer-resunet-to-resnet)

### What is transfer learning
Extract transferable representations from some source tasks and then adapt the gained representations to improve learning in related target tasks.

### What we need to do in transfer learning
- Load weights and bias from source model
- Set weights and bias to target model
- It's better the source model and the target model have some same structure/blocks

### Let's transfer a ResNet

In [1]:
import torch
from ptimz.model_zoo import ResNet
from torch import nn
from torch.nn.modules.batchnorm import _BatchNorm

In [2]:
from mmcv.cnn import constant_init, normal_init
def weight_init(module):
    if isinstance(module, (nn.Conv3d, nn.Conv2d, nn.ConvTranspose3d, nn.ConvTranspose2d, nn.Linear)):
        normal_init(module, mean=0.5, std=0.02)
    elif isinstance(module, (_BatchNorm, nn.GroupNorm)):
        constant_init(module, 1)

Source network input 2 channels, output 4 classes

In [3]:
conv_cfg = dict(type='Conv3d')
norm_cfg = dict(type='BN3d', requires_grad=True)
source_net = ResNet(50,
                in_channels=2,
                deep_stem=True,
                conv_cfg=conv_cfg,
                norm_cfg=norm_cfg,
                head_type='classification',
                num_classes=4)

# weight init
source_net.apply(weight_init)

source_net = source_net.cuda()

input = torch.rand(1, 2, 64, 64, 64).cuda()
with torch.no_grad():
    output = source_net(input)
print(output.shape)

torch.Size([1, 4])


In [4]:
src_state_dict = source_net.state_dict()
param_keys = list(src_state_dict.keys())
print(param_keys[:10], param_keys[-10:])

['stem.0.conv.weight', 'stem.0.bn.weight', 'stem.0.bn.bias', 'stem.0.bn.running_mean', 'stem.0.bn.running_var', 'stem.0.bn.num_batches_tracked', 'stem.1.conv.weight', 'stem.1.bn.weight', 'stem.1.bn.bias', 'stem.1.bn.running_mean'] ['layer4.2.bn2.running_var', 'layer4.2.bn2.num_batches_tracked', 'layer4.2.conv3.weight', 'layer4.2.bn3.weight', 'layer4.2.bn3.bias', 'layer4.2.bn3.running_mean', 'layer4.2.bn3.running_var', 'layer4.2.bn3.num_batches_tracked', 'head_fc.weight', 'head_fc.bias']


In [5]:
print('First layer weightout_channels x in_channels x kernel_size')
print('OUT CHANNESL x IN CHANNELS x KERNEL SIZE')
print(src_state_dict['stem.0.conv.weight'].shape, torch.mean(src_state_dict['stem.0.conv.weight']))

print('Last layer weight')
print(src_state_dict['head_fc.weight'].shape, src_state_dict['head_fc.bias'].shape)

First layer weightout_channels x in_channels x kernel_size
OUT CHANNESL x IN CHANNELS x KERNEL SIZE
torch.Size([32, 2, 3, 3, 3]) tensor(0.5007, device='cuda:0')
Last layer weight
torch.Size([4, 2048]) torch.Size([4])


Target network input 4 channels, output 8 classes

In [6]:
conv_cfg = dict(type='Conv3d')
norm_cfg = dict(type='BN3d', requires_grad=True)
target_net = ResNet(50,
                in_channels=4,
                deep_stem=True,
                conv_cfg=conv_cfg,
                norm_cfg=norm_cfg,
                head_type='classification',
                num_classes=8).cuda()

In [7]:
target_state_dict = target_net.state_dict()
print(target_state_dict['stem.0.conv.weight'].shape, torch.mean(target_state_dict['stem.0.conv.weight']))

torch.Size([32, 4, 3, 3, 3]) tensor(-3.7450e-05, device='cuda:0')


Transfer parameters

In [8]:
from torch.nn import functional as F
def adapt_input_conv(in_chans, conv_weight):
    # use linear interpolate on channels
    conv_type = conv_weight.dtype
    conv_weight = conv_weight.float()  # Some weights are in torch.half, ensure it's float for sum on CPU
    weight_shape = list(conv_weight.shape)
    conv_weight = conv_weight.permute(0, *list(range(2, len(weight_shape))), 1)
    conv_weight = conv_weight.reshape(-1, 1, weight_shape[1])
    conv_weight = F.interpolate(conv_weight, in_chans, mode='linear')
    conv_weight = conv_weight.reshape(weight_shape[0], *weight_shape[2:], in_chans).permute(0, len(weight_shape) - 1,
                                                                                            *list(range(1,
                                                                                                        len(weight_shape) - 1)))
    conv_weight = conv_weight.to(conv_type)
    return conv_weight

In [9]:
# Transfer the first layer
src_state_dict['stem.0.conv.weight'] = adapt_input_conv(4, src_state_dict['stem.0.conv.weight'])
print(src_state_dict['stem.0.conv.weight'].shape)

# Drop the last layer
del src_state_dict['head_fc.weight']
del src_state_dict['head_fc.bias']

# load parameters
target_net.load_state_dict(src_state_dict, strict=False)

torch.Size([32, 4, 3, 3, 3])


_IncompatibleKeys(missing_keys=['head_fc.weight', 'head_fc.bias'], unexpected_keys=[])

In [10]:
target_state_dict = target_net.state_dict()
print("weight", target_state_dict['stem.0.conv.weight'].shape, torch.mean(target_state_dict['stem.0.conv.weight']))

input = torch.rand(1, 4, 64, 64, 64).cuda()
with torch.no_grad():
    output = target_net(input)
print("resnet output", output.shape, torch.mean(output))

weight torch.Size([32, 4, 3, 3, 3]) tensor(0.5007, device='cuda:0')
resnet output torch.Size([1, 8]) tensor(-0.1146, device='cuda:0')


### ResNet to ResUNet

Figure out the dataflow of backbone ResNet

In [11]:
print("input shape", input.shape)
out = target_net.stem(input)
print("stem out", out.shape)
out = target_net.maxpool(out)
print("maxpool out", out.shape)

# reslayers
for i in range(1, 5):
    reslayer = getattr(target_net, f'layer{i}')
    out = reslayer(out)
    print(f"layer{i} {out.shape}")

input shape torch.Size([1, 4, 64, 64, 64])
stem out torch.Size([1, 64, 32, 32, 32])
maxpool out torch.Size([1, 64, 16, 16, 16])
layer1 torch.Size([1, 256, 16, 16, 16])
layer2 torch.Size([1, 512, 8, 8, 8])
layer3 torch.Size([1, 1024, 4, 4, 4])
layer4 torch.Size([1, 2048, 2, 2, 2])


Construct UNet with ResNet backbone/encoders

In [12]:
from ptimz.model_zoo import UNet
unet_encoders = [
    target_net.stem,
    nn.Sequential(target_net.maxpool, getattr(target_net, target_net.res_layers[0]))
] + [getattr(target_net, x) for x in target_net.res_layers[1:]]
extra = dict(num_channels=[64, 256, 512, 1024, 2048], nclasses=8)
resunet = UNet(extra, unet_encoders, conv_cfg=conv_cfg, norm_cfg=norm_cfg).cuda()

print("input", input.shape)
with torch.no_grad():
    output = resunet(input)
print("output", output.shape)

input torch.Size([1, 4, 64, 64, 64])
output torch.Size([1, 8, 64, 64, 64])


In [13]:
unet_state_dict = resunet.state_dict()
unet_keys = list(unet_state_dict.keys())
print(unet_keys[:10])
print(torch.mean(unet_state_dict['encoder_0.0.conv.weight']))

['encoder_0.0.conv.weight', 'encoder_0.0.bn.weight', 'encoder_0.0.bn.bias', 'encoder_0.0.bn.running_mean', 'encoder_0.0.bn.running_var', 'encoder_0.0.bn.num_batches_tracked', 'encoder_0.1.conv.weight', 'encoder_0.1.bn.weight', 'encoder_0.1.bn.bias', 'encoder_0.1.bn.running_mean']
tensor(0.5007, device='cuda:0')


### Transfer ResUNet to ResNet

In [14]:
def weight_init2(module):
    if isinstance(module, (nn.Conv3d, nn.Conv2d, nn.ConvTranspose3d, nn.ConvTranspose2d, nn.Linear)):
        normal_init(module, mean=-0.5, std=0.02)
    elif isinstance(module, (_BatchNorm, nn.GroupNorm)):
        constant_init(module, 1)

In [15]:
resunet.apply(weight_init2)
unet_state_dict = resunet.state_dict()
print("unet first layer", torch.mean(unet_state_dict['encoder_0.0.conv.weight']))
resnet_state_dict = target_net.state_dict()
print("resnet first layer", torch.mean(resnet_state_dict['stem.0.conv.weight']))

with torch.no_grad():
    output = target_net(input)
print("resnet output", output.shape, torch.mean(output))

unet first layer tensor(-0.4998, device='cuda:0')
resnet first layer tensor(-0.4998, device='cuda:0')
resnet output torch.Size([1, 8]) tensor(-0.0794, device='cuda:0')


In [16]:
unet_fl = resunet.encoder_0[0].conv.weight
print(unet_fl.shape)
resnet_fl = target_net.stem[0].conv.weight
print(resnet_fl.shape)
print(unet_fl is resnet_fl, id(unet_fl) == id(resnet_fl))

torch.Size([32, 4, 3, 3, 3])
torch.Size([32, 4, 3, 3, 3])
True True
