# Basic Model Manipulations

## Includes

In [None]:
# mass includes
import os, sys
import time
import torch as t
import math as m

## Basic methods

In [None]:
class BasicModule(t.nn.Module):
    def __init__(self):
        super(BasicModule, self).__init__()

    def load(self, save_root=None, model_name=None):
        if save_root is None:
            save_root = './saves'
        if model_name is None:
            model_name = self.model_name

        save_list = [
            file for file in os.listdir(save_root)
            if file.startswith(model_name)
        ]
        if len(save_list) == 0:
            sys.exit('Weight file(s) not found!')
        save_list.sort()

        file_path = os.path.join(save_root, save_list[-1])
        state_dict = t.load(file_path,
                            map_location=next(self.parameters()).device)
        self.load_state_dict(state_dict)
        print('Weight file loaded: %s' % file_path)

    def loadPartialDict(self, file_path):
        pretrained_dict = t.load(file_path)
        model_dict = self.state_dict()
        pretrained_dict = {
            key: value
            for key, value in pretrained_dict.items() if key in model_dict
        }
        model_dict.update(pretrained_dict)
        self.load_state_dict(model_dict)
        print('Partial weights loaded: %s' % file_path)

    def save(self, save_root=None):
        if save_root is None:
            save_root = './saves'
        prefix = os.path.join('./saves/' + self.model_name + '_')
        file_path = time.strftime(prefix + '%m%d-%H%M%S.pth')
        t.save(self.state_dict(), file_path)
        print('Weight file saved: %s' % file_path)