In [1]:
import re
import torch
from functools import partial, reduce

In [3]:
a = torch.load('./pytorch_i3d/models/flow_charades.pt')
b = torch.load('./pytorch_i3d/models/flow_kinetics_origin.pth')

In [6]:
a_keys = a.keys()
b_keys = b.keys()
a_set = set(map(lambda var: var.split('.')[0], list(a.keys())))
b_set = set(map(lambda var: var.split('.')[0], list(b.keys())))

In [8]:
a_set.remove('logits')

In [9]:
a_set

{'Conv3d_1a_7x7',
 'Conv3d_2b_1x1',
 'Conv3d_2c_3x3',
 'Mixed_3b',
 'Mixed_3c',
 'Mixed_4b',
 'Mixed_4c',
 'Mixed_4d',
 'Mixed_4e',
 'Mixed_4f',
 'Mixed_5b',
 'Mixed_5c'}

In [10]:
b_set

{'conv3d_0c_1x1',
 'conv3d_1a_7x7',
 'conv3d_2b_1x1',
 'conv3d_2c_3x3',
 'mixed_3b',
 'mixed_3c',
 'mixed_4b',
 'mixed_4c',
 'mixed_4d',
 'mixed_4e',
 'mixed_4f',
 'mixed_5b',
 'mixed_5c'}

In [11]:
list(filter(lambda var: re.match('logit', var), list(a_keys)))

['logits.conv3d.weight', 'logits.conv3d.bias']

In [12]:
list(filter(lambda var: re.match('Mixed_5c', var.capitalize()), list(b_keys)))

['mixed_5c.branch_0.conv3d.weight',
 'mixed_5c.branch_0.batch3d.weight',
 'mixed_5c.branch_0.batch3d.bias',
 'mixed_5c.branch_0.batch3d.running_mean',
 'mixed_5c.branch_0.batch3d.running_var',
 'mixed_5c.branch_1.0.conv3d.weight',
 'mixed_5c.branch_1.0.batch3d.weight',
 'mixed_5c.branch_1.0.batch3d.bias',
 'mixed_5c.branch_1.0.batch3d.running_mean',
 'mixed_5c.branch_1.0.batch3d.running_var',
 'mixed_5c.branch_1.1.conv3d.weight',
 'mixed_5c.branch_1.1.batch3d.weight',
 'mixed_5c.branch_1.1.batch3d.bias',
 'mixed_5c.branch_1.1.batch3d.running_mean',
 'mixed_5c.branch_1.1.batch3d.running_var',
 'mixed_5c.branch_2.0.conv3d.weight',
 'mixed_5c.branch_2.0.batch3d.weight',
 'mixed_5c.branch_2.0.batch3d.bias',
 'mixed_5c.branch_2.0.batch3d.running_mean',
 'mixed_5c.branch_2.0.batch3d.running_var',
 'mixed_5c.branch_2.1.conv3d.weight',
 'mixed_5c.branch_2.1.batch3d.weight',
 'mixed_5c.branch_2.1.batch3d.bias',
 'mixed_5c.branch_2.1.batch3d.running_mean',
 'mixed_5c.branch_2.1.batch3d.running_v

In [13]:
def update_keys(pattern, new_b_dict):
    a_matched = list(filter(lambda var: re.match(pattern, var), list(a_keys)))
    b_matched = list(filter(lambda var: re.match(pattern, var.capitalize()), list(b_keys)))
    for a_key, b_key in zip(a_matched, b_matched):
        print('{:30s} -> {:30s}'.format(a_key, b_key))
        print()
        
        if a_key.split('.')[-1] != b_key.split('.')[-1]:
            print('Not matched1')
            
        if a_key.split('.')[-2][0] != b_key.split('.')[-2][0]:
            print('Not matched2')
            
        value = b[b_key]
        new_b_dict[a_key] = value
            
    return new_b_dict

In [14]:
new_b_dict = {}
for pattern in a_set:
    update_keys(pattern, new_b_dict)

Mixed_3b.b0.conv3d.weight      -> mixed_3b.branch_0.conv3d.weight

Mixed_3b.b0.bn.weight          -> mixed_3b.branch_0.batch3d.weight

Mixed_3b.b0.bn.bias            -> mixed_3b.branch_0.batch3d.bias

Mixed_3b.b0.bn.running_mean    -> mixed_3b.branch_0.batch3d.running_mean

Mixed_3b.b0.bn.running_var     -> mixed_3b.branch_0.batch3d.running_var

Mixed_3b.b1a.conv3d.weight     -> mixed_3b.branch_1.0.conv3d.weight

Mixed_3b.b1a.bn.weight         -> mixed_3b.branch_1.0.batch3d.weight

Mixed_3b.b1a.bn.bias           -> mixed_3b.branch_1.0.batch3d.bias

Mixed_3b.b1a.bn.running_mean   -> mixed_3b.branch_1.0.batch3d.running_mean

Mixed_3b.b1a.bn.running_var    -> mixed_3b.branch_1.0.batch3d.running_var

Mixed_3b.b1b.conv3d.weight     -> mixed_3b.branch_1.1.conv3d.weight

Mixed_3b.b1b.bn.weight         -> mixed_3b.branch_1.1.batch3d.weight

Mixed_3b.b1b.bn.bias           -> mixed_3b.branch_1.1.batch3d.bias

Mixed_3b.b1b.bn.running_mean   -> mixed_3b.branch_1.1.batch3d.running_mean

Mixed_3b.b1

In [15]:
new_b_dict['logits.conv3d.weight'] = b['conv3d_0c_1x1.conv3d.weight']
new_b_dict['logits.conv3d.bias'] = b['conv3d_0c_1x1.conv3d.bias']

In [None]:
len(new_b_dict)

In [16]:
torch.save(new_b_dict, './pytorch_i3d/models/flow_kinetics.pth')