# Contrastive Learning

### Dependencies

In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
from tqdm import tqdm
import matplotlib.pyplot as plt
from dataclasses import dataclass

from sklearn.manifold import TSNE
from sklearn.neighbors import KNeighborsClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report
from sklearn.datasets import make_classification, fetch_20newsgroups

import torch
from torch import nn
import torch.nn.functional as F
from sentence_transformers import SentenceTransformer

### Data 

In [None]:
SEED = 42

@dataclass
class DatasetConfig:
    n_samples: int = 5000
    n_features: int = 32
    n_classes: int = 8
    n_clusters_per_class: int = 2
    n_informative: int = 5
    random_state: int = SEED

@dataclass
class SplitConfig:
    random_state: int = SEED
    test_size: float = 0.25

In [None]:
X, y = make_classification(**DatasetConfig().__dict__)

X_train, X_test, y_train, y_test = train_test_split(X, y, **SplitConfig().__dict__)
X_train, y_train = torch.from_numpy(X_train).float(), torch.from_numpy(y_train).float()

### Visialization

In [None]:

def plot_tsne(data, labels, **kwargs):
    tsne = TSNE(n_components=2, random_state=42, **kwargs)
    if isinstance(data, torch.Tensor):
        data = data.detach().to("cpu").numpy()
    decomposed_data = tsne.fit_transform(data)

    plt.figure(figsize=(8, 6))
    if isinstance(labels, torch.Tensor):
        labels = labels.detach().to("cpu").numpy()
    classes = np.unique(labels)
    for cls in classes:
        idx = labels == cls
        plt.scatter(decomposed_data[idx, 0], decomposed_data[idx, 1], label=str(cls.item()), alpha=0.7)
    plt.legend(title="Label")
    plt.title("t-SNE visualization")
    plt.xlabel("Dim 1")
    plt.ylabel("Dim 2")
    plt.tight_layout()
    plt.show()

plot_tsne(X_train, y_train)

### Base Classifier

In [None]:
# X_train, y_train = X_train.numpy(), y_train.numpy()
# X_test, y_test = X_test.numpy(), y_test.numpy()

In [None]:
knn = KNeighborsClassifier(n_jobs=-1, n_neighbors=5)
knn.fit(X=X_train, y=y_train)

train_pred = knn.predict(X_train)
test_pred = knn.predict(X_test)

accuracy_score(y_train, train_pred), accuracy_score(y_test, test_pred)

###  TripletLoss

В случае L2-нормилизованных векторов:

$$\max\left(0,\|f(x)-f(x^+)\|^2_2-\|f(x)-f(x^-)\|^2_2+\varepsilon\right)=\max\left(0,f(x)f^T(x^-)-f(x)f^T(x^+)+\varepsilon\right)$$


Пусть $D=\{x_i, y_i\}_i$ -- выборка классификации. Пусть $S=XX^T$.  
Такие $j\neq i:\;y_i=y_j$ -- диагональ $S$ вырезается. При этом по матрице $S$ можно сформировать два непересекающихся множества: позитивов $P$ и негативов $N$ (как оставшихся пар, где $y_i\neq y_j$).  
Пусть $L$ -- минимальная мощность этих двух множеств. Возьмем $\hat{P}=\{p_i\}_i$, $\hat{N}=\{n_i\}_i$ как сэмплы размера $L$ из $P$ и $N$ соответственно. Тогда итоговая функция ошибки выглядит так:

$$\mathcal{L}=\frac{1}{L}\sum\limits_{i=1}^L\max\left(0, n_i-p_i+\varepsilon\right)$$

В такой постановке в паре позитивов и негативов не обязательно должен быть один и тот же якорный элемент.

In [None]:
class TripletLoss:
    def __init__(self, margin, random_state=None):
        self.margin = margin
        self.random_state = random_state

    def __call__(self, x, labels):
        x = F.normalize(input=x, dim=1, p=2)

        n = x.size(0)
        S = (x @ x.T)

        mask_diag = torch.eye(n, dtype=torch.bool)
        labels = labels.view(-1)

        labels = labels.to("cpu")

        mask_pos = (labels.unsqueeze(1) == labels.unsqueeze(0)) & (~mask_diag)
        mask_neg = (labels.unsqueeze(1) != labels.unsqueeze(0))

        pos_indices = mask_pos.nonzero(as_tuple=False)
        neg_indices = mask_neg.nonzero(as_tuple=False)

        L = min(pos_indices.size(0), neg_indices.size(0))
        if L == 0:
            return torch.tensor(0.0)

        pos_perm = np.random.default_rng(self.random_state).choice(pos_indices.size(0), size=L, replace=False)
        neg_perm = np.random.default_rng(self.random_state).choice(neg_indices.size(0), size=L, replace=False)
        pos_samples = pos_indices[pos_perm]
        neg_samples = neg_indices[neg_perm]

        p_vals = S[pos_samples[:,0], pos_samples[:,1]]
        n_vals = S[neg_samples[:,0], neg_samples[:,1]]

        loss = (n_vals - p_vals + self.margin).clamp(min=0).mean()
        return loss

In [None]:
criterion = TripletLoss(0.2, random_state=101)

### Simple Model

In [None]:
class MLP(nn.Module):
    def __init__(
        self,
        in_dim,
        out_dim,
        hidden_dim=512,
        num_layers=2,
        dropout=None
    ):
        super().__init__()
        layers = []
        last_dim = in_dim
        for _ in range(num_layers):
            layers.append(nn.Linear(last_dim, hidden_dim))
            layers.append(nn.ReLU())
            if dropout is not None:
                layers.append(nn.Dropout(dropout))
            last_dim = hidden_dim
        self.emb = nn.Sequential(*layers)
        self.classifier = nn.Linear(hidden_dim, out_dim)

    def forward(self, x):
        feats = self.emb(x)
        logits = self.classifier(feats)

        return logits

In [None]:
def predict(model, sample, batch_size, device):
    if isinstance(sample, np.ndarray):
        sample = torch.from_numpy(sample).float()
    sample = sample.to('cpu')
    n = sample.size(0)
    preds = []

    model.eval()
    with torch.no_grad():
        for i in range(0, n, 64):
            X_batch = sample[i:i+64].to(device)
            embs = model(X_batch).detach().to("cpu")
            preds.append(embs)
    return torch.cat(preds)

#### Functions

In [None]:
def domain_adaptation(
    model, X, y, X_test, y_test,
    epochs=10,
    batch_size=64,
    lr=1e-3,
    margin=0.2,
    triplet_weight=0.2,
    device='cpu',
    random_state=None,
):
    model = model.to(device)
    X = X.to(device)
    y = y.to(device)

    if isinstance(X_test, np.ndarray):
        X_test = torch.from_numpy(X_test).float()
    X_test = X_test.to(device)

    criterion_triplet = TripletLoss(margin, random_state)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    train_size = X.size(0)

    rng = np.random.default_rng(random_state)
    for epoch in tqdm(range(epochs), total=epochs, desc='Обучение'):
        model.train()
        training_indices = rng.choice(train_size, size=train_size, replace=False)
        losses_through_epoch = []
        for i in range(0, train_size, batch_size):
            batch_indices = training_indices[i:i+batch_size]
            X_batch = X[batch_indices]
            y_batch = y[batch_indices].long()

            optimizer.zero_grad()
            logits = model(X_batch)

            loss_triplet = criterion_triplet(logits, y_batch)
            loss_triplet.backward()

            optimizer.step()
            losses_through_epoch.append(loss_triplet.item())

    return model

#### Training

In [None]:
if not isinstance(X_train, torch.Tensor):
    X_train = torch.from_numpy(X_train).float()
if not isinstance(y_train, torch.Tensor):
    y_train = torch.from_numpy(y_train).float()

In [None]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

mlp = MLP(in_dim=X_train.shape[1], out_dim=64, hidden_dim=128, num_layers=2, dropout=.2)
mlp = domain_adaptation(
    mlp, X_train, y_train, X_test, y_test,
    epochs=40, batch_size=128, lr=1e-3, device=device, random_state=101
)

#### New visualization

In [None]:
train_scored = predict(mlp, X_train, 64, device)
test_scored = predict(mlp, X_test, 64, device)
plot_tsne(train_scored, y_train)

In [None]:
plot_tsne(test_scored, y_test)

In [None]:
knn = KNeighborsClassifier(n_jobs=-1, n_neighbors=5)
knn.fit(X=train_scored, y=y_train)

train_pred = knn.predict(train_scored)
test_pred = knn.predict(test_scored)

accuracy_score(y_train, train_pred), accuracy_score(y_test, test_pred)

### Real dataset

In [None]:
categories = [
    "sci.space",
    "sci.med",
    "sci.electronics",
    "comp.os.ms-windows.misc",
    "comp.sys.ibm.pc.hardware",
    "comp.sys.mac.hardware"
]

newsgroups_train = fetch_20newsgroups(subset="train", categories=categories)
newsgroups_test = fetch_20newsgroups(subset="test", categories=categories)

X_train = newsgroups_train.data
X_test = newsgroups_test.data

y_train = newsgroups_train.target
y_test = newsgroups_test.target

In [None]:
def test_logreg(X_train_mapped, y_train, X_test_mapped, y_test, target_names=newsgroups_test.target_names):
    clf = LogisticRegression(max_iter=10000)
    clf.fit(X_train_mapped, y_train)

    y_pred = clf.predict(X_test_mapped)
    score = accuracy_score(y_test, y_pred)
    report = classification_report(y_test, y_pred, target_names=target_names)

    print(f"Accuracy: {score:.3f}")
    print(f"Classification Report: {report}")

#### Real model

In [None]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
st_model = SentenceTransformer(
    "sentence-transformers/all-MiniLM-L6-v2",
    device=device
)

st_model.eval()

preds = []
with torch.no_grad():
    for i in tqdm(range(0, len(X_train), 64)):
        X_batch = X_train[i:i+64]
        embs = st_model.encode(X_batch, device='cpu') # Надо было с return_tensors="pt"
        preds.append(embs)
        torch.cuda.empty_cache()
train_embeds = torch.cat([torch.from_numpy(pred) for pred in preds])

preds = []
with torch.no_grad():
    for i in tqdm(range(0, len(X_test), 64)):
        X_batch = X_test[i:i+64]
        embs = st_model.encode(X_batch, device='cpu')
        preds.append(embs)
        torch.cuda.empty_cache()
test_embeds = torch.cat([torch.from_numpy(pred) for pred in preds])

In [None]:
plot_tsne(train_embeds, y_train)

In [None]:
plot_tsne(test_embeds, y_test)

In [None]:
knn = KNeighborsClassifier(n_jobs=-1, n_neighbors=5)
knn.fit(X=train_embeds, y=y_train)

train_pred = knn.predict(train_embeds)
test_pred = knn.predict(test_embeds)

accuracy_score(y_train, train_pred), accuracy_score(y_test, test_pred)

#### Embeddings adaptation

In [None]:
y_train = torch.from_numpy(y_train)
y_test = torch.from_numpy(y_test)

In [None]:
mlp = MLP(in_dim=train_embeds.shape[1], out_dim=64, hidden_dim=128, num_layers=2, dropout=.2)
mlp = domain_adaptation(
    mlp, train_embeds, y_train, test_embeds, y_test,
    epochs=40, batch_size=128, lr=1e-3, device=device, random_state=101
)

#### New experiment metrics

In [None]:
train_scored = predict(mlp, train_embeds, 64, device)
test_scored = predict(mlp, test_embeds, 64, device)
plot_tsne(train_scored, y_train)

In [None]:
plot_tsne(test_scored, y_test)

In [None]:
knn = KNeighborsClassifier(n_jobs=-1, n_neighbors=5)
knn.fit(X=train_scored, y=y_train)

train_pred = knn.predict(train_scored)
test_pred = knn.predict(test_scored)

accuracy_score(y_train, train_pred), accuracy_score(y_test, test_pred)