# FixMatch implementation
Unofficial pytorch implementation of FixMatch [[paper]](https://arxiv.org/pdf/2001.07685.pdf) on CIFAR-10.

## Initialisation

In [None]:
import math
import numpy as np
import os.path as osp
import torch

from metric_utils.metrics import Metrics

from sslh.augments.rand_augment import RandAugment
from sslh.datasets.utils import split_classes_idx, get_classes_idx, shuffle_classes_idx
from sslh.datasets.wrappers.multiple_dataset import MultipleDataset
from sslh.datasets.wrappers.onehot_dataset import OneHotDataset
from sslh.datasets.wrappers.no_label_dataset import NoLabelDataset
from sslh.models.wrn28_2 import WideResNet28
from sslh.utils.display import ColumnDisplay
from sslh.utils.misc import reset_seed, get_datetime
from sslh.utils.other_metrics import CategoricalAccuracyOnehot
from sslh.utils.recorder.recorder import Recorder
from sslh.utils.torch import CrossEntropyWithVectors, get_lr, get_reduction_from_name

from torch import Tensor
from torch.nn import Module
from torch.nn.functional import one_hot
from torch.optim import Adam
from torch.optim.lr_scheduler import LambdaLR
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader, Subset
from torch.utils.tensorboard import SummaryWriter
from torchvision.datasets import CIFAR10
from torchvision.transforms import RandomHorizontalFlip, RandomChoice, Compose, ToTensor, Normalize, RandomCrop
from typing import Callable, Dict, Iterable, Optional, Sized

In [2]:
reset_seed(1234)

# Hyperparameters
nb_epochs = 200
lambda_u = 1.0
lr = 1e-3
threshold = 0.95
bsize = 128
mu = 7  # = bsize_u / bsize_s, must be in range [1..bsize[

dataset_root = osp.join("..", "dataset")
tensorboard_root = osp.join("..", "results", "tensorboard")

device = torch.device("cuda")

# Same as CrossEntropy but accept non-"onehot encoding" vectors as targets (labels)
criterion_s = CrossEntropyWithVectors(reduction="none")
criterion_u = CrossEntropyWithVectors(reduction="none")

### Learning rate Scheduler

In [None]:
class CosineLRScheduler(LambdaLR):
	"""
		Scheduler that decreases the learning rate from lr0 to almost 0 by using the following rule :
		lr = lr0 * cos(7 * pi * epoch / (16 * nb_epochs))
	"""
	def __init__(self, optim: Optimizer, nb_epochs: int):
		lr_lambda = lambda p_epoch: math.cos(7.0 * math.pi * p_epoch / (16.0 * nb_epochs))
		super().__init__(optim, lr_lambda)

### Build objects

In [None]:
# Build WideResNet-28-2 model
model = WideResNet28(num_classes=10, width=2).to(device)
activation = torch.softmax
optim = Adam(model.parameters(), lr=lr)

# Build metrics. You can add a metric in this dictionary.
metrics_s = {"acc_s": CategoricalAccuracyOnehot(dim=1)}
metrics_u = {"acc_u": CategoricalAccuracyOnehot(dim=1)}
metrics_val = {"acc": CategoricalAccuracyOnehot(dim=1)}

# Tensorboard writer and the Recorder wrapper for tracking max, std & min of the values stored.
writer = SummaryWriter(osp.join(tensorboard_root, "CIFAR10_%s_WideResNet28_FixMatch" % get_datetime()))
recorder = Recorder(writer)

# Class for managing how the values are print in terminal
display = ColumnDisplay()

sched = CosineLRScheduler(optim, nb_epochs=nb_epochs)

## Build Dataloaders

### ZipCycle class

In [None]:
class ZipCycle(Iterable, Sized):
	"""
		Zip through a list of iterables and sized objects of different lengths.
		Reset the iterators when there and finish iteration when the longest one is over.

		Example :
		r1 = range(1, 4)
		r2 = range(1, 6)
		iters = ZipCycle([r1, r2])
		for v1, v2 in iters:
			print(v1, v2)

		will print :
		1 1
		2 2
		3 3
		1 4
		2 5
	"""

	def __init__(self, iterables: list):
		for iterable in iterables:
			if len(iterable) == 0:
				raise RuntimeError("An iterable is empty.")

		self._iterables = iterables
		self._len = max([len(iterable) for iterable in self._iterables])

	def __iter__(self) -> list:
		cur_iters = [iter(iterable) for iterable in self._iterables]
		cur_count = [0 for _ in self._iterables]

		for _ in range(len(self)):
			items = []

			for i, _ in enumerate(cur_iters):
				if cur_count[i] < len(self._iterables[i]):
					item = next(cur_iters[i])
					cur_count[i] += 1
				else:
					cur_iters[i] = iter(self._iterables[i])
					item = next(cur_iters[i])
					cur_count[i] = 1
				items.append(item)

			yield items

	def __len__(self) -> int:
		return self._len

In [4]:
def get_loaders() -> (DataLoader, DataLoader):
	augm_weak = Compose([
		RandomChoice([
			RandomHorizontalFlip(0.5),
			RandomCrop((32, 32), padding=8),
		]),
		lambda img: np.array(img),
	])
	augm_strong = Compose([
		lambda img: np.array(img),
		RandAugment(
			ratio=1.0,
			magnitude_m=2.0,
			nb_choices_n=1
		)
	])

	# Add postprocessing after each augmentation (shape : [32, 32, 3] -> [3, 32, 32])
	post_process_fn = Compose([
		ToTensor(),
		Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
	])

	transform_augm_weak = Compose([
		augm_weak,
		post_process_fn,
	])
	transform_augm_strong = Compose([
		augm_strong,
		post_process_fn,
	])

	dataset_train_augm_weak = CIFAR10(dataset_root, train=True, download=True, transform=transform_augm_weak)
	dataset_train_augm_strong = CIFAR10(dataset_root, train=True, download=True, transform=transform_augm_strong)

	# Use 4000 data with labels (8%) and 46000 data without labels (92%)
	cls_idx_all = get_classes_idx(dataset_train_augm_weak, 10, is_one_hot=False)
	cls_idx_all = shuffle_classes_idx(cls_idx_all)
	idx_split = split_classes_idx(cls_idx_all, ratios=[0.08, 0.92])
	idx_s, idx_u = idx_split

	dataset_train_augm_weak_s = Subset(dataset_train_augm_weak, idx_s)
	dataset_train_augm_weak_u = Subset(dataset_train_augm_weak, idx_u)
	dataset_train_augm_strong_u = Subset(dataset_train_augm_strong, idx_u)

	dataset_train_augm_weak_s = OneHotDataset(dataset_train_augm_weak_s, nb_classes=10)
	dataset_train_augm_weak_u = NoLabelDataset(dataset_train_augm_weak_u)
	dataset_train_augm_strong_u = NoLabelDataset(dataset_train_augm_strong_u)

	dataset_train_augms_weak_strong_u = MultipleDataset([dataset_train_augm_weak_u, dataset_train_augm_strong_u])

	# Build loaders for supervised data and unsupervised data
	bsize_s = int(bsize / (mu + 1))
	bsize_u = int(bsize * mu / (mu + 1))
	loader_train_s_augm = DataLoader(
		dataset=dataset_train_augm_weak_s, batch_size=bsize_s, shuffle=True, num_workers=1, drop_last=True)
	loader_train_u_augms = DataLoader(
		dataset=dataset_train_augms_weak_strong_u, batch_size=bsize_u, shuffle=True, num_workers=mu, drop_last=True)

	loader_train = ZipCycle([loader_train_s_augm, loader_train_u_augms])

	# Create validation dataloader
	dataset_val = CIFAR10(dataset_root, train=False, download=True, transform=post_process_fn)
	dataset_val = OneHotDataset(dataset_val, nb_classes=10)

	loader_val = DataLoader(dataset=dataset_val, batch_size=bsize, shuffle=False, drop_last=False)
	return loader_train, loader_val

loader_train, loader_val = get_loaders()

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


  img = torch.from_numpy(pic.transpose((2, 0, 1))).contiguous()


## Criterion

### Cross Entropy with probabilities

In [None]:
class CrossEntropyWithVectors(Module):
	"""
		Compute Cross-Entropy between two distributions.
		Input and targets must be a batch of probabilities distributions of shape (batch_size, nb_classes) tensor.
	"""
	def __init__(self, reduction: str = "batchmean", dim: Optional[int] = 1, log_input: bool = False):
		super().__init__()
		self.reduce_fn = get_reduction_from_name(reduction)
		self.dim = dim
		self.log_input = log_input

	def forward(self, input_: Tensor, targets: Tensor, dim: Optional[int] = None) -> Tensor:
		"""
			Compute cross-entropy with targets.
			Input and target must be a (batch_size, nb_classes) tensor.
		"""
		if dim is None:
			dim = self.dim
		if not self.log_input:
			input_ = torch.log(input_)
		loss = -torch.sum(input_ * targets, dim=dim)
		return self.reduce_fn(loss)

### FixMatch loss

In [5]:
class FixMatchLoss(Module):
	"""
		FixMatch loss module.

		Loss formula : loss = CE(pred_s, label_s) + lambda_u * mask * CE(pred_u, label_u)

		The mask used is 1 if the confidence prediction on weakly augmented data is above a specific threshold.
	"""

	def __init__(
		self,
		criterion_s: Callable = CrossEntropyWithVectors(reduction="none"),
		criterion_u: Callable = CrossEntropyWithVectors(reduction="none"),
		reduction: str = "batchmean",
	):
		"""
			:param criterion_s: The criterion used for labeled loss component.
			:param criterion_u: The criterion used for unlabeled loss component. No reduction must be applied.
			:param reduction: The main reduction to use. Can be 'none', 'mean', 'batchmean' or 'sum'.
		"""
		super().__init__()
		self.criterion_s = criterion_s
		self.criterion_u = criterion_u
		self.reduce_fn = get_reduction_from_name(reduction)

	def forward(
		self,
		pred_s_augm_weak: Tensor,
		pred_u_augm_strong: Tensor,
		mask: Tensor,
		labels_s: Tensor,
		labels_u: Tensor,
		lambda_s: float = 1.0,
		lambda_u: float = 1.0,
	) -> (Tensor, Tensor, Tensor):
		"""
			Compute FixMatch loss.

			Generic :
				loss = lambda_s * mean(criterion_s(pred_s, labels_s)) + lambda_u * mean(criterion_u(pred_u, labels_u) * mask)

			:param pred_s_augm_weak: Output of the model for labeled batch s of shape (batch_size, nb_classes).
			:param pred_u_augm_strong: Output of the model for unlabeled batch u of shape (batch_size, nb_classes).
			:param mask: Binary confidence mask used to avoid using low-confidence labels as targets of shape (batch_size).
			:param labels_s: True label of labeled batch s of shape (batch_size, nb_classes).
			:param labels_u: Guessed label of unlabeled batch u of shape (batch_size, nb_classes).
			:param lambda_s: Coefficient used to multiply the supervised loss component.
			:param lambda_u: Coefficient used to multiply the unsupervised loss component.
		"""
		loss_s = self.criterion_s(pred_s_augm_weak, labels_s)

		loss_u = self.criterion_u(pred_u_augm_strong, labels_u)
		loss_u *= mask

		loss_s = self.reduce_fn(loss_s)
		loss_u = self.reduce_fn(loss_u)

		loss = lambda_s * loss_s + lambda_u * loss_u

		return loss, loss_s, loss_u

criterion = FixMatchLoss(criterion_s, criterion_u)

## Training

In [6]:
def guess_label(batch_u_augm_weak: Tensor) -> (Tensor, Tensor):
	logits_u_augm_weak = model(batch_u_augm_weak)
	pred_u_augm_weak = activation(logits_u_augm_weak, dim=1)

	nb_classes = pred_u_augm_weak.shape[1]
	labels_u = one_hot(pred_u_augm_weak.argmax(dim=1), nb_classes)
	return labels_u, pred_u_augm_weak

In [7]:
def confidence_mask(pred_weak: Tensor, threshold: float, dim: int) -> Tensor:
	max_values, _ = pred_weak.max(dim=dim)
	return (max_values > threshold).float()

In [8]:
def reset_metrics(metrics: Dict[str, Metrics]):
	for name, metric in metrics.items():
		metric.reset()

In [9]:
def train(epoch: int):
	model.train()
	reset_metrics(metrics_s)
	reset_metrics(metrics_u)

	recorder.start_record(epoch)
	keys = list(metrics_s.keys()) + list(metrics_u.keys()) + ["loss", "loss_s", "loss_u", "labels_used"]
	display.print_header("train", keys)

	iter_loader = iter(loader_train)

	for i, ((batch_s_augm_weak, labels_s), (batch_u_augm_weak, batch_u_augm_strong)) in enumerate(iter_loader):
		batch_s_augm_weak = batch_s_augm_weak.to(device).float()
		labels_s = labels_s.to(device).float()
		batch_u_augm_weak = batch_u_augm_weak.to(device).float()
		batch_u_augm_strong = batch_u_augm_strong.to(device).float()

		# Guess label with prediction of weakly augment of u
		with torch.no_grad():
			labels_u, pred_u_augm_weak = guess_label(batch_u_augm_weak)
			mask = confidence_mask(pred_u_augm_weak, threshold, dim=1)

		optim.zero_grad()

		# Compute predictions
		logits_s_augm_weak = model(batch_s_augm_weak)
		logits_u_augm_strong = model(batch_u_augm_strong)

		pred_s_augm_weak = activation(logits_s_augm_weak, dim=1)
		pred_u_augm_strong = activation(logits_u_augm_strong, dim=1)

		# Update model
		loss, loss_s, loss_u = criterion(
			pred_s_augm_weak,
			pred_u_augm_strong,
			mask,
			labels_s,
			labels_u,
			lambda_u=lambda_u
		)

		loss.backward()
		optim.step()

		# Compute metrics
		with torch.no_grad():
			recorder.add_point("train/loss", loss.item())
			recorder.add_point("train/loss_s", loss_s.item())
			recorder.add_point("train/loss_u", loss_u.item())

			proportion_labels_used = mask.sum() / mask.shape[0]
			recorder.add_point("train/labels_used", proportion_labels_used.item())

			for metric_name, metric in metrics_s.items():
				_mean = metric(pred_s_augm_weak, labels_s)
				recorder.add_point("train/{:s}".format(metric_name), metric.value.item())

			for metric_name, metric in metrics_u.items():
				_mean = metric(pred_u_augm_strong, labels_u)
				recorder.add_point("train/{:s}".format(metric_name), metric.value.item())

			display.print_current_values(recorder.get_current_means(), i, len(loader_train), epoch)

	recorder.add_point("train/lr", get_lr(optim))
	recorder.end_record(epoch)

In [10]:
def val(epoch: int):
	model.eval()
	reset_metrics(metrics_val)
	current_means = {"val/%s" % name: 0.0 for name in metrics_val.keys()}

	display.print_header("val", metrics_val.keys(), False)

	for i, (x, y) in enumerate(loader_val):
		x = x.to(device).float()
		y = y.to(device).float()

		# Compute logits
		logits = model(x)
		pred = activation(logits, dim=1)

		for name, metric in metrics_val.items():
			current_means["val/{:s}".format(name)] = metric(pred, y)

		display.print_current_values(current_means, i, len(loader_val), epoch)

	for name, mean_ in current_means.items():
		writer.add_scalar(name, mean_, epoch)

## Start learning

In [11]:
for epoch in range(nb_epochs):
	train(epoch)
	with torch.no_grad():
		val(epoch)
	sched.step()

writer.close()


-      train       -   acc_s    -   acc_u    - labels_use -    loss    -   loss_s   -   loss_u   -  took (s)  -
- Epoch   1 - 100% - 2.8384e-01 - 5.8691e-01 - 1.5244e-04 - 1.9456e+00 - 1.9455e+00 - 4.2729e-05 -   13.70    -
-       val        -    acc     -  took (s)  -
- Epoch   1 - 100% - 3.7866e-01 -    1.77    -

-      train       -   acc_s    -   acc_u    - labels_use -    loss    -   loss_s   -   loss_u   -  took (s)  -
- Epoch   2 - 100% - 3.8095e-01 - 6.1433e-01 - 1.2848e-03 - 1.6743e+00 - 1.6740e+00 - 2.7921e-04 -   13.26    -
-       val        -    acc     -  took (s)  -
- Epoch   2 - 100% - 4.3750e-01 -    1.79    -

-      train       -   acc_s    -   acc_u    - labels_use -    loss    -   loss_s   -   loss_u   -  took (s)  -
- Epoch   3 - 100% - 4.3628e-01 - 6.2267e-01 - 8.6455e-03 - 1.5454e+00 - 1.5428e+00 - 2.5869e-03 -   13.51    -
-       val        -    acc     -  took (s)  -
- Epoch   3 - 100% - 4.9773e-01 -    1.78    -

-      train       -   acc_s    -   