In [None]:
import time
import logging
import os
import threading

from datetime import datetime
from collections import defaultdict

class TrainLog:
    """Saves training logs in Pandas msgpacks"""

    INCREMENTAL_UPDATE_TIME = 300

    def __init__(self, directory, name):
        self.log_file_path = "{}/{}.msgpack".format(directory, name)
        self._log = defaultdict(dict)
        self._log_lock = threading.RLock()
        self._last_update_time = time.time() - self.INCREMENTAL_UPDATE_TIME

    def record_single(self, step, column, value):
        self._record(step, {column: value})

    def record(self, step, col_val_dict):
        self._record(step, col_val_dict)

    def save(self):
        df = self._as_dataframe()
        df.to_msgpack(self.log_file_path, compress='zlib')

    def _record(self, step, col_val_dict):
        with self._log_lock:
            self._log[step].update(col_val_dict)
            if time.time() - self._last_update_time >= self.INCREMENTAL_UPDATE_TIME:
                self._last_update_time = time.time()
                self.save()

    def _as_dataframe(self):
        with self._log_lock:
            return DataFrame.from_dict(self._log, orient='index')

class RunContext:
    """Creates directories and files for the run"""

    def __init__(self, runner_file, run_idx):
        logging.basicConfig(level=logging.INFO, format='%(message)s')
        runner_name = os.path.basename(runner_file).split(".")[0]
        self.result_dir = "{root}/{runner_name}/{date:%Y-%m-%d_%H:%M:%S}/{run_idx}".format(
            root='results',
            runner_name=runner_name,
            date=datetime.now(),
            run_idx=run_idx
        )
        self.transient_dir = self.result_dir + "/transient"
        os.makedirs(self.result_dir)
        os.makedirs(self.transient_dir)

    def create_train_log(self, name):
        return TrainLog(self.result_dir, name)


# global variable & object
logging.basicConfig(level=logging.INFO)
LOG = logging.getLogger('main')
context = RunContext(os.getcwd(), 0)
checkpoint_path = context.transient_dir
training_log = context.create_train_log("training")
validation_log = context.create_train_log("validation")
ema_validation_log = context.create_train_log("ema_validation")


args = None
best_prec1 = 0
global_step = 0
NO_LABEL = -1



### Divide images to 11 groups, generate  11 txt files, 'filename label'

In [None]:
import os
from math import ceil

# list file name
video_dir_path = '/media/yszhao/Data/PublicDataset/Kvasir-Capsule/labeled_videos'  # need to modify
nor_image_dir_path = '/media/yszhao/Data/PublicDataset/Kvasir-Capsule/labeled_images/Normal'
abnor_image_dir_path = '/media/yszhao/Data/PublicDataset/Kvasir-Capsule/labeled_images/Abnormal'

videos = os.listdir(video_dir_path)
videos = [video.split('.')[0] for video in videos]
num_videos = len(videos)


# divide groups
n =4
videos = [videos[i:i+n] for i in range(0, len(videos), n)]
videos

# save groups files
num_groups = ceil(num_videos/n)  #  =11

image_groups = [[] for i in range(num_groups)]   # results of groups divided

# write normal
for image in os.listdir(nor_image_dir_path):
    for i in range(len(image_groups)):
        if image.split('_')[0] in videos[i]:
            image_groups[i].append(image+' '+'normal')

# write abnormal
for image in os.listdir(abnor_image_dir_path):
    for i in range(len(image_groups)):
        if image.split('_')[0] in videos[i]:
            image_groups[i].append(image+' '+'abnormal')


# write file groups

for i in range(len(image_groups)):
    textfile = open('{}.txt'.format(i),'w')
    for element in image_groups[i]:
        textfile.write(element + "\n")
        
    textfile.close()


### Link kv_capsule dataset

In [None]:
import random
import os 

nor_image_dir_path = '/media/yszhao/Data/PublicDataset/Kvasir-Capsule/labeled_images/Normal'
abnor_image_dir_path = '/media/yszhao/Data/PublicDataset/Kvasir-Capsule/labeled_images/Abnormal'
unlabeled_image_dir_path = '/media/yszhao/Data/PublicDataset/Kvasir-Capsule/unlabeled_images'

textfile = open('Link kv_capsule dataset.txt','w')
text = []

for image in os.listdir(nor_image_dir_path):
    text.append('ln -s '+ nor_image_dir_path+'/'+image+ ' '+
                './data-local/images/kv_capsule/train/normal/'+image)
    
# print(len(os.listdir(nor_image_dir_path))) 

for image in os.listdir(abnor_image_dir_path):
    text.append('ln -s '+ abnor_image_dir_path+'/'+image+ ' '+
                './data-local/images/kv_capsule/train/abnormal/'+image)
    
# print(len(os.listdir(abnor_image_dir_path)))   

unlabel_files = os.listdir(unlabeled_image_dir_path)
random.shuffle(unlabel_files)

for i in range(100000):
    image = unlabel_files[i]
    text.append('ln -s '+ unlabeled_image_dir_path+'/'+image+ ' '+
                './data-local/images/kv_capsule/train/unlabeled/'+image)

# print(len(os.listdir(unlabeled_image_dir_path)))
    
for element in text:
        textfile.write(element + "\n")
        
textfile.close()

### Calculate mean and std of dataset

In [None]:
import torchvision
import torchvision.transforms as transforms

traindir = './data-local/images/cifar/cifar10/by-image/train'
dataset = torchvision.datasets.ImageFolder(traindir, transforms.ToTensor())
dataset.imgs

In [None]:
import itertools

def iterate_eternally(indices):
    def infinite_shuffles():
        while True:
            yield np.random.permutation(indices)
    return itertools.chain.from_iterable(infinite_shuffles())


secondary_iter = iterate_eternally(range(10))
secondary_iter
itertools.chain.from_iterable('infinite_shuffles()')
for i in itertools.chain.from_iterable(range(10)):
    print(i)

In [None]:
from torch import nn
m = nn.BatchNorm2d(100, affine=False)
input = torch.randn(20, 100, 35, 45)
output = m(input)
output.shape

In [None]:
import torch

def get_mean_std(loader):
    channels_sum, channels_squared_sum, num_batches = 0, 0, 0
    
    for data, _ in loader:
        channels_sum += torch.mean(data, dim=[0,2,3])
        channels_squared_sum += torch.mean(data**2, dim=[0,2,3])
        num_batches += 1
    mean = channels_sum/num_batches
    std = (channels_squared_sum/num_batches - mean**2)**0.5
    
    return mean, std

train_loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=64, shuffle=True)
get_mean_std(train_loader)

### Extract Validation Dataset

In [None]:
import os


label_file = './data-local/labels/kv_capsule/10.txt'
original_path = './data-local/images/kv_capsule/train'
target_path = './data-local/images/kv_capsule/validate'
with open(label_file) as f:
    for file in f:
        if file.split(' ')[1]=='0\n':
            os.rename(original_path+'/normal/'+file.split(' ')[0], target_path+'/normal/'+file.split(' ')[0])
        elif file.split(' ')[1]=='1\n':
            os.rename(original_path+'/abnormal/'+file.split(' ')[0], target_path+'/abnormal/'+file.split(' ')[0])

### Config: args

In [None]:
import re
import argparse
import logging

# from . import architectures, datasets


# __all__ = ['parse_cmd_args', 'parse_dict_args']

def create_parser():
    parser = argparse.ArgumentParser(description='PyTorch Cifar-10 Training')
    
    parser.add_argument('--arch', '-a', metavar='ARCH', default='cifar_shakeshake26',
#                         choices=architectures.__all__,
                        help='model architecture: ')
    
    parser.add_argument('-b', '--batch-size', default=256, type=int,
                        metavar='N', help='mini-batch size (default: 256)')
    
    parser.add_argument('--checkpoint-epochs', default=1, type=int,
                        metavar='EPOCHS', help='checkpoint frequency in epochs, 0 to turn checkpointing off (default: 1)')
    
    parser.add_argument('--consistency', default=100, type=float, metavar='WEIGHT',
                        help='use consistency loss with given weight (default: None)')
    
    parser.add_argument('--consistency-type', default="mse", type=str, metavar='TYPE',
                        choices=['mse', 'kl'],
                        help='consistency loss type to use')
    
    parser.add_argument('--consistency-rampup', default=5, type=int, metavar='EPOCHS',
                        help='length of the consistency loss ramp-up')
    
    parser.add_argument('--dataset', metavar='DATASET', default='cifar10',
#                         choices=datasets.__all__,
                        help='dataset: ')
    
    parser.add_argument('--ema-decay', default=0.999, type=float, metavar='ALPHA',     # momentuem of ema
                        help='ema variable decay rate (default: 0.999)')
    
    parser.add_argument('--epochs', default=180, type=int, metavar='N',
                        help='number of total epochs to run')
    
    parser.add_argument('-e', '--evaluate', type=str2bool,
                        help='evaluate model on evaluation set')
    
    parser.add_argument('--eval-subdir', type=str, default='val',
                        help='the subdirectory inside the data directory that contains the evaluation data')
    
    parser.add_argument('--evaluation-epochs', default=1, type=int,
                        metavar='EPOCHS', help='evaluation frequency in epochs, 0 to turn evaluation off (default: 1)')
    
    parser.add_argument('--exclude-unlabeled', default=False, type=str2bool, metavar='BOOL',
                        help='exclude unlabeled examples from the training set')
    
    parser.add_argument('--initial-lr', default=0.0, type=float,
                        metavar='LR', help='initial learning rate when using linear rampup')
    
    parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
                        help='number of data loading workers (default: 4)')
    
    parser.add_argument('--labels', default='data-local/labels/cifar10/1000_balanced_labels/00.txt', 
                        type=str, metavar='FILE',
                        help='list of image labels (default: based on directory structure)')
    
    parser.add_argument('--labeled-batch-size', default=62, type=int,
                        metavar='N', help="labeled examples per minibatch (default: no constrain)")
    
    parser.add_argument('--logit-distance-cost', default=-1, type=float, metavar='WEIGHT',
                        help='let the student model have two outputs and use an MSE loss between the logits with the given weight (default: only have one output)')
    
    parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
                        metavar='LR', help='max learning rate')

    parser.add_argument('--lr-rampup', default=0, type=int, metavar='EPOCHS',
                        help='length of leabrning rate rampup in the beginning')
    
    parser.add_argument('--lr-rampdown-epochs', default=None, type=int, metavar='EPOCHS',
                        help='length of learning rate cosine rampdown (>= length of training)')
    
    parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
                        help='momentum')
    
    parser.add_argument('--num_classes', default=10, type=int, metavar='N',
                        help='number of classes (default: 10)')
    
    parser.add_argument('--nesterov', default=False, type=str2bool,
                        help='use nesterov momentum', metavar='BOOL')
    
    parser.add_argument('--pretrained', default=False, dest='pretrained', action='store_true',
                        help='use pre-trained model')
    
    parser.add_argument('--print-freq', '-p', default=10, type=int,
                        metavar='N', help='print frequency (default: 10)')
    
    parser.add_argument('--resume', default='', type=str, metavar='PATH',
                        help='path to latest checkpoint (default: none)')
     
    parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
                        help='manual epoch number (useful on restarts)')
    
    parser.add_argument('--train-subdir', type=str, default='train',
                        help='the subdirectory inside the data directory that contains the training data')
    
    parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
                        metavar='W', help='weight decay (default: 1e-4)')
    
    return parser


def parse_commandline_args():
    return create_parser().parse_args([])

def parse_dict_args(**kwargs):
    def to_cmdline_kwarg(key, value):
        if len(key) == 1:
            key = "-{}".format(key)
        else:
            key = "--{}".format(re.sub(r"_", "-", key))
        value = str(value)
        return key, value

    kwargs_pairs = (to_cmdline_kwarg(key, value)
                    for key, value in kwargs.items())
    cmdline_args = list(sum(kwargs_pairs, ()))

    LOG.info("Using these command line args: %s", " ".join(cmdline_args))

    return create_parser().parse_args(cmdline_args)


def str2bool(v):
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')


def str2epochs(v):
    try:
        if len(v) == 0:
            epochs = []
        else:
            epochs = [int(string) for string in v.split(",")]
    except:
        raise argparse.ArgumentTypeError(
            'Expected comma-separated list of integers, got "{}"'.format(v))
    if not all(0 < epoch1 < epoch2 for epoch1, epoch2 in zip(epochs[:-1], epochs[1:])):
        raise argparse.ArgumentTypeError(
            'Expected the epochs to be listed in increasing order')
    return epochs


args = parse_commandline_args()
args.__dict__

### Create/Load model. (why it has 2  fc layers??)

In [None]:
import sys

import math
import itertools

import torch
from torch import nn
from torch.nn import functional as F
from torch.autograd import Function


def export(fn):
    mod = sys.modules[fn.__module__]
    if hasattr(mod, '__all__'):
        mod.__all__.append(fn.__name__)
    else:
        mod.__all__ = [fn.__name__]
    return fn

def parameter_count(module):
    return sum(int(param.numel()) for param in module.parameters())

'''Load cifar_shakeshake26'''
@export
def cifar_shakeshake26(pretrained=False, **kwargs):      # This architecture is similar with ResNet, which is designed for cifar dataset.
    assert not pretrained
    model = ResNet32x32(ShakeShakeBlock,         #  ShakeShakeBlock also means BasicBlock in ResNet
                        layers=[4, 4, 4],
                        channels=96,
                        downsample='shift_conv', **kwargs)
    return model



class ResNet32x32(nn.Module):
    def __init__(self, block, layers, channels, groups=1, num_classes=1000, downsample='basic'):
        super().__init__()
        assert len(layers) == 3
        self.downsample_mode = downsample
        self.inplanes = 16
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1,
                               padding=1, bias=False)
        self.layer1 = self._make_layer(block, channels, groups, layers[0])
        self.layer2 = self._make_layer(
            block, channels * 2, groups, layers[1], stride=2)
        self.layer3 = self._make_layer(
            block, channels * 4, groups, layers[2], stride=2)
        self.avgpool = nn.AvgPool2d(8)
        self.fc1 = nn.Linear(block.out_channels(
            channels * 4, groups), num_classes)
        self.fc2 = nn.Linear(block.out_channels(
            channels * 4, groups), num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self, block, planes, groups, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != block.out_channels(planes, groups):
            if self.downsample_mode == 'basic' or stride == 1:
                downsample = nn.Sequential(
                    nn.Conv2d(self.inplanes, block.out_channels(planes, groups),
                              kernel_size=1, stride=stride, bias=False),
                    nn.BatchNorm2d(block.out_channels(planes, groups)),
                )
            elif self.downsample_mode == 'shift_conv':
                downsample = ShiftConvDownsample(in_channels=self.inplanes,
                                                 out_channels=block.out_channels(planes, groups))
            else:
                assert False

        layers = []
        layers.append(block(self.inplanes, planes, groups, stride, downsample))
        self.inplanes = block.out_channels(planes, groups)
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes, groups))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        return self.fc1(x), self.fc2(x)


def conv3x3(in_planes, out_planes, stride=1):
    "3x3 convolution with padding"
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)


class ShakeShakeBlock(nn.Module):
    @classmethod
    def out_channels(cls, planes, groups):
        assert groups == 1
        return planes

    def __init__(self, inplanes, planes, groups, stride=1, downsample=None):
        super().__init__()
        assert groups == 1
        self.conv_a1 = conv3x3(inplanes, planes, stride)
        self.bn_a1 = nn.BatchNorm2d(planes)
        self.conv_a2 = conv3x3(planes, planes)
        self.bn_a2 = nn.BatchNorm2d(planes)

        self.conv_b1 = conv3x3(inplanes, planes, stride)
        self.bn_b1 = nn.BatchNorm2d(planes)
        self.conv_b2 = conv3x3(planes, planes)
        self.bn_b2 = nn.BatchNorm2d(planes)

        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        a, b, residual = x, x, x

        a = F.relu(a, inplace=False)
        a = self.conv_a1(a)
        a = self.bn_a1(a)
        a = F.relu(a, inplace=True)
        a = self.conv_a2(a)
        a = self.bn_a2(a)

        b = F.relu(b, inplace=False)
        b = self.conv_b1(b)
        b = self.bn_b1(b)
        b = F.relu(b, inplace=True)
        b = self.conv_b2(b)
        b = self.bn_b2(b)

        ab = shake(a, b, training=self.training)

        if self.downsample is not None:
            residual = self.downsample(x)

        return residual + ab


class Shake(Function):
    @classmethod
    def forward(cls, ctx, inp1, inp2, training):
        assert inp1.size() == inp2.size()
        gate_size = [inp1.size()[0], *itertools.repeat(1, inp1.dim() - 1)]
        gate = inp1.new(*gate_size)
        if training:
            gate.uniform_(0, 1)
        else:
            gate.fill_(0.5)
        return inp1 * gate + inp2 * (1. - gate)

    @classmethod
    def backward(cls, ctx, grad_output):
        grad_inp1 = grad_inp2 = grad_training = None
        gate_size = [grad_output.size()[0], *itertools.repeat(1,
                                                              grad_output.dim() - 1)]
        gate = grad_output.data.new(*gate_size).uniform_(0, 1).clone()
        if ctx.needs_input_grad[0]:
            grad_inp1 = grad_output * gate
        if ctx.needs_input_grad[1]:
            grad_inp2 = grad_output * (1 - gate)
        assert not ctx.needs_input_grad[2]
        return grad_inp1, grad_inp2, grad_training


def shake(inp1, inp2, training=False):
    return Shake.apply(inp1, inp2, training)


class ShiftConvDownsample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.relu = nn.ReLU(inplace=True)
        self.conv = nn.Conv2d(in_channels=2 * in_channels,
                              out_channels=out_channels,
                              kernel_size=1,
                              groups=2)
        self.bn = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        x = torch.cat((x[:, :, 0::2, 0::2],
                       x[:, :, 1::2, 1::2]), dim=1)
        x = self.relu(x)
        x = self.conv(x)
        x = self.bn(x)
        return x

    
# '''maybe this is for imagenet,leave it'''

# @export
# class BottleneckBlock(nn.Module):
#     @classmethod
#     def out_channels(cls, planes, groups):
#         if groups > 1:
#             return 2 * planes
#         else:
#             return 4 * planes

#     def __init__(self, inplanes, planes, groups, stride=1, downsample=None):
#         super().__init__()
#         self.relu = nn.ReLU(inplace=True)

#         self.conv_a1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
#         self.bn_a1 = nn.BatchNorm2d(planes)
#         self.conv_a2 = nn.Conv2d(
#             planes, planes, kernel_size=3, stride=stride, padding=1, bias=False, groups=groups)
#         self.bn_a2 = nn.BatchNorm2d(planes)
#         self.conv_a3 = nn.Conv2d(planes, self.out_channels(
#             planes, groups), kernel_size=1, bias=False)
#         self.bn_a3 = nn.BatchNorm2d(self.out_channels(planes, groups))

#         self.downsample = downsample
#         self.stride = stride

#     def forward(self, x):
#         a, residual = x, x

#         a = self.conv_a1(a)
#         a = self.bn_a1(a)
#         a = self.relu(a)
#         a = self.conv_a2(a)
#         a = self.bn_a2(a)
#         a = self.relu(a)
#         a = self.conv_a3(a)
#         a = self.bn_a3(a)

#         if self.downsample is not None:
#             residual = self.downsample(residual)

#         return self.relu(residual + a)
    
# def resnext152(pretrained=False, **kwargs):
#     assert not pretrained
#     model = ResNet224x224(BottleneckBlock,
#                           layers=[3, 8, 36, 3],
#                           channels=32 * 4,
#                           groups=32,
#                           downsample='basic', **kwargs)
#     return model

# class ResNet224x224(nn.Module):
#     def __init__(self, block, layers, channels, groups=1, num_classes=1000, downsample='basic'):
#         super().__init__()
#         assert len(layers) == 4
#         self.downsample_mode = downsample
#         self.inplanes = 64
#         self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
#                                bias=False)
#         self.bn1 = nn.BatchNorm2d(self.inplanes)
#         self.relu = nn.ReLU(inplace=True)
#         self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
#         self.layer1 = self._make_layer(block, channels, groups, layers[0])
#         self.layer2 = self._make_layer(
#             block, channels * 2, groups, layers[1], stride=2)
#         self.layer3 = self._make_layer(
#             block, channels * 4, groups, layers[2], stride=2)
#         self.layer4 = self._make_layer(
#             block, channels * 8, groups, layers[3], stride=2)
#         self.avgpool = nn.AvgPool2d(7)
#         self.fc1 = nn.Linear(block.out_channels(
#             channels * 8, groups), num_classes)
#         self.fc2 = nn.Linear(block.out_channels(
#             channels * 8, groups), num_classes)

#         for m in self.modules():
#             if isinstance(m, nn.Conv2d):
#                 n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
#                 m.weight.data.normal_(0, math.sqrt(2. / n))
#             elif isinstance(m, nn.BatchNorm2d):
#                 m.weight.data.fill_(1)
#                 m.bias.data.zero_()

#     def _make_layer(self, block, planes, groups, blocks, stride=1):
#         downsample = None
#         if stride != 1 or self.inplanes != block.out_channels(planes, groups):
#             if self.downsample_mode == 'basic' or stride == 1:
#                 downsample = nn.Sequential(
#                     nn.Conv2d(self.inplanes, block.out_channels(planes, groups),
#                               kernel_size=1, stride=stride, bias=False),
#                     nn.BatchNorm2d(block.out_channels(planes, groups)),
#                 )
#             elif self.downsample_mode == 'shift_conv':
#                 downsample = ShiftConvDownsample(in_channels=self.inplanes,
#                                                  out_channels=block.out_channels(planes, groups))
#             else:
#                 assert False

#         layers = []
#         layers.append(block(self.inplanes, planes, groups, stride, downsample))
#         self.inplanes = block.out_channels(planes, groups)
#         for i in range(1, blocks):
#             layers.append(block(self.inplanes, planes, groups))

#         return nn.Sequential(*layers)

#     def forward(self, x):
#         x = self.conv1(x)
#         x = self.bn1(x)
#         x = self.relu(x)
#         x = self.maxpool(x)
#         x = self.layer1(x)
#         x = self.layer2(x)
#         x = self.layer3(x)
#         x = self.layer4(x)
#         x = self.avgpool(x)
#         x = x.view(x.size(0), -1)
#         return self.fc1(x), self.fc2(x)


'''Parse the model parameters.'''
def parameters_string(module):
    lines = [
        "",
        "List of model parameters:",
        "=========================",
    ]

    row_format = "{name:<40} {shape:>20} ={total_size:>12,d}"
    params = list(module.named_parameters())
    for name, param in params:
        lines.append(row_format.format(
            name=name,
            shape=" * ".join(str(p) for p in param.size()),
            total_size=param.numel()
        ))
    lines.append("=" * 75)
    lines.append(row_format.format(
        name="all parameters",
        shape="sum of above",
        total_size=sum(int(param.numel()) for name, param in params)
    ))
    lines.append("")
    return "\n".join(lines)


def create_model(ema=False):

    LOG.info("=> creating {pretrained}{ema}model '{arch}'".format(
            pretrained='pre-trained' if args.pretrained else '',
            ema='EMA ' if ema else '',
            arch=args.arch))
    
    model = cifar_shakeshake26(pretrained=args.pretrained, num_classes=args.num_classes).cuda()   
    if ema:
        for param in model.parameters():
            param.detach_()   # detach_() is in-place version of  detach(), the result will never require gradient, 
                                # just need calculate mean and update by EMA method. --zys

    return model

model = create_model()
ema_model = create_model(ema=True)
LOG.info(parameters_string(model))

In [1]:
from torch import nn
import torchvision

def parameters_string(module):
    lines = [
        "",
        "List of model parameters:",
        "=========================",
    ]

    row_format = "{name:<40} {shape:>20} ={total_size:>12,d}"
    params = list(module.named_parameters())
    for name, param in params:
        lines.append(row_format.format(
            name=name,
            shape=" * ".join(str(p) for p in param.size()),
            total_size=param.numel()
        ))
    lines.append("=" * 75)
    lines.append(row_format.format(
        name="all parameters",
        shape="sum of above",
        total_size=sum(int(param.numel()) for name, param in params)
    ))
    lines.append("")
    return "\n".join(lines)


model = torchvision.models.mobilenet_v3_small()
# model = torchvision.models.resnet18()

pretrain = nn.Sequential(*list(model.children())[:-1])
pretrain


Sequential(
  (0): Sequential(
    (0): ConvBNActivation(
      (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
      (2): Hardswish()
    )
    (1): InvertedResidual(
      (block): Sequential(
        (0): ConvBNActivation(
          (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=16, bias=False)
          (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
        )
        (1): SqueezeExcitation(
          (fc1): Conv2d(16, 8, kernel_size=(1, 1), stride=(1, 1))
          (relu): ReLU(inplace=True)
          (fc2): Conv2d(8, 16, kernel_size=(1, 1), stride=(1, 1))
        )
        (2): ConvBNActivation(
          (0): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_s

In [None]:
import torch
from torch import nn

m = nn.Sequential(
    nn.ConvTranspose2d(
        in_channels=576, out_channels=1024, kernel_size=4, stride=1, padding=0
    ),
    nn.ConvTranspose2d(
        in_channels=1024, out_channels=512, kernel_size=4, stride=2, padding=1
    ),
    nn.ConvTranspose2d(
        in_channels=512, out_channels=256, kernel_size=4, stride=2, padding=1
    ),
    nn.ConvTranspose2d(
        in_channels=256, out_channels=3, kernel_size=4, stride=2, padding=1
    )
)

x = torch.randn(256, 576, 1, 1)
y = m(x)
print(y.shape)

In [None]:
import torch
print(torch.max(torch.rand((30,30),device='cuda')))

### Load Dataset (cifar10)

In [None]:
import sys
import os

import torch
import torchvision
from torchvision import transforms
from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler
from torch.utils.data.sampler import Sampler


class RandomTranslateWithReflect:
    """Translate image randomly

    Translate vertically and horizontally by n pixels where
    n is integer drawn uniformly independently for each axis
    from [-max_translation, max_translation].

    Fill the uncovered blank area with reflect padding.
    """

    def __init__(self, max_translation):
        self.max_translation = max_translation

    def __call__(self, old_image):
        xtranslation, ytranslation = np.random.randint(-self.max_translation,
                                                       self.max_translation + 1,
                                                       size=2)
        xpad, ypad = abs(xtranslation), abs(ytranslation)
        xsize, ysize = old_image.size

        flipped_lr = old_image.transpose(Image.FLIP_LEFT_RIGHT)
        flipped_tb = old_image.transpose(Image.FLIP_TOP_BOTTOM)
        flipped_both = old_image.transpose(Image.ROTATE_180)

        new_image = Image.new("RGB", (xsize + 2 * xpad, ysize + 2 * ypad))

        new_image.paste(old_image, (xpad, ypad))

        new_image.paste(flipped_lr, (xpad + xsize - 1, ypad))
        new_image.paste(flipped_lr, (xpad - xsize + 1, ypad))

        new_image.paste(flipped_tb, (xpad, ypad + ysize - 1))
        new_image.paste(flipped_tb, (xpad, ypad - ysize + 1))

        new_image.paste(flipped_both, (xpad - xsize + 1, ypad - ysize + 1))
        new_image.paste(flipped_both, (xpad + xsize - 1, ypad - ysize + 1))
        new_image.paste(flipped_both, (xpad - xsize + 1, ypad + ysize - 1))
        new_image.paste(flipped_both, (xpad + xsize - 1, ypad + ysize - 1))

        new_image = new_image.crop((xpad - xtranslation,
                                    ypad - ytranslation,
                                    xpad + xsize - xtranslation,
                                    ypad + ysize - ytranslation))
        return new_image
    
class TransformTwice:
    def __init__(self, transform):
        self.transform = transform

    def __call__(self, inp):
        out1 = self.transform(inp)
        out2 = self.transform(inp)
        return out1, out2

def export(fn):
    mod = sys.modules[fn.__module__]
    if hasattr(mod, '__all__'):
        mod.__all__.append(fn.__name__)
    else:
        mod.__all__ = [fn.__name__]
    return fn
@export
def cifar10():
    channel_stats = dict(mean=[0.4914, 0.4822, 0.4465],    # how to get these values?  --zys
                         std=[0.2470,  0.2435,  0.2616])
    train_transformation = TransformTwice(transforms.Compose([
        RandomTranslateWithReflect(4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(**channel_stats)
    ]))
    eval_transformation = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(**channel_stats)
    ])

    return {
        'train_transformation': train_transformation,  #TransformTwice, return 2 results.  --zys
        'eval_transformation': eval_transformation,
        'datadir': 'data-local/images/cifar/cifar10/by-image',
        'num_classes': 10
    }

def assert_exactly_one(lst):
    assert sum(int(bool(el)) for el in lst) == 1, ", ".join(str(el)
                                                            for el in lst)
    
def relabel_dataset(dataset, labels):     # lables: cifar_1000_balanced_labels in 'data-local-labels-cifar, eg: 45313_airplane.png airplane'
    unlabeled_idxs = []
    for idx in range(len(dataset.imgs)):
        path, _ = dataset.imgs[idx]
        filename = os.path.basename(path)
        if filename in labels:
            label_idx = dataset.class_to_idx[labels[filename]]
            dataset.imgs[idx] = path, label_idx
            del labels[filename]
        else:
            dataset.imgs[idx] = path, NO_LABEL
            unlabeled_idxs.append(idx)

    if len(labels) != 0:
        message = "List of unlabeled contains {} unknown files: {}, ..."
        some_missing = ', '.join(list(labels.keys())[:5])
        raise LookupError(message.format(len(labels), some_missing))

    labeled_idxs = sorted(set(range(len(dataset.imgs))) - set(unlabeled_idxs))

    return labeled_idxs, unlabeled_idxs



def iterate_once(iterable):
    return np.random.permutation(iterable)

def iterate_eternally(indices):
    def infinite_shuffles():
        while True:
            yield np.random.permutation(indices)
    return itertools.chain.from_iterable(infinite_shuffles())

def grouper(iterable, n):
    "Collect data into fixed-length chunks or blocks"
    # grouper('ABCDEFG', 3) --> ABC DEF"
    args = [iter(iterable)] * n
    return zip(*args)

class TwoStreamBatchSampler(Sampler):
    """Iterate two sets of indices

    An 'epoch' is one iteration through the primary indices.
    During the epoch, the secondary indices are iterated through
    as many times as needed.
    """
    def __init__(self, primary_indices, secondary_indices, batch_size, secondary_batch_size):
        self.primary_indices = primary_indices  #  unlabeled_idxs  --zys
        self.secondary_indices = secondary_indices #  labeled_idxs  --zys
        self.secondary_batch_size = secondary_batch_size   #   unlabeled_idxs_ batch_size  --zys
        self.primary_batch_size = batch_size - secondary_batch_size  #   labeled_idxs_ batch_size  --zys

        assert len(self.primary_indices) >= self.primary_batch_size > 0
        assert len(self.secondary_indices) >= self.secondary_batch_size > 0

    def __iter__(self):
        primary_iter = iterate_once(self.primary_indices)
        secondary_iter = iterate_eternally(self.secondary_indices)
        return (
            primary_batch + secondary_batch
            for (primary_batch, secondary_batch)
            in  zip(grouper(primary_iter, self.primary_batch_size),
                    grouper(secondary_iter, self.secondary_batch_size))
        )

    def __len__(self):
        return len(self.primary_indices) // self.primary_batch_size


def create_data_loaders(train_transformation,       # **dataset_config, args=args  --zys
                        eval_transformation,
                        datadir,
                        args):
    traindir = os.path.join(datadir, args.train_subdir)
    evaldir = os.path.join(datadir, args.eval_subdir)

    assert_exactly_one([args.exclude_unlabeled, args.labeled_batch_size]) # args.exclude_unlabeled=False,
                                                                        # args.labeled_batch_size = 62  --zys

    dataset = torchvision.datasets.ImageFolder(traindir, train_transformation)

    if args.labels:   # 'args.labels' is path of cidar10_1000 balanced labels.   --zys
        with open(args.labels) as f:
            labels = dict(line.split(' ') for line in f.read().splitlines())   # 1000 labeled images  --zys
        labeled_idxs, unlabeled_idxs = relabel_dataset(dataset, labels)

    if args.exclude_unlabeled:     # exclude unlabeled dataset,default=False  --zys
        sampler = SubsetRandomSampler(labeled_idxs)
        batch_sampler = BatchSampler(sampler, args.batch_size, drop_last=True)
    elif args.labeled_batch_size:     # include labeled and unlabeled dataset  --zys
        batch_sampler = TwoStreamBatchSampler(  # batch_size = 256, labeled_batch_size = 62. --zys
            unlabeled_idxs, labeled_idxs, args.batch_size, args.labeled_batch_size)    # return set(dataset)-cifar10_1000_balanced_labels  --zys
    else:
        assert False, "labeled batch size {}".format(args.labeled_batch_size)

    train_loader = torch.utils.data.DataLoader(dataset,        # include labeled and unlabeled dataset  --zys
                                               batch_sampler=batch_sampler,
                                               num_workers=args.workers,
                                               pin_memory=True)

    eval_loader = torch.utils.data.DataLoader(
        torchvision.datasets.ImageFolder(evaldir, eval_transformation),  #  data dir, data transform. --zys
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=2 * args.workers,  # Needs images twice as fast
        pin_memory=True,
        drop_last=False)

    return train_loader, eval_loader

dataset_config = cifar10()   # dictionary type, args.dataset= cifar10, return train_transformation, 
                                                    # envl_trans, dataset dir, num_classes...   --zys
num_classes = dataset_config.pop('num_classes')

train_loader, eval_loader = create_data_loaders(**dataset_config, args=args)  # load cifar10 dataset.
train_loader


### Validation Function

In [None]:
def validate(eval_loader, model, log, global_step, epoch):
    class_criterion = nn.CrossEntropyLoss(reduction='sum', ignore_index=NO_LABEL).cuda()
    meters = AverageMeterSet()

    # switch to evaluate mode
    model.eval()

    end = time.time()
    for i, (input, target) in enumerate(eval_loader):
        meters.update('data_time', time.time() - end)

        input_var = input.to('cuda')
        # target_var = torch.autograd.Variable(target.cuda(async=True), volatile=True)
        target_var = target.to('cuda')    # --zys

        minibatch_size = len(target_var)
        labeled_minibatch_size = target_var.data.ne(NO_LABEL).sum()
        assert labeled_minibatch_size > 0
        meters.update('labeled_minibatch_size', labeled_minibatch_size)

        # compute output
        output1, output2 = model(input_var)
        softmax1, softmax2 = F.softmax(output1, dim=1), F.softmax(output2, dim=1)
        class_loss = class_criterion(output1, target_var) / minibatch_size

        # measure accuracy and record loss
        prec1, prec5 = accuracy(output1.data, target_var.data, topk=(1, 5))
#         meters.update('class_loss', class_loss.data[0], labeled_minibatch_size)
        meters.update('class_loss', class_loss.data, labeled_minibatch_size)   #  -zys
        meters.update('top1', prec1, labeled_minibatch_size)
        meters.update('error1', 100.0 - prec1, labeled_minibatch_size)
        meters.update('top5', prec5, labeled_minibatch_size)
        meters.update('error5', 100.0 - prec5, labeled_minibatch_size)

        # measure elapsed time
        meters.update('batch_time', time.time() - end)
        end = time.time()

#         if i % args.print_freq == 0:
#             LOG.info(
#                 'Test: [{0}/{1}]\t'
#                 'Time {meters[batch_time]:.3f}\t'
#                 'Data {meters[data_time]:.3f}\t'
#                 'Class {meters[class_loss]:.4f}\t'
#                 'Prec@1 {meters[top1]:.3f}\t'
#                 'Prec@5 {meters[top5]:.3f}'.format(
#                     i, len(eval_loader), meters=meters))

    LOG.info(' * Prec@1 {top1.avg:.3f}\tPrec@5 {top5.avg:.3f}'
          .format(top1=meters['top1'], top5=meters['top5']))
    log.record(epoch, {
        'step': global_step,
        **meters.values(),
        **meters.averages(),
        **meters.sums()
    })

    return meters['top1'].avg

### Check resume

In [None]:

if args.resume:
    assert os.path.isfile(args.resume), "=> no checkpoint found at '{}'".format(args.resume)
    LOG.info("=> loading checkpoint '{}'".format(args.resume))
    # load from args.resume
    checkpoint = torch.load(args.resume)
    args.start_epoch = checkpoint['epoch']
    global_step = checkpoint['global_step']
    best_prec1 = checkpoint['best_prec1']
    model.load_state_dict(checkpoint['state_dict'])
    ema_model.load_state_dict(checkpoint['ema_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    LOG.info("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
    
# validate the preloaded model
    if args.evaluate:
        LOG.info("Evaluating the primary model:")
        validate(eval_loader, model, validation_log, global_step, args.start_epoch)
        LOG.info("Evaluating the EMA model:")
        validate(eval_loader, ema_model, ema_validation_log, global_step, args.start_epoch)

### Create Optimizer (SGD)

In [None]:
optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay,
                                nesterov=args.nesterov)
optimizer

### Loss Function

In [None]:
import torch
from torch.nn import functional as F

def softmax_mse_loss(input_logits, target_logits):
    """Takes softmax on both sides and returns MSE loss

    Note:
    - Returns the sum over all examples. Divide by the batch size afterwards
      if you want the mean.
    - Sends gradients to inputs but not the targets.
    """
    assert input_logits.size() == target_logits.size()
    input_softmax = F.softmax(input_logits, dim=1)
    target_softmax = F.softmax(target_logits, dim=1)
    num_classes = input_logits.size()[1]
    return F.mse_loss(input_softmax, target_softmax, reduction='sum') / num_classes

def softmax_kl_loss(input_logits, target_logits):
    """Takes softmax on both sides and returns KL divergence

    Note:
    - Returns the sum over all examples. Divide by the batch size afterwards
      if you want the mean.
    - Sends gradients to inputs but not the targets.
    """
    assert input_logits.size() == target_logits.size()
    input_log_softmax = F.log_softmax(input_logits, dim=1)
    target_softmax = F.softmax(target_logits, dim=1)
    return F.kl_div(input_log_softmax, target_softmax, reduction='sum')

def symmetric_mse_loss(input1, input2):
    """Like F.mse_loss but sends gradients to both directions

    Note:
    - Returns the sum over all examples. Divide by the batch size afterwards
      if you want the mean.
    - Sends gradients to both input1 and input2.
    """
    assert input1.size() == input2.size()
    num_classes = input1.size()[1]
    return torch.sum((input1 - input2)**2) / num_classes

### Ramp-up Defination

In [None]:
import numpy as np


def sigmoid_rampup(current, rampup_length):
    """Exponential rampup from https://arxiv.org/abs/1610.02242"""
    if rampup_length == 0:
        return 1.0
    else:
        current = np.clip(current, 0.0, rampup_length)
        phase = 1.0 - current / rampup_length
        return float(np.exp(-5.0 * phase * phase))


def linear_rampup(current, rampup_length):
    """Linear rampup"""
    assert current >= 0 and rampup_length >= 0
    if current >= rampup_length:
        return 1.0
    else:
        return current / rampup_length


def cosine_rampdown(current, rampdown_length):
    """Cosine rampdown from https://arxiv.org/abs/1608.03983"""
    assert 0 <= current <= rampdown_length
    return float(.5 * (np.cos(np.pi * current / rampdown_length) + 1))

### utils

In [None]:

class AverageMeterSet:
    def __init__(self):
        self.meters = {}

    def __getitem__(self, key):
        return self.meters[key]

    def update(self, name, value, n=1):
        if not name in self.meters:
            self.meters[name] = AverageMeter()
        self.meters[name].update(value, n)

    def reset(self):
        for meter in self.meters.values():
            meter.reset()

    def values(self, postfix=''):
        return {name + postfix: meter.val for name, meter in self.meters.items()}

    def averages(self, postfix='/avg'):
        return {name + postfix: meter.avg for name, meter in self.meters.items()}

    def sums(self, postfix='/sum'):
        return {name + postfix: meter.sum for name, meter in self.meters.items()}

    def counts(self, postfix='/count'):
        return {name + postfix: meter.count for name, meter in self.meters.items()}


class AverageMeter:
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __format__(self, format):
        return "{self.val:{format}} ({self.avg:{format}})".format(self=self, format=format)

    
def get_current_consistency_weight(epoch):
    # Consistency ramp-up from https://arxiv.org/abs/1610.02242
    return args.consistency * sigmoid_rampup(epoch, args.consistency_rampup)

### Results

In [None]:
def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    print('output size: {}'.format(output.shape))
    maxk = max(topk)
    labeled_minibatch_size = max(target.ne(NO_LABEL).sum(), 1e-8)

    _, pred = output.topk(maxk, 1, True, True)
    
    pred = pred.t()    # --transpose, I don't know how t()method can transpose the matrix,  --zys
    print('pred_topk: {}'.format(pred.shape))
    correct = pred.eq(target.view(1, -1).expand_as(pred))
    print('correct:{}'.format(correct.shape))

    res = []
    for k in topk:
#         print(correct[:k].shape)
        correct_k = correct[:k].float().sum(0, keepdim=True)
        res.append(correct_k.mul_(100.0 / labeled_minibatch_size))
    return res

### Train function

In [None]:
import numpy as np
from PIL import Image



def adjust_learning_rate(optimizer, epoch, step_in_epoch, total_steps_in_epoch):
    lr = args.lr
    epoch = epoch + step_in_epoch / total_steps_in_epoch

    # LR warm-up to handle large minibatch sizes from https://arxiv.org/abs/1706.02677
    lr = linear_rampup(epoch, args.lr_rampup) * (args.lr - args.initial_lr) + args.initial_lr

    # Cosine LR rampdown from https://arxiv.org/abs/1608.03983 (but one cycle only)
    if args.lr_rampdown_epochs:
        assert args.lr_rampdown_epochs >= args.epochs
        lr *= ramps.cosine_rampdown(epoch, args.lr_rampdown_epochs)

    for param_group in optimizer.param_groups:
        param_group['lr'] = lr



def update_ema_variables(model, ema_model, alpha, global_step):
    # Use the true average until the exponential average is more correct
    alpha = min(1 - 1 / (global_step + 1), alpha)
    for ema_param, param in zip(ema_model.parameters(), model.parameters()):
        ema_param.data.mul_(alpha).add_(param.data, alpha=1-alpha)


def train(train_loader, model, ema_model, optimizer, epoch):
    global global_step

    # for class criterion  --zys
    class_criterion = nn.CrossEntropyLoss(ignore_index=NO_LABEL, reduction='sum').cuda()  # for class cost, between labeled target and student model  --zys

    # for consistency cost between student model and teacher model  --zys
    if args.consistency_type == 'mse':    # Mean Squared Error  --zys
        consistency_criterion = softmax_mse_loss
    elif args.consistency_type == 'kl':    # KL-divergence  --zys
        consistency_criterion = softmax_kl_loss
    else:
        assert False, args.consistency_type
        
    # residual logit criterion??? --zys
    residual_logit_criterion = symmetric_mse_loss

    """Computes and stores the average and current value"""
    meters = AverageMeterSet()

    # switch to train mode
    model.train()
    ema_model.train()

    end = time.time()
    for i, ((input, ema_input), target) in enumerate(train_loader):
        # measure data loading time
        meters.update('data_time', time.time() - end)

        adjust_learning_rate(optimizer, epoch, i, len(train_loader))

        meters.update('lr', optimizer.param_groups[0]['lr'])

#         input_var = torch.autograd.Variable(input)
        input_var = input.to('cuda')
        # ema_input_var = torch.autograd.Variable(ema_input, volatile=True)
#         ema_input_var = torch.autograd.Variable(ema_input)        # --zys
        ema_input_var = ema_input.to('cuda')
        # target_var = torch.autograd.Variable(target.cuda(async=True))
#         target_var = torch.autograd.Variable(target.cuda())   # --zys
        target_var = target.to('cuda')

        minibatch_size = len(target_var)
        '''target_var.data:unlabeled target=-1; labeled target=a num(0-9),
        target_var.ne(NO_LABEL): make labeled target=False; unlabeled target=True; And then calculate sum
        of unlabeled target   --zys
        '''
        labeled_minibatch_size = target_var.data.ne(NO_LABEL).sum()   
#         print(labeled_minibatch_size)   # --zys
        assert labeled_minibatch_size > 0
        meters.update('labeled_minibatch_size', labeled_minibatch_size)

        ema_model_out = ema_model(ema_input_var)      # output of teacher model.  --zys
        model_out = model(input_var)              # output of student model for labeled data  --zys

        if isinstance(model_out, torch.Tensor):
            assert args.logit_distance_cost < 0    # logit_distance_cost < 0 is default, means only have one output  --zys
            logit1 = model_out
            ema_logit = ema_model_out
        else:
            assert len(model_out) == 2
            assert len(ema_model_out) == 2
            logit1, logit2 = model_out
            ema_logit, _ = ema_model_out

        ema_logit = ema_logit.data.clone().detach().requires_grad_(False)     # make ema_logit  --zys

        if args.logit_distance_cost >= 0:
            class_logit, cons_logit = logit1, logit2
            res_loss = args.logit_distance_cost * residual_logit_criterion(class_logit, cons_logit) / minibatch_size
            meters.update('res_loss', res_loss.data[0])
        else:
            class_logit, cons_logit = logit1, logit1      # 'class_logit' is used to update stu model; 'cons_logit' is used to update tea model --zys
            res_loss = 0

        class_loss = class_criterion(class_logit, target_var) / minibatch_size
        # meters.update('class_loss', class_loss.data[0])
        meters.update('class_loss', class_loss.data)   # update: self.val = val, self.sum += val * n self.count += nself.avg = self.sum / self.count   --zys

        ema_class_loss = class_criterion(ema_logit, target_var) / minibatch_size
        # meters.update('ema_class_loss', ema_class_loss.data[0])
        meters.update('ema_class_loss', ema_class_loss.data)   # --zys

        if args.consistency:
            consistency_weight = get_current_consistency_weight(epoch)     # weight of consistency loss in total loss.  --zys
            meters.update('cons_weight', consistency_weight)
            consistency_loss = consistency_weight * consistency_criterion(cons_logit, ema_logit) / minibatch_size
            # meters.update('cons_loss', consistency_loss.data[0])
            meters.update('cons_loss', consistency_loss.data)   # --zys

        else:
            consistency_loss = 0
            meters.update('cons_loss', 0)

        loss = class_loss + consistency_loss + res_loss   # if logit_distance_cost == -1, res_loss = 0.   --zys
        # assert not (np.isnan(loss.data[0]) or loss.data[0] > 1e5), 'Loss explosion: {}'.format(loss.data[0])
        # assert not (np.isnan(loss.data) or loss.data > 1e5), 'Loss explosion: {}'.format(loss.data)   #  --zys
        # meters.update('loss', loss.data[0])
        meters.update('loss', loss.data)  #   --zys

        prec1, prec5 = accuracy(class_logit.data, target_var.data, topk=(1, 5))    # top1 and top5 accuracy of stu model.   --zys
        
        
        meters.update('top1', prec1, labeled_minibatch_size)
        meters.update('error1', 100. - prec1, labeled_minibatch_size)
        meters.update('top5', prec5, labeled_minibatch_size)
        meters.update('error5', 100. - prec5, labeled_minibatch_size)

        ema_prec1, ema_prec5 = accuracy(ema_logit.data, target_var.data, topk=(1, 5))       # top1 and top5 accuracy of tea model.   --zys
        meters.update('ema_top1', ema_prec1, labeled_minibatch_size)
        meters.update('ema_error1', 100. - ema_prec1, labeled_minibatch_size)
        meters.update('ema_top5', ema_prec5, labeled_minibatch_size)
        meters.update('ema_error5', 100. - ema_prec5, labeled_minibatch_size)

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        global_step += 1
        update_ema_variables(model, ema_model, args.ema_decay, global_step)       # theta = alpha*theta_{t-1} + (1-theta)*theta_t, 'ema-decay' means momentum  --zys

        # measure elapsed time
        meters.update('batch_time', time.time() - end)
        end = time.time()

#         if i % args.print_freq == 0:
#             LOG.info(
#                 'Epoch: [{0}][{1}/{2}]\t'
#                 'Time {meters[batch_time]:.3f}\t'
#                 'Data {meters[data_time]:.3f}\t'
#                 'Class {meters[class_loss]:.4f}\t'
#                 'Cons {meters[cons_loss]:.4f}\t'
#                 'Prec@1 {meters[top1]:.3f}\t'
#                 'Prec@5 {meters[top5]:.3f}'.format(
#                     epoch, i, len(train_loader), meters=meters))
#             log.record(epoch + i / len(train_loader), {
#                 'step': global_step,
#                 **meters.values(),
#                 **meters.averages(),
#                 **meters.sums()
#             })
            
  


### Save checkpoint

In [None]:

def save_checkpoint(state, is_best, dirpath, epoch):
    filename = 'checkpoint.{}.ckpt'.format(epoch)
    checkpoint_path = os.path.join(dirpath, filename)
    best_path = os.path.join(dirpath, 'best.ckpt')
    torch.save(state, checkpoint_path)
    LOG.info("--- checkpoint saved to %s ---" % checkpoint_path)
    if is_best:
        shutil.copyfile(checkpoint_path, best_path)
        LOG.info("--- checkpoint copied to %s ---" % best_path)

### Train &Update

In [None]:
for epoch in range(args.start_epoch, args.epochs):
    
    start_time = time.time()
    
    train(train_loader, model, ema_model, optimizer, epoch)
    LOG.info("--- training epoch in %s seconds ---" % (time.time() - start_time))
    
    '''evaluate every args.evaluation_epochs    --zys'''
    if args.evaluation_epochs and (epoch + 1) % args.evaluation_epochs == 0:
        start_time = time.time()
        LOG.info("Evaluating the primary model:")
        prec1 = validate(eval_loader, model, validation_log, global_step, epoch + 1)
        LOG.info("Evaluating the EMA model:")
        ema_prec1 = validate(eval_loader, ema_model, ema_validation_log, global_step, epoch + 1)
        LOG.info("--- validation in %s seconds ---" % (time.time() - start_time))
        is_best = ema_prec1 > best_prec1
        best_prec1 = max(ema_prec1, best_prec1)
    
    else:
        is_best = False
    
    '''save checkpoint every args.checkpoint_epochs    --zys'''
    if args.checkpoint_epochs and (epoch + 1) % args.checkpoint_epochs == 0:
        save_checkpoint({
            'epoch': epoch + 1,
            'global_step': global_step,
            'arch': args.arch,
            'state_dict': model.state_dict(),
            'ema_state_dict': ema_model.state_dict(),
            'best_prec1': best_prec1,
            'optimizer' : optimizer.state_dict(),
        }, is_best, checkpoint_path, epoch + 1)

In [None]:
import torch
from  mean_teacher import architectures

SimCLR = torch.load('SimCLR.pth')
SimCLR['model']

model_params = dict(num_classes=10)
model = architectures.SSL_model(**model_params).cuda()

model.load_state_dict(SimCLR['model'], strict=False)
model.parameters()