In [1]:
from sklearn.cluster import MiniBatchKMeans
import numpy as np
X = np.array([[1, 2], [1, 4], [1, 0],
              [4, 2], [4, 0], [4, 4],
              [4, 5], [0, 1], [2, 2],
              [3, 2], [5, 5], [1, -1]])
# manually fit on batches
kmeans = MiniBatchKMeans(n_clusters=2,
                         random_state=0,
                         batch_size=6)
kmeans = kmeans.partial_fit(X[0:6,:])
kmeans = kmeans.partial_fit(X[6:12,:])
kmeans.cluster_centers_


kmeans.predict([[0, 0], [4, 4]])

# fit on the whole data
kmeans = MiniBatchKMeans(n_clusters=2,
                         random_state=0,
                         batch_size=6,
                         max_iter=10).fit(X)
kmeans.cluster_centers_


kmeans.predict([[0, 0], [4, 4]])

array([1, 0], dtype=int32)

In [20]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Kmeans_Batch(nn.Module):
    def __init__(self, commitment_cost=0.25):
        super().__init__()
        
        self.win_size = 41
        
        num_embeddings = 40
        embedding_dim = self.win_size * 12
        self._embedding_dim = embedding_dim
        self._num_embeddings = num_embeddings
        self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim)
        self._commitment_cost = commitment_cost

    def forward(self, inputs):
        
        x = inputs
        
        x_transpose = x.transpose(-1,-2) #[B, T, A]
        x_pad = F.pad(x, pad=(0,0,(self.win_size-1)//2,(self.win_size-1)//2,0,0), mode='constant', value=0) #[B,T+win,A]
        x_unfold = x_pad.unfold(1, self.win_size,1) #[B, T, A, win]
        x_unfold_reshape = x_unfold.reshape(x_unfold.shape[0], x_unfold.shape[1], x_unfold.shape[2]*x_unfold.shape[3]) #[B,T,A*win]
        
        inputs = x_unfold_reshape

        #shape is B,T,D
        
        # Flatten input
        batch_size = inputs.shape[0]
        input_shape = inputs.shape
        flat_input = inputs.view(-1, self._embedding_dim) #[B*T, D]
        
        # Calculate distances
        distances = (torch.sum(flat_input**2, dim=1, keepdim=True) 
                    + torch.sum(self._embedding.weight**2, dim=1)
                    - 2 * torch.matmul(flat_input, self._embedding.weight.t()))
            
        # Encoding
        encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
        encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device)
        encodings.scatter_(1, encoding_indices, 1)
        
        
        # Quantize and unflatten
        quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape)
        
        # Loss
        loss_vq = F.mse_loss(quantized, inputs)
        

        return loss_vq, quantized, encoding_indices
    
    
    

In [14]:
model = Kmeans_Batch()


In [17]:
loss_vq, quantized, encoding_indices = model(x_unfold_reshape)