In [1]:
import torch
import torch.nn as nn
from torch.autograd import Variable

In [2]:
from easydict import EasyDict as EDict

import re

func_call_pat = re.compile("(^|\s)[a-zA-Z_][a-zA-Z0-9_\.]*\(.*\)")
def is_func_call(string):
    return func_call_pat.search(string) is not None

format_str_pat = re.compile("\{.*\}")
def is_format_str(string):
    return format_str_pat.search(string) is not None


def parse_params(value, trial = None, conf=None):
    if isinstance(value,list):
        return type(value)([parse_params(v,trial,conf) for v in value])
    if isinstance(value,dict):
        conf = {}
        for k,v in value.items():
            conf[k] = parse_params(v,trial,conf)
        return conf
    res = value
    #print(type(res),res,is_func_call(res),is_format_str(res))
    try:
        if isinstance(res,str):
            if conf is not None and res in conf.keys():
                res = conf[res]
            elif not is_format_str(res):
                res = eval(res)            
    except SyntaxError as e:
        warnings.warn("syntax error in '%s'"%value)
        raise RuntimeError(e.msg)
    except NameError:
        res = value
    return res

def to_list(l):
    if l is None:
        return []
    elif type(l)==list:
        return l
    return [l]

class LossWeight:
    def __init__(self,init,final=None,start=0,end=0):
        self.init = init
        self.final = final
        self.start = start
        self.end = end
    def __call__(self,curr_step):
        if self.final is None or curr_step <= self.start:
            return self.init
        if curr_step > self.end:
            return self.final
        rate = (curr_step - self.start) / (self.end - self.start)
        return self.init + rate*(self.final-self.init)
    def __str__(self):
        if self.final is None:
            return str(self.init)
        return "(%d:%f,%d:%f)"%(self.start,self.init,self.end,self.final)
    
import os

In [4]:
import warnings
import yaml

class Network(nn.Module):
    def __init__(self,network_conf,logger=None, gpu=0,trial=None):
        super(Network, self).__init__()
        
        self.logger = logger
        
        self.text_conf = dict(network_conf)
        self.conf = parse_params(network_conf, trial)
        if self.logger is not None:
            for k,v in self.conf.items():
                self.logger.debug("%s: %s"%(k,v))

        self.store('sn_params',None)
        self.store('LeakyReLU_param',0.2)
        
        self.models = EDict()
        self.losses = EDict()
        self.build()

        self.init_weights()
            
        self.set_losses(self.conf['losses'])
        
        self.gpu = gpu
        if gpu < 0:
            warnings.warn("Run on CPU.")
            return
        
        for k,(_,func) in self.losses.items():
            self.losses[k][1] = func.cuda(gpu)            
        self.cuda(gpu)

    def save(self,name=None,epoch=None):
        if name is None:
            name = type(self).__name__ + "_model"
        with open(name+'.yaml','w') as f:
            yaml.dump(self.text_conf,f)
            
        if epoch is not None:
            torch.save(self.state_dict(),name+'%03d.pth'%epoch)
            if epoch>0:
                self.remove(name,epoch-1)
    @classmethod
    def load(cls, name, **kwargs):
        if name is None:
            name = cls.__name__ + "_model"
        with open(name+'.yaml','r') as f:
            conf = yaml.load(f)
            net = cls(conf, **kwargs)
        net.load_state_dict(torch.load(name+'.pth'))
        return net
    
    @classmethod
    def remove(cls, name,epoch):
        if name is None:
            name = cls.__name__ + "_model"
        if epoch is not None:
            name += '%03d'%epoch
        os.remove(name+'.pth')
        
    def store(self,name,default):
        val = default
        if name in self.conf.keys():
            val = self.conf[name]
        setattr(self,name,val)        
        
    def get_act(self,name):
        if name == 'LeakyReLU':
            return nn.LeakyReLU(self.LeakyReLU_param)
        elif hasattr(nn,name):
            # ReLU, Sigmoid, Softmax, Softmax2d, Tanh, Softplus, ...
            return eval('nn.%s()'%name)
        else:
            raise NameError('Unknown activation:'+name)
            
    def conv_block(self,c_in, c_out, k, s, p, norm='bn', activation=None, dropout=None,transpose=False):
        layers = []
        if transpose:
            layers.append(nn.ConvTranspose2d(c_in, c_out, kernel_size=k, stride=s, padding=p))
        else:
            layers.append(         nn.Conv2d(c_in, c_out, kernel_size=k, stride=s, padding=p))
        if self.sn_params is not None:
            layers[-1] = torch.nn.utils.spectral_norm(layers[-1],**self.sn_params)
        if dropout:
            layers.append(nn.Dropout2d(dropout))
        if norm == 'bn':
            layers.append(nn.BatchNorm2d(c_out))
        if activation is not None:
            layers.append(self.get_act(activation))
        return nn.Sequential(*layers)
    
    # create a fully connected layer
    def fc_block(self,c_in, c_out, norm='bn', activation=None, dropout=None):
        layers = []
        layers.append(nn.Linear(c_in,c_out))
        if self.sn_params is not None:
            layers[-1] = torch.nn.utils.spectral_norm(layers[-1],**self.sn_params)
            
        if dropout:
            layers.append(nn.Dropout(dropout))
        if norm == 'bn':
            layers.append(nn.BatchNorm1d(c_out))
        if activation is not None:
            layers.append(self.get_act(activation))
        return nn.Sequential(*layers)
    
        
    def build_model(self,conf,input_dim,name):
        if name in self.models.keys():
            raise RuntimeError("a model '%s' is already built."%name)
        model_info = []
        block_types = []
        params = []
        _input_dim = input_dim
        
        for i,para in enumerate(conf):  
            output_dim = para[1]            
            block_name = "%s_%s_%02d"%(name,para[0],i)
            if para[0] == 'conv':
                block = self.conv_block(_input_dim,output_dim,*para[2:],transpose=False)
            elif para[0] in 'trans_conv':
                block = self.conv_block(_input_dim,output_dim,*para[2:],transpose=True)
            elif para[0] == 'fc':
                block = self.fc_block(_input_dim,output_dim,*para[2:])
            elif para[0] == 'maxpool':
                output_dim = _input_dim
                block = nn.MaxPool2d(*para[1:])
            elif para[0] == 'avgpool':
                output_dim = _input_dim
                block = nn.AvgPool2d(*para[1:])
            elif para[0] == 'upsample':
                output_dim = _input_dim
                # ignore size option.
                block = nn.Upsample(None, *para[1:])                
            else:
                if self.logger is not None:
                    self.logger.error('Unknown block type: '+para[0])
                raise NameError('Unknown block type: '+para[0])
            setattr(self,block_name,block)
            model_info.append((para[0],block_name,para[2:]))
            self.logger.debug("%s: %s -> %s"%(block_name,_input_dim,output_dim))
            _input_dim = output_dim

        self.models[name] = {
            'input_dim':input_dim,
            'output_dim':output_dim,
            'blocks': model_info,
        }
        

    
    def forward_model(self,name,x):        
        for i,(ltype,name,_) in enumerate(self.models[name]['blocks']):
            if (ltype == 'fc') and (len(x.size())>2):
                batch_size = x.size()[0]        
                x = x.view(batch_size,-1)
            prev = x.shape
            x = getattr(self,name)(x)
            #print(name,prev,'->',x.shape)
        return x
    
    def set_losses(self):
        raise NotImplementedError()
    
    def init_weights(self):
        self.apply(self._init_weights)

    def _init_weights(self,m):
        if not hasattr(m,'weights'):
            return
        inits = []
        
        classname = m.__class__.__name__
        # e.g.) self.conf['initialize'] = {'Conv':'{}.data.normal_(0.0, 0.1)', 'BatchNorm':[0.02,0.01]}
        try:
            layer_type = next(filter(lambda x: type(m)==x,self.conf['initialize'].keys()))
        except StopIteration:       
            try:               
                layer_type = next(filter(lambda x: classname.find(x)>=0,self.conf['initialize'].keys()))
            except StopIteration:
                logger.info("Layer type '%s' is initialized by torch default."%classname)
                return
        inits = to_list(self.conf['initialize'][layer_type])
        
        init = inits[0]
        if isinstance(init,str):
            eval(init.format('m.weight'))
        else:
            m.weight.data.fill_(init)
        
        if not hasattr(m,'bias'):
            return
        
        if len(inits)>1 and inits[1] != 'repeat':
            init = inits[1]

        if isinstance(init,str):
            eval(init.format('m.bias'))
        else:
            m.bias.data.fill_(init)
           

    def build(self):
        raise RuntimeError("the function 'build' is a pure-virtual function. This must be implemented in any child class.")

    def set_losses(self,conf):
        for k,v in conf.items():
            weight = LossWeight(*to_list(v['lambda']))
            if not callable(v['func']):
                func = eval(v['func'])
            else:
                func = v['func']
            self.losses[k] = [weight,func]
    
    def calc_loss(self,name,global_step,y_pred,y_true=None, writer=None):
        weight,func = self.losses[name]
        l = func(y_true,y_pred)
        if writer is not None:
            writer.add_scalar(os.path.join(type(self).__name__,'losses',name),l,global_step)
            
        return weight(global_step) * l
        
        

In [None]:
class AutoEncoder(Network):
    def __init__(self,*args,input_dim=3,**kwargs):
        self.input_dim = input_dim
        super(AutoEncoder,self).__init__(*args,**kwargs)        
        
    def build(self):
        depth = -1
        if 'AE_depth' in self.conf.keys():
            depth = self.conf['AE_depth']
        self.build_encoder(self.conf['encoder'])
        self.build_decoder(self.conf['decoder'])
        
    def forward(self,input_x, return_enc = False):
        z = self.forward_model('encoder',input_x)
        if return_enc:
            return z
        reconst = self.forward_model('decoder',z)
        return z, reconst
        
    @property
    def feature_dim(self):
        return self.models['encoder']['output_dim']
    
    @property
    def output_dim(self):
        return self.models['decoder']['output_dim']
    
    def build_encoder(self,conf,depth=-1):
        if depth<=0:
            depth = len(conf)
        self.build_model(conf[:depth],self.input_dim,'encoder')
        
    def build_decoder(self,conf,depth=-1):
        if depth<=0:
            depth = len(conf)
        self.build_model(conf[:depth],self.feature_dim,'decoder')


In [None]:
class VAE(AutoEncoder):
    def __init__(self,*args, input_dim=3, **kwargs):
        self.input_dim = input_dim
        super(VAE,self).__init__(*args, **kwargs)
        
        self.set_vae_loss()
        
    def build(self):
        self.encoder = self.build_encoder(self.conf['encoder'])
        layer_conf = self.conf['enc_mu']
        self.enc_mu = self.build_model(layer_conf)
        if 'enc_logvar' in self.conf.keys():
            layer_conf = self.conf['enc_logvar']
        self.enc_logvar = self.build_model(layer_conf)
        
        self.decoder = self.build_decoder(self.conf['decoder'])

    def forward(self,input_x, return_enc = False):
        x = self.forward_model('encoder',input_x)
        mu = self.forward_model('enc_mu',x)
        logvar = self.forward_model('enc_logvar',x)
        z = self.reparameterize(mu, logvar)
        if return_enc:
            return z
        reconst = self.forward_model('decoder',z)
        return z, reconst
    
    def reparameterize(self, mu, logvar):
        if self.training:
            std = logvar.mul(0.5).exp_()
            eps = Variable(std.data.new(std.size()).normal_())
            return eps.mul(std).add_(mu)
        else:
            return mu        
    
    def set_vae_loss(self):
        pass