# FixMatch implementation
An unofficial pytorch implementation of FixMatch [[paper]](https://arxiv.org/pdf/2001.07685.pdf) on CIFAR-10.

## Initialisation

In [1]:
import math
import os.path as osp
import torch

from argparse import Namespace

from augments import (
	RandAugment,
	RAND_AUGMENT_POOL_2,
	CutOutImgPIL,
)
from utils import (
	reset_seed,
	get_datetime,
	get_lr,
	generate_indexes,
	ZipDataset,
	Metric,
	IncrementalMean,
	ColumnPrinter,
	MaxTracker,
)

from torch import Tensor, nn
from torch.nn import Module
from torch.nn.functional import one_hot
from torch.optim import SGD
from torch.optim.lr_scheduler import LambdaLR
from torch.optim.optimizer import Optimizer
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Subset
from torch.utils.tensorboard import SummaryWriter
from torchvision.datasets import CIFAR10
from torchvision.transforms import RandomHorizontalFlip, Compose, ToTensor, RandomCrop, Normalize

from typing import Callable, Iterable, Optional, Sized

In [2]:
args = Namespace()

# Hyperparameters
args.nb_epochs = 1000
args.lambda_u = 1.0
args.threshold = 0.95
args.bsize = 64
args.mu = 7
args.nb_labels = 4000  # 4000 of 50000 examples = 8% of labeled data
args.lr = 0.03
args.seed = 1234
args.num_classes = 10
args.loader_policy = "max"
args.sched_coef = 7.0 / 16.0

# SGD parameters
args.weight_decay = 0.0005
args.momentum = 0.9  # called "beta" in paper
args.nesterov = True

# RandAugment
args.nb_augm_apply = 1

reset_seed(args.seed)

dataset_root = osp.join("..", "datasets")
tensorboard_root = osp.join("..", "results", "tensorboard")
device = torch.device("cuda")

In [3]:
print(args)

Namespace(bsize=64, lambda_u=1.0, loader_policy='max', lr=0.03, momentum=0.9, mu=7, nb_augm_apply=1, nb_epochs=1000, nb_labels=4000, nesterov=True, num_classes=10, sched_coef=0.4375, seed=1234, threshold=0.95, weight_decay=0.0005)


### Model

In [4]:
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
	"""3x3 convolution with padding"""
	return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
					 padding=dilation, groups=groups, bias=False, dilation=dilation)


def conv1x1(in_planes, out_planes, stride=1):
	"""1x1 convolution"""
	return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class BasicBlock(nn.Module):
	expansion = 1

	def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
				 base_width=64, dilation=1, norm_layer=None):
		super(BasicBlock, self).__init__()

		# Both self.conv1 and self.downsample layers downsample the input when stride != 1
		self.conv1 = conv3x3(inplanes, planes, stride)
		self.bn1 = norm_layer(planes)
		self.relu = nn.ReLU(inplace=True)
		self.conv2 = conv3x3(planes, planes)
		self.bn2 = norm_layer(planes)
		self.downsample = downsample
		self.stride = stride

		self.expansion = 2

	def forward(self, x):
		identity = x

		out = self.conv1(x)
		out = self.bn1(out)
		out = self.relu(out)

		out = self.conv2(out)
		out = self.bn2(out)

		if self.downsample is not None:
			identity = self.downsample(x)

		out += identity
		out = self.relu(out)

		return out


class ResNet(Module):
	def __init__(self, layers, width: int = 2, num_classes=10, zero_init_residual=False,
				 groups=1, width_per_group=16, replace_stride_with_dilation=None,
				 norm_layer=None):
		Module.__init__(self)

		if norm_layer is None:
			norm_layer = nn.BatchNorm2d
		self._norm_layer = norm_layer

		block = BasicBlock
		self.inplanes = 16*width
		self.dilation = 1
		if replace_stride_with_dilation is None:
			# each element in the tuple indicates if we should replace
			# the 2x2 stride with a dilated convolution instead
			replace_stride_with_dilation = [False, False, False]
		if len(replace_stride_with_dilation) != 3:
			raise ValueError("replace_stride_with_dilation should be None "
							 "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
		self.groups = groups
		self.base_width = width_per_group
		self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
		self.bn1 = norm_layer(self.inplanes)
		self.relu = nn.ReLU(inplace=True)
		self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
		self.layer1 = self._make_layer(block, 16*width, layers[0])
		self.layer2 = self._make_layer(block, 32*width, layers[1], stride=2,
									   dilate=replace_stride_with_dilation[0])
		self.layer3 = self._make_layer(block, 64*width, layers[2], stride=2,
									   dilate=replace_stride_with_dilation[1])

		self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
		self.fc = nn.Linear(64 * width * block.expansion, num_classes)

		for m in self.modules():
			if isinstance(m, nn.Conv2d):
				nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

			elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
				nn.init.constant_(m.weight, 1)
				nn.init.constant_(m.bias, 0)

		# Zero-initialize the last BN in each residual branch,
		# so that the residual branch starts with zeros, and each residual block behaves like an identity.
		# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
		if zero_init_residual:
			for m in self.modules():
				if isinstance(m, BasicBlock):
					nn.init.constant_(m.bn2.weight, 0)

	def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
		norm_layer = self._norm_layer
		downsample = None
		previous_dilation = self.dilation
		if dilate:
			self.dilation *= stride
			stride = 1
		if stride != 1 or self.inplanes != planes * block.expansion:
			downsample = nn.Sequential(
				conv1x1(self.inplanes, planes * block.expansion, stride),
				norm_layer(planes * block.expansion),
			)

		layers = []
		layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
							self.base_width, previous_dilation, norm_layer))
		self.inplanes = planes * block.expansion
		for _ in range(1, blocks):
			layers.append(block(self.inplanes, planes, groups=self.groups,
								base_width=self.base_width, dilation=self.dilation,
								norm_layer=norm_layer))

		return nn.Sequential(*layers)

	def _forward_impl(self, x):
		# See note [TorchScript super()]
		x = self.conv1(x)
		x = self.bn1(x)
		x = self.relu(x)
		x = self.maxpool(x)

		x = self.layer1(x)
		x = self.layer2(x)
		x = self.layer3(x)

		x = self.avgpool(x)
		x = torch.flatten(x, 1)
		x = self.fc(x)

		return x

	def forward(self, x):
		return self._forward_impl(x)


class WideResNet28(ResNet):
	def __init__(self, num_classes: int, width: int = 2):
		super().__init__(layers=[4, 4, 4], width=width, num_classes=num_classes)

### Learning rate Scheduler

In [5]:
class CosineLRScheduler(LambdaLR):
	"""
		Scheduler that decreases the learning rate from lr0 to almost 0 by using the following rule :
		lr = lr0 * cos(7 * pi * step / (16 * nb_steps))
	"""
	def __init__(self, optim: Optimizer, nb_steps: int, coef: float = 7.0 / 16.0):
		"""
			:param optim: The optimizer to update.
			:param nb_steps: The number of step() call. Can be the number of epochs or iteration.
			:param coef: The coefficient in [0, 0.5] for controlling the decrease cosine rate.
				If closer to 0.5, the final lr will be close to 0.0
		"""
		lr_lambda = lambda step: math.cos(math.pi * coef * step / nb_steps)
		super().__init__(optim, lr_lambda)

In [6]:
class CategoricalAccuracy(Metric):
	"""
		Compute the categorical accuracy between a batch of prediction and labels.
	"""
	def __init__(self, vector_input: bool = True, vector_target: bool = True, dim: int = 1):
		super().__init__()
		self.vector_input = vector_input
		self.vector_target = vector_target
		self.dim = dim

	def compute_score(self, input_: Tensor, target: Tensor) -> Tensor:
		if self.vector_input:
			input_ = input_.argmax(dim=self.dim)
		if self.vector_target:
			target = target.argmax(dim=self.dim)

		assert input_.shape == target.shape, "Input and target must have the same shape."
		score = input_.eq(target).float().mean()
		return score

### Build models, optimizer, metrics and utilities

In [7]:
# Build WideResNet-28-2 model
model = WideResNet28(num_classes=args.num_classes, width=2).to(device)
activation = torch.softmax
optimizer = SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, momentum=args.momentum, nesterov=args.nesterov)
scheduler = CosineLRScheduler(optimizer, nb_steps=args.nb_epochs, coef=args.sched_coef)

# Build metrics for labeled, unlabeled and validation predictions.
metrics_s = {"train/acc_s": CategoricalAccuracy()}
metrics_u = {"train/acc_u": CategoricalAccuracy()}
metrics_u_true = {"train/acc_u_true": CategoricalAccuracy()}
metrics_val = {"val/acc": CategoricalAccuracy()}

# Tensorboard writer and the Recorder wrapper for tracking max, std & min of the values stored.
writer = SummaryWriter(osp.join(tensorboard_root, "CIFAR10_%s_WideResNet28_FixMatch_Notebook" % get_datetime()))
writer.add_hparams(args.__dict__, {})

# Class for managing how the values are print in terminal
printer = ColumnPrinter()

## Data preparation

### Augmentations

In [8]:
transform_post_process = Compose([
	ToTensor(),
	Normalize(
        mean=(0.4914009, 0.48215896, 0.4465308),
        std=(0.24703279, 0.24348423, 0.26158753)
    ),
])

transform_train_augm_weak = Compose([
	RandomHorizontalFlip(0.5),
	RandomCrop((32, 32), padding=8),
	transform_post_process,
])

transform_train_augm_strong = Compose([
	RandAugment(
		nb_augm_apply=args.nb_augm_apply,
		magnitude_policy="random",
		augm_pool=RAND_AUGMENT_POOL_2,
	),
	CutOutImgPIL(scales=(0.5, 0.5)),
	transform_post_process,
])

transform_val = transform_post_process

target_transform = lambda x: one_hot(torch.as_tensor(x), args.num_classes).numpy()

### Builds datasets

In [9]:
dataset_train_augm_weak = CIFAR10(
	dataset_root, train=True, download=True, transform=transform_train_augm_weak, target_transform=target_transform)
dataset_train_augm_strong = CIFAR10(
	dataset_root, train=True, download=True, transform=transform_train_augm_strong, target_transform=target_transform)

# Use 4000 data with labels (8%) and 46000 data without labels (92%)
supervised_ratio = args.nb_labels / len(dataset_train_augm_weak)
indexes_s, indexes_u = generate_indexes(
	dataset_train_augm_weak,
	nb_classes=args.num_classes,
	ratios=[supervised_ratio, 1.0 - supervised_ratio],
	target_one_hot=True,
)

dataset_train_augm_weak_s = Subset(dataset_train_augm_weak, indexes_s)
dataset_train_augm_weak_u = Subset(dataset_train_augm_weak, indexes_u)
dataset_train_augm_strong_u = Subset(dataset_train_augm_strong, indexes_u)

dataset_train_augms_weak_strong_u = ZipDataset([dataset_train_augm_weak_u, dataset_train_augm_strong_u])

# Create validation dataset
dataset_val = CIFAR10(
	dataset_root, train=False, download=True, transform=transform_val, target_transform=target_transform)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


### ZipCycle class
Used for iterate on labeled and unlabeled dataloaders at the same time.

In [10]:
class ZipCycle(Iterable, Sized):
	def __init__(self, iterables: list, policy: str = "max"):
		"""
			Zip through a list of iterables and sized objects of different lengths.
			Reset the iterators when there and finish iteration when the longest one is over.

			Example :
				r1 = range(1, 4)
				r2 = range(1, 6)
				cycle = ZipCycle([r1, r2])
				for v1, v2 in cycle:
					print("(", v1, v2, ")")

			will print :
				( 1 1 )
				( 2 2 )
				( 3 3 )
				( 1 4 )
				( 2 5 )

			:param iterables: A list of Sized Iterables to browse.
			:param policy: The policy to use during iteration.
				If policy = "max", the output will stop when the last iterable is finished. (like in the example above)
				If policy = "min", the class will stop when the first iterable is finished. (like in the built-in "zip" python)
		"""
		assert policy in ["min", "max"]
		lens = [len(iterable) for iterable in iterables]
		for len_ in lens:
			if len_ == 0:
				raise RuntimeError("An iterable is empty.")

		self._iterables = iterables
		self._len = max(lens) if policy == "max" else min(lens)
		self._policy = policy

	def __iter__(self) -> list:
		cur_iters = [iter(iterable) for iterable in self._iterables]
		cur_count = [0 for _ in self._iterables]

		for _ in range(len(self)):
			items = []

			for i, _ in enumerate(cur_iters):
				if cur_count[i] < len(self._iterables[i]):
					item = next(cur_iters[i])
					cur_count[i] += 1
				else:
					cur_iters[i] = iter(self._iterables[i])
					item = next(cur_iters[i])
					cur_count[i] = 1
				items.append(item)

			yield items

	def __len__(self) -> int:
		return self._len

	def set_policy(self, policy: str):
		assert policy in ["min", "max"]
		lens = [len(iterable) for iterable in self._iterables]
		self._len = max(lens) if policy == "max" else min(lens)
		self._policy = policy

### Build loaders

In [11]:
bsize_s = args.bsize
bsize_u = args.bsize * args.mu

loader_train_s_augm = DataLoader(
	dataset=dataset_train_augm_weak_s, batch_size=bsize_s, shuffle=True, num_workers=1, drop_last=False)

loader_train_u_augms = DataLoader(
	dataset=dataset_train_augms_weak_strong_u, batch_size=bsize_u, shuffle=True, num_workers=args.mu, drop_last=False)

loader_train = ZipCycle([loader_train_s_augm, loader_train_u_augms], policy=args.loader_policy)

loader_val = DataLoader(dataset=dataset_val, batch_size=bsize_s, shuffle=False, drop_last=False)


## Criterion

### Cross Entropy with probabilities

Same as CrossEntropy but accept non-"onehot encoding" vectors as targets (labels)

In [12]:
class CrossEntropyWithVectors(Module):
	"""
		Compute Cross-Entropy between two distributions.
		Input and targets must be a batch of probabilities distributions of shape (batch_size, nb_classes) tensor.
	"""
	def __init__(self, reduction: str = "mean", dim: Optional[int] = 1, log_input: bool = False):
		super().__init__()
		self.reduce_fn = torch.mean if reduction == "mean" else lambda x: x
		self.dim = dim
		self.log_input = log_input

	def forward(self, input_: Tensor, targets: Tensor, dim: Optional[int] = None) -> Tensor:
		"""
			Compute cross-entropy with targets.
			Input and target must be a (batch_size, nb_classes) tensor.
		"""
		if dim is None:
			dim = self.dim
		if not self.log_input:
			input_ = torch.log(input_)
		loss = -torch.sum(input_ * targets, dim=dim)
		return self.reduce_fn(loss)

### FixMatch loss

In [13]:
class FixMatchLoss(Module):
	"""
		FixMatch loss module.

		Loss formula : loss = CE(pred_s, label_s) + lambda_u * mask * CE(pred_u, label_u)

		The mask used is 1 if the confidence prediction on weakly augmented data is above a specific threshold.
	"""

	def __init__(
		self,
		criterion_s: Callable = CrossEntropyWithVectors(reduction="none"),
		criterion_u: Callable = CrossEntropyWithVectors(reduction="none"),
	):
		"""
			:param criterion_s: The criterion used for labeled loss component.
			:param criterion_u: The criterion used for unlabeled loss component. No reduction must be applied.
		"""
		super().__init__()
		self.criterion_s = criterion_s
		self.criterion_u = criterion_u
		self.reduce_fn = torch.mean

	def forward(
		self,
		pred_s_augm_weak: Tensor,
		pred_u_augm_strong: Tensor,
		mask: Tensor,
		labels_s: Tensor,
		labels_u: Tensor,
		lambda_s: float = 1.0,
		lambda_u: float = 1.0,
	) -> (Tensor, Tensor, Tensor):
		"""
			Compute FixMatch loss.

			Generic :
				loss = lambda_s * mean(criterion_s(pred_s, labels_s)) + lambda_u * mean(criterion_u(pred_u, labels_u) * mask)

			:param pred_s_augm_weak: Output of the model for labeled batch s of shape (batch_size, nb_classes).
			:param pred_u_augm_strong: Output of the model for unlabeled batch u of shape (batch_size, nb_classes).
			:param mask: Binary confidence mask used to avoid using low-confidence labels as targets of shape (batch_size).
			:param labels_s: True label of labeled batch s of shape (batch_size, nb_classes).
			:param labels_u: Guessed label of unlabeled batch u of shape (batch_size, nb_classes).
			:param lambda_s: Coefficient used to multiply the supervised loss component.
			:param lambda_u: Coefficient used to multiply the unsupervised loss component.
		"""
		loss_s = self.criterion_s(pred_s_augm_weak, labels_s)

		loss_u = self.criterion_u(pred_u_augm_strong, labels_u)
		loss_u *= mask

		loss_s = self.reduce_fn(loss_s)
		loss_u = self.reduce_fn(loss_u)

		loss = lambda_s * loss_s + lambda_u * loss_u

		return loss, loss_s, loss_u

In [14]:
criterion = FixMatchLoss(
	criterion_s=CrossEntropyWithVectors(reduction="none"),
	criterion_u=CrossEntropyWithVectors(reduction="none")
)

## Training

In [15]:
def guess_label(batch_u_augm_weak: Tensor) -> (Tensor, Tensor):
	logits_u_augm_weak = model(batch_u_augm_weak)
	pred_u_augm_weak = activation(logits_u_augm_weak, dim=1)

	nb_classes = pred_u_augm_weak.shape[1]
	labels_u = one_hot(pred_u_augm_weak.argmax(dim=1), nb_classes)
	return labels_u, pred_u_augm_weak

In [16]:
def confidence_mask(pred_weak: Tensor, threshold: float, dim: int) -> Tensor:
	max_values, _ = pred_weak.max(dim=dim)
	return (max_values > threshold).float()

In [17]:
def train(epoch: int):
	model.train()

	metric_names = list(metrics_s.keys()) + list(metrics_u.keys()) + list(metrics_u_true.keys()) + \
		[f"train/{name}" for name in ["loss", "loss_s", "loss_u", "labels_used", "lr"]]
	continue_metrics = {name: IncrementalMean() for name in metric_names}

	continue_metrics["train/lr"].add(get_lr(optimizer))

	for i, ((batch_s_augm_weak, labels_s), ((batch_u_augm_weak, true_labels_u), (batch_u_augm_strong, _))) in enumerate(loader_train):
		batch_s_augm_weak = batch_s_augm_weak.to(device).float()
		labels_s = labels_s.to(device).float()
		batch_u_augm_weak = batch_u_augm_weak.to(device).float()
		batch_u_augm_strong = batch_u_augm_strong.to(device).float()
		true_labels_u = true_labels_u.to(device).float()

		# Guess label with prediction of weakly augment of u
		with torch.no_grad():
			labels_u, pred_u_augm_weak = guess_label(batch_u_augm_weak)
			mask = confidence_mask(pred_u_augm_weak, args.threshold, dim=1)

		optimizer.zero_grad()

		# Compute predictions
		logits_s_augm_weak = model(batch_s_augm_weak)
		logits_u_augm_strong = model(batch_u_augm_strong)

		pred_s_augm_weak = activation(logits_s_augm_weak, dim=1)
		pred_u_augm_strong = activation(logits_u_augm_strong, dim=1)

		# Update model
		loss, loss_s, loss_u = criterion(
			pred_s_augm_weak,
			pred_u_augm_strong,
			mask,
			labels_s,
			labels_u,
			lambda_u=args.lambda_u
		)

		loss.backward()
		optimizer.step()

		# Compute metrics
		with torch.no_grad():
			continue_metrics["train/loss"].add(loss.item())
			continue_metrics["train/loss_s"].add(loss_s.item())
			continue_metrics["train/loss_u"].add(loss_u.item())
			continue_metrics["train/labels_used"].add(mask.mean().item())

			for name, metric in metrics_s.items():
				score = metric(pred_s_augm_weak, labels_s)
				continue_metrics[name].add(score.item())

			for name, metric in metrics_u.items():
				score = metric(pred_u_augm_strong, labels_u)
				continue_metrics[name].add(score.item())

			for name, metric in metrics_u_true.items():
				score = metric(pred_u_augm_strong, true_labels_u)
				continue_metrics[name].add(score.item())

			current_values = {name: continue_metric.get_current() for name, continue_metric in continue_metrics.items()}
			printer.print_current_values(current_values, i, len(loader_train), epoch)

	for name, continue_metric in continue_metrics.items():
		writer.add_scalar(name, continue_metric.get_current(), epoch)

In [18]:
max_val_scores = {name: MaxTracker() for name in metrics_val.keys()}

def val(epoch: int):
	model.eval()

	continue_metrics = {name: IncrementalMean() for name in metrics_val.keys()}

	for i, (x, y) in enumerate(loader_val):
		x = x.to(device).float()
		y = y.to(device).float()

		# Compute logits
		logits = model(x)
		pred = activation(logits, dim=1)

		for name, metric in metrics_val.items():
			score = metric(pred, y)
			continue_metrics[name].add(score.item())

		current_values = {name: continue_metric.get_current() for name, continue_metric in continue_metrics.items()}
		printer.print_current_values(current_values, i, len(loader_val), epoch)

	for name, continue_metric in continue_metrics.items():
		writer.add_scalar(name, continue_metric.get_current(), epoch)
		max_val_scores[name].add(continue_metric.get_current())

## Start learning

In [19]:
for e in range(args.nb_epochs):
	train(e)
	with torch.no_grad():
		val(e)
	if scheduler is not None:
		scheduler.step()
	print("")

writer.close()

-      train       -   acc_s    -   acc_u    - acc_u_true - labels_use -    loss    -   loss_s   -   loss_u   -     lr     -  took (s)  -
- Epoch   1 - 100% - 2.4333e-01 - 4.7074e-01 - 1.8394e-01 - 5.6345e-04 - 2.0648e+00 - 2.0646e+00 - 2.4622e-04 - 3.0000e-02 -   12.36    -
-       val        -    acc     -  took (s)  -
- Epoch   1 - 100% - 3.4942e-01 -    2.56    -

-      train       -   acc_s    -   acc_u    - acc_u_true - labels_use -    loss    -   loss_s   -   loss_u   -     lr     -  took (s)  -
- Epoch   2 - 100% - 3.4208e-01 - 4.1621e-01 - 2.5332e-01 - 8.8852e-04 - 1.7543e+00 - 1.7531e+00 - 1.1513e-03 - 3.0000e-02 -   12.56    -
-       val        -    acc     -  took (s)  -
- Epoch   2 - 100% - 3.5659e-01 -    2.54    -

-      train       -   acc_s    -   acc_u    - acc_u_true - labels_use -    loss    -   loss_s   -   loss_u   -     lr     -  took (s)  -
- Epoch   3 - 100% - 4.0049e-01 - 4.1038e-01 - 2.7078e-01 - 3.2074e-03 - 1.6430e+00 - 1.6399e+00 - 3.1692e-03 - 3.0000e-

## Print validation results

In [20]:
print("Best val scores :")
for name, tracker in max_val_scores.items():
	print("{} : {} at epoch {}".format(name, tracker.get_current(), tracker.get_index()))

Best val scores :
val/acc : 0.8767914012738853 at epoch 988


In [21]:
print("Terminated")


Terminated
