In [1]:
!pip install easydict



In [2]:
import os
import math
import random
import time
import numpy as np

import torch
from torch import nn
from torch.nn import functional as F
from torch import optim
from torch.optim.lr_scheduler import LambdaLR
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler

import easydict
from tqdm import tqdm
from PIL import Image

from augmentation import RandAugmentCIFAR
from models import WideResNet, ModelEMA

In [3]:
args = easydict.EasyDict({
    "seed" : 0,
    "gpu": 0,
    "start_step" : 0,
    "total_steps" : 30000,
    "eval_step" : 100,
    "lambda_u" : 0.01,
    
    # for data
    "data_path" : "./data", 
    "num_labeled" : 5000,# number of labeled data
    "num_classes" : 10, # number of classes
    "resize" : 32, # resize image
    "batch_size" : 64,
    "mu" : 1, # coefficient of unlabeled batch size,
    
    # for WideResNet model
    "depth" : 28,
    "widen_factor" : 2,
    "teacher_dropout" : 0, # dropout on last dense layer of teacher model
    "student_dropout" : 0, # dropout on last dense layer of student model
    
    # for optimizing
    "teacher_lr" : 0.01, # train learning rate of teacher model
    "student_lr" : 0.01, # train learning rate of student model
    "momentum" : 0.9, # SGD Momentum
    "nesterov" : True, # use nesterov
    "weight_decay" : 0, # train weight decay
    
})

In [4]:
args.device = torch.device('cuda', args.gpu)

In [5]:
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)

In [6]:
base_dataset = datasets.CIFAR10(args.data_path, train=True, download=True)
test_dataset = datasets.CIFAR10(args.data_path, train=False, download=False)

Files already downloaded and verified


In [7]:
def l_u_split(args, labels):
    label_per_class = args.num_labeled // args.num_classes
    labels = np.array(labels)
    labeled_idx = []
    
    unlabeled_idx = np.array(range(len(labels))) # unlabeled data: all training data
    for i in range(args.num_classes):
        idx = np.where(labels == i)[0]
        idx = np.random.choice(idx, label_per_class, False)
        labeled_idx.extend(idx)
    labeled_idx = np.array(labeled_idx)
    np.random.shuffle(labeled_idx)
    
    # unlabeled_idx = np.array([i for i in unlabeled_idx if i not in labeled_idx])
    return labeled_idx, unlabeled_idx

In [8]:
labeled_idxs, unlabeled_idxs = l_u_split(args, base_dataset.targets)

In [9]:
cifar10_mean = (0.491400, 0.482158, 0.4465231)
cifar10_std = (0.247032, 0.243485, 0.2615877)

transform_labeled = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(size=args.resize,
                              padding=int(args.resize * 0.125),
                              fill=128,
                              padding_mode='constant'),
        transforms.ToTensor(),
        transforms.Normalize(mean=cifar10_mean, std=cifar10_std),
    ])

transform_val = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=cifar10_mean, std=cifar10_std)
    ])

class CustomTransform(object):
    def __init__(self, args, mean, std):
        n, m = 2, 10
        
        self.ori = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(size=args.resize,
                                  padding=int(args.resize * 0.125),
                                  fill=128,
                                  padding_mode='constant')])
        
        self.aug = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(size=args.resize,
                                  padding=int(args.resize * 0.125),
                                  fill=128,
                                  padding_mode='constant'),
            RandAugmentCIFAR(n=n, m=m)])
        
        self.normalize = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=mean, std=std)])

    def __call__(self, x):
        ori = self.ori(x)
        aug = self.aug(x)
        return self.normalize(ori), self.normalize(aug)

In [10]:
class CustomCIFAR10SSL(datasets.CIFAR10):
    def __init__(self, root, indexs, train=True,
                 transform=None, target_transform=None, download=False):
        super().__init__(root, train=train,
                         transform=transform,
                         target_transform=target_transform,
                         download=download)
        if indexs is not None:
            self.data = self.data[indexs]
            self.targets = np.array(self.targets)[indexs]

    def __getitem__(self, index):
        img, target = self.data[index], self.targets[index]
        img = Image.fromarray(img)

        if self.transform is not None:
            img = self.transform(img)

        return img, target

In [11]:
labeled_dataset = CustomCIFAR10SSL(args.data_path, labeled_idxs, train=True, transform=transform_labeled)
unlabeled_dataset = CustomCIFAR10SSL(args.data_path, unlabeled_idxs, train=True, 
                                     transform=CustomTransform(args, mean=cifar10_mean, std=cifar10_std))
test_dataset = datasets.CIFAR10(args.data_path, train=False, transform=transform_val, download=False)

In [12]:
labeled_loader = DataLoader(labeled_dataset, sampler=RandomSampler(labeled_dataset),
                            batch_size=args.batch_size, drop_last=True)
unlabeled_loader = DataLoader(unlabeled_dataset, sampler=RandomSampler(unlabeled_dataset),
                              batch_size=args.batch_size * args.mu, drop_last=True)
test_loader = DataLoader(test_dataset, sampler=SequentialSampler(test_dataset), batch_size=args.batch_size)

# Supervised learning

In [13]:
teacher_model = WideResNet(num_classes=args.num_classes,
                           depth=args.depth,
                           widen_factor=args.widen_factor,
                           dropout=0,
                           dense_dropout=args.teacher_dropout)
teacher_model.to(args.device)
print(f"Params: {sum(p.numel() for p in teacher_model.parameters())/1e6:.2f}M")
# K킬로 1000, M 메가 100만 million, G 기가 10억 billion

Params: 1.47M


In [14]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(teacher_model.parameters(), lr=args.teacher_lr, momentum=args.momentum, nesterov=args.nesterov)

In [15]:
since = time.time()

for epoch in range(100):
    # 모델은 training mode로 설정
    teacher_model.train()
    
    running_loss = 0
    running_total = 0
    
    for inputs, targets in labeled_loader:
        inputs = inputs.to(args.device)
        targets = targets.to(args.device, dtype=torch.long)
        
        # parameter gradients를 0으로 설정
        optimizer.zero_grad()
        
        # forward
        outputs = teacher_model(inputs)
        #print(outputs)
        #print(targets)
        loss = criterion(outputs, targets)
        
        # backward
        loss.backward()
        optimizer.step()
        
        # batch별 loss를 축적함
        running_loss += loss.item() * inputs.size(0)
        running_total += inputs.size(0)

    # epoch의 loss 도출
    epoch_loss = running_loss / running_total
    print(f'{epoch+1} Loss : {epoch_loss:.4f}')
    
time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))

1 Loss : 1.9969
2 Loss : 1.7568
3 Loss : 1.6429
4 Loss : 1.5526
5 Loss : 1.4886
6 Loss : 1.4297
7 Loss : 1.3584
8 Loss : 1.3083
9 Loss : 1.2663
10 Loss : 1.2027
11 Loss : 1.1688
12 Loss : 1.1478
13 Loss : 1.1140
14 Loss : 1.0737
15 Loss : 1.0463
16 Loss : 1.0108
17 Loss : 1.0036
18 Loss : 0.9770
19 Loss : 0.9250
20 Loss : 0.9342
21 Loss : 0.9000
22 Loss : 0.8721
23 Loss : 0.8501
24 Loss : 0.8400
25 Loss : 0.8084
26 Loss : 0.7942
27 Loss : 0.7591
28 Loss : 0.7505
29 Loss : 0.7238
30 Loss : 0.7215
31 Loss : 0.7000
32 Loss : 0.6843
33 Loss : 0.6584
34 Loss : 0.6646
35 Loss : 0.6220
36 Loss : 0.6251
37 Loss : 0.6028
38 Loss : 0.6084
39 Loss : 0.5687
40 Loss : 0.5522
41 Loss : 0.5259
42 Loss : 0.5245
43 Loss : 0.5137
44 Loss : 0.4872
45 Loss : 0.4931
46 Loss : 0.4537
47 Loss : 0.4440
48 Loss : 0.4277
49 Loss : 0.4479
50 Loss : 0.4093
51 Loss : 0.4158
52 Loss : 0.3942
53 Loss : 0.3851
54 Loss : 0.3645
55 Loss : 0.3498
56 Loss : 0.3550
57 Loss : 0.3467
58 Loss : 0.3091
59 Loss : 0.3463
60 Los

In [16]:
teacher_model.eval()
with torch.no_grad():
    corrects = 0
    total = 0
    for inputs, targets in test_loader:
        inputs = inputs.to(args.device)
        targets = targets.to(args.device, dtype=torch.long)
        
        # forward
        outputs = teacher_model(inputs)
        
        # output 중 최대값의 위치에 해당하는 class로 예측 수행
        _, preds = torch.max(outputs, 1)
        
        # batch별 정답 개수를 축적함
        corrects += torch.sum(preds == targets.data)
        total += targets.size(0)

test_acc = corrects.double() / total
print('Testing Acc: {:.4f}'.format(test_acc))

Testing Acc: 0.6774


In [17]:
teacher_model_parameter = teacher_model.state_dict()

# Semi-supervized learning using pseudo labeling

In [18]:
teacher_model = WideResNet(num_classes=args.num_classes,
                           depth=args.depth,
                           widen_factor=args.widen_factor,
                           dropout=0,
                           dense_dropout=args.teacher_dropout)
teacher_model.to(args.device)

student_model = WideResNet(num_classes=args.num_classes,
                           depth=args.depth,
                           widen_factor=args.widen_factor,
                           dropout=0,
                           dense_dropout=args.teacher_dropout)
student_model.to(args.device)

WideResNet(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (block1): NetworkBlock(
    (layer): Sequential(
      (0): BasicBlock(
        (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.001, affine=True, track_running_stats=True)
        (relu1): LeakyReLU(negative_slope=0.1, inplace=True)
        (conv1): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.001, affine=True, track_running_stats=True)
        (relu2): LeakyReLU(negative_slope=0.1, inplace=True)
        (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (convShortcut): Conv2d(16, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (1): BasicBlock(
        (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.001, affine=True, track_running_stats=True)
        (relu1): LeakyReLU(negative_slope=0.1, inplace=True)
        (conv1): Conv2d(32, 32, kernel_size=(

In [19]:
teacher_model.load_state_dict(teacher_model_parameter)

<All keys matched successfully>

In [20]:
t_optimizer = optim.SGD(teacher_model.parameters(),
                        lr=args.teacher_lr,
                        momentum=args.momentum,
                        nesterov=args.nesterov)
s_optimizer = optim.SGD(student_model.parameters(),
                        lr=args.student_lr,
                        momentum=args.momentum,
                        nesterov=args.nesterov)
criterion = nn.CrossEntropyLoss()

In [21]:
def train_pseudo_labeling(args, teacher_model, student_model, t_optimizer, s_optimizer, criterion):
    since = time.time()
    for step in range(args.start_step, args.total_steps):
        if step % args.eval_step == 0:
            if step != 0:
                print('{} Step - Teacher loss: {:.4f} Student loss: {:.4f}\nl_loss: {:.4f} u_loss: {:.4f}'.format(step, np.mean(t_losses), np.mean(s_losses),
                                                                                    np.mean(l_losses), np.mean(u_losses)))
                
            s_losses = []
            t_losses = []
            l_losses = []
            u_losses = []
            
        teacher_model.train()
        student_model.train()

        try:
            images_l, targets = labeled_iter.next()
        except:
            labeled_iter = iter(labeled_loader)
            images_l, targets = labeled_iter.next()

        try:
            (images_uw, images_us), _ = unlabeled_iter.next()
        except:
            unlabeled_iter = iter(unlabeled_loader)
            (images_uw, images_us), _ = unlabeled_iter.next()

        images_l = images_l.to(args.device)
        images_uw = images_uw.to(args.device)
        images_us = images_us.to(args.device)
        targets = targets.to(args.device, dtype=torch.long)

        # parameter gradients를 0으로 설정
        t_optimizer.zero_grad()
        s_optimizer.zero_grad()

        # forward teacher model
        batch_size = images_l.shape[0]
        t_images = torch.cat((images_l, images_uw))
        t_logits = teacher_model(t_images)
        t_logits_l = t_logits[:batch_size]
        t_logits_uw = t_logits[batch_size:]
        del t_logits

        t_loss_l = criterion(t_logits_l, targets)

        # make pseudo label
        soft_pseudo_label = torch.softmax(t_logits_uw, dim=-1)
        max_probs, hard_pseudo_label = torch.max(soft_pseudo_label, dim=-1)
        #mask = max_probs.ge(args.threshold).float()
        
        # forward student model
        s_images = torch.cat((images_l, images_us))
        s_logits = student_model(s_images)
        s_logits_l = s_logits[:batch_size]
        s_logits_us = s_logits[batch_size:]
        del s_logits

        s_loss_l = criterion(s_logits_l, targets)
        s_loss_u = criterion(s_logits_us, hard_pseudo_label.detach())
        #s_loss_u =(-(soft_pseudo_label * torch.log_softmax(s_logits_us, dim=-1)).sum(dim=-1) * mask).mean()
        s_loss = s_loss_l + (args.lambda_u * s_loss_u)

        # backward
        t_loss_l.backward()
        t_optimizer.step()
        
        s_loss.backward()
        s_optimizer.step()

        s_losses.append(s_loss.item())
        t_losses.append(t_loss_l.item())
        l_losses.append(s_loss_l.item())
        u_losses.append(s_loss_u.item())
        
    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))

In [None]:
train_pseudo_labeling(args, teacher_model, student_model, t_optimizer, s_optimizer, criterion)

100 Step - Teacher loss: 1.9447 Student loss: 0.0625
l_loss: 1.9232 u_loss: 2.1544
200 Step - Teacher loss: 1.6967 Student loss: 0.0662
l_loss: 1.6770 u_loss: 1.9711
300 Step - Teacher loss: 1.5648 Student loss: 0.0860
l_loss: 1.5459 u_loss: 1.8891
400 Step - Teacher loss: 1.4846 Student loss: 0.0906
l_loss: 1.4665 u_loss: 1.8103
500 Step - Teacher loss: 1.3872 Student loss: 0.0643
l_loss: 1.3700 u_loss: 1.7265
600 Step - Teacher loss: 1.3109 Student loss: 0.0838
l_loss: 1.2940 u_loss: 1.6907
700 Step - Teacher loss: 1.2397 Student loss: 0.0744
l_loss: 1.2228 u_loss: 1.6849
800 Step - Teacher loss: 1.1722 Student loss: 0.0582
l_loss: 1.1560 u_loss: 1.6169
900 Step - Teacher loss: 1.1222 Student loss: 0.0656
l_loss: 1.1065 u_loss: 1.5739
1000 Step - Teacher loss: 1.1049 Student loss: 0.0793
l_loss: 1.0895 u_loss: 1.5463
1100 Step - Teacher loss: 1.0529 Student loss: 0.0524
l_loss: 1.0374 u_loss: 1.5508
1200 Step - Teacher loss: 0.9979 Student loss: 0.0613
l_loss: 0.9826 u_loss: 1.5296
1

In [None]:
def test(args, model, loader):
    model.eval()
    with torch.no_grad():
        corrects = 0
        total = 0
        for inputs, targets in loader:
            inputs = inputs.to(args.device)
            targets = targets.to(args.device, dtype=torch.long)

            # forward
            outputs = model(inputs)

            # output 중 최대값의 위치에 해당하는 class로 예측 수행
            _, preds = torch.max(outputs, 1)

            # batch별 정답 개수를 축적함
            corrects += torch.sum(preds == targets.data)
            total += targets.size(0)

    test_acc = corrects.double() / total
    print('Testing Acc: {:.4f}'.format(test_acc))

In [None]:
test(args, student_model, test_loader)