In [1]:
import torch
import numpy as np
from types import SimpleNamespace
import torch.nn.functional as F

In [2]:
class Net(torch.nn.Module):
    def __init__(self, n_feature, n_output):
        super(Net, self).__init__()
        self.hidden_1= torch.nn.Linear(n_feature,50)
        self.hidden_2 = torch.nn.Linear(50,30)
        self.predict = torch.nn.Linear(30,n_output)

    def forward(self, x):
        x = F.relu(self.hidden_1(x))
        x = F.relu(self.hidden_2(x))
        x = self.predict(x)
        return x

In [3]:
#data generation
torch.manual_seed(4) 
x_train = torch.unsqueeze(torch.linspace(-1, 1, 1000), dim=1)  # x data (tensor), shape=(100, 1)
y_train = x_train.pow(2) 

#training loop for MSE
net_mse = Net(n_feature=x_train.shape[1], n_output=1)     # define the network
optimizer = torch.optim.Adam(net_mse.parameters(), lr=0.0005)
loss_func = torch.nn.MSELoss()  # this is for regression mean squared loss

for epoch in range(100):
  prediction = net_mse(x_train)     # input x and predict based on x
  loss = loss_func(prediction,y_train)
  optimizer.zero_grad()   # clear gradients for next train
  loss.backward()         # backpropagation, compute gradients
  optimizer.step()        # apply gradients

In [4]:
def find_groups(args,W_ij_ab,W_ij_bc,debug=False):
    
    A = (np.abs(W_ij_ab) > args.xi).astype(int)  #设置A值（大于某阈值为1）
    B = (np.abs(W_ij_bc) > args.xi).astype(int)  #设置B值（大于某阈值为1）
  
    layer_a, layer_b, layer_c = W_ij_ab.shape[0], W_ij_ab.shape[1], W_ij_bc.shape[1] #上中下层
    #print("layer_a:{}".format(layer_a))
    #print("layer_b:{}".format(layer_b))
    #print("layer_c:{}".format(layer_c))

    #print("A:{}".format(A.shape))
    #print("B:{}".format(B.shape))

    
    q_kc = np.random.dirichlet(np.ones(args.c),size=layer_b) #从dirichelet分布中抽取样本初始化q_kc，return的size为（size，k）
    prev = q_kc
    #print("prev:{}".format(prev))
    diff=0

    for t in range(args.iterations):
        #print(np.argmax(q_kc,axis=1))
        #Eq 8
        Pi = q_kc.sum(axis=0)/q_kc.shape[0]

        tau_a = A.dot(q_kc)
        tau_a = (tau_a/tau_a.sum(axis=0,keepdims=True))

        tau_b = q_kc.transpose().dot(B).transpose()
        tau_b = (tau_b/tau_b.sum(axis=0,keepdims=True))


        #shape check
        #print("Pi:{}".format(Pi.shape)) #(3,)
        #print("tau_a:{}".format(tau_a.shape)) #(1,3)
        #print("tau_b:{}".format(tau_b.shape))  #(30,3)
        #print("first part shape:{}".format(np.prod(tau_b[:,1:2]**B.transpose(),axis=0).shape)) #(50,)
        #print("second part shape:{}".format(np.prod(tau_a[:,1:2]**A,axis=0).shape))  #(50,)
        #print("pi something:{}".format(Pi[1].shape)) #（ ）
        #print("log part b:{}".format(np.sum(np.log(tau_b[:,1:2]**B.transpose()),axis=0).shape)) #(50,)
        #print("log part a:{}".format(np.sum(np.log(tau_a[:,1:2]**A),axis=0).shape)) #(50,)
        #print("log part multi:{}".format((Pi[1]*np.multiply(np.prod(tau_a[:,1:2]**A,axis=0), np.prod(tau_b[:,1:2]**B.transpose(),axis=0))).shape))
          

        #with log version
        log_q_kc =np.array([np.log(Pi[c])+np.sum(np.log(tau_b[:,c:(c+1)]**B.transpose()),axis=0)+np.sum(np.log(tau_a[:,c:(c+1)]**A),axis=0)for c in range(args.c)]) #shape (3,50)
        #q_kc_test = np.array([1/((np.exp(log_q_kc - log_q_kc[c,:])).sum(axis=0)) for c in range(args.c)]).transpose() #origin version
        q_kc_test=np.array(np.exp(log_q_kc)/np.exp(log_q_kc).sum(axis=0)).transpose()


        #non-log version
        q_kc_num = np.array([Pi[c] * np.multiply(np.prod(tau_a[:,c:(c+1)]**A,axis=0), np.prod(tau_b[:,c:(c+1)]**B.transpose(),axis=0)) for c in range(args.c)]) #(3,50)
        #qdiff=np.exp(log_q_kc)-q_kc_num
        #print("qdiff:{}".format(qdiff))  #very small difference
        #print("qkc_num shape:{}".format(q_kc_num.shape))
        #q_kc = np.array([1/((q_kc_num/q_kc_num[c,:]).sum(axis=0)) for c in range(args.c)]).transpose() #(50,3)
        q_kc = np.array(q_kc_num/(q_kc_num.sum(axis=0))).transpose()
        

        #versions and shape check
        #qdiff=q_kc-q_kc_test
        #print("qdiff:{}".format(qdiff)) 
        #print("q_kc shape:{}".format(q_kc.shape)) 
        #print("q_kc:{}".format(q_kc))
        

        
        memberships = np.argmax(q_kc,axis=1)
        #print("memberships:{}".format(memberships))
        #if np.where(memberships==memberships[0])[0].shape[0] == memberships.shape[0]:
        #    break

        diff = (np.abs(prev - q_kc)).sum()  
        #print("diff:{}".format(diff))    
        if diff< args.threshold:
            tau_a = np.where(tau_a==0, 1e-10, tau_a)
            tau_b = np.where(tau_b==0, 1e-10, tau_b)

            var1 = np.log(Pi[memberships]).sum() 
            var2 = sum([sum(A[:,i]*np.log(tau_a[:,memberships[i]])) for i in range(memberships.shape[0])]) 
            var3 = sum([sum(B[i,:]*np.log(tau_b[:,memberships[i]])) for i in range(memberships.shape[0])])
            
            likelihood = np.log(Pi[memberships]).sum() + \
                sum([sum(A[:,i]*np.log(tau_a[:,memberships[i]])) for i in range(memberships.shape[0])]) + \
                sum([sum(B[i,:]*np.log(tau_b[:,memberships[i]])) for i in range(memberships.shape[0])])

            return True,np.argmax(q_kc,axis=1),diff,t,likelihood

        prev = q_kc
    return False,[-1]*layer_b,diff,t,-np.Inf

In [5]:
for i, p in enumerate(net_mse.parameters()):
  print("index:{}  p:{}".format(i,p.shape))

index:0  p:torch.Size([50, 1])
index:1  p:torch.Size([50])
index:2  p:torch.Size([30, 50])
index:3  p:torch.Size([30])
index:4  p:torch.Size([1, 30])
index:5  p:torch.Size([1])


In [6]:
def creategroups(class_number):		
  groups_all=[]
  W_ij_ab = None			
  for i, p in enumerate(net_mse.parameters()): # per layer weight matrix
    ''' if len(p.shape)==1:
      continue '''
    #print(i)
    arguments = SimpleNamespace(c=class_number,xi =0.05,iterations=2,threshold=6)				

    args = SimpleNamespace(**vars(arguments))

    #index_layer = i // 2 # as w, b are seperate in net.parameters()
    groups_layer = []
    W_ij_bc = p.data.numpy().transpose()
    #W_ij_bc = p.cpu().data.numpy().transpose()
    
    if i ==2:
      #print(index_layer)				
      print("ab:{}".format(W_ij_ab.shape))
      print("bc:{}".format(W_ij_bc.shape))
      
      trials, prev_memberships,prev_likelihood = 0, [], -np.Inf
      while True:
        converge,memberships,difference,t, likelihood = find_groups(arguments,W_ij_ab,W_ij_bc)
        #print(converge)
        #print("likelihood:{}".format(likelihood))
        if converge:
          trials+=1						
          if likelihood>prev_likelihood:
            prev_memberships,prev_likelihood = memberships,likelihood						
          if trials>10:
            for c in range(max(prev_memberships)+1):					
              #print("something:{}".format(prev_memberships))
              groups_layer.append(np.where(prev_memberships==c)[0])
            #print(memberships)
            break
        class_number = min(max(3,class_number+np.random.randint(-1,2)),int(W_ij_bc.shape[0]/4))
      groups_all.append(groups_layer) 
      
    if i == 4:
      #print(index_layer)				
      print("ab:{}".format(W_ij_ab.shape))
      print("bc:{}".format(W_ij_bc.shape))
      
      trials, prev_memberships,prev_likelihood = 0, [], -np.Inf
      while True:
        converge,memberships,difference,t, likelihood = find_groups(arguments,W_ij_ab,W_ij_bc)
        #print(converge)
        #print("likelihood:{}".format(likelihood))
        if converge:
          trials+=1						
          if likelihood>prev_likelihood:
            prev_memberships,prev_likelihood = memberships,likelihood						
          if trials>10:
            for c in range(max(prev_memberships)+1):					
              #print("something:{}".format(prev_memberships))
              groups_layer.append(np.where(prev_memberships==c)[0])
            #print(memberships)
            break
        class_number = min(max(3,class_number+np.random.randint(-1,2)),int(W_ij_bc.shape[0]/4))
      groups_all.append(groups_layer)

    if i==0 or i==2:
      W_ij_ab=W_ij_bc
  return groups_all

In [7]:
result=creategroups(3)
result

ab:(1, 50)
bc:(50, 30)
ab:(50, 30)
bc:(30, 1)


[[array([ 0,  2,  4,  5,  6, 10, 12, 15, 21, 23, 25, 27, 28, 30, 31, 41, 42,
         44]),
  array([ 1,  3,  8, 11, 13, 14, 16, 19, 20, 29, 33, 36, 39, 40, 43, 45]),
  array([ 7,  9, 17, 18, 22, 24, 26, 32, 34, 35, 37, 38, 46, 47, 48, 49])],
 [array([ 2,  3, 11, 15, 16, 22, 23, 26]),
  array([ 0,  1,  5,  9, 12, 13, 14, 18, 19]),
  array([17, 20, 27, 28]),
  array([ 4,  7,  8, 10, 21, 24, 25]),
  array([ 6, 29])]]