# MixMatch implementation
Unofficial pytorch implementation of MixMatch [[paper]](https://arxiv.org/pdf/1905.02249.pdf) on CIFAR-10.

## Initialisation

In [1]:
import numpy as np
import os.path as osp
import torch

from metric_utils.metrics import Metrics

from mlu.utils.misc import reset_seed, get_datetime

from mlu.datasets.wrappers import NoLabelDataset, OneHotDataset, ZipDataset
from mlu.datasets.utils import split_dataset
from mlu.metrics import CategoricalAccuracy
from mlu.nn import CrossEntropyWithVectors
from mlu.utils.printers import ColumnPrinter
from mlu.utils.zip_cycle import ZipCycle

from sslh.models.wideresnet28 import WideResNet28
from sslh.mixmatch.warmup import WarmUp
from sslh.utils.torch import collapse_first_dimension

from torch import Tensor
from torch.optim import Adam
from torch.utils.data.dataloader import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision.datasets import CIFAR10
from torchvision.transforms import RandomHorizontalFlip, RandomChoice, Compose, ToTensor, Normalize, RandomCrop
from typing import Dict

In [2]:
reset_seed(1234)

lambda_u = 1.0
temperature = 0.5
alpha = 0.75
nb_epochs = 200
nb_augms = 2  # "K" in paper
lr = 1e-3

dataset_root = osp.join("..", "datasets")
tensorboard_root = osp.join("..", "results", "tensorboard")

device = torch.device("cuda")

# Same as CrossEntropy but accept non-"onehot encoding" vectors as targets (labels)
criterion_s = CrossEntropyWithVectors()
criterion_u = CrossEntropyWithVectors()

In [3]:
# Build WideResNet-28-2 model
model = WideResNet28(num_classes=10, width=2).to(device)
activation = torch.softmax
optim = Adam(model.parameters(), lr=lr)

# Build metrics. You can add a metric in this dictionary.
metrics_train_s = {"acc_s_mix": CategoricalAccuracy(dim=1)}
metrics_train_u = {"acc_u_mix": CategoricalAccuracy(dim=1)}
metrics_val = {"acc": CategoricalAccuracy(dim=1)}

# Tensorboard writer
writer = SummaryWriter(osp.join(tensorboard_root, "CIFAR10_%s_WideResNet28_MixMatch_Notebook" % get_datetime()))

# Class for managing how the values are print in terminal
printer = ColumnPrinter()

# Linearly increase a value from 0 to lambda_u during 16000 steps.
warmup = WarmUp(nb_steps=16000, max_value=lambda_u)

In [4]:
def get_loaders() -> (DataLoader, DataLoader):
	augm_weak = RandomChoice([
		RandomHorizontalFlip(0.5),
		RandomCrop((32, 32), padding=8),
	])
	# Add postprocessing after each augmentation (shape : [32, 32, 3] -> [3, 32, 32])
	post_process_fn = Compose([
		ToTensor(),
		Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
	])

	transform_augm_weak = Compose([
		augm_weak,
		post_process_fn,
	])

	# Create train dataloader
	dataset_train_augm_weak = CIFAR10(dataset_root, train=True, download=True, transform=transform_augm_weak)

	# Use 4000 data with labels (8%) and 46000 data without labels (92%)
	dataset_train_s_augm_weak, dataset_train_u_augm_weak = \
		split_dataset(dataset_train_augm_weak, nb_classes=10, ratios=[0.08, 0.92], target_one_hot=False)

	dataset_train_s_augm_weak = OneHotDataset(dataset_train_s_augm_weak, nb_classes=10)
	dataset_train_u_augm_weak = NoLabelDataset(dataset_train_u_augm_weak)

	# Duplicate augmentations calls
	dataset_train_u_augm_weaks = ZipDataset([dataset_train_u_augm_weak] * nb_augms)

	# Build loaders for supervised data and unsupervised data
	loader_train_s_augm = DataLoader(
		dataset=dataset_train_s_augm_weak, batch_size=64, shuffle=True, num_workers=2, drop_last=True)
	loader_train_u_augms = DataLoader(
		dataset=dataset_train_u_augm_weaks, batch_size=64, shuffle=True, num_workers=6, drop_last=True)

	loader_train = ZipCycle([loader_train_s_augm, loader_train_u_augms])

	# Create validation dataloader
	dataset_val = CIFAR10(dataset_root, train=False, download=True, transform=post_process_fn)
	dataset_val = OneHotDataset(dataset_val, nb_classes=10)

	loader_val = DataLoader(dataset=dataset_val, batch_size=64, shuffle=False, drop_last=False)
	return loader_train, loader_val

loader_train, loader_val = get_loaders()

Files already downloaded and verified
Files already downloaded and verified


## Prepare functions for training

In [5]:
def criterion(pred_s: Tensor, pred_u: Tensor, labels_s: Tensor, labels_u: Tensor, lambda_u: float) -> (Tensor, Tensor, Tensor):
	loss_s = criterion_s(pred_s, labels_s).mean()
	loss_u = criterion_u(pred_u, labels_u).mean()
	loss = loss_s + lambda_u * loss_u
	return loss, loss_s, loss_u

In [6]:
def guess_label(batch_u_multiple: Tensor, temperature: float, dim: int) -> Tensor:
	nb_augms = batch_u_multiple.shape[0]
	preds_u = [torch.zeros(0) for _ in range(nb_augms)]
	for k in range(nb_augms):
		logits_u = model(batch_u_multiple[k])
		preds_u[k] = activation(logits_u, dim=dim)
	preds_u = torch.stack(preds_u)
	labels_u = preds_u.mean(dim=0)

	labels_u = sharpen(labels_u, temperature, dim=1)
	return labels_u

In [7]:
def sharpen(pred: Tensor, temperature: float, dim: int = 1) -> Tensor:
	pred = pred ** (1.0 / temperature)
	pred = pred / pred.norm(p=1, dim=dim, keepdim=True)
	return pred

In [8]:
def mixup(batch_1: Tensor, labels_1: Tensor, batch_2: Tensor, labels_2: Tensor) -> (Tensor, Tensor):
	lambda_ = np.random.beta(alpha, alpha)
	lambda_ = max(lambda_, 1.0 - lambda_)
	batch_mixed = batch_1 * lambda_ + batch_2 * (1.0 - lambda_)
	labels_mixed = labels_1 * lambda_ + labels_2 * (1.0 - lambda_)

	return batch_mixed, labels_mixed

In [9]:
def mixmatch(
	batch_s_augm: Tensor, batch_u_augm_multiple: Tensor, labels_s: Tensor, labels_u: Tensor
) -> (Tensor, Tensor, Tensor, Tensor):
	# Duplicate label u for all the augmented versions of batch u
	nb_augms = batch_u_augm_multiple.shape[0]
	repeated_size = [nb_augms] + [1] * (len(labels_u.size()) - 1)
	labels_u = labels_u.repeat(repeated_size)
	batch_u_augm = collapse_first_dimension(batch_u_augm_multiple)

	# Concatenate s and u in w
	batch_w = torch.cat((batch_s_augm, batch_u_augm))
	labels_w = torch.cat((labels_s, labels_u))

	# Shuffle batch and labels with the same order
	indices = torch.randperm(batch_w.shape[0])
	batch_w, labels_w = batch_w[indices], labels_w[indices]

	# Apply mixup with (s, w[:len(s)]) and with (u, w[len(s):])
	len_s = len(batch_s_augm)
	batch_s_mix, labels_s_mix = mixup(batch_s_augm, labels_s, batch_w[:len_s], labels_w[:len_s])
	batch_u_mix, labels_u_mix = mixup(batch_u_augm, labels_u, batch_w[len_s:], labels_w[len_s:])

	return batch_s_mix, batch_u_mix, labels_s_mix, labels_u_mix

In [10]:
def reset_metrics(metrics: Dict[str, Metrics]):
	for name, metric in metrics.items():
		metric.reset()

## Training loop

In [11]:
def train(epoch: int):
	reset_metrics(metrics_train_s)
	reset_metrics(metrics_train_u)
	metrics_train_names = list(metrics_train_s.keys()) + list(metrics_train_u.keys()) + ["loss", "loss_s", "loss_u"]
	current_means = {"train/%s" % name: 0.0 for name in metrics_train_names}

	model.train()
	printer.print_header("train", metrics_train_names)

	for i, ((batch_s_augm, labels_s), batch_u_augm_multiple) in enumerate(loader_train):
		batch_s_augm = batch_s_augm.to(device).float()
		labels_s = labels_s.to(device).float()
		batch_u_augm_multiple = torch.stack(batch_u_augm_multiple).to(device).float()

		with torch.no_grad():
			# Guess label
			labels_u = guess_label(batch_u_augm_multiple, temperature, dim=1)

			# Apply mix
			batch_s_mix, batch_u_mix, labels_s_mix, labels_u_mix = mixmatch(
				batch_s_augm, batch_u_augm_multiple, labels_s, labels_u
			)

		# Compute logits
		logits_s_mix = model(batch_s_mix)
		logits_u_mix = model(batch_u_mix)

		pred_s_mix = activation(logits_s_mix, dim=1)
		pred_u_mix = activation(logits_u_mix, dim=1)

		# Update model
		loss, loss_s, loss_u = criterion(pred_s_mix, pred_u_mix, labels_s_mix, labels_u_mix, warmup.get_value())
		optim.zero_grad()
		loss.backward()
		optim.step()

		# Compute metrics
		with torch.no_grad():
			for name, metric in metrics_train_s.items():
				current_means["train/%s" % name] = metric(pred_s_mix, labels_s_mix)

			for name, metric in metrics_train_u.items():
				current_means["train/%s" % name] = metric(pred_u_mix, labels_u_mix)

			current_means.update({
				"train/loss": loss.item(),
				"train/loss_s": loss_s.item(),
				"train/loss_u": loss_u.item(),
			})

			printer.print_current_values(current_means, i, len(loader_train), epoch)
			warmup.step()

	for name, mean_ in current_means.items():
		writer.add_scalar(name, mean_, epoch)

	writer.add_scalar("hparams/lambda_u", warmup.get_value(), epoch)

## Validation loop

In [12]:
def val(epoch: int):
	reset_metrics(metrics_val)
	current_means = {"val/%s" % name: 0.0 for name in metrics_val.keys()}

	model.eval()
	printer.print_header("val", metrics_val.keys(), False)

	for i, (x, y) in enumerate(loader_val):
		x = x.to(device).float()
		y = y.to(device).float()

		# Compute logits
		logits = model(x)
		pred = activation(logits, dim=1)

		for name, metric in metrics_val.items():
			current_means["val/{:s}".format(name)] = metric(pred, y)

		printer.print_current_values(current_means, i, len(loader_val), epoch)

	for name, mean_ in current_means.items():
		writer.add_scalar(name, mean_, epoch)

## Start learning

In [13]:
for epoch in range(nb_epochs):
	train(epoch)
	with torch.no_grad():
		val(epoch)

writer.close()


-      train       -   acc_s    -   acc_u    -    loss    -   loss_s   -   loss_u   -  took (s)  -
- Epoch   1 - 100% - 4.1326e-01 - 7.0674e-01 - 1.2355e+00 - 1.1637e+00 - 1.6034e+00 -   30.18    -
-       val        -    acc     -  took (s)  -
- Epoch   1 - 100% - 5.7514e-01 -    2.13    -

-      train       -   acc_s    -   acc_u    -    loss    -   loss_s   -   loss_u   -  took (s)  -
- Epoch   2 - 100% - 5.7118e-01 - 7.3028e-01 - 1.8668e+00 - 1.7084e+00 - 1.7656e+00 -   30.86    -
-       val        -    acc     -  took (s)  -
- Epoch   2 - 100% - 6.4480e-01 -    2.06    -

-      train       -   acc_s    -   acc_u    -    loss    -   loss_s   -   loss_u   -  took (s)  -
- Epoch   3 - 100% - 6.4909e-01 - 7.4182e-01 - 1.9540e+00 - 1.8438e+00 - 8.1886e-01 -   29.62    -
-       val        -    acc     -  took (s)  -
- Epoch   3 - 100% - 6.6809e-01 -    2.09    -

-      train       -   acc_s    -   acc_u    -    loss    -   loss_s   -   loss_u   -  took (s)  -
- Epoch   4 - 1