# Basic Model Manipulations

## Includes

In [None]:
# mass includes
import os
import time
import torch as t
import math as m

## Basic methods

In [None]:
class BasicModule(t.nn.Module):
    def __init__(self):
        super(BasicModule, self).__init__()
        self.model_name = str(type(self))

    def load(self, root, device=None):
        save_list = [
            file for file in os.listdir(root)
            if file.startswith(self.model_name)
        ]
        save_list.sort()
        file_path = os.path.join(root, save_list[-1])
        state_dict = t.load(file_path, map_location=device)
        self.load_state_dict(t.load(file_path, map_location=device))
        print('Weights loaded: %s' % file_path)

        return len(save_list)

    def loadPartialDict(self, file_path, device=None):
        pretrained_dict = t.load(file_path, map_location=device)
        model_dict = self.state_dict()
        pretrained_dict = {
            key: value
            for key, value in pretrained_dict.items() if key in model_dict
        }
        model_dict.update(pretrained_dict)
        self.load_state_dict(model_dict)
        print('Partial weights loaded: %s' % file_path)

    def save(self):
        prefix = './saves/' + self.model_name + '_'
        file_name = time.strftime(prefix + '%m%d-%H%M%S.pth')
        t.save(self.state_dict(), file_name)
        print('Weights saved: %s' % file_name)

    def initLayers(self):
        for module in self.modules():
            if isinstance(module, (t.nn.Conv2d, t.nn.Linear)):
                t.nn.init.xavier_normal_(module.weight)