# This notebook is a test bench for working on the core model

In [1]:
import torch
import numpy as np
from torch import nn

In [2]:
import numpy as np
import torch
from typing import Optional
from scipy.optimize import linear_sum_assignment


def cluster_accuracy(y_true, y_predicted, cluster_number: Optional[int] = None):
    """
    Calculate clustering accuracy after using the linear_sum_assignment function in SciPy to
    determine reassignments.
    :param y_true: list of true cluster numbers, an integer array 0-indexed
    :param y_predicted: list  of predicted cluster numbers, an integer array 0-indexed
    :param cluster_number: number of clusters, if None then calculated from input
    :return: reassignment dictionary, clustering accuracy
    """
    if cluster_number is None:
        cluster_number = (
            max(y_predicted.max(), y_true.max()) + 1
        )  # assume labels are 0-indexed
    count_matrix = np.zeros((cluster_number, cluster_number), dtype=np.int64)
    for i in range(y_predicted.size):
        count_matrix[y_predicted[i], y_true[i]] += 1

    row_ind, col_ind = linear_sum_assignment(count_matrix.max() - count_matrix)
    reassignment = dict(zip(row_ind, col_ind))
    accuracy = count_matrix[row_ind, col_ind].sum() / y_predicted.size
    return reassignment, accuracy

In [3]:
from torch.utils.data import Dataset, DataLoader


class TextDataset(Dataset):
    def __init__(self, texts, labels):
        
        self.texts = texts
        self.labels = labels


        self.labels = torch.tensor(self.labels, dtype=torch.float)
        
    def __getitem__(self, index):
        return self.texts[index], self.labels[index]

    
    def __len__(self):
        return len(self.texts)

In [4]:
def lp_distance(X, Y, p=1):
    """
    Computes row wise minkowski distances between matrices X and Y
    """
    return torch.sum(torch.abs(X-Y)**p, dim=1)**(1/p)

In [5]:
from transformers import PreTrainedTokenizer
from typing import *

def mask_tokens(inputs: torch.Tensor, tokenizer) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
    """

    if tokenizer.mask_token is None:
        raise ValueError(
            "This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the --mlm flag if you want to use this tokenizer."
        )

    labels = inputs.clone()
    # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
    probability_matrix = torch.full(labels.shape, 0.15)
    special_tokens_mask = [
        tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
    ]
    probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
    if tokenizer._pad_token is not None:
        padding_mask = labels.eq(tokenizer.pad_token_id)
        probability_matrix.masked_fill_(padding_mask, value=0.0)
    masked_indices = torch.bernoulli(probability_matrix).bool()
    labels[~masked_indices] = -100  # We only compute loss on masked tokens

    # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
    indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
    inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)

    # 10% of the time, we replace masked input tokens with random word
    indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
    random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long)
    inputs[indices_random] = random_words[indices_random]

    # The rest of the time (10% of the time) we keep the masked input tokens unchanged
    return inputs, labels

In [6]:
def mean_sparsity(X):
    """
    Computes the mean of the row-wise sparsity of the input feature-matrix X
    """
    return np.mean(np.count_nonzero(X, axis=1) / X.shape[1])

In [7]:
mean_sparsity(np.array([
    [0, 1, 0],
    [1, 0, 0]
]))

0.3333333333333333

In [8]:
from dataclasses import dataclass
from transformers.file_utils import ModelOutput

@dataclass
class ClusterOutput(ModelOutput):
    
    loss: torch.FloatTensor = None
    predicted_labels: torch.IntTensor = None
    embeddings: torch.FloatTensor = None

In [9]:
class ClusterLM(nn.Module):

    def __init__(self,
                 initial_centroids: torch.tensor,
                 lm_model,
                 tokenizer,
                 metric=lp_distance,
                 do_language_modeling=True,
                 device='cpu'
                 ):
        super(ClusterLM, self).__init__()

        self.initial_centroids = initial_centroids

        self.add_module('lm_model', lm_model)
        self.register_parameter('centroids', nn.Parameter(initial_centroids.clone().float(), requires_grad=True))

        self.tokenizer = tokenizer
        self.metric = metric
        self.do_language_modeling = do_language_modeling
        self.device = device

        self.to(self.device)

    def forward(self, texts, alpha=1.0):
        """
        Input: texts and labels (optional)
        Returns: lm_language modelling output, own output dict (clustering_loss, predicted_labels)
        """
        # Language Modeling Part:

        lm_outputs = None

        if self.do_language_modeling:
            inputs = self.tokenizer(
                texts,
                return_tensors='pt',
                padding=True,
                truncation=True)

            input_ids = inputs['input_ids'].clone()
            input_ids, true_ids = mask_tokens(input_ids, self.tokenizer)
            inputs['input_ids'] = input_ids

            inputs = inputs.to(self.device)
            true_ids = true_ids.to(self.device)
            lm_outputs = self.lm_model(labels=true_ids, **inputs)

        # Clustering Part:
        inputs = self.tokenizer(
            texts,
            return_tensors='pt',
            padding=True,
            truncation=True)

        inputs.to(self.device)

        # 0. Obtain embeddings for each input
        input_embeddings = self.lm_model.base_model(**inputs).last_hidden_state[:, 0, :].float()

        # 1. Compute distances from each input embedding to each centroids
        distances = torch.stack([self.metric(embedding.unsqueeze(0), self.centroids) for embedding in input_embeddings])
        nearest_centroids = torch.argmin(distances.cpu().clone().detach(), dim=1)
        distances = torch.transpose(distances, 0, 1)  # => shape (n_centroids, n_samples)

        # 2. Compute the paramterized softmin for each centroid of each distance to each centroid per input sample
        # Find min distances for each centroid
        min_distances = torch.min(distances, dim=1).values
        # Compute exponetials
        exponentials = torch.exp(- alpha * (distances - min_distances.unsqueeze(1)))
        # Compute softmin
        softmin = exponentials / torch.sum(exponentials, dim=1).unsqueeze(1)

        # 3. Weight the distance between each sample and each centroid
        weighted_distances = distances * softmin

        # 4. Sum over weighted_distances to obtain loss
        clustering_loss = weighted_distances.sum(dim=1).mean()

        # Create clustering output dictionary
        cluster_outputs = ClusterOutput(
            loss=clustering_loss,
            predicted_labels=nearest_centroids.long(),
            embeddings=input_embeddings.cpu().detach()
        )

        return lm_outputs, cluster_outputs


In [10]:
# test params 
N_EPOCHS = 5
DEVICE = 'cuda:0'

In [11]:
import pandas as pd
df = pd.read_csv('datasets/ag_news_subset5.csv')

texts = df['texts'].to_numpy()
labels = df['labels'].to_numpy()

In [12]:
#from sklearn.model_selection import train_test_split

#texts, _, labels, _ = train_test_split(texts, labels, test_size=0.99, random_state=42)

In [13]:
data = TextDataset(texts, labels)
data_loader = DataLoader(dataset=data, batch_size=4, shuffle=False)

In [None]:
from transformers import DistilBertTokenizer, DistilBertForMaskedLM

tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased', return_dict=True)
lm_model = DistilBertForMaskedLM.from_pretrained('distilbert-base-uncased', return_dict=True)
lm_model.to(DEVICE)

In [None]:
from tqdm import tqdm
from sklearn.utils import resample
embeddings = []
for index, text in enumerate(tqdm(texts)):
    inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True)
    inputs = inputs.to(DEVICE)
    outputs = lm_model.base_model(**inputs)
    cls_embedding = outputs.last_hidden_state[:,0,:].flatten().cpu().detach().numpy()
    
    embeddings.append(cls_embedding)

embeddings = np.array(embeddings)

In [None]:
from sklearn.cluster.k_means_ import _k_init
from sklearn.utils.extmath import row_norms

# Using KMeans++ initialization
initial_centroids = _k_init(
    embeddings,
    n_clusters=np.unique(labels).shape[0],
    x_squared_norms=row_norms(embeddings, squared=True),  #aka np.linalg.norm(embeddings, axis=1)**2
    random_state=np.random.RandomState(42))
initial_centroids = torch.from_numpy(initial_centroids).to('cpu')
initial_centroids

In [None]:
model = ClusterLM(initial_centroids=initial_centroids, lm_model=lm_model, tokenizer=tokenizer, device=DEVICE)
list(model.parameters())

In [None]:
from tqdm import tqdm
import numpy as np

opt = torch.optim.AdamW(
    params=model.parameters(),
    lr=2e-5, #2e-5, 5e-7, 5e-10
    eps=1e-8
)

model.train()


emb_hist = []
for epoch in range(N_EPOCHS):
    batch_embs = []
    
    true_labels = []
    predicted_labels = []
    
    pbar = tqdm(data_loader)
    for texts, labels in pbar:

        lm_outputs, cluster_outputs = model(texts=texts)
        
        batch_embs.append(cluster_outputs.embeddings.numpy())

        combined_loss = lm_outputs.loss + 0.025 * cluster_outputs.loss
        
        opt.zero_grad()
        combined_loss.backward()
        opt.step()
        
        pbar.set_description(f'Epoch: {epoch} | LM Loss: {lm_outputs.loss.item()} | Cluster Loss: {cluster_outputs.loss.item()}')
        
        true_labels.extend(labels.numpy().astype('int'))
        predicted_labels.extend(cluster_outputs.predicted_labels.numpy().astype('int'))
    
    emb_hist.append(np.vstack(batch_embs))
    true_labels = np.array(true_labels)
    predicted_labels = np.array(predicted_labels)
    print(f'Epoch: {epoch} | Cluster acc: {cluster_accuracy(true_labels, predicted_labels)}')
    

In [None]:
from umap import UMAP
from sklearn.decomposition import PCA

reducer = UMAP(n_components=2)
Xr = reducer.fit_transform(emb_hist[-1])

In [None]:
import seaborn as sns

sns.scatterplot(x=Xr[:,0], y=Xr[:,1], hue=[f'C{i}' for i in predicted_labels])

In [None]:
emb_hist[-1].shape