# <Semi-supervised learning tutorial 3 - consistency regularization>

In [1]:
import os
import math
import random
import time
import numpy as np

import torch
from torch import nn
from torch.nn import functional as F
from torch import optim
from torch.optim.lr_scheduler import LambdaLR
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler

import easydict
from tqdm import tqdm
from PIL import Image

from augmentation import RandAugmentCIFAR
from models import WideResNet, ModelEMA

### 1. 하이퍼파라미터세팅

In [2]:
args = easydict.EasyDict({
    "seed" : 0,
    "gpu": 0,
    "start_step" : 0,
    "total_steps" : 2000, # 30000
    "eval_step" : 20, # 100
    "lambda_u" : 1,
    "threshold" : 0.95,
    "T" : 0.6,
    
    # for data
    "data_path" : "./data",
    "num_data" : 10000, # 50000
    "num_labeled" : 1000,# 5000 
    "num_classes" : 10, # number of classes
    "resize" : 32, # resize image
    "batch_size" : 64,
    "mu" : 1, # coefficient of unlabeled batch size,
    
    # for WideResNet model
    "depth" : 10,
    "widen_factor" : 1,
    "teacher_dropout" : 0, # dropout on last dense layer of teacher model
    "student_dropout" : 0, # dropout on last dense layer of student model
    
    # for optimizing
    "lr" : 0.01, # train learning rate of model
    "momentum" : 0.9, # SGD Momentum
    "nesterov" : True, # use nesterov
    "weight_decay" : 0.01, # train weight decay
    
})

In [3]:
args.device = torch.device('cuda', args.gpu)

In [4]:
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

### 2. 데이터셋 & 데이터로더 준비

In [5]:
base_dataset = datasets.CIFAR10(args.data_path, train=True, download=True)
test_dataset = datasets.CIFAR10(args.data_path, train=False, download=False)

Files already downloaded and verified


In [6]:
def l_u_split(args, labels):
    
    label_per_class = args.num_labeled // args.num_classes
    num_unlabel_data = ((args.num_data // args.num_classes) - label_per_class) * args.num_classes
    # 학습 시간을 줄이기 위해서 데이터 개수를 줄이기 위해서 추가
    
    print(f'클래스별 labeled data 개수 : {label_per_class}')
    print(f'Labeled data 개수 : {label_per_class * args.num_classes}')
    print(f'Unlabeled data 개수 : {num_unlabel_data}')
    
    labels = np.array(labels)
    labeled_idx = []
    
    unlabeled_idx = np.array(range(len(labels))) 
    for i in range(args.num_classes):
        idx = np.where(labels == i)[0]
        idx = np.random.choice(idx, label_per_class, False)
        labeled_idx.extend(idx)
    labeled_idx = np.array(labeled_idx)
    np.random.shuffle(labeled_idx)
    
    unlabeled_idx = np.array([i for i in unlabeled_idx if i not in labeled_idx])
    np.random.shuffle(unlabeled_idx)
    unlabeled_idx = unlabeled_idx[:num_unlabel_data]
    
    return labeled_idx, unlabeled_idx

In [7]:
labeled_idxs, unlabeled_idxs = l_u_split(args, base_dataset.targets)

클래스별 labeled data 개수 : 100
Labeled data 개수 : 1000
Unlabeled data 개수 : 9000


In [8]:
# 정규화에 사용될 평균, 표준편차
cifar10_mean = (0.491400, 0.482158, 0.4465231)
cifar10_std = (0.247032, 0.243485, 0.2615877)

# Labeled 데이터셋을 위한 데이터변환 사전에 정의
transform_labeled = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(size=args.resize,
                              padding=int(args.resize * 0.125),
                              fill=128,
                              padding_mode='constant'),
        transforms.ToTensor(),
        transforms.Normalize(mean=cifar10_mean, std=cifar10_std),
    ])

# Test 데이터셋을 위한 데이터변환 사전에 정의
transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=cifar10_mean, std=cifar10_std)
    ])

# Unlabeled 데이터셋을 위한 데이터변환 사전에 정의
# Unlabeled 데이터셋을 위한 커스터마이징된 데이터변환 클래스 만들기
class CustomTransform(object):
    # class 초기화
    def __init__(self, args, mean, std):
        n, m = 5, 10
        
        self.ori = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(size=args.resize,
                                  padding=int(args.resize * 0.125),
                                  fill=128,
                                  padding_mode='constant')])
        
        self.aug = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(size=args.resize,
                                  padding=int(args.resize * 0.125),
                                  fill=128,
                                  padding_mode='constant'),
            RandAugmentCIFAR(n=n, m=m)])
        
        self.normalize = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=mean, std=std)])
        
    # class가 사용될 때
    def __call__(self, x):
        ori = self.ori(x)
        aug = self.aug(x)
        return self.normalize(ori), self.normalize(aug)
    
transform_unlabeled = CustomTransform(args, mean=cifar10_mean, std=cifar10_std)

In [9]:
class CustomCIFAR10SSL(datasets.CIFAR10):
    def __init__(self, root, indexs, train=True,
                 transform=None, target_transform=None, download=False):
        super().__init__(root, train=train,
                         transform=transform,
                         target_transform=target_transform,
                         download=download)
        if indexs is not None:
            self.data = self.data[indexs]
            self.targets = np.array(self.targets)[indexs]

    def __getitem__(self, index):
        img, target = self.data[index], self.targets[index]
        img = Image.fromarray(img)

        if self.transform is not None:
            img = self.transform(img)

        return img, target

In [11]:
labeled_dataset = CustomCIFAR10SSL(args.data_path, labeled_idxs, train=True, transform=transform_labeled)
unlabeled_dataset = CustomCIFAR10SSL(args.data_path, unlabeled_idxs, train=True, 
                                     transform=transform_unlabeled)
test_dataset = datasets.CIFAR10(args.data_path, train=False, transform=transform_test, download=False)

In [12]:
labeled_loader = DataLoader(labeled_dataset, sampler=RandomSampler(labeled_dataset),
                            batch_size=args.batch_size, drop_last=True)
unlabeled_loader = DataLoader(unlabeled_dataset, sampler=RandomSampler(unlabeled_dataset),
                              batch_size=args.batch_size * args.mu, drop_last=True)
test_loader = DataLoader(test_dataset, sampler=SequentialSampler(test_dataset), batch_size=args.batch_size)

### 3. Labeled & unlabeled 데이터셋을 사용한 semi-supervised learning
### 3-1 Consistency regularization 예시1

In [13]:
model = WideResNet(num_classes=args.num_classes,
                   depth=args.depth,
                   widen_factor=args.widen_factor,
                   dropout=0,
                   dense_dropout=args.teacher_dropout)
model.to(args.device)
optimizer = optim.SGD(model.parameters(),
                      lr=args.lr,
                      momentum=args.momentum,
                      nesterov=args.nesterov)

In [14]:
def train_consistency_regularization(args, model, optimizer):
    since = time.time()
    for step in range(args.start_step, args.total_steps):
        if step % args.eval_step == 0:
            if step != 0:
                print('{} Step - loss: {:.4f} cross entropy : {:.4f} consistency reg : {:.4f}'.format(step,
                                                                                                      np.mean(losses), 
                                                                                                      np.mean(ce_losses), 
                                                                                                      np.mean(cr_losses)))
        
            losses = []
            ce_losses = []
            cr_losses = []
            
        model.train()

        try:
            images_l, targets = labeled_iter.next()
        except:
            labeled_iter = iter(labeled_loader)
            images_l, targets = labeled_iter.next()

        try:
            (images_uw, images_us), _ = unlabeled_iter.next()
        except:
            unlabeled_iter = iter(unlabeled_loader)
            (images_uw, images_us), _ = unlabeled_iter.next()

        images_l = images_l.to(args.device)
        images_uw = images_uw.to(args.device)
        images_us = images_us.to(args.device)
        targets = targets.to(args.device, dtype=torch.long)

        # parameter gradients를 0으로 설정
        optimizer.zero_grad()

        # forward model
        batch_size = images_l.shape[0]
        images = torch.cat((images_l, images_uw, images_us))
        logits = model(images)
        logits_l = logits[:batch_size]
        logits_uw, logits_us = logits[batch_size:].chunk(2)
        del logits

        loss_l = F.cross_entropy(logits_l, targets, reduction='mean')

        # make pseudo label
        soft_pseudo_label = torch.softmax(logits_uw.detach()/args.T, dim=-1)
        max_probs, hard_pseudo_label = torch.max(soft_pseudo_label, dim=-1)
        mask = max_probs.ge(args.threshold).float()
        
        loss_u = (-(soft_pseudo_label * torch.log_softmax(logits_us, dim=-1)).sum(dim=-1) * mask).mean()
        #loss_u = (((soft_pseudo_label - torch.log_softmax(logits_us, dim=-1))**2).sum(dim=-1) * mask).mean()
        loss = loss_l + args.lambda_u * loss_u

        # backward
        loss.backward()
        optimizer.step()

        losses.append(loss.item())
        ce_losses.append(loss_l.item())
        cr_losses.append(loss_u.item())
        
    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))

In [15]:
train_consistency_regularization(args, model, optimizer)

20 Step - loss: 2.2874 cross entropy : 2.2874 consistency reg : 0.0000
40 Step - loss: 2.0385 cross entropy : 2.0366 consistency reg : 0.0019
60 Step - loss: 1.9537 cross entropy : 1.9537 consistency reg : 0.0000
80 Step - loss: 1.8988 cross entropy : 1.8954 consistency reg : 0.0033
100 Step - loss: 1.8394 cross entropy : 1.8384 consistency reg : 0.0010
120 Step - loss: 1.8435 cross entropy : 1.8408 consistency reg : 0.0028
140 Step - loss: 1.7899 cross entropy : 1.7887 consistency reg : 0.0013
160 Step - loss: 1.7733 cross entropy : 1.7694 consistency reg : 0.0039
180 Step - loss: 1.7618 cross entropy : 1.7586 consistency reg : 0.0032
200 Step - loss: 1.7248 cross entropy : 1.7200 consistency reg : 0.0048
220 Step - loss: 1.6829 cross entropy : 1.6807 consistency reg : 0.0022
240 Step - loss: 1.6999 cross entropy : 1.6862 consistency reg : 0.0136
260 Step - loss: 1.6755 cross entropy : 1.6683 consistency reg : 0.0071
280 Step - loss: 1.6564 cross entropy : 1.6541 consistency reg : 0.0

In [16]:
def test(args, model, loader):
    model.eval()
    with torch.no_grad():
        corrects = 0
        total = 0
        for inputs, targets in loader:
            inputs = inputs.to(args.device)
            targets = targets.to(args.device, dtype=torch.long)

            # forward
            outputs = model(inputs)

            # output 중 최대값의 위치에 해당하는 class로 예측 수행
            _, preds = torch.max(outputs, 1)

            # batch별 정답 개수를 축적함
            corrects += torch.sum(preds == targets.data)
            total += targets.size(0)

    test_acc = corrects.double() / total
    print('Testing Acc: {:.4f}'.format(test_acc))

In [17]:
test(args, model, test_loader)

Testing Acc: 0.4319
