# 🔥 MixMatch : A Holistic Approach Semi-Super Vised Learning

## 1. Import Libraries

In [14]:
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data as data
import torchvision.transforms as transforms
import torchvision
import torch.nn.functional as F

import os
import shutil
import time
import numpy as np

import models.wideresnet as models
from utils.logger import *
from utils.misc import *
from utils.accuracy import *

## 2. Hyper-Parameters

Hyper-Parameter를 정의해 둠. 대부분은 Original Paper에 기반으로 정의 되어 있음. 사용자가 조절할만한 Hyper-Parameter들은 아래와 같음.

- epochs : 학습시킬 epochs 지정
- batch_size : batch 크기.
- random_seed : random seed값 설정.
- lr : learning rate. batch_size에 따라서 조절 필요. batch_size가 k배 커지만 lr은 sqrt(k)배 만큼 크게 만들면 좋음
- is_continue_train : 이전 saved best model을 불러와서 이어서 학습할지 여부 세팅

In [15]:
alpha = 0.75
lambda_u = 75
# train_iteration = 1024
# train_iteration = 256 
train_iteration = 512 
ema_decay = 0.999
Temperature = 0.5

is_ema = True 
is_continue_train = False 
start_epoch = 0
epochs = 1000
# num_labeled = 500 
num_labeled = 250 


use_interleaving = False # Interleaving을 False로 할 경우, Batchsize는 3배로 늘어나므로 sqrt(3)배를 lr에 곱해야 함

# batch_size = 64 # Original Batch Size
batch_size = 128 # x2배 Batch Size

# lr = 0.002 # Original lr (interleaving), batch_size = 64
# lr = 0.00346 # sqrt('batch size') * 'Original lr' (non-interleaving), batch_size = 64
lr = 0.0049 # sqrt('batch size') * 'Original lr' (non-interleaving), batch_size = 128 




# Use CUDA
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
use_cuda = torch.cuda.is_available()

# Random seed
np.random.seed(512)

# best_acc = 0  # best test accuracy

CIFAR10_MEAN = (0.4914, 0.4822, 0.4465) 
CIFAR10_STD = (0.2471, 0.2435, 0.2616)



## 3. Functions for Augmentation 

Data Augmentation을 위한 Function들을 모아둠. Pytorch Data Loader를 통해 Data를 Loading할 때 마다 Real-Time으로 Augmentation을 진행함

- Image Normalization (Standard Scaling)
- Image Transpose (from CIFAR-10 shape to Pytorch Shape)
- Augmentation Method #1 : Image Padding & Cropping
- Augmentation Method #2 : Image Horizontal Flipping

In [16]:
# Dataset Normalization (Standard Scaling)
def normalize(x, mean=CIFAR10_MEAN, std=CIFAR10_STD):
    x, mean, std = [np.array(a, np.float32) for a in (x, mean, std)]
    x -= mean*255
    x *= 1.0/(255*std)
    return x

# Transpose Image Shape for Pytorch shape from CIFAR-10 Original shape
def transpose(x, source='NHWC', target='NCHW'):
    return x.transpose([source.index(d) for d in target]) 

# Padding to image borders for cropping image
def pad(x, border=4):
    return np.pad(x, [(0, 0), (border, border), (border, border)], mode='reflect')

# Augmentation Method 1 : Random Padding and Cropping 
class RandomPadandCrop(object):
    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        if isinstance(output_size, int):
            self.output_size = (output_size, output_size)
        else:
            assert len(output_size) == 2
            self.output_size = output_size

    def __call__(self, x):
        x = pad(x, 4)

        h, w = x.shape[1:]
        new_h, new_w = self.output_size

        top = np.random.randint(0, h - new_h)
        left = np.random.randint(0, w - new_w)

        x = x[:, top: top + new_h, left: left + new_w]

        return x

# Augmentation Method 2 : Random Horizontal Flipping 
class RandomFlip(object):
    def __call__(self, x):
        if np.random.rand() < 0.5:
            x = x[:, :, ::-1]

        return x.copy()

class ToTensor(object):
    def __call__(self, x):
        x = torch.from_numpy(x)
        return x

def split_dataset(labels, num_train):
    labels = np.array(labels)
    index_train_labeled = []
    index_train_unlabeled = []
    index_valid = []

    num_labels = len(np.unique(labels))
    num_labeled_per_class = int(num_train / num_labels)

    for i in range(num_labels):
        index = np.where(labels == i)[0]
        np.random.shuffle(index)
        
        index_train_labeled.extend(index[:num_labeled_per_class])
        index_train_unlabeled.extend(index[num_labeled_per_class:-500])
        index_valid.extend(index[-500:])

    np.random.shuffle(index_train_labeled)
    np.random.shuffle(index_train_unlabeled)
    np.random.shuffle(index_valid)

    return index_train_labeled, index_train_unlabeled, index_valid

# Twice Augmentation for Unlabeld Image (eg. K=2)
class Multi_Augmentation:
    def __init__(self, transform_method):
        self.transform_method = transform_method
        # self.num_transform = num_transform

    def __call__(self, inp):
        aug_out_1 = self.transform_method(inp)
        aug_out_2 = self.transform_method(inp)
        return aug_out_1, aug_out_2



## 4. Functions for loading CIFAR-10 dataset 

Pytorch의 dataloader에서 사용할 함수들. Data를 불러오면서 augmentation을 수행해 줌. unlabeled data의 augmentation횟수는 original paper에 따라 K=2로 맞춰서 2번 augmentation수행함

In [17]:
# Load CIFAR-10 Labeled Images with Random Augmentation
class get_cifar10_labeled(torchvision.datasets.CIFAR10):

    def __init__(self, path_cifar10, indexs=None, train=True,
                 transform=None, target_transform=None,
                 download=False):
        super(get_cifar10_labeled, self).__init__(path_cifar10, train=train,
                 transform=transform, target_transform=target_transform,
                 download=download)
        if indexs is not None:
            self.data = self.data[indexs]
            self.targets = np.array(self.targets)[indexs]
        self.data = transpose(normalize(self.data))

    def __getitem__(self, index):
        data_x, target_y = self.data[index], self.targets[index]

        if self.transform is not None:
            data_x = self.transform(data_x)

        if self.target_transform is not None:
            target_y = self.target_transform(target_y)

        return data_x, target_y
    

# Load CIFAR-10 Unlabeled Images with Random Augmentation (K=2)
class get_cifar10_unlabeled(get_cifar10_labeled):

    def __init__(self, path_cifar10, indexs, train=True,
                 transform=None, target_transform=None,
                 download=False):
        super(get_cifar10_unlabeled, self).__init__(path_cifar10, indexs, train=train,
                 transform=transform, target_transform=target_transform,
                 download=download)
        self.targets = np.array([-1 for i in range(len(self.targets))])
        


## 5. Functions for MixMatch learning method

MixMatch를 학습할 때 필요한 중요한 기능들을 Function으로 정의해 둠
- Ramp-Up Function (weight balancing betweein supervised and unsupervised loss)
- Calculation of semi-supervised loss (Each_Loss)
- Exponential Moving Average(EMA) Function
- Interleaving Functions : 이번 Tutorial에서 확인할 함수. MixUp된 Labeled Data와 Unlabeled Data를 섞어줌.


In [18]:
# Ramp-up Function for balancing the weight between supervised loss and unsupervised loss
# - loss_total = loss_l + weight * loss_u
def ramp_up(current, rampup_length=epochs):
    if rampup_length == 0:
        return 1.0
    else:
        current = np.clip(current / rampup_length, 0.0, 1.0)
        return float(current)

class Each_Loss(object):
    def __call__(self, pred_l, target_l, pred_u, target_u, epoch):
        probs_u = torch.softmax(pred_u, dim=1)

        loss_l = -torch.mean(torch.sum(F.log_softmax(pred_l, dim=1) * target_l, dim=1))
        loss_u = torch.mean((probs_u - target_u)**2)

        return loss_l, loss_u, lambda_u * ramp_up(epoch)

class exponential_moving_average(object):
    def __init__(self, student_model, teacher_model, alpha=0.999):
        self.model = student_model
        self.ema_model = teacher_model
        self.alpha = alpha
        self.student_params = list(student_model.state_dict().values())
        self.teacher_params = list(teacher_model.state_dict().values())
        self.wd = 0.02 * lr

        for param, ema_param in zip(self.student_params, self.teacher_params):
            param.data.copy_(ema_param.data)

    def step(self):
        one_minus_alpha = 1.0 - self.alpha
        for student_param, teacher_param in zip(self.student_params, self.teacher_params):
            if teacher_param.dtype==torch.float32:
                teacher_param.mul_(self.alpha)
                teacher_param.add_(student_param * one_minus_alpha)
                # customized weight decay
                student_param.mul_(1 - self.wd)

def interleave_offsets(batch, nu):
    groups = [batch // (nu + 1)] * (nu + 1)
    for x in range(batch - sum(groups)):
        groups[-x - 1] += 1
    offsets = [0]
    for g in groups:
        offsets.append(offsets[-1] + g)
    assert offsets[-1] == batch
    return offsets


def interleave(xy, batch):
    nu = len(xy) - 1
    offsets = interleave_offsets(batch, nu)
    xy = [[v[offsets[p]:offsets[p + 1]] for p in range(nu + 1)] for v in xy]
    for i in range(1, nu + 1):
        xy[0][i], xy[i][i] = xy[i][i], xy[0][i]
    return [torch.cat(v, dim=0) for v in xy]

## 6. Utility Functions (Tedious)

In [19]:
def save_checkpoint(state, is_best, checkpoint='./Results/', filename='checkpoint.pth.tar'):
    filepath = os.path.join(checkpoint, filename)
    torch.save(state, filepath)
    if is_best:
        shutil.copyfile(filepath, os.path.join(checkpoint, 'model_best.pth.tar'))

## 7. Training Function

MixMatch구현에 있어서 가장 중요한 함수이다. 이 함수에서 MixMatch의 모든 기법이 사용되며, 알고리즘의 Sequence대로 구현이 되어있다.
또한 이 함수에 Interleaving이 존재한다. Batch-Normalization을 고려하여, Supervised Logits과 Unsupervised Logits을 Model에 각각 태우면 Interleaving이 필요하며,
그렇지 않고 한번에 Model에 태워 계산 할 경우 Interleaving은 필요하지 않다.

- 1. Data Augmentation
- 2. Label Guessing and Label Sharpening
- 3. MixUp 

In [20]:
def train(labeled_trainloader, unlabeled_trainloader, model, optimizer, ema_optimizer, each_loss, epoch, use_cuda, is_ema=False):

    losses_total = AverageMeter()
    losses_l = AverageMeter()
    losses_u = AverageMeter()
    weights = AverageMeter()


    iter_train_labeled = iter(labeled_trainloader)
    iter_train_unlabeled = iter(unlabeled_trainloader)

    model.train()
    for batch_index in range(train_iteration):
        #########################################
        # 1. Data Augmentation
        #########################################
        try:
            inputs_l, targets_l = next(iter_train_labeled)
        except:
            iter_train_labeled = iter(labeled_trainloader)
            inputs_l, targets_l = next(iter_train_labeled)


        try:
            (inputs_u, inputs_u2), _ = next(iter_train_unlabeled)
        except:
            iter_train_unlabeled = iter(unlabeled_trainloader)
            (inputs_u, inputs_u2), _ = next(iter_train_unlabeled)

        # Transform label to one-hot
        batch_size = inputs_l.size(0)
        targets_l = torch.zeros(batch_size, 10).scatter_(1, targets_l.view(-1,1).long(), 1)

        if use_cuda:
            inputs_l, targets_l = inputs_l.cuda(), targets_l.cuda(non_blocking=True)
            inputs_u = inputs_u.cuda()
            inputs_u2 = inputs_u2.cuda()


        #########################################
        # 2. Label Guessing & Label Sharpening 
        #########################################
        with torch.no_grad():
            # compute guessed labels of unlabel samples
            outputs_u = model(inputs_u)
            outputs_u2 = model(inputs_u2)
            p = (torch.softmax(outputs_u, dim=1) + torch.softmax(outputs_u2, dim=1)) / 2
            pt = p**(1/Temperature)
            targets_u = pt / pt.sum(dim=1, keepdim=True)
            targets_u = targets_u.detach()

        #########################################
        # 3. MixUp 
        #########################################
        all_inputs = torch.cat([inputs_l, inputs_u, inputs_u2], dim=0)
        all_targets = torch.cat([targets_l, targets_u, targets_u], dim=0)

        Lambda_ = np.random.beta(alpha, alpha)

        Lambda_ = max(Lambda_, 1-Lambda_)

        idx = torch.randperm(all_inputs.size(0))

        input_a, input_b = all_inputs, all_inputs[idx]
        target_a, target_b = all_targets, all_targets[idx]

        mixed_input = Lambda_ * input_a + (1 - Lambda_) * input_b
        mixed_target = Lambda_ * target_a + (1 - Lambda_) * target_b


        ########################################
        # 4. Interleaving or No-Interleaving
        ########################################
        if use_interleaving:
            # 1) interleave labeled and unlabed images between batches to get correct batchnorm calculation 
            mixed_input = list(torch.split(mixed_input, batch_size))
            mixed_input = interleave(mixed_input, batch_size)

            # 2) labeled prediction
            logits = [model(mixed_input[0])]
            
            # 3) unlabeled prediction
            for input in mixed_input[1:]:
                logits.append(model(input))

            # 4) de_interleave to calculate labeled supervised loss and unlabeld unsupervised loss properly 
            logits = interleave(logits, batch_size)

            logits_l = logits[0]
            logits_u = torch.cat(logits[1:], dim=0)
        else:
            # No Interleaving and calculate both labeled and unlabeled sample.
            # The model is used only once to predict logits. So A calculation of batchnorm is proper.
            # But if you want to use this no-interleaving method then you should adjust the learning rate (with multiply sqrt(k))
            # k means k-times of increased batch size
            logits = model(mixed_input)
            split_logits = list(torch.split(logits, batch_size))
            logits_l = split_logits[0]
            logits_u = torch.cat(split_logits[1:], dim=0)


        ########################################
        # 5. Semi-Supervised Loss 
        ########################################
        loss_l, loss_u, weight = each_loss(logits_l, mixed_target[:batch_size], logits_u, mixed_target[batch_size:], epoch+batch_index/train_iteration)

        loss_total = loss_l + weight * loss_u

        # record loss
        losses_total.update(loss_total.item(), inputs_l.size(0))
        losses_l.update(loss_l.item(), inputs_l.size(0))
        losses_u.update(loss_u.item(), inputs_l.size(0))
        weights.update(weight, inputs_l.size(0))

        ########################################
        # 6. Backpropagation 
        ########################################
        optimizer.zero_grad()
        loss_total.backward()
        optimizer.step()

        
        ########################################
        # 7. EMA Learning for Teacher Model 
        ########################################
        if is_ema is True:
            ema_optimizer.step()


        if batch_index % 10 == 0:
            print(f'{batch_index+1}/{train_iteration} : Total Loss {losses_total.avg:.3f}')

    return (losses_total.avg, losses_l.avg, losses_u.avg,)




## 8. Validation Function

In [21]:
def validate(valloader, model, criterion, use_cuda, mode):
    losses_total = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    model.eval()

    with torch.no_grad():
        for batch_index, (inputs, targets) in enumerate(valloader):

            batch_size = inputs.size(0)

            ori_targets = targets
            ori_targets = ori_targets.cuda(non_blocking=True)

            targets = torch.zeros(batch_size, 10).scatter_(1, targets.view(-1,1).long(), 1)

            if use_cuda:
                inputs, targets = inputs.cuda(), targets.cuda(non_blocking=True)


            outputs = model(inputs)
            loss = criterion(outputs, targets)

            # accuracy top-1 and top-5
            prec1, prec5 = accuracy(outputs, ori_targets, topk=(1, 1))

            losses_total.update(loss.item(), inputs.size(0))
            top1.update(prec1.item(), inputs.size(0))
            top5.update(prec5.item(), inputs.size(0))


            if batch_index % 10 == 0:
                print(f'{batch_index+1}/{train_iteration} : Total Loss {losses_total.avg:.3f}, Top1 Acc {top1.avg:.3f}')
    return (losses_total.avg, top1.avg)


## 9. 🌊Main - Start to Train the MixMatch!

In [23]:
def main():
    global start_epoch
    best_acc = 0

    output_folder = './Results/'
    path_cifar10 = './data'

    if is_continue_train == True:
        load_model = output_folder+'model_best.pth.tar'


    if not os.path.isdir(output_folder):
        mkdir_p(output_folder)

    print('1. Load Dataset (CIFAR-10)')
    augmentation_train = transforms.Compose([
        RandomPadandCrop(32), # 32x32 Random Cropping
        RandomFlip(), # Random Horizontal Flipping
        ToTensor(),
    ])

    augmentation_test = transforms.Compose([
        ToTensor(),
    ])


    # Preparing CIFAR-10 Dataset (train, valid, test sets)
    cifar10_dataset = torchvision.datasets.CIFAR10(path_cifar10, train=True, download=True)
    train_labeled_index, train_unlabeled_index, valid_index = split_dataset(cifar10_dataset.targets, num_labeled)

    train_labeled_dataset = get_cifar10_labeled(path_cifar10, train_labeled_index, train=True, transform=augmentation_train)
    train_unlabeled_dataset = get_cifar10_unlabeled(path_cifar10, train_unlabeled_index, train=True, transform=Multi_Augmentation(augmentation_train))
    valid_dataset = get_cifar10_labeled(path_cifar10, valid_index, train=True, transform=augmentation_test, download=True)
    test_dataset = get_cifar10_labeled(path_cifar10, train=False, transform=augmentation_test, download=True)

    print (f"Labeled: {len(train_labeled_index)} Unlabeled: {len(train_unlabeled_index)} Validation: {len(valid_index)}")

    # Define DataLoader for pytorch
    train_loader_labeled = data.DataLoader(train_labeled_dataset, batch_size=batch_size, shuffle=True, num_workers=0, drop_last=True)
    train_loader_unlabeled = data.DataLoader(train_unlabeled_dataset, batch_size=batch_size, shuffle=True, num_workers=0, drop_last=True)
    valid_loader = data.DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
    test_loader = data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0)

    # Define Model (WideResNet)
    student_model = models.WideResNet(num_classes=10).cuda()
    teacher_model = models.WideResNet(num_classes=10).cuda()

    for param in teacher_model.parameters():
        param.detach_()


    train_criterion = Each_Loss()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(student_model.parameters(), lr=lr)

    # if is_ema is True:
    ema_optimizer= exponential_moving_average(student_model, teacher_model, alpha=ema_decay)

    # Load Model & Logger Setting
    title = 'MixMatch Semi-Supervised Learning'
    if is_continue_train is True:
        print(' > Load Checkpoint')
        output_folder = os.path.dirname(load_model)
        checkpoint = torch.load(load_model)
        best_acc = checkpoint['best_acc']
        start_epoch = checkpoint['epoch']
        student_model.load_state_dict(checkpoint['state_dict'])
        teacher_model.load_state_dict(checkpoint['ema_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        logger = Logger(os.path.join(output_folder, 'log.txt'), title=title, resume=True)
    else:
        logger = Logger(os.path.join(output_folder, 'log.txt'), title=title)
        logger.set_names(['Train Loss', 'Train Loss X', 'Train Loss U',  'Valid Loss', 'Valid Acc.', 'Test Loss', 'Test Acc.'])

    test_accs = []


    print('2. Start to train the MixMatch Model')
    ####################################################
    # Train and Validation 
    ####################################################
    for epoch in range(start_epoch, epochs):

        print(f'Epoch: {epoch+1} / {epochs}\n')

        if is_ema is True:
            train_loss, train_loss_x, train_loss_u = train(train_loader_labeled, train_loader_unlabeled, student_model, optimizer, ema_optimizer, train_criterion, epoch, use_cuda, True)
            _, train_acc = validate(train_loader_labeled, teacher_model, criterion, use_cuda, mode='Train Accuracy')
            val_loss, val_acc = validate(valid_loader, teacher_model, criterion, use_cuda, mode='Valid Accuracy')
            test_loss, test_acc = validate(test_loader, teacher_model, criterion, use_cuda, mode='Test Accuracy')
        else:
            train_loss, train_loss_x, train_loss_u = train(train_loader_labeled, train_loader_unlabeled, student_model, optimizer, ema_optimizer, train_criterion, epoch, use_cuda, False)
            _, train_acc = validate(train_loader_labeled, student_model, criterion, use_cuda, mode='Train Accuracy')
            val_loss, val_acc = validate(valid_loader, student_model, criterion, use_cuda, mode='Valid Accuracy')
            test_loss, test_acc = validate(test_loader, student_model, criterion, use_cuda, mode='Test Accuracy')

        # append logger file
        logger.append([train_loss, train_loss_x, train_loss_u, val_loss, val_acc, test_loss, test_acc])

        # save model
        is_best = val_acc > best_acc
        best_acc = max(val_acc, best_acc)
        save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': student_model.state_dict(),
                'ema_state_dict': teacher_model.state_dict(),
                'acc': val_acc,
                'best_acc': best_acc,
                'optimizer' : optimizer.state_dict(),
            }, is_best)
        test_accs.append(test_acc)
    logger.close()

    print('Best acc:')
    print(best_acc)

    print('Mean acc:')
    print(np.mean(test_accs[-20:]))

if __name__ == '__main__':
    main()


1. Load Dataset (CIFAR-10)
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Labeled: 250 Unlabeled: 44750 Validation: 5000
Epoch: 1 / 1000

1/512 : Total Loss 2.300
11/512 : Total Loss 2.187
21/512 : Total Loss 2.090
31/512 : Total Loss 2.017
41/512 : Total Loss 1.946
51/512 : Total Loss 1.902
61/512 : Total Loss 1.865
71/512 : Total Loss 1.833
81/512 : Total Loss 1.761
91/512 : Total Loss 1.760
101/512 : Total Loss 1.723
111/512 : Total Loss 1.708
121/512 : Total Loss 1.680
131/512 : Total Loss 1.662
141/512 : Total Loss 1.659
151/512 : Total Loss 1.653
161/512 : Total Loss 1.646
171/512 : Total Loss 1.634
181/512 : Total Loss 1.620
191/512 : Total Loss 1.616
201/512 : Total Loss 1.604
211/512 : Total Loss 1.583
221/512 : Total Loss 1.574
231/512 : Total Loss 1.548
241/512 : Total Loss 1.530
251/512 : Total Loss 1.510
261/512 : Total Loss 1.493
271/512 : Total Loss 1.488
281/512 : Total Loss 1.476
291/512 : Total Loss 1.

KeyboardInterrupt: 