In [13]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.autograd import Variable

In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.autograd import Variable
import math
def init_weight_variable(shape):
    
    sigma = 0.01
    print "initializing using sigma %f" % sigma
    initial = torch.zeros(shape)
    initial.uniform_(-sigma, sigma)
    
    if torch.cuda.is_available():
        return nn.Parameter(initial.cuda())
    else:
        return nn.Parameter(initial)

class MultiChannelCompressor(nn.Module):

#     _TAU = 1.0
#     _BATCH_SIZE = 64

    def __init__(self, embedding, n_codebooks, n_centroids,_TAU=1.0,lr=0.0001,export_path=None,top_rate=0.05,num_pointers=1):
        
        """
        M: number of codebooks (subcodes)
        K: number of vectors in each codebook
        """
        
        super(MultiChannelCompressor,self).__init__()
        self.lr = lr
        self.M = n_codebooks
        self.K = n_centroids
        self.tanh = nn.Tanh()
        self.softplus = nn.Softplus()
        self.softmax = nn.Softmax(-1)
        self.V, self.embd_dim  = embedding.size()
        self.export_path = export_path
        self.embedding = nn.Embedding(self.V,self.embd_dim)
        self.embedding.weight =nn.Parameter(Variable(embedding),requires_grad=False)
#         self.bias =nn.Parameter(Variable(torch.rand([1,self.embd_dim])),requires_grad=True)

        self.num_pointer = num_pointers

        self.linear_h = nn.Linear(self.embd_dim, self.M*self.K/2)
        self.linear_h.weight = init_weight_variable([self.M*self.K/2, self.embd_dim])
        
        self.linear_logit = nn.Linear(self.M*self.K/2, self.num_pointer * self.M *self.K)
        self.linear_logit.weight = init_weight_variable([self.num_pointer * self.M * self.K,self.M*self.K/2])
        
        self.linear_mask_h = nn.Linear(self.embd_dim, self.M/2)
        self.linear_mask_h.weight = init_weight_variable([self.M/2, self.embd_dim])
        self.linear_mask_logit = nn.Linear(self.M/2, 2 * self.M)
        self.linear_mask_logit.weight = init_weight_variable([2*self.M ,self.M/2])
        self.num_top = int(self.V * top_rate)
#         self.top_basis = nn.Parameter(embedding[:self.num_top] ,requires_grad=False)
        
        self.codebooks = init_weight_variable([ self.M * self.K, self.embd_dim])
#         self.codebooks =  nn.Parameter(embedding[:self.M * self.K] /200,requires_grad=False)
        
        self.top_words_vecs = embedding[:self.num_top]
        
        self._TAU = _TAU
        self.sigmoid = nn.Sigmoid()
        self.scale = nn.Parameter(Variable(torch.rand([1,self.num_pointer,self.embd_dim])),requires_grad=True)

        parameters = [p for p in self.parameters() if p.requires_grad]
        self.optimizer = torch.optim.Adam(parameters, lr=self.lr)
        
    def _gumbel_dist(self, shape, eps=1e-20):
        U = torch.rand(shape).cuda()
        return -torch.log(-torch.log(U + eps) + eps)

    def _sample_gumbel_vectors(self, logits, temperature):
        y = logits + self._gumbel_dist(logits.size())
        return self.softmax( y / temperature)

    def _gumbel_softmax(self, logits, temperature, sampling=True):
        """
        Compute gumbel softmax.

        """
        return self._sample_gumbel_vectors(logits, temperature)
        
    def _encode(self, word_ids,top_vecs):
        input_embeds = self.embedding(word_ids)
        num_words,dim = input_embeds.size()
        M, K = self.M, self.K
    
        
        h = self.tanh(self.linear_h(input_embeds))
        
        mask_h = self.tanh(self.linear_mask_h(input_embeds))
        mask_logits = self.linear_mask_logit(mask_h)
        mask_logits = torch.log(self.softplus(mask_logits) + 1e-8)
        mask_logits = mask_logits.view([-1,  self.M ,2])
        logits = self.linear_logit(h)
        logits = torch.log(self.softplus(logits) + 1e-8)
        logits = logits.view([-1,  self.num_pointer * M, K])
    
        return input_embeds, logits,mask_logits

    def _decode(self, gumbel_output,top_out):
        
        return gumbel_output.mm(self.codebooks)
    
    def _reconstruct(self, codes, codebooks):
        return None
    
    def var_loss(self,logits):
        prob = self.softmax(logits-logits.max(-1,keepdim=True)[0])
        prob = prob.view(-1,self.num_pointer, self.M, self.K)
        prob = prob.transpose(1,2)
        prob = prob.contiguous().view(-1,self.num_pointer,self.K)
        prob = prob.bmm(prob.transpose(1,2))
        b = torch.eye(self.num_pointer,self.num_pointer) 
        b = b.cuda()
        loss_ =  ((prob - b [None,:,:]) ** 2).sum(-1).sum(-1)
        loss_ = loss_.view(-1,self.M)
        loss_ = loss_.sum(-1)
        
        return loss_
    
    def compress(self, word_ids,freq,top_vecs):
        """Export the graph for exporting codes and codebooks.

        Args:
            embed_matrix: numpy matrix of original embeddings
        """
        vocab_size,embed_size = self.V,self.embd_dim
        
        # Coding
        input_embeds, logits,mask_logits = self._encode(word_ids,top_vecs)  # ~ (B, M, K)
#       
        # Discretization
        D = self._gumbel_softmax(logits, self._TAU, sampling=True)
        D_mask = self._gumbel_softmax(mask_logits, self._TAU/2, sampling=True)
        

        D_mask = D_mask[:,:,0] #- D_mask[:,:,1]
        
        D = D.view(-1,self.num_pointer,self.M,self.K) * D_mask.view([-1,1,self.M,1])

        gumbel_output = D.view([-1,  self.M * self.K])   # ~ (B, M * K)
        
        # Decoding
        y_hat = self._decode(gumbel_output, D_mask)
        y_hat = (self.sigmoid(self.scale) * y_hat.view(-1,self.num_pointer,embed_size)).sum(1) #+ self.bias
        
    
        mask = (freq > 100 ).float()
        
        # Define loss
        loss = 0.5   * ((y_hat - input_embeds)**2).sum(dim=1)
        
        loss_ =  0.5  * ((y_hat - input_embeds)**2).sum(dim=1)
        loss_freq = (loss_ * mask).sum()
        
        loss = loss.mean()
        
#         loss += 0.05 * vloss.mean()

#         loss = loss + 0.1 * loss_top
        
#         print loss_top 
        
        return loss, y_hat - input_embeds, loss_freq


    def compress_apply(self, word_ids, top_vecs):
        """
        Args:
            word_ids to be encoded
        """
        vocab_size, embed_size = [self.V,self.embd_dim]

        batch_size = word_ids.size(0)
        # Define codebooks
        codebooks = self.codebooks
        # Coding
        input_embeds, logits,mask_logits = self._encode(word_ids,top_vecs)  # ~ (B, M, K)
        
        D_mask = (mask_logits[:,:,0] > mask_logits[:,:,1]).float()
#         D_mask = D_mask.view(batch_size,self.num_pointer,self.M,self.K)
        codes = logits.argmax(-1)
        
#         D_mask = torch.gather(D_mask,-1,codes_)
        
        
        codes = codes.view(-1,self.num_pointer,self.M).view(-1,self.M)

        # Reconstruct
        offset = torch.arange(0,self.M).long().cuda() * self.K
        codes_with_offset = codes + offset.view(1, self.M)
        codes_with_offset = codes_with_offset.view(-1)
        
        selected_vectors = torch.index_select(codebooks,0,codes_with_offset)  # ~ (B, M, H)
        selected_vectors = selected_vectors.view(batch_size,self.num_pointer,self.M,-1) * D_mask[:,None,:,None]       
        selected_vectors = selected_vectors.sum(2)

        reconstructed_embed =   (self.sigmoid(self.scale)* selected_vectors).sum(1)#+ self.bias#
        
        return reconstructed_embed.cpu().data.numpy(), input_embeds.cpu().data.numpy(),D_mask.cpu().data.numpy()
    
    
    def _train(self,freq,sorted_index,batch_size=64,num_epochs=10):
        
        self.train()
        vocab_size = self.V
        top_reserve = self.num_top
        word_list  = np.arange(0,vocab_size)

        num_batches = (vocab_size) / batch_size 
        best_loss = 1000.
        
        if vocab_size % batch_size:
            num_batches += 1
            
        for epoch in range(num_epochs):
            
            monitor_loss = 0.0
            np.random.shuffle(word_list)
            _freq = [freq[w] for w in word_list]
            
            freq_loss = 0
            for i in range(num_batches):

                words = word_list[i*batch_size:(i+1)*batch_size]
                words = Variable(torch.LongTensor(words))
                words = words.cuda()
                freq_ = Variable(torch.Tensor(_freq[i*batch_size:(i+1)*batch_size])).cuda()
                loss,diff,f_loss = self.compress(words,freq_,self.top_words_vecs)
                
                loss.backward()
                torch.nn.utils.clip_grad_norm(self.parameters(),0.001)
                self.optimizer.step()
                self.zero_grad()
                monitor_loss += loss.data.cpu()
                freq_loss += f_loss.data.cpu()
            monitor_loss /= num_batches
            monitor_loss = monitor_loss.data.cpu().numpy()
            print "epoch:",epoch,"loss:",monitor_loss
            if monitor_loss <= best_loss * 0.99:
                best_loss = monitor_loss
                self._test(sorted_index,64)
                print "best_loss:", monitor_loss
                
    def _test(self,sorted_index, batch_size=64):
        
        self.eval()
        vocab_size = self.V
        top_reserve = self.num_top
        word_list  = np.arange(0,vocab_size)
        

        
        new_embedding = np.zeros([self.V,self.embd_dim],dtype=np.float32)
        old_embedding = self.embedding.weight.data.cpu().numpy()

        num_batches = (vocab_size) / batch_size 

        if vocab_size % batch_size:
            num_batches += 1
            
        distance = []
        reconstructed_ = []
        top_words = set()
        num_codes = 0
        for i in range(num_batches):

            words = word_list[i*batch_size:(i+1)*batch_size]
            words = Variable(torch.LongTensor(words))
            words = words.cuda()

            reconstructed_vecs, original_vecs,activate_code = self.compress_apply(words,self.top_words_vecs)
            num_codes +=  activate_code.sum()

            reconstructed_.append(reconstructed_vecs)
            distance.append(np.linalg.norm(reconstructed_vecs - original_vecs,2,-1))
        reconstruct = np.concatenate(reconstructed_,0)
        
        distance = np.concatenate(distance,-1)
        new_embedding[sorted_index] = reconstruct
        cnt = 0 
        
        for i in range(top_reserve):
            if distance[i] < 5.0:
                cnt += 1
            else:
                new_embedding[sorted_index[i]] = self.top_words_vecs[i]
#         new_embedding[sorted_index[:top_reserve]] = self.top_words_vecs

        if self.export_path:
            np.save(open(self.export_path,'w'),new_embedding)
        print "test:",distance.mean()
        print "%d top words get compressed" % (cnt) 
        compression = (top_reserve-cnt+self.num_pointer) * self.embd_dim * 16 + self.M*self.K*self.embd_dim * 16 + (self.V-top_reserve) * self.M + num_codes*self.num_pointer*math.log(self.K,2)
        print ""
        print "compression_rate:",(self.V * self.embd_dim * 32)/ compression
        

In [15]:
matrix_ = './data/W1.npy' #path to W1.npy
mat = np.load(matrix_)
def remap(matrix, freq):
    index =  np.argsort(freq)[::-1]
    return matrix[index],index,[freq[i] for i in index]
import cPickle as pickle
freq = np.load("./data/freqs") #path to frequency

freq = [freq[i] for i in range(len(mat))]
mat_,index,freq = remap(mat,freq)
matrix = torch.from_numpy(mat_)
matrix = matrix

M = 16
K = 16

model = MultiChannelCompressor(matrix,M,K,_TAU=1.0,lr=0.001
                            ,export_path="./data/W1.shu.top_cut.npy",top_rate=0.05,num_pointers=3) #output to data/W1.shu.top_cut.npy

if torch.cuda.is_available():
    model = model.cuda()
model._train(freq,index,64,num_epochs=200)


initializing using sigma 0.010000
initializing using sigma 0.010000
initializing using sigma 0.010000
initializing using sigma 0.010000
initializing using sigma 0.010000




epoch: 0 loss: 19.267584
test: 6.299299
0 top words get compressed

compression_rate: 25.1247653975031
best_loss: 19.267584
epoch: 1 loss: 18.197708
test: 6.232256
0 top words get compressed

compression_rate: 25.10642508985485
best_loss: 18.197708


KeyboardInterrupt: 