### 경량화

In [1]:
# torch
import torch

# built-in library
import os
import copy
import argparse

# custom modules
from exp_utils import get_udevice, load_cfg
from exp_manager import Manager

In [2]:
### input options ###
parser = argparse.ArgumentParser(description='AI Fashion Coordinator.')

parser.add_argument('--cfg_path', type=str,
                    default='./cfgs/08_cg_architecture_with_05.yaml', # 주의) 모델 변경 시 cfg 파일 경로도 변경
                    help="실험에 필요한 값들을 설정해둔 yaml 파일의 경로를 입력합니다.")

args, _ = parser.parse_known_args()
### input options ###

In [3]:
# load configuration file
cfg = load_cfg(args.cfg_path)

# set pred mode
cfg['global']['mode'] = 'pred'

cfg['path']['in_file_fashion'] = '../aif/data/mdata.wst.txt.2023.08.23'
cfg['path']['subWordEmb_path'] = '../aif/sstm_v0p5_deploy/sstm_v4p49_np_n36134_d128.dat'

# check configuration value
print('<Parsed arguments>')
for category, value in cfg.items():
    print(f"##### {category} #####")
    for name, value in cfg[category].items():
        print(f"{name}: {value}")

    print('-' * 20)

<Parsed arguments>
##### global #####
seed: 2024
mode: pred
num_tasks: 6
use_multimodal: False
--------------------
##### path #####
in_file_trn_dialog: /aif/data/task1.ddata.wst.txt
in_file_tst_dialog: /aif/data/cl_eval_task1.wst.dev
in_file_fashion: ../aif/data/mdata.wst.txt.2023.08.23
in_file_img_feats: /aif/data/extracted_feat.json
model_path: ./model
model_file: 08_cg_architecture_with_05.pt
subWordEmb_path: ../aif/sstm_v0p5_deploy/sstm_v4p49_np_n36134_d128.dat
--------------------
##### data #####
permutation_iteration: 6
num_augmentation: 5
corr_thres: 0.9
mem_size: 16
--------------------
##### model #####
etc: {'use_batch_norm': False, 'use_dropout': False, 'zero_prob': 0.0}
ReqMLP: {'out_size': 300, 'req_node': '[3000,2000,1000,500]'}
PolicyNet: {'eval_node': '[6000,3000,1000,500,200][2000]'}
--------------------
##### exp #####
learning_rate: 0.0001
max_grad_norm: 40.0
batch_size: 100
epochs: 10
evaluation_iteration: 10
--------------------


In [11]:
# set model
manager = Manager(cfg, get_udevice())

Using device: cpu

<Initialize subword embedding>
loading= ../aif/sstm_v0p5_deploy/sstm_v4p49_np_n36134_d128.dat

<Make metadata>
loading fashion item metadata
vectorizing data
_requirement.model.0.scores.0
_requirement.model.0.scores.1
_requirement.model.0.scores.2
_requirement.model.0.scores.3
_requirement.model.0.scores.4
_requirement.model.0.scores.5
_requirement.model.2.scores.0
_requirement.model.2.scores.1
_requirement.model.2.scores.2
_requirement.model.2.scores.3
_requirement.model.2.scores.4
_requirement.model.2.scores.5
_requirement.model.4.scores.0
_requirement.model.4.scores.1
_requirement.model.4.scores.2
_requirement.model.4.scores.3
_requirement.model.4.scores.4
_requirement.model.4.scores.5
_requirement.model.6.scores.0
_requirement.model.6.scores.1
_requirement.model.6.scores.2
_requirement.model.6.scores.3
_requirement.model.6.scores.4
_requirement.model.6.scores.5
_requirement.model.8.scores.0
_requirement.model.8.scores.1
_requirement.model.8.scores.2
_requirement.

### 모델 복사

In [13]:
model = copy.deepcopy(manager._model)

### 경량화 이전 모델 크기 확인

In [14]:
param_size = 0
for param in manager._model.parameters():
    param_size += param.nelement() * param.element_size()

buffer_size = 0
for buffer in manager._model.buffers():
    buffer_size += buffer.nelement() * buffer.element_size()

size_all_mb = (param_size + buffer_size) / 1024**2
print('model size: {:.3f}MB'.format(size_all_mb))

model size: 1396.400MB


### 경령화 진행

In [23]:
for name, param in model.named_parameters():
    if 'score' in name:
        model_name, module_name, model_idx, _, score_idx = name.split(".")
        model_idx, score_idx = int(model_idx), int(score_idx)
        
        if model_name == "_requirement":
            score_cp = model._requirement.model[model_idx].scores[score_idx].clone().detach().to(torch.float16)

            model._requirement.model[model_idx].scores[score_idx] = score_cp

        elif model_name == "_policy":
            if module_name == '_mlp_eval':
                score_cp = model._policy._mlp_eval[model_idx].scores[score_idx].clone().detach().to(torch.float16)

                model._policy._mlp_eval[model_idx].scores[score_idx] = score_cp

            elif module_name == '_mlp_rnk':
                score_cp = model._policy._mlp_rnk[model_idx].scores[score_idx].clone().detach().to(torch.float16)

                model._policy._mlp_rnk[model_idx].scores[score_idx] = score_cp
        

### 경량화 이후 모델 크기 확인

In [24]:
param_size = 0
for param in model.parameters():
    param_size += param.nelement() * param.element_size()

buffer_size = 0
for buffer in model.buffers():
    buffer_size += buffer.nelement() * buffer.element_size()

size_all_mb = (param_size + buffer_size) / 1024**2
print('model size: {:.3f}MB'.format(size_all_mb))

model size: 797.943MB


### 경량화를 적용한 모델 저장

In [27]:
model_name_final = cfg['path']['model_file'].split(".")[0] + "_lightweight.pt"
file_name_final = os.path.join(cfg['path']['model_path'], model_name_final)
torch.save({'model': model.state_dict()}, file_name_final)

### 경량화를 적용한 모델이 정상적으로 로드되는지 테스트

In [4]:
cfg_cp = copy.deepcopy(cfg)

In [5]:
cfg_cp['path']['model_file'] = cfg['path']['model_file'].split(".")[0] + "_lightweight.pt"

In [6]:
manager_lightweight = Manager(cfg_cp, get_udevice())

Using device: cpu

<Initialize subword embedding>
loading= ../aif/sstm_v0p5_deploy/sstm_v4p49_np_n36134_d128.dat

<Make metadata>
loading fashion item metadata
vectorizing data


TypeError: 'str' object cannot be interpreted as an integer

In [43]:
for name, param in manager_lightweight._model.named_parameters():
    if 'score' in name:
        print(name, param.dtype)

_requirement.model.0.scores.0 torch.float32
_requirement.model.0.scores.1 torch.float32
_requirement.model.0.scores.2 torch.float32
_requirement.model.0.scores.3 torch.float32
_requirement.model.0.scores.4 torch.float32
_requirement.model.0.scores.5 torch.float32
_requirement.model.2.scores.0 torch.float32
_requirement.model.2.scores.1 torch.float32
_requirement.model.2.scores.2 torch.float32
_requirement.model.2.scores.3 torch.float32
_requirement.model.2.scores.4 torch.float32
_requirement.model.2.scores.5 torch.float32
_requirement.model.4.scores.0 torch.float32
_requirement.model.4.scores.1 torch.float32
_requirement.model.4.scores.2 torch.float32
_requirement.model.4.scores.3 torch.float32
_requirement.model.4.scores.4 torch.float32
_requirement.model.4.scores.5 torch.float32
_requirement.model.6.scores.0 torch.float32
_requirement.model.6.scores.1 torch.float32
_requirement.model.6.scores.2 torch.float32
_requirement.model.6.scores.3 torch.float32
_requirement.model.6.scores.4 to