In [1]:
import torch
from torch.utils.data import DataLoader,TensorDataset
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import sys
sys.path.append("..")
from torch import nn, optim
from torch.nn import functional as F
from src.datasets.dataset_wrappers import AugRepresentationDataset
from IPython.display import display,clear_output
import tqdm
from sklearn.metrics import normalized_mutual_info_score as nmi
import umap

In [2]:
#import the data
org = torch.load("../originals.pt")
aug = torch.load("../augmentations.pt")
labels = torch.load("../labels.pt")
nn_indices = torch.load("../nn_indices.pt")
dataset = AugRepresentationDataset(aug,org,labels)


In [3]:
reducer = umap.UMAP(n_components=2,random_state=42,metric="cosine")
emb = reducer.fit_transform(org)


  warn(f"n_jobs value {self.n_jobs} overridden to 1 by setting random_state. Use no seed for parallelism.")


In [4]:
#get umap kmeans labels
from sklearn.cluster import KMeans
kmeans = KMeans(n_clusters=10,random_state=42)
umap_labels = kmeans.fit_predict(emb)


In [None]:
class WarmStarter(nn.Module):
    def __init__(self):
        super(WarmStarter,self).__init__()
        self.block1 = nn.Sequential(
            nn.Linear(512,1024),
            nn.Dropout(0.1),
            nn.GELU(),
            nn.Linear(1024,1024),
            nn.Dropout(0.1),
            nn.GELU(),
            nn.Linear(1024,10)
        )

    def forward(self,w):
        return self.block1(w)
    

# Pretraining

In [9]:
print(nmi(umap_labels,labels))
def hungarian_matching(y_pred, y_true):
    """
    Perform Hungarian matching to align predicted labels with true labels.
    
    Args:
        y_pred: predicted cluster labels
        y_true: true labels
        
    Returns:
        y_pred_hungarianmatched: predicted labels after Hungarian matching
    """
    from scipy.optimize import linear_sum_assignment
    
    # Convert to numpy if tensors
    if hasattr(y_pred, 'cpu'):
        y_pred = y_pred.cpu().numpy()
    if hasattr(y_true, 'cpu'):
        y_true = y_true.cpu().numpy()
    
    # Get unique labels
    pred_labels = np.unique(y_pred)
    true_labels = np.unique(y_true)
    
    # Create confusion matrix
    confusion_matrix = np.zeros((len(pred_labels), len(true_labels)))
    
    for i, pred_label in enumerate(pred_labels):
        for j, true_label in enumerate(true_labels):
            confusion_matrix[i, j] = np.sum((y_pred == pred_label) & (y_true == true_label))
    
    # Apply Hungarian algorithm (minimize cost, so use negative of confusion matrix)
    row_indices, col_indices = linear_sum_assignment(-confusion_matrix)
    
    # Create mapping from predicted labels to true labels
    label_mapping = {}
    for i, j in zip(row_indices, col_indices):
        label_mapping[pred_labels[i]] = true_labels[j]
    
    # Apply mapping to predicted labels
    y_pred_hungarianmatched = np.array([label_mapping[label] for label in y_pred])
    
    return y_pred_hungarianmatched

umap_mapped = hungarian_matching(umap_labels,labels)
print((sum(umap_mapped == labels)/len(labels)))

0.9623271475908873
tensor(0.9857)


In [30]:
model = WarmStarter()
device = 'cuda'
model.to(device)
optimizer = optim.Adam(model.parameters(),lr=0.0003)
#pretraining batches
umap_dataset = AugRepresentationDataset(aug,org,umap_labels)
umap_dataloader = DataLoader(umap_dataset,batch_size=1000,shuffle=True)
criterion = nn.CrossEntropyLoss()
pretraining_loss = []
#run 10 epochs
progress_bar = tqdm.tqdm(range(25))
for epoch in progress_bar:
    temp_loss = []
    
    for (w1,w2,w,l) in umap_dataloader:
        p = model(w1.to(device))
        loss = criterion(p,l.to(device).long())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        temp_loss.append(loss.item())
    
    pretraining_loss.append(np.mean(temp_loss))
    progress_bar.set_description(f"Pretraining Loss: {pretraining_loss[-1]:.4f}")


  0%|          | 0/25 [00:00<?, ?it/s]

Pretraining Loss: 1.4685: 100%|██████████| 25/25 [00:25<00:00,  1.04s/it]


In [34]:

umap_dataloader = DataLoader(umap_dataset,batch_size=1000,shuffle=False)
predictions = []
model.train()
for (w1,w2,w,l) in umap_dataloader:
    with torch.no_grad():
        p = model(w.to(device))
        
        predictions.append(p.argmax(dim=1).cpu().numpy())

predictions = np.concatenate(predictions)
print(nmi(predictions,labels))

0.9612386386305357


In [16]:

def iic_loss(p1,p2,lamb=1.0,eps=1e-12,scaling=True):
    """
    Args:
        p1 (batch_size, num_clusters): logits 1
        p2 (batch_size, num_clusters): logits 2
    """
    p1,p2 = F.softmax(p1,dim=1),F.softmax(p2,dim=1)
    
    #joint probability dist
    #P = (p1.unsqueeze(2) * p2.unsqueeze(1)).mean(dim=0)
    P = 1/p1.shape[0] * p1.T @ p2

    #marginal probabilities
    P1 = p1.mean(dim=0).unsqueeze(1)
    P2 = p2.mean(dim=0).unsqueeze(0)

    #compute the mutual information
    MI = (P * (torch.log(P + eps) - lamb * torch.log((P1+eps)*(P2+eps)))).sum()
    if scaling:
        k = p1.shape[1]
        MI -= np.log(k)*(2*lamb-2)
        MI /= (np.log(k))

    return -MI



def distribution_compression(x1,x2,p1,p2,tau=0.05,lamb=1.0,eps=1e-4):

    Px = torch.exp((x1 @ x2.t()-1)/tau)
    Qx = 1 - Px
    Py = p1 @ p2.t()
    Qy = 1 - Py

    PlogP = -(Py*torch.log(Px+eps)).mean()
    QlogQ = -(Qy*torch.log(Qx+eps)).mean()
    
    return PlogP + lamb * QlogQ

In [27]:
iic_losses = [] 
progress_bar = tqdm.tqdm(range(50))
umap_dataloader = DataLoader(umap_dataset,batch_size=1000,shuffle=True)
for epoch in progress_bar:
    temp_loss = []
    for w1,w2,w,l in umap_dataloader:
        p1 = model(w1.to(device))
        p2 = model(w2.to(device))
        loss = iic_loss(p1,p2,lamb=1.0)
        print(loss)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        temp_loss.append(loss.item())
    iic_losses.append(np.mean(temp_loss))
    progress_bar.set_description(f"IIC Loss: {iic_losses[-1]:.4f}")
    

  0%|          | 0/50 [00:00<?, ?it/s]

tensor(-0.0009, device='cuda:0', grad_fn=<NegBackward0>)
tensor(-0.0009, device='cuda:0', grad_fn=<NegBackward0>)
tensor(-0.0009, device='cuda:0', grad_fn=<NegBackward0>)
tensor(-0.0009, device='cuda:0', grad_fn=<NegBackward0>)
tensor(-0.0009, device='cuda:0', grad_fn=<NegBackward0>)
tensor(-0.0009, device='cuda:0', grad_fn=<NegBackward0>)
tensor(-0.0009, device='cuda:0', grad_fn=<NegBackward0>)
tensor(-0.0009, device='cuda:0', grad_fn=<NegBackward0>)
tensor(-0.0009, device='cuda:0', grad_fn=<NegBackward0>)
tensor(-0.0009, device='cuda:0', grad_fn=<NegBackward0>)
tensor(-0.0009, device='cuda:0', grad_fn=<NegBackward0>)
tensor(-0.0009, device='cuda:0', grad_fn=<NegBackward0>)
tensor(-0.0009, device='cuda:0', grad_fn=<NegBackward0>)
tensor(-0.0009, device='cuda:0', grad_fn=<NegBackward0>)
tensor(-0.0009, device='cuda:0', grad_fn=<NegBackward0>)
tensor(-0.0009, device='cuda:0', grad_fn=<NegBackward0>)
tensor(-0.0009, device='cuda:0', grad_fn=<NegBackward0>)
tensor(-0.0009, device='cuda:0'

IIC Loss: -0.0009:   2%|▏         | 1/50 [00:01<01:13,  1.51s/it]

tensor(-0.0009, device='cuda:0', grad_fn=<NegBackward0>)
tensor(-0.0009, device='cuda:0', grad_fn=<NegBackward0>)
tensor(-0.0009, device='cuda:0', grad_fn=<NegBackward0>)
tensor(-0.0009, device='cuda:0', grad_fn=<NegBackward0>)
tensor(-0.0009, device='cuda:0', grad_fn=<NegBackward0>)
tensor(-0.0009, device='cuda:0', grad_fn=<NegBackward0>)
tensor(-0.0009, device='cuda:0', grad_fn=<NegBackward0>)
tensor(-0.0009, device='cuda:0', grad_fn=<NegBackward0>)
tensor(-0.0009, device='cuda:0', grad_fn=<NegBackward0>)
tensor(-0.0009, device='cuda:0', grad_fn=<NegBackward0>)
tensor(-0.0009, device='cuda:0', grad_fn=<NegBackward0>)
tensor(-0.0009, device='cuda:0', grad_fn=<NegBackward0>)
tensor(-0.0009, device='cuda:0', grad_fn=<NegBackward0>)
tensor(-0.0009, device='cuda:0', grad_fn=<NegBackward0>)
tensor(-0.0009, device='cuda:0', grad_fn=<NegBackward0>)
tensor(-0.0009, device='cuda:0', grad_fn=<NegBackward0>)
tensor(-0.0009, device='cuda:0', grad_fn=<NegBackward0>)
tensor(-0.0009, device='cuda:0'

IIC Loss: -0.0009:   4%|▍         | 2/50 [00:02<01:05,  1.37s/it]

tensor(-0.0009, device='cuda:0', grad_fn=<NegBackward0>)
tensor(-0.0009, device='cuda:0', grad_fn=<NegBackward0>)
tensor(-0.0009, device='cuda:0', grad_fn=<NegBackward0>)
tensor(-0.0009, device='cuda:0', grad_fn=<NegBackward0>)
tensor(-0.0009, device='cuda:0', grad_fn=<NegBackward0>)





KeyboardInterrupt: 

In [28]:
p1[0]

tensor([7.0085e-20, 4.1113e-19, 1.7329e-17, 8.4034e-27, 5.8317e-24, 1.0000e+00,
        2.0896e-16, 5.0117e-25, 8.5955e-24, 1.8630e-22], device='cuda:0',
       grad_fn=<SelectBackward0>)

In [29]:

umap_dataloader = DataLoader(umap_dataset,batch_size=1000,shuffle=False)
predictions = []
for (w1,w2,w,l) in umap_dataloader:
    with torch.no_grad():
        p = model(w.to(device))
        predictions.append(p.argmax(dim=1).cpu().numpy())


predictions = np.concatenate(predictions)
print(nmi(predictions,labels))

0.9558750264671436
