# Advanced SSL Techniques - FixMatch, FlexMatch, and MixMatch

Welcome to the advanced chapter of our SSL journey! We've explored pseudo-labeling and consistency regularization earlier. Now, let's dive into cutting-edge techniques: **FixMatch**, **FlexMatch**, and **MixMatch**. These methods combine the best of pseudo-labeling and consistency to tackle datasets with limited labels, like `DermaMNIST`.

> Think of this as upgrading your SSL toolkit with turbocharged algorithms!

**Core Principles:**
- **FixMatch**: Uses weak and strong augmentations with a confidence threshold for pseudo-labels.
- **FlexMatch**: Enhances FixMatch with dynamic thresholding per class, ideal for imbalanced data.
- **MixMatch**: Adds data mixing (e.g., MixUp) to improve robustness by blending labeled and unlabeled samples.

**Objectives:**
1. Return to `DermaMNIST` classification with 100 labeled images.
2. Implement FixMatch, FlexMatch, and MixMatch.
3. Compare results to baseline methods to showcase SSL advancements.

## 1. Preparation (The Usual Setup)

Let’s set up our environment for `DermaMNIST` classification. We’ll use 100 labeled images and leverage Albumentations for controlled augmentations.

In [13]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
from torchvision import transforms
import medmnist
from medmnist import INFO, Evaluator
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import albumentations as A
from albumentations.pytorch import ToTensorV2
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, ConfusionMatrixDisplay, roc_auc_score

In [2]:
# Load DermaMNIST data
data_flag = 'dermamnist'
info = INFO[data_flag]
n_classes = len(info['label'])
DataClass = getattr(medmnist, info['python_class'])

train_dataset = DataClass(split='train', download=True)
test_dataset = DataClass(split='test', transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[.5], std=[.5])]), download=True)

# Split into labeled (100) and unlabeled sets
all_indices = list(range(len(train_dataset)))
labels_array = np.array(train_dataset.labels).flatten()
labeled_indices, unlabeled_indices = train_test_split(all_indices, train_size=500, random_state=42, stratify=labels_array)

print(f'Labeled data: {len(labeled_indices)}, Unlabeled data: {len(unlabeled_indices)}')

Labeled data: 500, Unlabeled data: 6507


### 🧪 Weak and Strong Augmentations

We need two augmentation pipelines: weak for pseudo-label generation and strong for training robustness.

In [3]:

# Define weak and strong augmentations for single-channel images
weak_transform = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    ToTensorV2(transpose_mask=True)  # Preserve 1 channel
])

strong_transform = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.ShiftScaleRotate(scale_limit=0.1, rotate_limit=15, p=0.5),
    A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5),
    A.GaussianBlur(p=0.3),
    A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    ToTensorV2(transpose_mask=True)  # Preserve 1 channel
])
print("Transforms initialized")

# Custom datasets for FixMatch
class FixMatchDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, indices, transform):
        self.dataset = Subset(dataset, indices)
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        img, label = self.dataset[idx]
        img = np.array(img)  # Ensure img is [H, W] (single-channel)
        transformed = self.transform(image=img)
        return transformed['image'], torch.tensor(label).long()

class FixMatchUnlabeledDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, indices, weak_transform, strong_transform):
        self.dataset = Subset(dataset, indices)
        self.weak_transform = weak_transform
        self.strong_transform = strong_transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        img, _ = self.dataset[idx]
        img = np.array(img)  # Ensure img is [H, W] (single-channel)
        weak = self.weak_transform(image=img)['image']
        strong = self.strong_transform(image=img)['image']
        return weak, strong

Transforms initialized


  original_init(self, **validated_kwargs)


## 2. Models and Training Loops

We’ll use a simple CNN and implement three training loops: FixMatch, FlexMatch, and MixMatch.

In [4]:

# Define the SimpleCNN model for single-channel input
class SimpleCNN(nn.Module):
    def __init__(self, in_channels, num_classes):
        super(SimpleCNN, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(in_channels, 16, kernel_size=3, padding=1), nn.BatchNorm2d(16), nn.ReLU(), nn.MaxPool2d(2))
        self.layer2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=3, padding=1), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(2))
        self.fc = nn.Linear(7 * 7 * 32, num_classes)

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = out.reshape(out.size(0), -1)
        return self.fc(out)

# Initialize model, optimizer, and loss functions
model = SimpleCNN(in_channels=3, num_classes=n_classes)
optimizer = torch.optim.Adam(model.parameters(), lr=0.002)

supervised_criterion = nn.CrossEntropyLoss()
unsupervised_criterion = nn.CrossEntropyLoss(reduction='none')

### ⚙️ 2.1 FixMatch Training Loop

Let’s implement the FixMatch algorithm step-by-step.

**Instructions:**
1. Compute supervised loss on labeled data.
2. Generate pseudo-labels: Predict on weak augmentations, compute probabilities, and create a mask for confident predictions (threshold = 0.95).
3. Compute unsupervised loss: Predict on strong augmentations and apply the mask to confident pseudo-labels.
4. Combine losses and backpropagate.

In [5]:

# Create DataLoaders
labeled_dataset = FixMatchDataset(train_dataset, labeled_indices, strong_transform)
unlabeled_dataset = FixMatchUnlabeledDataset(train_dataset, unlabeled_indices, weak_transform, strong_transform)
print(f"Datasets created: labeled={len(labeled_dataset)}, unlabeled={len(unlabeled_dataset)}")

labeled_loader = DataLoader(labeled_dataset, batch_size=16, shuffle=True)
unlabeled_loader = DataLoader(unlabeled_dataset, batch_size=64, shuffle=True)
print(f"Dataloaders ready: batches labeled={len(labeled_loader)}, unlabeled={len(unlabeled_loader)}")

print("Starting training loop")
# FixMatch training as a function (to unify with other methods)
def train_fixmatch(model, labeled_loader, unlabeled_loader, epochs=30, threshold=0.95, unsupervised_weight=1.0):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.002)
    sup_crit = nn.CrossEntropyLoss()
    unsup_crit = nn.CrossEntropyLoss(reduction='none')
    for epoch in tqdm(range(epochs), desc='Training FixMatch'):
        model.train()
        batch_iterator = zip(labeled_loader, unlabeled_loader)
        for (labeled_imgs, labels), (weak_unlabeled, strong_unlabeled) in batch_iterator:
            optimizer.zero_grad()
            # Supervised loss
            logits_sup = model(labeled_imgs)
            loss_sup = sup_crit(logits_sup, labels.squeeze())
            # Pseudo-labels from weak
            with torch.no_grad():
                logits_weak = model(weak_unlabeled)
                probs = F.softmax(logits_weak, dim=1)
                max_probs, pseudo_labels = torch.max(probs, dim=1)
                mask = max_probs.ge(threshold).float()
            # Unsupervised on strong
            logits_strong = model(strong_unlabeled)
            loss_unsup_raw = unsup_crit(logits_strong, pseudo_labels)
            loss_unsup = (loss_unsup_raw * mask).mean()
            # Total
            total_loss = loss_sup + unsupervised_weight * loss_unsup
            total_loss.backward()
            optimizer.step()
    return model

Datasets created: labeled=500, unlabeled=6507
Dataloaders ready: batches labeled=32, unlabeled=102
Starting training loop


In [6]:

EPOCHS = 50
THRESHOLD = 0.95
UNSUPERVISED_WEIGHT = 1.0

print("Starting FixMatch training...")
fix_model = SimpleCNN(in_channels=3, num_classes=n_classes)
fix_model = train_fixmatch(fix_model, labeled_loader, unlabeled_loader, epochs=EPOCHS, threshold=THRESHOLD, unsupervised_weight=UNSUPERVISED_WEIGHT)


Starting FixMatch training...


Training FixMatch: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [03:33<00:00,  4.27s/it]


### ⚙️ 2.2 FlexMatch Training Loop

FlexMatch adapts the threshold dynamically per class to handle imbalanced datasets.

**Instructions:**
1. Compute supervised loss as before.
2. Generate pseudo-labels with a dynamic threshold: Use the mean maximum probability per class as the threshold.
3. Compute unsupervised loss with the dynamic mask.
4. Combine and backpropagate.

In [7]:

def train_flexmatch(model, labeled_loader, unlabeled_loader, epochs=20, threshold=0.95, unsupervised_weight=1.0):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.002)
    sup_crit = nn.CrossEntropyLoss()
    unsup_crit = nn.CrossEntropyLoss(reduction='none')
    ema_conf = torch.full((n_classes,), 0.7)
    ema_m = 0.9
    for epoch in tqdm(range(epochs), desc='Training FlexMatch'):
        model.train()
        for (labeled_imgs, labels), (weak_unlabeled, strong_unlabeled) in zip(labeled_loader, unlabeled_loader):
            optimizer.zero_grad()
            # Supervised
            logits_sup = model(labeled_imgs)
            loss_sup = sup_crit(logits_sup, labels.squeeze())
            # Weak preds
            with torch.no_grad():
                logits_weak = model(weak_unlabeled)
                probs = F.softmax(logits_weak, dim=1)
                max_probs, pseudo_labels = torch.max(probs, dim=1)
                # Update class-wise EMA confidence using samples of each predicted class
                for k in range(n_classes):
                    mask_k = (pseudo_labels == k)
                    if mask_k.any():
                        conf_k = max_probs[mask_k].mean()
                        ema_conf[k] = ema_m * ema_conf[k] + (1 - ema_m) * conf_k
                # Class-wise dynamic thresholds
                max_ema = torch.clamp(ema_conf.max(), min=1e-6)
                tau_k = threshold * (max_ema / torch.clamp(ema_conf, min=1e-6))
                eff_thresh = tau_k[pseudo_labels]
                mask = max_probs.ge(eff_thresh).float()
            # Unsupervised loss on strong views
            logits_strong = model(strong_unlabeled)
            loss_unsup_raw = unsup_crit(logits_strong, pseudo_labels)
            loss_unsup = (loss_unsup_raw * mask).mean()
            # Total
            total_loss = loss_sup + unsupervised_weight * loss_unsup
            total_loss.backward()
            optimizer.step()
    return model

In [8]:
# Train and evaluate FlexMatch
print("Starting FlexMatch training...")
flex_model = SimpleCNN(in_channels=3, num_classes=n_classes)
flex_model = train_flexmatch(flex_model, labeled_loader, unlabeled_loader, epochs=EPOCHS, threshold=THRESHOLD, unsupervised_weight=UNSUPERVISED_WEIGHT)

Starting FlexMatch training...


Training FlexMatch: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [03:48<00:00,  4.57s/it]


### ⚙️ 2.3 MixMatch Training Loop

MixMatch combines labeled and unlabeled data using MixUp and sharpening.

**Instructions:**
1. Compute supervised loss on labeled data.
2. Generate pseudo-labels with sharpening (soften probabilities with temperature).
3. Mix labeled and unlabeled data using MixUp.
4. Compute unsupervised loss on mixed data.
5. Combine and backpropagate.

In [9]:
def one_hot(labels, num_classes):
    y = torch.zeros(labels.size(0), num_classes, device=labels.device)
    return y.scatter_(1, labels.view(-1, 1).long(), 1)

def sharpen(p, T=0.5):
    p_power = p ** (1.0 / T)
    return p_power / p_power.sum(dim=1, keepdim=True)

def soft_cross_entropy(logits, soft_targets):
    log_probs = F.log_softmax(logits, dim=1)
    return -(soft_targets * log_probs).sum(dim=1)
    
def train_mixmatch(model, labeled_loader, unlabeled_loader, epochs=200, alpha=0.75, T=0.5, lambda_u=100.0):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.002)
    for epoch in tqdm(range(epochs), desc='Training MixMatch'):
        model.train()
        for (labeled_imgs, labels), (u_imgs_w, _) in zip(labeled_loader, unlabeled_loader):
            b_l = labeled_imgs.size(0)
            b_u = u_imgs_w.size(0)
            # Guess labels for unlabeled
            with torch.no_grad():
                logits_u = model(u_imgs_w)
                probs_u = F.softmax(logits_u, dim=1)
                q_u = sharpen(probs_u, T)
            # One-hot for labeled
            y_l = one_hot(labels.squeeze(), n_classes)
            # Concatenate
            X = torch.cat([labeled_imgs, u_imgs_w], dim=0)
            Y = torch.cat([y_l, q_u], dim=0)
            # MixUp
            idx = torch.randperm(X.size(0))
            lam = np.random.beta(alpha, alpha)
            lam = max(lam, 1 - lam)
            X_mixed = lam * X + (1 - lam) * X[idx]
            Y_mixed = lam * Y + (1 - lam) * Y[idx]
            # Forward
            logits = model(X_mixed)
            # Losses
            loss_sup = soft_cross_entropy(logits[:b_l], Y_mixed[:b_l]).mean()
            probs_mixed = F.softmax(logits[b_l:], dim=1)
            loss_unsup = F.mse_loss(probs_mixed, Y_mixed[b_l:])
            loss = loss_sup + lambda_u * loss_unsup
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    return model

In [10]:
# Train and evaluate MixMatch
print("Starting MixMatch training...")
mix_model = SimpleCNN(in_channels=3, num_classes=n_classes)
mix_model = train_mixmatch(mix_model, labeled_loader, unlabeled_loader, epochs=EPOCHS, alpha=0.75, T=0.5, lambda_u=50.0)


Starting MixMatch training...


Training MixMatch: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [02:50<00:00,  3.40s/it]


## 3. Final Evaluation and Retrospective

Let’s evaluate all models and compare their performance.

In [11]:
@torch.no_grad()
def evaluate_model(model, test_dataset, data_flag):
    model.eval()
    y_true = torch.tensor([])
    y_score_logits = torch.tensor([])
    y_score_preds = torch.tensor([])
    test_loader = DataLoader(test_dataset, batch_size=128)
    for images, labels in test_loader:
        outputs = model(images)
        y_true = torch.cat((y_true, labels), 0)
        y_score_logits = torch.cat((y_score_logits, outputs), 0)
        preds = torch.argmax(outputs, dim=1)
        y_score_preds = torch.cat((y_score_preds, preds), 0)
    y_true_np = y_true.squeeze().cpu().numpy()
    y_score_logits_np = y_score_logits.detach().cpu().numpy()
    y_score_preds_np = y_score_preds.detach().cpu().numpy()
    evaluator = Evaluator(data_flag, 'test')
    metrics = evaluator.evaluate(y_score_logits_np)
    f1_macro = f1_score(y_true_np, y_score_preds_np, average='macro')
    f1_weighted = f1_score(y_true_np, y_score_preds_np, average='weighted')
    return metrics[0], metrics[1], f1_macro, f1_weighted

In [14]:
# Consolidated Evaluation
print("Starting consolidated evaluation for FixMatch, FlexMatch, and MixMatch...")
results = []
for name, mdl in [("FixMatch", fix_model), ("FlexMatch", flex_model), ("MixMatch", mix_model)]:
    auc, acc, f1_macro, f1_weighted = evaluate_model(mdl, test_dataset, data_flag)
    results.append((name, auc, acc, f1_macro, f1_weighted))
    print(f"--- {name} Results ---")
    print(f"AUC: {auc:.3f}, Accuracy: {acc:.3f}, F1(macro): {f1_macro:.3f}, F1(weighted): {f1_weighted:.3f}")

Starting consolidated evaluation for FixMatch, FlexMatch, and MixMatch...
--- FixMatch Results ---
AUC: 0.814, Accuracy: 0.670, F1(macro): 0.321, F1(weighted): 0.641
--- FlexMatch Results ---
AUC: 0.819, Accuracy: 0.678, F1(macro): 0.350, F1(weighted): 0.667
--- MixMatch Results ---
AUC: 0.777, Accuracy: 0.672, F1(macro): 0.152, F1(weighted): 0.543


## 9. Bilan chiffré et cap pour la suite

Voici un récapitulatif des résultats obtenus dans ce notebook :

- **Supervisé (350 images étiquetées, modèle de base)**  
  AUC ≈ `0.824` | Accuracy ≈ `0.489` | F1 macro ≈ `0.234`

- **Pseudo‑Labeling (itératif, simple)**  
  Iter 1 → AUC ≈ `0.805`, Acc ≈ `0.547`, F1 ≈ `0.290`  
  Iter 2 → AUC ≈ `0.845`, Acc ≈ `0.586`, F1 ≈ `0.308`  
  Iter 3 → AUC ≈ `0.852`, Acc ≈ `0.585`, F1 ≈ `0.295`  
  Iter 4 → AUC ≈ `0.846`, Acc ≈ `0.598`, F1 ≈ `0.289`  
  Iter 5 → AUC ≈ `0.844`, Acc ≈ `0.605`, F1 ≈ `0.301`

- **Label Propagation (graphe sur embeddings du SimpleCNN)**  
  AUC ≈ `0.505` | Accuracy ≈ `0.367` | F1 macro ≈ `0.355`

- **SGAN (Semi‑Supervised GAN)**  
  AUC ≈ `0.832` | Accuracy ≈ `0.482` | F1 macro ≈ `0.297`

- **FixMatch / FlexMatch / MixMatch**  
  FixMatch → AUC ≈ `0.825`, Acc ≈ `0.675`, F1 (macro) ≈ `0.360`, F1 (weighted) ≈ `0.663`  
  FlexMatch → AUC ≈ `0.824`, Acc ≈ `0.678`, F1 (macro) ≈ `0.318`, F1 (weighted) ≈ `0.636`  
  MixMatch → AUC ≈ `0.793`, Acc ≈ `0.671`, F1 (macro) ≈ `0.149`, F1 (weighted) ≈ `0.540`

> Note : Mean Teacher a été utilisé pour de la segmentation dans un autre contexte, donc non comparé ici.

### Que retenir ici ?
- Dans ce contexte, la solution la plus simple — le **pseudo‑labeling** — fonctionne bien et offre déjà un gain net sur le supervisé seul.
- Les méthodes plus avancées (Fix/Flex/MixMatch, SGAN) montrent des **hausses d’accuracy** notables (≈ `0.67`), mais le **F1 macro** peut fluctuer selon la méthode et la sensibilité au déséquilibre des classes.
- La question clé reste le **rapport complexité/bénéfice** : la mise en place, le tuning et le temps de calcul supplémentaires valent‑ils le gain obtenu dans votre cas d’usage ?

### Si vous voulez pousser un cran plus loin
- Tenter des **embeddings plus expressifs** (ex. `ResNet` pré‑entraîné) et ré‑évaluer la propagation.
- Standardiser les embeddings et ajuster le graphe (`kernel`, `gamma`, `n_neighbors`).
- Tester une **stratégie hybride** : pseudo‑labels de haute confiance comme seeds du graphe, ou pré‑filtrage pour Fix/Flex/MixMatch.

Si votre priorité est un bon compromis efficacité/temps, rester sur le **pseudo‑labeling simple** est un choix solide. Si vous visez le dernier pourcent, les méthodes avancées peuvent valoir l’exploration — en gardant un œil sur la complexité et la stabilité des métriques (dont le F1 macro).