---
### SimCLR

##### 1. Overview

**SimCLR (Simple Contrastive Learning of Representations)** is a self-supervised learning framework that trains neural networks to learn meaningful representations **without labels** by solving a *contrastive prediction task*.

The key objective is:

> Learn embeddings in which two augmented views of the **same sample** are close together, while embeddings of **different samples** are far apart.

This is achieved using a contrastive loss known as **NT-Xent** (Normalized Temperature-scaled Cross-Entropy), which is a specific instantiation of the **InfoNCE** objective.



##### 2. Data Pipeline

For each data point $x \sim \mathcal{D}$:

1. Two stochastic augmentations are sampled:
   $$
   x_i = t_a(x), \qquad x_j = t_b(x).
   $$

2. Both views are passed through:
   - an **encoder** $ f_\theta $,
   - a **projection head** $ g_\phi $,

   yielding:
   $$
   h = f_\theta(x), \qquad
   z = g_\phi(h), \qquad
   \tilde z = \frac{z}{\|z\|_2}.
   $$

3. $\tilde z $ is the representation used for the contrastive loss.


##### 3. Positive and Negative Pairs

With a minibatch of size $ N $ original samples:

- We produce $2N $ augmented views.
- For each embedding $ \tilde z_i $:
  - the paired view $ \tilde z_j $ of the same data sample is the **positive example**,
  - the remaining $2N - 2 $ embeddings are treated as **negatives**.

No explicit negative sampling is required — negatives come from the batch.



##### 4. The NT-Xent (SimCLR) Loss

The similarity function is cosine similarity with temperature scaling:

$$
s(\tilde z_i, \tilde z_k)
=
\frac{\tilde z_i^\top \tilde z_k}{\tau},
$$

where $ \tau > 0 $ is the temperature hyperparameter.

For anchor  $i$ and its positive partner $ j $, the normalized temperature-scaled cross-entropy (**NT-Xent**) loss is:

$$
\ell(i,j)
=
-\log
\frac{
\exp\!\big(s(\tilde z_i,\tilde z_j)\big)
}{
\sum_{k=1}^{2N}
\mathbf{1}_{[k \neq i]}
\exp\!\big(s(\tilde z_i,\tilde z_k)\big)
}.
$$

**Properties**

- The **positive example is included in the denominator**.
- The denominator forms a softmax over all candidates except the trivial self-pair $ i $.
- Each embedding acts as an anchor once; the loss is **symmetrized**:

$$
\mathcal{L}_{\text{SimCLR}}
=
\frac{1}{2N}
\sum_{(i,j)}
\big( \ell(i,j) + \ell(j,i) \big).
$$

This is mathematically equivalent to performing a $2N-1$-way classification task:

Given anchor $ i $, predict which candidate  $k \neq i $ is its true positive.


In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import random

from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv
from torch_geometric.utils import dropout_edge


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

dataset = Planetoid(root="C:/Users/halac/Graph-SSL/data", name="Cora")
data = dataset[0].to(device)

num_nodes  = data.num_nodes
num_feats  = data.num_features
num_classes = dataset.num_classes


# Graph Augmentations (SimCLR views)


def random_edge_dropout(edge_index, p=0.2):
    edge_index, _ = dropout_edge(edge_index, p=p)
    return edge_index

def random_feature_mask(x, p=0.2):
    mask = torch.rand_like(x) > p
    return x * mask


def augment_graph(data, edge_p=0.2, feat_p=0.2):
    """
    Create a stochastic augmented view of the full graph.
    """
    x = random_feature_mask(data.x, p=feat_p)
    edge_index = random_edge_dropout(data.edge_index, p=edge_p)

    return x, edge_index



# GCN Encoder


class GCNEncoder(nn.Module):
    def __init__(self, in_dim, hidden_dim=128, out_dim=128):
        super().__init__()

        self.conv1 = GCNConv(in_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, out_dim)

    def forward(self, x, edge_index):
        h = self.conv1(x, edge_index)
        h = F.relu(h)
        h = self.conv2(h, edge_index)
        return h



# Projection Head (SimCLR)


class ProjectionHead(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(dim, dim),
            nn.ReLU(),
            nn.Linear(dim, dim)
        )

    def forward(self, h):
        return self.mlp(h)



# NT-Xent Contrastive Loss (SimCLR)


def nt_xent_loss(z1, z2, temperature=0.5):
    """
    Node-wise SimCLR loss.
    """
    z1 = F.normalize(z1, dim=1)
    z2 = F.normalize(z2, dim=1)

    N = z1.size(0)

    z = torch.cat([z1, z2], dim=0)     # (2N, d)
    sim = torch.matmul(z, z.T)         # cosine similarity

    # positive pairs along diagonal
    pos = torch.cat([
        torch.diag(sim, N),
        torch.diag(sim, -N)
    ], dim=0)

    neg_mask = ~torch.eye(2 * N, device=device, dtype=torch.bool)

    negatives = sim[neg_mask].view(2 * N, -1)

    logits = torch.cat([pos.unsqueeze(1), negatives], dim=1)
    logits /= temperature

    labels = torch.zeros(2 * N, dtype=torch.long, device=device)

    return F.cross_entropy(logits, labels)



# Model


encoder = GCNEncoder(num_feats).to(device)
projector = ProjectionHead(128).to(device)

optimizer = torch.optim.Adam(
    list(encoder.parameters()) + list(projector.parameters()),
    lr=1e-3,
    weight_decay=1e-4
)



# Training Loop (Graph SimCLR)

epochs = 300
temperature = 0.5

encoder.train()
projector.train()

for epoch in range(1, epochs + 1):

    # Two random augmented graph views
    x1, edge1 = augment_graph(data)
    x2, edge2 = augment_graph(data)

    # Encode
    h1 = encoder(x1, edge1)
    h2 = encoder(x2, edge2)

    # Project
    z1 = projector(h1)
    z2 = projector(h2)

    # Contrastive loss
    loss = nt_xent_loss(z1, z2, temperature)

    # Optimize
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if epoch % 20 == 0:
        print(f"Epoch {epoch:03d} | Loss = {loss.item():.4f}")




  from .autonotebook import tqdm as notebook_tqdm


Using device: cpu
Epoch 020 | Loss = 7.1274
Epoch 040 | Loss = 6.9606
Epoch 060 | Loss = 6.8973
Epoch 080 | Loss = 6.8680
Epoch 100 | Loss = 6.8393
Epoch 120 | Loss = 6.8237
Epoch 140 | Loss = 6.8112
Epoch 160 | Loss = 6.7968
Epoch 180 | Loss = 6.7886
Epoch 200 | Loss = 6.7779
Epoch 220 | Loss = 6.7767
Epoch 240 | Loss = 6.7713
Epoch 260 | Loss = 6.7627
Epoch 280 | Loss = 6.7559
Epoch 300 | Loss = 6.7564


In [12]:
encoder.eval()

with torch.no_grad():
    embeddings = encoder(data.x, data.edge_index)

# Train linear classifier on labeled train mask
clf = nn.Linear(128, num_classes).to(device)
optimizer = torch.optim.Adam(clf.parameters(), lr=0.01)

for _ in range(200):
    clf.train()
    optimizer.zero_grad()

    out = clf(embeddings[data.train_mask])
    loss = F.cross_entropy(out, data.y[data.train_mask])

    loss.backward()
    optimizer.step()

# Evaluation
clf.eval()
pred = clf(embeddings).argmax(dim=1)

acc = (
    pred[data.test_mask] ==
    data.y[data.test_mask]
).float().mean().item()

print(f"✅ Test accuracy: {acc:.4f}")


✅ Test accuracy: 0.7910


---

## MixCo
SimCLR is extended by introducing **mixup** and **soft positives**.

#### Mixup

For anchor $z_i$:

1. Keep **true positive** $z_j$.
2. Sample a random partner $z_q$.
3. Sample a coefficient $\lambda \sim \text{Beta}(\alpha,\alpha)$.
4. Create a mixed representation:

$z_{\text{mix}} = \lambda z_j + (1-\lambda) z_q.$



#### Soft targets

The anchor now has **two positives instead of one**, with weights (label smoothing?):

$y_{ik} =
\begin{cases}
\lambda & k = j \\
1 - \lambda & k = q \\
0 & \text{otherwise}
\end{cases}$

All other embeddings remain negatives.



#### Soft NT-Xent loss

Using the same probabilities:

$p_{ik} = \frac{\exp(s(z_i,z_k))}{\sum_{m \neq i} \exp(s(z_i,z_m))},$

the MixCo loss for anchor $i$ becomes:

$L_i^{\text{MixCo}}
= -\sum_k y_{ik}\log p_{ik}.$



#### Key differences

| Method | Positives per anchor | Target type | Effect |
|--------|------------------------|----------------|---------|
| **SimCLR** | 1 (paired view) | Hard (0 or 1) | Strong pull to one positive, strong repulsion to all others |
| **MixCo** | 2 (paired view + random mix partner) | Soft weights via $\lambda$ | Reduced false-negative repulsion, smoother gradients, improved stability |



#### About $\lambda$

- $\lambda$ is **not fixed**.
- For each mixed pair and each forward pass:

$\lambda \sim \text{Beta}(\alpha,\alpha)$

- $\alpha$ is the **hyperparameter** controlling the distribution:
  - Larger $\alpha$ → $\lambda$ concentrates near $0.5$.
  - Smaller $\alpha$ → more extreme values (near 0 or 1).

In practice $\alpha \approx 1$ or greater is used so that:

- Very small $\lambda$ values are rare.
- The true positive usually retains significant weight.

Occasional small $\lambda$ acts as **benign regularization**, not destructive noise.

In [13]:

# MixCo Loss (Graph MixCo)


def mixco_loss(z1, z2, temperature=0.5, alpha=1.0):
    """
    Node-wise MixCo loss.
    z1, z2: node embeddings from two augmented views
    """

    z1 = F.normalize(z1, dim=1)
    z2 = F.normalize(z2, dim=1)

    N, d = z1.shape

    
    # choose random mix partner for each node
    
    perm = torch.randperm(N, device=device)

    z2_perm = z2[perm]

    
    # sample mixing coefficient lambda per node
    
    beta = torch.distributions.Beta(alpha, alpha)
    lam = beta.sample((N,)).to(device)              # (N,)
    lam = lam.view(N, 1)

    
    # create mixed positive representations
    
    z_mix = lam * z2 + (1. - lam) * z2_perm          # (N,d)

    
    # Step 4: similarity scores against all candidates
    
    z_all = torch.cat([z2, z2_perm], dim=0)          # (2N,d)

    sim = torch.matmul(z_mix, z_all.T) / temperature  # (N,2N)

    
    # construct soft targets
    # targets[i, j] = lambda_i
    # targets[i, perm[i]+N] = (1-lambda_i)
    
    targets = torch.zeros(N, 2*N, device=device)

    ij = torch.arange(N, device=device)

    # true positives (z2)
    targets[ij, ij] = lam.squeeze()

    # mixed positives (permuted z2)
    targets[ij, perm + N] = (1. - lam).squeeze()

    # soft cross-entropy
    log_probs = F.log_softmax(sim, dim=1)

    loss = -(targets * log_probs).sum(dim=1).mean()

    return loss


In [16]:
encoder = GCNEncoder(num_feats).to(device)
projector = ProjectionHead(128).to(device)

optimizer = torch.optim.Adam(
    list(encoder.parameters()) + list(projector.parameters()),
    lr=1e-3,
    weight_decay=1e-4
)



# Training Loop (Graph MixCo)

epochs = 300
temperature = 0.2

encoder.train()
projector.train()

for epoch in range(1, epochs + 1):

    # Two random augmented graph views
    x1, edge1 = augment_graph(data)
    x2, edge2 = augment_graph(data)

    # Encode
    h1 = encoder(x1, edge1)
    h2 = encoder(x2, edge2)

    # Project
    z1 = projector(h1)
    z2 = projector(h2)

    # Contrastive loss
    loss = mixco_loss(z1, z2, temperature, alpha = 2.0)

    # Optimize
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if epoch % 20 == 0:
        print(f"Epoch {epoch:03d} | Loss = {loss.item():.4f}")




Epoch 020 | Loss = 7.5789
Epoch 040 | Loss = 7.3942
Epoch 060 | Loss = 7.3576
Epoch 080 | Loss = 7.3073
Epoch 100 | Loss = 7.2677
Epoch 120 | Loss = 7.3000
Epoch 140 | Loss = 7.2399
Epoch 160 | Loss = 7.3007
Epoch 180 | Loss = 7.2183
Epoch 200 | Loss = 7.2222
Epoch 220 | Loss = 7.2636
Epoch 240 | Loss = 7.2283
Epoch 260 | Loss = 7.2715
Epoch 280 | Loss = 7.2086
Epoch 300 | Loss = 7.2205


In [18]:
encoder.eval()

with torch.no_grad():
    embeddings = encoder(data.x, data.edge_index)

# Train linear classifier on labeled train mask
clf = nn.Linear(128, num_classes).to(device)
optimizer = torch.optim.Adam(clf.parameters(), lr=0.01)

for _ in range(200):
    clf.train()
    optimizer.zero_grad()

    out = clf(embeddings[data.train_mask])
    loss = F.cross_entropy(out, data.y[data.train_mask])

    loss.backward()
    optimizer.step()

# Evaluation
clf.eval()
pred = clf(embeddings).argmax(dim=1)

acc = (
    pred[data.test_mask] ==
    data.y[data.test_mask]
).float().mean().item()

print(f"✅ Test accuracy: {acc:.4f}")


✅ Test accuracy: 0.6930


### Problem Addressed by **MoCHi**

In standard contrastive learning methods such as SimCLR, each anchor embedding is compared against a large number of negative samples, and all negatives are treated equally in the loss. 

In practice, however, most of these negatives are **easy negatives**: they are already far away from the anchor in representation space and therefore contribute almost no gradient to the training objective. Only a small fraction of negatives are **hard negatives**; negatives actually teach the model how to separate similar-but-different samples and generate informative gradients that refine the decision boundary. Because hard negatives occur rarely by chance, contrastive learning becomes inefficient: it requires very **large** batch sizes or large memory queues to encounter enough useful negatives. 

**MoCHi addresses this inefficiency by actively generating hard negatives instead of relying on random sampling, ensuring that each training batch contains challenging and informative contrastive examples.**

### MoCHi: Algorithm and Updated Loss



#### Algorithm (per training step)

For each anchor embedding $z_i$:

1. **Compute similarities** to all negatives in the batch / memory bank:
   
   $s_{ik} = z_i^\top z_k.$

2. **Select hard negatives**:
   - Rank negatives by similarity.
   - Choose two from the highest-similarity set:
     
     $z_p, z_q \in \text{HardNeg}(i).$

3. **Mix hard negatives**:
   
   Sample $\lambda \sim \text{Beta}(\alpha,\alpha)$ and form

   $z_{\text{mix}} = \lambda z_p + (1-\lambda) z_q.$

4. **Add the mixed sample to the negative set** for anchor $z_i$.




#### Updated Contrastive Loss

The loss keeps the standard **InfoNCE / NT-Xent form**, but with an **augmented negative set**.

##### Similarity

$s_{ik} = \frac{z_i^\top z_k}{\tau}$



#### MoCHi loss for anchor $i$

$$
L_i^{\text{MoCHi}}
=
- \log
\frac{
\exp(s_{ij})
}{
\exp(s_{ij})
+
\sum_{k \in \mathcal N_i}
\exp(s_{ik})
+
\exp(s_{i,\text{mix}})
}.
$$

Where:

- $j$ is the true positive for anchor $i$,
- $\mathcal N_i$ is the set of normal negatives,
- $z_{\text{mix}}$ is the **synthetic hard negative**.



#### What changed vs SimCLR

| Element | SimCLR | MoCHi |
|--------|---------|-------|
| Positives | 1 hard positive | Same |
| Negatives | Random batch negatives | Random + **synthetic hard negatives** |
| Mixing | Not used | **Mix hard negatives** |
| Loss type | NT-Xent / InfoNCE | **Same formula, larger denominator** |
| Main benefit | Needs huge batches | **Creates hard negatives explicitly** |



In [4]:

# MoCHi Loss (Hard Negative Mixing for Graph SimCLR)


def mochi_loss(z1, z2, temperature=0.5, alpha=1.0):
    """
    Node-wise MoCHi loss.
    For each anchor z1[i], we:
        - use z2[i] as the positive
        - find the two hardest negatives among z2[j!=i]
        - mix them to create a synthetic negative
        - add it to the NT-Xent denominator
    """
    
    z1 = F.normalize(z1, dim=1)
    z2 = F.normalize(z2, dim=1)

    N, d = z1.size()

    # similarity matrix between anchors and candidate negatives
    sim = torch.matmul(z1, z2.T)   # (N, N)

    beta = torch.distributions.Beta(alpha, alpha)

    total_loss = 0.0

    for i in range(N):

        # Positive
        pos_sim = sim[i, i]

        # remove true positive from negatives
        neg_sims = sim[i].clone()
        neg_sims[i] = -float("inf")

        # Select 2 hardest negatives
        hard_idx = torch.topk(neg_sims, k=2).indices

        z_p = z2[hard_idx[0]]
        z_q = z2[hard_idx[1]]

        # Mix the two negatives
        lam = beta.sample().to(device)

        z_mix = lam * z_p + (1.0 - lam) * z_q
        z_mix = F.normalize(z_mix, dim=0)

        mix_sim = torch.dot(z1[i], z_mix)

        # Build denominator
        denom = torch.exp(pos_sim / temperature)

        # original batch negatives
        denom = denom + torch.sum(
            torch.exp(neg_sims / temperature)
        )

        # add synthetic negative
        denom = denom + torch.exp(mix_sim / temperature)


        # InfoNCE loss
        total_loss += -torch.log(
            torch.exp(pos_sim / temperature) / denom
        )

    return total_loss / N


In [5]:
encoder = GCNEncoder(num_feats).to(device)
projector = ProjectionHead(128).to(device)

optimizer = torch.optim.Adam(
    list(encoder.parameters()) + list(projector.parameters()),
    lr=1e-3,
    weight_decay=1e-4
)



# Training Loop (Graph SimCLR)

epochs = 300
temperature = 0.5

encoder.train()
projector.train()

for epoch in range(1, epochs + 1):

    # Two random augmented graph views
    x1, edge1 = augment_graph(data)
    x2, edge2 = augment_graph(data)

    # Encode
    h1 = encoder(x1, edge1)
    h2 = encoder(x2, edge2)

    # Project
    z1 = projector(h1)
    z2 = projector(h2)

    # Contrastive loss
    loss = mochi_loss(z1, z2, temperature)

    # Optimize
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if epoch % 20 == 0:
        print(f"Epoch {epoch:03d} | Loss = {loss.item():.4f}")




Epoch 020 | Loss = 6.4316
Epoch 040 | Loss = 6.2627
Epoch 060 | Loss = 6.2109
Epoch 080 | Loss = 6.1763
Epoch 100 | Loss = 6.1510
Epoch 120 | Loss = 6.1375
Epoch 140 | Loss = 6.1235
Epoch 160 | Loss = 6.1103
Epoch 180 | Loss = 6.1012


KeyboardInterrupt: 

In [6]:
encoder.eval()

with torch.no_grad():
    embeddings = encoder(data.x, data.edge_index)

# Train linear classifier on labeled train mask
clf = nn.Linear(128, num_classes).to(device)
optimizer = torch.optim.Adam(clf.parameters(), lr=0.01)

for _ in range(200):
    clf.train()
    optimizer.zero_grad()

    out = clf(embeddings[data.train_mask])
    loss = F.cross_entropy(out, data.y[data.train_mask])

    loss.backward()
    optimizer.step()

# Evaluation
clf.eval()
pred = clf(embeddings).argmax(dim=1)

acc = (
    pred[data.test_mask] ==
    data.y[data.test_mask]
).float().mean().item()

print(f"✅ Test accuracy: {acc:.4f}")


✅ Test accuracy: 0.7840
