In [None]:
import sys
sys.path.append('./assert')

In [None]:
import torch
from model import E2E

In [None]:
import itertools
protocols = ['pa', 'la']
networks = ['attentive_filtering_network', 'dilated_resnet', 'senet34', 'senet50']
all_networks = list(itertools.product(protocols, networks))

In [None]:

def port_weights(protocol, network):
    models_dict = {'attentive_filtering_network': 5, 'dilated_resnet': 1, 'senet34':7, 
                   'senet50': 6}
    model_params = {
        'MODEL_SELECT' : models_dict[network], # which model 
        'NUM_SPOOF_CLASS' : 2, # x-class classification
        'FOCAL_GAMMA' : None, # gamma parameter for focal loss; if obj is not focal loss, set this to None 
        'NUM_RESNET_BLOCK' : 5, # number of resnet blocks in ResNet 
        'AFN_UPSAMPLE' : 'Bilinear', # upsampling method in AFNet: Conv or Bilinear
        'AFN_ACTIVATION' : 'sigmoid', # activation function in AFNet: sigmoid, softmaxF, softmaxT
        'NUM_HEADS' : 3, # number of heads for multi-head att in SAFNet 
        'SAFN_HIDDEN' : 10, # hidden dim for SAFNet
        'SAFN_DIM' : 'T', # SAFNet attention dim: T or F
        'RNN_HIDDEN' : 128, # hidden dim for RNN
        'RNN_LAYERS' : 4, # number of hidden layers for RNN
        'RNN_BI': True, # bidirecitonal/unidirectional for RNN
        'DROPOUT_R' : 0.0, # dropout rate 
    }
    model = E2E(**model_params)
    pa_weights = torch.load(f'./ASSERT/pretrained/{protocol}/{network}', map_location='cpu', encoding='bytes')
    # Convert the first level keys.
    data_dict = dict(pa_weights)
    for key in list(data_dict):
        if type(key) is bytes:
            data_dict[key.decode()] = data_dict[key]
            data_dict.pop(key)
    data_dict['state_dict'] = dict(data_dict['state_dict'])
    for key in list(data_dict['state_dict']):
        if type(key) is bytes:
            data_dict['state_dict'][key.decode()] = data_dict['state_dict'][key]
            data_dict['state_dict'].pop(key)
    model.load_state_dict(data_dict['state_dict'])
    torch.save(data_dict, f'./ASSERT/pretrained/{protocol}/{network}.py3.ckpt')
    print(f"Ported {network} - {protocol}")

In [None]:
for _p, _n in all_networks:
    port_weights(_p, _n)

attentive filtering network
Ported attentive_filtering_network - pa
resnet
Ported dilated_resnet - pa
squeeze-and-excitation network
Ported senet34 - pa
squeeze-and-excitation network
Ported senet50 - pa
attentive filtering network
Ported attentive_filtering_network - la
resnet
Ported dilated_resnet - la
squeeze-and-excitation network
Ported senet34 - la
squeeze-and-excitation network
Ported senet50 - la
