In [2]:
import torch
import numpy as np

torch.set_default_tensor_type(torch.FloatTensor)

In [8]:
def to_32(x):
    return x.type(torch.FloatTensor)

def to_64(x):
    return x.type(torch.DoubleTensor)

def pos_gram(gram, regularlizer = None):
    _type = gram.type()
    _size = len(gram)
    if regularlizer is None:
        if gram.abs().max() == 0:
            raise ValueError("gram error, expect matrix with none-zero element")
        
        # the fraction of float32 is 2**(-23)~10**(-7) we start with 10**(-7) times of maximun element
        regularlizer = gram.abs().max()*0.0000001
    
    if regularlizer <= 0:
        raise ValueError("regularlizer error, expect positive, got %s" %(regularlizer))
    
    while True:
        lambdas, vectors = torch.symeig(gram + regularlizer*torch.eye(_size).type(_type))
        if lambdas.min() > 0:
            break
        
        regularlizer *= 2.
    
    return gram + regularlizer*torch.eye(_size).type(_type)

class LinearExpander():
    def __init__(self, linear_model, activation_function, candidate_num=1, std = None):
        self.linear_model = linear_model
        self.activation_function = activation_function
        self.candidate = torch.nn.Linear(self.linear_model.in_features, candidate_num)
        if std is not None:
            self.candidate.weight.data *= torch.tensor(std*(3*self.candidate.in_features)**0.5)
        
        self.reset()
    
    def reset(self):
        # regressor_gram : store X^t*X
        # projector : store X^t*Y
        # responsor_ss : store component-wise square sum of y (=diag(Y^t*Y))
        self.regressor_gram = torch.zeros((self.linear_model.out_features+1, self.linear_model.out_features+1))
        self.projector = torch.zeros((self.linear_model.out_features+1, self.candidate.out_features))
        self.responsor_ss = torch.zeros((self.candidate.out_features))
        self.datums_acc = 0
    
    def data_input(self, data):
        datums = data.size()[0]
        regressor = self.linear_model(data)
        regressor = self.activation_function(regressor)
        expand = torch.cat((regressor, torch.ones((datums, 1))), 1)
        self.regressor_gram += torch.mm(expand.t(), expand)
        
        responsor = self.candidate(data)
        responsor = self.activation_function(responsor)
        self.projector += torch.mm(expand.t(), responsor)
        
        self.responsor_ss += (responsor**2).sum(0)
        
        self.datums_acc += datums
    
    def take(self, take_num=1, choice_round=0):
        lots_num = self.candidate.out_features
        if take_num > lots_num:
            raise ValueError("take_num exceed candidate")
        
        if (choice_round == 0):
            took_idx = torch.randperm(lots_num)[:take_num]
        else:
            mean_gram = pos_gram(to_64(self.regressor_gram / self.datums_acc)) # avoid singular gram
            lambdas, vectors = torch.symeig(mean_gram, eigenvectors=True) # eigen
            mean_projector = to_64(self.projector/self.datums_acc)
            lambdas_inv = 1/(lambdas+0.0000001) 
            VtXtY = vectors.t().mm(mean_projector)
            dependency = ((VtXtY.t()*lambdas_inv).t()*VtXtY).sum(0) # diag of Y^t*X*Gram^(-1)*X^t*Y
            independency = to_64(self.responsor_ss/self.datums_acc) - dependency
            took_idx = []
            for t in range(take_num):
                hit = np.random.randint(self.candidate.out_features)
                while hit in took_idx:
                    hit = np.random.randint(self.candidate.out_features)
                
                for c in range(choice_round):
                    new_hit = np.random.randint(self.candidate.out_features)
                    while new_hit in took_idx:
                        new_hit = np.random.randint(self.candidate.out_features)
                    
                    if np.random.rand() < (independency[new_hit]/independency[hit]):
                        hit = new_hit
                
                took_idx.append(hit)
            
        return took_idx

In [9]:
foo = torch.nn.Linear(20,10)

In [10]:
myLE = LinearExpander(foo, torch.tanh, candidate_num=20)

In [11]:
data = np.random.normal(0,1,(1000,20))
myLE.data_input(torch.FloatTensor(data))

In [12]:
counter = torch.zeros((20))
for t in range(10000):
    counter[myLE.take(1, choice_round=10)] += 1

In [13]:
counter

tensor([286., 676., 464., 579., 645., 596., 577., 364., 532., 645., 610., 453.,
        426., 393., 371., 470., 469., 519., 409., 516.])

In [14]:
regressor = torch.FloatTensor(data)
regressor = myLE.linear_model(regressor)
regressor = myLE.activation_function(regressor)
regressor = np.array(regressor.data)
regressor = np.concatenate((regressor, np.ones((len(regressor), 1))), axis=1)

responsor = torch.FloatTensor(data)
responsor = myLE.candidate(responsor)
responsor = myLE.activation_function(responsor)
responsor = np.array(responsor.data)

In [15]:
np.array(counter)/np.linalg.lstsq(regressor, responsor)[1]

  """Entry point for launching an IPython kernel.


array([5.16156042, 4.63571527, 4.96884496, 4.81084715, 4.57201779,
       4.4853007 , 4.69594333, 4.3789729 , 4.38589555, 4.47402833,
       4.46076148, 4.77183801, 4.49550764, 4.4203575 , 4.81240436,
       4.43681148, 4.639682  , 4.74345333, 4.23209378, 4.31664408])