In [None]:
import torch
import torch.nn as nn
from torch.autograd import Variable

In [None]:
from easydict import EasyDict as EDict

import re
import numpy as np
func_call_pat = re.compile("(^|\s)[a-zA-Z_][a-zA-Z0-9_\.]*\(.*\)")
def is_func_call(string):
    return func_call_pat.search(string) is not None

format_str_pat = re.compile("\{.*\}")
def is_format_str(string):
    return format_str_pat.search(string) is not None

format_model_path = re.compile("(.+)_E(\d+)$")
def parse_model_path(string):
    m = format_model_path.match(string)
    if m is None:
        return string, None
    return m.group(1), m.group(2)

def parse_params(value, trial = None, conf=None,external_params=None):
    if isinstance(value,list):
        return type(value)([parse_params(v,trial,conf,external_params) for v in value])
    if isinstance(value,dict):
        conf = {}
        for k,v in value.items():
            conf[k] = parse_params(v,trial,conf,external_params)
        return conf
    res = value
    #print(type(res),res,is_func_call(res),is_format_str(res))
    try:
        if isinstance(res,str):
            if len(res)==0:
                pass
            elif conf is not None and res in conf.keys():
                res = conf[res]
            elif external_params is not None and res in external_params.keys():
                res = external_params[res]
            elif not is_format_str(res):
                res = eval(res)             
    except SyntaxError as e:
        res = value
        #warnings.warn("syntax error in '%s'"%value)
        #raise RuntimeError(e.msg)
    except NameError:
        res = value
    return res

def to_list(l):
    if l is None:
        return []
    elif type(l)==list:
        return l
    return [l]

def calc_acc(pred,gt):
    gt = gt.cpu() if gt.is_cuda else gt
    pred = pred.cpu() if pred.is_cuda else pred    
    if gt.shape[-1]==pred.shape[-1]:
        acc = float(sum(np.argmax(pred.detach().numpy(),axis=-1)==np.argmax(gt.numpy())))/len(gt)
    else:
        acc = float(sum(np.argmax(pred.detach().numpy(),axis=-1)==gt.numpy().reshape(-1)))/len(gt)
    
    return acc
def calc_acc_binary(pred,gt):
    gt = gt.cpu() if gt.is_cuda else gt
    pred = pred.cpu() if pred.is_cuda else pred
    coincidence = (pred.detach().numpy()>=0.5)==(gt.numpy()>=0.5)
    acc = np.mean(coincidence,axis=(0,1))
    return acc

class LossWeight:
    def __init__(self,init,final=None,start=0,end=0):
        self.init = init
        self.final = final
        self.start = start
        self.end = end
    def __call__(self,curr_step):
        if self.final is None or curr_step <= self.start:
            return self.init
        if curr_step > self.end:
            return self.final
        rate = (curr_step - self.start) / (self.end - self.start)
        return self.init + rate*(self.final-self.init)
    def __str__(self):
        if self.final is None:
            return str(self.init)
        return "(%d:%f,%d:%f)"%(self.start,self.init,self.end,self.final)
    
import os

In [None]:
import warnings
import yaml

from copy import deepcopy

class Network(nn.Module):
    def __init__(self,network_conf,input_dim,logger=None, gpu=0,trial=None):
        super().__init__()
        
        self.input_dim = input_dim        
        self.logger = logger
        
        self.text_conf = dict(network_conf)
        self.conf = parse_params(network_conf, trial, None, network_conf)
        if self.logger:
            for k,v in self.conf.items():
                self.logger.debug("%s: %s"%(k,v))

        self.store('sn_params',None)
        self.store('LeakyReLU_param',0.2)
        
        self.models = EDict()
        self.losses = EDict()
        self.build()

        self.init_weights()
        
        if 'losses' in self.conf.keys():
            self.set_losses(self.conf['losses'])
        elif self.logger:
            self.logger.debug('No loss functions are set to the network.')
            
        self.gpu = gpu
        if gpu < 0:
            warnings.warn("Run on CPU.")
            return
        
        self.cuda(gpu)
        for k,(_,func) in self.losses.items():
            if hasattr(func,'cuda'):
                self.losses[k][1] = func.cuda(gpu)                        
    
    def _save_best_model(self,name,epoch,score):
        self.best_score = score
        torch.save(self.state_dict(),name+'_best.pth')
        if epoch is not None:
            with open(name+'_best.txt','w') as f:
                f.write("%0.10f at epoch %d\n"%(score,epoch))
    
    def save(self,name=None,epoch=None,score=None,keep_prev_epoch=False):
        if name is None:
            name = type(self).__name__ + "_model"
        with open(name+'.yaml','w') as f:
            yaml.dump(self.text_conf,f)
            
        if score is not None:
            if not hasattr(self,'best_score') or score>self.best_score:
                self._save_best_score(name,epoch,score)
        
        if epoch is not None:
            torch.save(self.state_dict(),name+'_E%06d.pth'%epoch)
            if epoch>0 and not keep_prev_epoch:
                self.remove(name,epoch-1)
        else:
            torch.save(self.state_dict(),name)
        
    @classmethod
    def load(cls, name, *args, strict=True,**kwargs):
        if name is None:
            name = cls.__name__ + "_model"
            
        name_chomp, epoch = parse_model_path(name)
        print('debug',name,name_chomp,epoch)
        
        with open(name_chomp+'.yaml','r') as f:
            conf = yaml.load(f)
            net = cls(conf, *args, **kwargs)
            
        '''
        # check if load is succeed
        net_backup = deepcopy(net)
        for p,q in zip(net_backup.parameters(),net.parameters()):
            axis = tuple(range(len(p.shape)))
            p = p.view(-1)
            q = q.view(-1)
            print(torch.sum(p==q).item()>0)
        '''    
        net.load_state_dict(torch.load(name+'.pth'),strict=strict)
        
        '''
        print("============ LOAD! ============")
        # check if load is succeed
        for p,q in zip(net_backup.parameters(),net.parameters()):
            axis = tuple(range(len(p.shape)))
            p = p.view(-1)
            q = q.view(-1)
            print(torch.sum(p==q).item()>0)
        '''        
        return net
    
    @classmethod
    def remove(cls, name,epoch):
        if name is None:
            name = cls.__name__ + "_model"
        if epoch is not None:
            name += '_E%06d'%epoch
        os.remove(name+'.pth')
    
    def extract_feature(self,data_loader,input_idxs=[0],path=None,ret_val_idx=None,
                        ret_gt_idxs=[],**kwargs4forward):
        input_idxs = to_list(input_idxs)
        ret_gt_idxs = to_list(ret_gt_idxs)
        assert(len(input_idxs)>0)
        
        was_in_training = self.training
        self.eval()
        
        gts = [None] * len(ret_gt_idxs)
        
        Z = None
        for data in data_loader:
            for j,idx in enumerate(ret_gt_idxs):
                if gts[j] is None:
                    gts[j] = data[idx].detach().numpy()
                else:
                    if len(gts[j].shape)>1:
                        gts[j] = np.vstack([gts[j],data[idx].detach().numpy()])
                    else:
                        gts[j] = np.hstack([gts[j],data[idx].detach().numpy()])
                        
            inputs = [data[i] for i in input_idxs]
            if self.gpu >= 0:
                inputs = [x.cuda(self.gpu) for x in inputs]
            outputs = self.forward(*inputs,**kwargs4forward)
            if ret_val_idx:
                if type(outputs) in [list,dict]:
                    z = outputs[ret_val_idx].detach()
                else:
                    z = outputs.detach()
                if z.is_cuda:
                    z = z.cpu()
            else:
                assert(not isinstance(outputs,list))
            if Z is None:
                Z = z
            else:
                Z = np.vstack([Z,z])

        if path:
            np.save(path,Z)
        
        if was_in_training:
            # recover the mode.
            self.train()
        return Z,gts
    
    def store(self,name,default):
        val = default
        if name in self.conf.keys():
            val = self.conf[name]
        setattr(self,name,val)        
        
    def get_act(self,name):
        if name is None or len(name)==0:
            return None
        if name == 'LeakyReLU':
            return nn.LeakyReLU(self.LeakyReLU_param)
        if hasattr(nn,name):
            # ReLU, Sigmoid, Softmax, Softmax2d, Tanh, Softplus, ...
            return eval('nn.%s()'%name)
        else:
            raise NameError('Unknown activation:'+name)
            
    def conv_block(self,c_in, c_out, k, s, p, norm='bn', activation=None, dropout=None,transpose=False):
        layers = []
        if transpose:
            layers.append(nn.ConvTranspose2d(c_in, c_out, kernel_size=k, stride=s, padding=p))
        else:
            layers.append(         nn.Conv2d(c_in, c_out, kernel_size=k, stride=s, padding=p))
        if self.sn_params is not None:
            layers[-1] = torch.nn.utils.spectral_norm(layers[-1],**self.sn_params)
        if dropout:
            layers.append(nn.Dropout2d(dropout))
        if norm == 'bn':
            layers.append(nn.BatchNorm2d(c_out))
        act = self.get_act(activation)
        if act is not None:
            layers.append(act)
        return nn.Sequential(*layers)
    
    # create a fully connected layer
    def fc_block(self,c_in, c_out, norm='bn', activation=None, dropout=None):
        layers = []
        layers.append(nn.Linear(c_in,c_out))
        if self.sn_params is not None:
            layers[-1] = torch.nn.utils.spectral_norm(layers[-1],**self.sn_params)
            
        if dropout:
            layers.append(nn.Dropout(dropout))
        if norm == 'bn':
            layers.append(nn.BatchNorm1d(c_out))
        act = self.get_act(activation)
        if act is not None:
            layers.append(act)            
        return nn.Sequential(*layers)
    
        
    def build_model(self,conf,input_dim,name):
        if name in self.models.keys():
            raise RuntimeError("a model '%s' is already built."%name)
        model_info = []
        block_types = []
        params = []
        _input_dim = input_dim
        
        
        for i,para in enumerate(conf):  
            output_dim = para[1]            
            block_name = "%s_%s_%02d"%(name,para[0],i)
            print(block_name)
            if para[0] == 'conv':
                block = self.conv_block(_input_dim,output_dim,*para[2:],transpose=False)
            elif para[0] in 'trans_conv':
                block = self.conv_block(_input_dim,output_dim,*para[2:],transpose=True)
            elif para[0] == 'fc':
                block = self.fc_block(_input_dim,output_dim,*para[2:])
            elif para[0] == 'maxpool':
                output_dim = _input_dim
                block = nn.MaxPool2d(*para[1:])
            elif para[0] == 'avgpool':
                output_dim = _input_dim
                block = nn.AvgPool2d(*para[1:])
            elif para[0] == 'upsample':
                output_dim = _input_dim
                # ignore size option.
                block = nn.Upsample(None, *para[1:])                
            else:
                if self.logger:
                    self.logger.error('Unknown block type: '+para[0])
                raise NameError('Unknown block type: '+para[0])
            setattr(self,block_name,block)
            model_info.append((para[0],block_name,para[1:]))
            if self.logger:
                self.logger.debug("%s: %s -> %s"%(block_name,_input_dim,output_dim))
            _input_dim = output_dim

        self.models[name] = {
            'input_dim':input_dim,
            'output_dim':output_dim,
            'blocks': model_info,
        }
        

    
    def forward_model(self,name,x):        
        for i,(ltype,name,_) in enumerate(self.models[name]['blocks']):
            if (ltype == 'fc') and (len(x.size())>2):
                batch_size = x.size()[0]        
                x = x.view(batch_size,-1)
            prev = x.shape
            x = getattr(self,name)(x)
            #print(name,prev,'->',x.shape)
        return x
    
    
    def init_weights(self):
        self.apply(self._init_weights)

    def _init_weights(self,m):
        if not hasattr(m,'weights'):
            return
        inits = []
        
        classname = m.__class__.__name__
        # e.g.) self.conf['initialize'] = {'Conv':'{}.data.normal_(0.0, 0.1)', 'BatchNorm':[0.02,0.01]}
        try:
            layer_type = next(filter(lambda x: type(m)==x,self.conf['initialize'].keys()))
        except StopIteration:       
            try:               
                layer_type = next(filter(lambda x: classname.find(x)>=0,self.conf['initialize'].keys()))
            except StopIteration:
                if self.logger:
                    logger.info("Layer type '%s' is initialized by torch default."%classname)
                    return
        inits = to_list(self.conf['initialize'][layer_type])
        
        init = inits[0]
        if isinstance(init,str):
            eval(init.format('m.weight'))
        else:
            m.weight.data.fill_(init)
        
        if not hasattr(m,'bias'):
            return
        
        if len(inits)>1 and inits[1] != 'repeat':
            init = inits[1]

        if isinstance(init,str):
            eval(init.format('m.bias'))
        else:
            m.bias.data.fill_(init)
           

    def build(self):
        raise RuntimeError("the function 'build' is a pure-virtual function. This must be implemented in any child class.")

    def set_losses(self,conf):
        for k,v in conf.items():
            weight = LossWeight(*to_list(v['lambda']))
            if not callable(v['func']):
                if hasattr(self,v['func']):
                    func = getattr(self,v['func'])
                else:
                    try:
                        func = eval(v['func'])
                    except NameError:
                        continue
            else:
                func = v['func']
            self.losses[k] = [weight,func]
    
    def calc_loss(self, name, global_step, y_pred, y_true=None, writer=None, acc=False, return4print=False):
        weight,func = self.losses[name]
        #print(name, type(y_pred),y_pred.dtype,type(y_true),y_true.dtype)
        if y_true is not None:
            l = func(y_pred,y_true)
        else:
            l = func(y_pred)
        if writer is not None:
            writer.add_scalar(os.path.join(type(self).__name__,'losses',name),l,global_step)
        w = weight(global_step)
        outputs = [w*l]
        if acc:
            if acc == 'categorical':
                accuracy = calc_acc(y_pred,y_true)
            elif acc == 'binary':
                accuracy = calc_acc_binary(y_pred,y_true)
            else:
                raise RuntimeError('unknown accuracy type.')
            outputs.append(accuracy)
            writer.add_scalar(os.path.join(type(self).__name__,'accs',name),accuracy,global_step)
        if return4print:
            if l.is_cuda:
                lp = l.cpu().detach()
            else:
                lp = l.detach()
            if weight.final is None:
                lp_str = "%.5f"%lp
            else:
                lp_str = "%.3f x %.5f"%(w,lp)
            if acc:
                lp_str += ", acc=%2.1f%%"%(accuracy*100)
            outputs.append(lp_str)
        return outputs
        
        

In [None]:
class AutoEncoder(Network):
    def __init__(self,*args,**kwargs):
        super().__init__(*args,**kwargs)        
        
    def build(self):
        depth = -1
        if 'AE_depth' in self.conf.keys():
            depth = self.conf['AE_depth']
        self.build_encoder(self.conf['encoder'])
        self.build_decoder(self.conf['decoder'])
        
    def forward(self,input_x, return_enc = False):
        z = self.forward_model('encoder',input_x)
        if return_enc:
            return z
        reconst = self.forward_model('decoder',z)
        return z, reconst
        
    @property
    def feature_dim(self):
        return self.models['encoder']['output_dim']
    
    @property
    def output_dim(self):
        return self.models['decoder']['output_dim']
    
    def build_encoder(self,conf,depth=-1):
        if depth<=0:
            depth = len(conf)
        self.build_model(conf[:depth],self.input_dim,'encoder')
        
    def build_decoder(self,conf,depth=-1):        
        if depth<=0:
            depth = 0
        self.build_model(conf[-depth:],self.feature_dim,'decoder')


In [None]:
class VAE(AutoEncoder):
    def __init__(self,*args, **kwargs):
        super().__init__(*args, **kwargs)
        
    def build(self):
        self.encoder = self.build_encoder(self.conf['encoder'])
        layer_conf = self.conf['enc_mu']
        self.enc_mu = self.build_model(layer_conf)
        if 'enc_logvar' in self.conf.keys():
            layer_conf = self.conf['enc_logvar']
        self.enc_logvar = self.build_model(layer_conf)
        
        self.decoder = self.build_decoder(self.conf['decoder'])

    def forward(self,input_x, return_enc = False):
        x = self.forward_model('encoder',input_x)
        mu = self.forward_model('enc_mu',x)
        logvar = self.forward_model('enc_logvar',x)
        z = self.reparameterize(mu, logvar)
        if return_enc:
            return z
        reconst = self.forward_model('decoder',z)
        return z, reconst
    
    def reparameterize(self, mu, logvar):
        if self.training:
            std = logvar.mul(0.5).exp_()
            eps = Variable(std.data.new(std.size()).normal_())
            return eps.mul(std).add_(mu)
        else:
            return mu

    def vae_kl(self, mu,logvar):
        return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    
    def vae_kl_mean(self, mu,logvar):
        return torch.mean(self.vae_kl(mu,logvar),-1)

In [None]:
class Classifier(Network):
    def __init__(self,*args,name=None,**kwargs):
        if name is None:
            self.name = 'classifier'
        else:
            self.name = name
        super().__init__(*args,**kwargs) 
        
    def build(self):
        net_conf = self.conf[self.name]
        if self.name+'_depth' in self.conf.keys():
            depth = self.conf[self.name+'_depth']
            if depth==1:
                net_conf = net_conf[-1:]
            else:
                net_conf = net_conf[:depth-1] + net_conf[-1:]
        self.build_model(net_conf,self.input_dim,self.name)
        
    def forward(self,input_x, return_enc = False):
        y = self.forward_model(self.name,input_x)
        return y
    
