forked from prophesier/diff-svc
-
Notifications
You must be signed in to change notification settings - Fork 0
/
simplify.py
28 lines (21 loc) · 824 Bytes
/
simplify.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from argparse import ArgumentParser
import torch
def simplify_pth(pth_name, project_name):
model_path = f'./checkpoints/{project_name}'
checkpoint_dict = torch.load(f'{model_path}/{pth_name}')
torch.save({'epoch': checkpoint_dict['epoch'],
'state_dict': checkpoint_dict['state_dict'],
'global_step': None,
'checkpoint_callback_best': None,
'optimizer_states': None,
'lr_schedulers': None
}, f'./clean_{pth_name}')
def main():
parser = ArgumentParser()
parser.add_argument('--proj', type=str)
parser.add_argument('--steps', type=str)
args = parser.parse_args()
model_name = f"model_ckpt_steps_{args.steps}.ckpt"
simplify_pth(model_name, args.proj)
if __name__ == '__main__':
main()