In [1]:
!pip install -q timm pytorch-metric-learning
import numpy as np
import argparse
import time
import os
import random
import shutil
import time
import warnings
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.optim
import torch.multiprocessing as mp
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models
from torch.optim import lr_scheduler
import torch.optim as optim
from pytorch_metric_learning import losses
from torch.cuda import amp
from tqdm.notebook import tqdm
from torch.autograd import Variable
from collections import defaultdict, OrderedDict
import copy
import cv2
import random
import math
import re
from PIL import Image, ImageOps, ImageEnhance, ImageChops
import PIL
import numpy as np


_PIL_VER = tuple([int(x) for x in PIL.__version__.split('.')[:2]])

_FILL = (128, 128, 128)

_LEVEL_DENOM = 10.  # denominator for conversion from 'Mx' magnitude scale to fractional aug level for op arguments

_HPARAMS_DEFAULT = dict(
    translate_const=250,
    img_mean=_FILL,
)

_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)


def _interpolation(kwargs):
    interpolation = kwargs.pop('resample', Image.BILINEAR)
    if isinstance(interpolation, (list, tuple)):
        return random.choice(interpolation)
    else:
        return interpolation


def _check_args_tf(kwargs):
    if 'fillcolor' in kwargs and _PIL_VER < (5, 0):
        kwargs.pop('fillcolor')
    kwargs['resample'] = _interpolation(kwargs)


def shear_x(img, factor, **kwargs):
    _check_args_tf(kwargs)
    return img.transform(img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs)


def shear_y(img, factor, **kwargs):
    _check_args_tf(kwargs)
    return img.transform(img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs)


def translate_x_rel(img, pct, **kwargs):
    pixels = pct * img.size[0]
    _check_args_tf(kwargs)
    return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs)


def translate_y_rel(img, pct, **kwargs):
    pixels = pct * img.size[1]
    _check_args_tf(kwargs)
    return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs)


def translate_x_abs(img, pixels, **kwargs):
    _check_args_tf(kwargs)
    return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs)


def translate_y_abs(img, pixels, **kwargs):
    _check_args_tf(kwargs)
    return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs)


def rotate(img, degrees, **kwargs):
    _check_args_tf(kwargs)
    if _PIL_VER >= (5, 2):
        return img.rotate(degrees, **kwargs)
    elif _PIL_VER >= (5, 0):
        w, h = img.size
        post_trans = (0, 0)
        rotn_center = (w / 2.0, h / 2.0)
        angle = -math.radians(degrees)
        matrix = [
            round(math.cos(angle), 15),
            round(math.sin(angle), 15),
            0.0,
            round(-math.sin(angle), 15),
            round(math.cos(angle), 15),
            0.0,
        ]

        def transform(x, y, matrix):
            (a, b, c, d, e, f) = matrix
            return a * x + b * y + c, d * x + e * y + f

        matrix[2], matrix[5] = transform(
            -rotn_center[0] - post_trans[0], -rotn_center[1] - post_trans[1], matrix
        )
        matrix[2] += rotn_center[0]
        matrix[5] += rotn_center[1]
        return img.transform(img.size, Image.AFFINE, matrix, **kwargs)
    else:
        return img.rotate(degrees, resample=kwargs['resample'])


def auto_contrast(img, **__):
    return ImageOps.autocontrast(img)


def invert(img, **__):
    return ImageOps.invert(img)


def equalize(img, **__):
    return ImageOps.equalize(img)


def solarize(img, thresh, **__):
    return ImageOps.solarize(img, thresh)


def solarize_add(img, add, thresh=128, **__):
    lut = []
    for i in range(256):
        if i < thresh:
            lut.append(min(255, i + add))
        else:
            lut.append(i)
    if img.mode in ("L", "RGB"):
        if img.mode == "RGB" and len(lut) == 256:
            lut = lut + lut + lut
        return img.point(lut)
    else:
        return img


def posterize(img, bits_to_keep, **__):
    if bits_to_keep >= 8:
        return img
    return ImageOps.posterize(img, bits_to_keep)


def contrast(img, factor, **__):
    return ImageEnhance.Contrast(img).enhance(factor)


def color(img, factor, **__):
    return ImageEnhance.Color(img).enhance(factor)


def brightness(img, factor, **__):
    return ImageEnhance.Brightness(img).enhance(factor)


def sharpness(img, factor, **__):
    return ImageEnhance.Sharpness(img).enhance(factor)


def _randomly_negate(v):
    """With 50% prob, negate the value"""
    return -v if random.random() > 0.5 else v


def _rotate_level_to_arg(level, _hparams):
    # range [-30, 30]
    level = (level / _LEVEL_DENOM) * 30.
    level = _randomly_negate(level)
    return level,


def _enhance_level_to_arg(level, _hparams):
    # range [0.1, 1.9]
    return (level / _LEVEL_DENOM) * 1.8 + 0.1,


def _enhance_increasing_level_to_arg(level, _hparams):
    # the 'no change' level is 1.0, moving away from that towards 0. or 2.0 increases the enhancement blend
    # range [0.1, 1.9] if level <= _LEVEL_DENOM
    level = (level / _LEVEL_DENOM) * .9
    level = max(0.1, 1.0 + _randomly_negate(level))  # keep it >= 0.1
    return level,


def _shear_level_to_arg(level, _hparams):
    # range [-0.3, 0.3]
    level = (level / _LEVEL_DENOM) * 0.3
    level = _randomly_negate(level)
    return level,


def _translate_abs_level_to_arg(level, hparams):
    translate_const = hparams['translate_const']
    level = (level / _LEVEL_DENOM) * float(translate_const)
    level = _randomly_negate(level)
    return level,


def _translate_rel_level_to_arg(level, hparams):
    # default range [-0.45, 0.45]
    translate_pct = hparams.get('translate_pct', 0.45)
    level = (level / _LEVEL_DENOM) * translate_pct
    level = _randomly_negate(level)
    return level,


def _posterize_level_to_arg(level, _hparams):
    # As per Tensorflow TPU EfficientNet impl
    # range [0, 4], 'keep 0 up to 4 MSB of original image'
    # intensity/severity of augmentation decreases with level
    return int((level / _LEVEL_DENOM) * 4),


def _posterize_increasing_level_to_arg(level, hparams):
    # As per Tensorflow models research and UDA impl
    # range [4, 0], 'keep 4 down to 0 MSB of original image',
    # intensity/severity of augmentation increases with level
    return 4 - _posterize_level_to_arg(level, hparams)[0],


def _posterize_original_level_to_arg(level, _hparams):
    # As per original AutoAugment paper description
    # range [4, 8], 'keep 4 up to 8 MSB of image'
    # intensity/severity of augmentation decreases with level
    return int((level / _LEVEL_DENOM) * 4) + 4,


def _solarize_level_to_arg(level, _hparams):
    # range [0, 256]
    # intensity/severity of augmentation decreases with level
    return int((level / _LEVEL_DENOM) * 256),


def _solarize_increasing_level_to_arg(level, _hparams):
    # range [0, 256]
    # intensity/severity of augmentation increases with level
    return 256 - _solarize_level_to_arg(level, _hparams)[0],


def _solarize_add_level_to_arg(level, _hparams):
    # range [0, 110]
    return int((level / _LEVEL_DENOM) * 110),


LEVEL_TO_ARG = {
    'AutoContrast': None,
    'Equalize': None,
    'Invert': None,
    'Rotate': _rotate_level_to_arg,
    # There are several variations of the posterize level scaling in various Tensorflow/Google repositories/papers
    'Posterize': _posterize_level_to_arg,
    'PosterizeIncreasing': _posterize_increasing_level_to_arg,
    'PosterizeOriginal': _posterize_original_level_to_arg,
    'Solarize': _solarize_level_to_arg,
    'SolarizeIncreasing': _solarize_increasing_level_to_arg,
    'SolarizeAdd': _solarize_add_level_to_arg,
    'Color': _enhance_level_to_arg,
    'ColorIncreasing': _enhance_increasing_level_to_arg,
    'Contrast': _enhance_level_to_arg,
    'ContrastIncreasing': _enhance_increasing_level_to_arg,
    'Brightness': _enhance_level_to_arg,
    'BrightnessIncreasing': _enhance_increasing_level_to_arg,
    'Sharpness': _enhance_level_to_arg,
    'SharpnessIncreasing': _enhance_increasing_level_to_arg,
    'ShearX': _shear_level_to_arg,
    'ShearY': _shear_level_to_arg,
    'TranslateX': _translate_abs_level_to_arg,
    'TranslateY': _translate_abs_level_to_arg,
    'TranslateXRel': _translate_rel_level_to_arg,
    'TranslateYRel': _translate_rel_level_to_arg,
}


NAME_TO_OP = {
    'AutoContrast': auto_contrast,
    'Equalize': equalize,
    'Invert': invert,
    'Rotate': rotate,
    'Posterize': posterize,
    'PosterizeIncreasing': posterize,
    'PosterizeOriginal': posterize,
    'Solarize': solarize,
    'SolarizeIncreasing': solarize,
    'SolarizeAdd': solarize_add,
    'Color': color,
    'ColorIncreasing': color,
    'Contrast': contrast,
    'ContrastIncreasing': contrast,
    'Brightness': brightness,
    'BrightnessIncreasing': brightness,
    'Sharpness': sharpness,
    'SharpnessIncreasing': sharpness,
    'ShearX': shear_x,
    'ShearY': shear_y,
    'TranslateX': translate_x_abs,
    'TranslateY': translate_y_abs,
    'TranslateXRel': translate_x_rel,
    'TranslateYRel': translate_y_rel,
}


class AugmentOp:

    def __init__(self, name, prob=0.5, magnitude=10, hparams=None):
        hparams = hparams or _HPARAMS_DEFAULT
        self.name = name
        self.aug_fn = NAME_TO_OP[name]
        self.level_fn = LEVEL_TO_ARG[name]
        self.prob = prob
        self.magnitude = magnitude
        self.hparams = hparams.copy()
        self.kwargs = dict(
            fillcolor=hparams['img_mean'] if 'img_mean' in hparams else _FILL,
            resample=hparams['interpolation'] if 'interpolation' in hparams else _RANDOM_INTERPOLATION,
        )

        # If magnitude_std is > 0, we introduce some randomness
        # in the usually fixed policy and sample magnitude from a normal distribution
        # with mean `magnitude` and std-dev of `magnitude_std`.
        # NOTE This is my own hack, being tested, not in papers or reference impls.
        # If magnitude_std is inf, we sample magnitude from a uniform distribution
        self.magnitude_std = self.hparams.get('magnitude_std', 0)
        self.magnitude_max = self.hparams.get('magnitude_max', None)

    def __call__(self, img):
        if self.prob < 1.0 and random.random() > self.prob:
            return img
        magnitude = self.magnitude
        if self.magnitude_std > 0:
            # magnitude randomization enabled
            if self.magnitude_std == float('inf'):
                magnitude = random.uniform(0, magnitude)
            elif self.magnitude_std > 0:
                magnitude = random.gauss(magnitude, self.magnitude_std)
        # default upper_bound for the timm RA impl is _LEVEL_DENOM (10)
        # setting magnitude_max overrides this to allow M > 10 (behaviour closer to Google TF RA impl)
        upper_bound = self.magnitude_max or _LEVEL_DENOM
        magnitude = max(0., min(magnitude, upper_bound))
        level_args = self.level_fn(magnitude, self.hparams) if self.level_fn is not None else tuple()
        return self.aug_fn(img, *level_args, **self.kwargs)

    def __repr__(self):
        fs = self.__class__.__name__ + f'(name={self.name}, p={self.prob}'
        fs += f', m={self.magnitude}, mstd={self.magnitude_std}'
        if self.magnitude_max is not None:
            fs += f', mmax={self.magnitude_max}'
        fs += ')'
        return fs


def auto_augment_policy_v0(hparams):
    # ImageNet v0 policy from TPU EfficientNet impl, cannot find a paper reference.
    policy = [
        [('Equalize', 0.8, 1), ('ShearY', 0.8, 4)],
        [('Color', 0.4, 9), ('Equalize', 0.6, 3)],
        [('Color', 0.4, 1), ('Rotate', 0.6, 8)],
        [('Solarize', 0.8, 3), ('Equalize', 0.4, 7)],
        [('Solarize', 0.4, 2), ('Solarize', 0.6, 2)],
        [('Color', 0.2, 0), ('Equalize', 0.8, 8)],
        [('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)],
        [('ShearX', 0.2, 9), ('Rotate', 0.6, 8)],
        [('Color', 0.6, 1), ('Equalize', 1.0, 2)],
        [('Invert', 0.4, 9), ('Rotate', 0.6, 0)],
        [('Equalize', 1.0, 9), ('ShearY', 0.6, 3)],
        [('Color', 0.4, 7), ('Equalize', 0.6, 0)],
        [('Posterize', 0.4, 6), ('AutoContrast', 0.4, 7)],
        [('Solarize', 0.6, 8), ('Color', 0.6, 9)],
        [('Solarize', 0.2, 4), ('Rotate', 0.8, 9)],
        [('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)],
        [('ShearX', 0.0, 0), ('Solarize', 0.8, 4)],
        [('ShearY', 0.8, 0), ('Color', 0.6, 4)],
        [('Color', 1.0, 0), ('Rotate', 0.6, 2)],
        [('Equalize', 0.8, 4), ('Equalize', 0.0, 8)],
        [('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)],
        [('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)],
        [('Posterize', 0.8, 2), ('Solarize', 0.6, 10)],  # This results in black image with Tpu posterize
        [('Solarize', 0.6, 8), ('Equalize', 0.6, 1)],
        [('Color', 0.8, 6), ('Rotate', 0.4, 5)],
    ]
    pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
    return pc


def auto_augment_policy_v0r(hparams):
    # ImageNet v0 policy from TPU EfficientNet impl, with variation of Posterize used
    # in Google research implementation (number of bits discarded increases with magnitude)
    policy = [
        [('Equalize', 0.8, 1), ('ShearY', 0.8, 4)],
        [('Color', 0.4, 9), ('Equalize', 0.6, 3)],
        [('Color', 0.4, 1), ('Rotate', 0.6, 8)],
        [('Solarize', 0.8, 3), ('Equalize', 0.4, 7)],
        [('Solarize', 0.4, 2), ('Solarize', 0.6, 2)],
        [('Color', 0.2, 0), ('Equalize', 0.8, 8)],
        [('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)],
        [('ShearX', 0.2, 9), ('Rotate', 0.6, 8)],
        [('Color', 0.6, 1), ('Equalize', 1.0, 2)],
        [('Invert', 0.4, 9), ('Rotate', 0.6, 0)],
        [('Equalize', 1.0, 9), ('ShearY', 0.6, 3)],
        [('Color', 0.4, 7), ('Equalize', 0.6, 0)],
        [('PosterizeIncreasing', 0.4, 6), ('AutoContrast', 0.4, 7)],
        [('Solarize', 0.6, 8), ('Color', 0.6, 9)],
        [('Solarize', 0.2, 4), ('Rotate', 0.8, 9)],
        [('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)],
        [('ShearX', 0.0, 0), ('Solarize', 0.8, 4)],
        [('ShearY', 0.8, 0), ('Color', 0.6, 4)],
        [('Color', 1.0, 0), ('Rotate', 0.6, 2)],
        [('Equalize', 0.8, 4), ('Equalize', 0.0, 8)],
        [('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)],
        [('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)],
        [('PosterizeIncreasing', 0.8, 2), ('Solarize', 0.6, 10)],
        [('Solarize', 0.6, 8), ('Equalize', 0.6, 1)],
        [('Color', 0.8, 6), ('Rotate', 0.4, 5)],
    ]
    pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
    return pc


def auto_augment_policy_original(hparams):
    # ImageNet policy from https://arxiv.org/abs/1805.09501
    policy = [
        [('PosterizeOriginal', 0.4, 8), ('Rotate', 0.6, 9)],
        [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
        [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
        [('PosterizeOriginal', 0.6, 7), ('PosterizeOriginal', 0.6, 6)],
        [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
        [('Equalize', 0.4, 4), ('Rotate', 0.8, 8)],
        [('Solarize', 0.6, 3), ('Equalize', 0.6, 7)],
        [('PosterizeOriginal', 0.8, 5), ('Equalize', 1.0, 2)],
        [('Rotate', 0.2, 3), ('Solarize', 0.6, 8)],
        [('Equalize', 0.6, 8), ('PosterizeOriginal', 0.4, 6)],
        [('Rotate', 0.8, 8), ('Color', 0.4, 0)],
        [('Rotate', 0.4, 9), ('Equalize', 0.6, 2)],
        [('Equalize', 0.0, 7), ('Equalize', 0.8, 8)],
        [('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
        [('Color', 0.6, 4), ('Contrast', 1.0, 8)],
        [('Rotate', 0.8, 8), ('Color', 1.0, 2)],
        [('Color', 0.8, 8), ('Solarize', 0.8, 7)],
        [('Sharpness', 0.4, 7), ('Invert', 0.6, 8)],
        [('ShearX', 0.6, 5), ('Equalize', 1.0, 9)],
        [('Color', 0.4, 0), ('Equalize', 0.6, 3)],
        [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
        [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
        [('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
        [('Color', 0.6, 4), ('Contrast', 1.0, 8)],
        [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
    ]
    pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
    return pc


def auto_augment_policy_originalr(hparams):
    # ImageNet policy from https://arxiv.org/abs/1805.09501 with research posterize variation
    policy = [
        [('PosterizeIncreasing', 0.4, 8), ('Rotate', 0.6, 9)],
        [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
        [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
        [('PosterizeIncreasing', 0.6, 7), ('PosterizeIncreasing', 0.6, 6)],
        [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
        [('Equalize', 0.4, 4), ('Rotate', 0.8, 8)],
        [('Solarize', 0.6, 3), ('Equalize', 0.6, 7)],
        [('PosterizeIncreasing', 0.8, 5), ('Equalize', 1.0, 2)],
        [('Rotate', 0.2, 3), ('Solarize', 0.6, 8)],
        [('Equalize', 0.6, 8), ('PosterizeIncreasing', 0.4, 6)],
        [('Rotate', 0.8, 8), ('Color', 0.4, 0)],
        [('Rotate', 0.4, 9), ('Equalize', 0.6, 2)],
        [('Equalize', 0.0, 7), ('Equalize', 0.8, 8)],
        [('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
        [('Color', 0.6, 4), ('Contrast', 1.0, 8)],
        [('Rotate', 0.8, 8), ('Color', 1.0, 2)],
        [('Color', 0.8, 8), ('Solarize', 0.8, 7)],
        [('Sharpness', 0.4, 7), ('Invert', 0.6, 8)],
        [('ShearX', 0.6, 5), ('Equalize', 1.0, 9)],
        [('Color', 0.4, 0), ('Equalize', 0.6, 3)],
        [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
        [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
        [('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
        [('Color', 0.6, 4), ('Contrast', 1.0, 8)],
        [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
    ]
    pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
    return pc


def auto_augment_policy(name='v0', hparams=None):
    hparams = hparams or _HPARAMS_DEFAULT
    if name == 'original':
        return auto_augment_policy_original(hparams)
    elif name == 'originalr':
        return auto_augment_policy_originalr(hparams)
    elif name == 'v0':
        return auto_augment_policy_v0(hparams)
    elif name == 'v0r':
        return auto_augment_policy_v0r(hparams)
    else:
        assert False, 'Unknown AA policy (%s)' % name


class AutoAugment:

    def __init__(self, policy):
        self.policy = policy

    def __call__(self, img):
        sub_policy = random.choice(self.policy)
        for op in sub_policy:
            img = op(img)
        return img

    def __repr__(self):
        fs = self.__class__.__name__ + f'(policy='
        for p in self.policy:
            fs += '\n\t['
            fs += ', '.join([str(op) for op in p])
            fs += ']'
        fs += ')'
        return fs


def auto_augment_transform(config_str, hparams):
    """
    Create a AutoAugment transform
    :param config_str: String defining configuration of auto augmentation. Consists of multiple sections separated by
    dashes ('-'). The first section defines the AutoAugment policy (one of 'v0', 'v0r', 'original', 'originalr').
    The remaining sections, not order sepecific determine
        'mstd' -  float std deviation of magnitude noise applied
    Ex 'original-mstd0.5' results in AutoAugment with original policy, magnitude_std 0.5
    :param hparams: Other hparams (kwargs) for the AutoAugmentation scheme
    :return: A PyTorch compatible Transform
    """
    config = config_str.split('-')
    policy_name = config[0]
    config = config[1:]
    for c in config:
        cs = re.split(r'(\d.*)', c)
        if len(cs) < 2:
            continue
        key, val = cs[:2]
        if key == 'mstd':
            # noise param injected via hparams for now
            hparams.setdefault('magnitude_std', float(val))
        else:
            assert False, 'Unknown AutoAugment config section'
    aa_policy = auto_augment_policy(policy_name, hparams=hparams)
    return AutoAugment(aa_policy)


_RAND_TRANSFORMS = [
    'AutoContrast',
    'Equalize',
    'Invert',
    'Rotate',
    'Posterize',
    'Solarize',
    'SolarizeAdd',
    'Color',
    'Contrast',
    'Brightness',
    'Sharpness',
    'ShearX',
    'ShearY',
    'TranslateXRel',
    'TranslateYRel',
    #'Cutout'  # NOTE I've implement this as random erasing separately
]


_RAND_INCREASING_TRANSFORMS = [
    'AutoContrast',
    'Equalize',
    'Invert',
    'Rotate',
    'PosterizeIncreasing',
    'SolarizeIncreasing',
    'SolarizeAdd',
    'ColorIncreasing',
    'ContrastIncreasing',
    'BrightnessIncreasing',
    'SharpnessIncreasing',
    'ShearX',
    'ShearY',
    'TranslateXRel',
    'TranslateYRel',
    #'Cutout'  # NOTE I've implement this as random erasing separately
]



# These experimental weights are based loosely on the relative improvements mentioned in paper.
# They may not result in increased performance, but could likely be tuned to so.
_RAND_CHOICE_WEIGHTS_0 = {
    'Rotate': 0.3,
    'ShearX': 0.2,
    'ShearY': 0.2,
    'TranslateXRel': 0.1,
    'TranslateYRel': 0.1,
    'Color': .025,
    'Sharpness': 0.025,
    'AutoContrast': 0.025,
    'Solarize': .005,
    'SolarizeAdd': .005,
    'Contrast': .005,
    'Brightness': .005,
    'Equalize': .005,
    'Posterize': 0,
    'Invert': 0,
}


def _select_rand_weights(weight_idx=0, transforms=None):
    transforms = transforms or _RAND_TRANSFORMS
    assert weight_idx == 0  # only one set of weights currently
    rand_weights = _RAND_CHOICE_WEIGHTS_0
    probs = [rand_weights[k] for k in transforms]
    probs /= np.sum(probs)
    return probs


def rand_augment_ops(magnitude=10, hparams=None, transforms=None):
    hparams = hparams or _HPARAMS_DEFAULT
    transforms = transforms or _RAND_TRANSFORMS
    return [AugmentOp(
        name, prob=0.5, magnitude=magnitude, hparams=hparams) for name in transforms]


class RandAugment:
    def __init__(self, ops, num_layers=2, choice_weights=None):
        self.ops = ops
        self.num_layers = num_layers
        self.choice_weights = choice_weights

    def __call__(self, img):
        # no replacement when using weighted choice
        ops = np.random.choice(
            self.ops, self.num_layers, replace=self.choice_weights is None, p=self.choice_weights)
        for op in ops:
            img = op(img)
        return img

    def __repr__(self):
        fs = self.__class__.__name__ + f'(n={self.num_layers}, ops='
        for op in self.ops:
            fs += f'\n\t{op}'
        fs += ')'
        return fs


def rand_augment_transform(config_str, hparams):
    """
    Create a RandAugment transform
    :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by
    dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining
    sections, not order sepecific determine
        'm' - integer magnitude of rand augment
        'n' - integer num layers (number of transform ops selected per image)
        'w' - integer probabiliy weight index (index of a set of weights to influence choice of op)
        'mstd' -  float std deviation of magnitude noise applied, or uniform sampling if infinity (or > 100)
        'mmax' - set upper bound for magnitude to something other than default of  _LEVEL_DENOM (10)
        'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0)
    Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5
    'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2
    :param hparams: Other hparams (kwargs) for the RandAugmentation scheme
    :return: A PyTorch compatible Transform
    """
    magnitude = _LEVEL_DENOM  # default to _LEVEL_DENOM for magnitude (currently 10)
    num_layers = 2  # default to 2 ops per image
    weight_idx = None  # default to no probability weights for op choice
    transforms = _RAND_TRANSFORMS
    config = config_str.split('-')
    assert config[0] == 'rand'
    config = config[1:]
    for c in config:
        cs = re.split(r'(\d.*)', c)
        if len(cs) < 2:
            continue
        key, val = cs[:2]
        if key == 'mstd':
            # noise param / randomization of magnitude values
            mstd = float(val)
            if mstd > 100:
                # use uniform sampling in 0 to magnitude if mstd is > 100
                mstd = float('inf')
            hparams.setdefault('magnitude_std', mstd)
        elif key == 'mmax':
            # clip magnitude between [0, mmax] instead of default [0, _LEVEL_DENOM]
            hparams.setdefault('magnitude_max', int(val))
        elif key == 'inc':
            if bool(val):
                transforms = _RAND_INCREASING_TRANSFORMS
        elif key == 'm':
            magnitude = int(val)
        elif key == 'n':
            num_layers = int(val)
        elif key == 'w':
            weight_idx = int(val)
        else:
            assert False, 'Unknown RandAugment config section'
    ra_ops = rand_augment_ops(magnitude=magnitude, hparams=hparams, transforms=transforms)
    choice_weights = None if weight_idx is None else _select_rand_weights(weight_idx)
    return RandAugment(ra_ops, num_layers, choice_weights=choice_weights)


_AUGMIX_TRANSFORMS = [
    'AutoContrast',
    'ColorIncreasing',  # not in paper
    'ContrastIncreasing',  # not in paper
    'BrightnessIncreasing',  # not in paper
    'SharpnessIncreasing',  # not in paper
    'Equalize',
    'Rotate',
    'PosterizeIncreasing',
    'SolarizeIncreasing',
    'ShearX',
    'ShearY',
    'TranslateXRel',
    'TranslateYRel',
]


def augmix_ops(magnitude=10, hparams=None, transforms=None):
    hparams = hparams or _HPARAMS_DEFAULT
    transforms = transforms or _AUGMIX_TRANSFORMS
    return [AugmentOp(
        name, prob=1.0, magnitude=magnitude, hparams=hparams) for name in transforms]


class AugMixAugment:
    """ AugMix Transform
    Adapted and improved from impl here: https://github.com/google-research/augmix/blob/master/imagenet.py
    From paper: 'AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty -
    https://arxiv.org/abs/1912.02781
    """
    def __init__(self, ops, alpha=1., width=3, depth=-1, blended=False):
        self.ops = ops
        self.alpha = alpha
        self.width = width
        self.depth = depth
        self.blended = blended  # blended mode is faster but not well tested

    def _calc_blended_weights(self, ws, m):
        ws = ws * m
        cump = 1.
        rws = []
        for w in ws[::-1]:
            alpha = w / cump
            cump *= (1 - alpha)
            rws.append(alpha)
        return np.array(rws[::-1], dtype=np.float32)

    def _apply_blended(self, img, mixing_weights, m):
        # This is my first crack and implementing a slightly faster mixed augmentation. Instead
        # of accumulating the mix for each chain in a Numpy array and then blending with original,
        # it recomputes the blending coefficients and applies one PIL image blend per chain.
        # TODO the results appear in the right ballpark but they differ by more than rounding.
        img_orig = img.copy()
        ws = self._calc_blended_weights(mixing_weights, m)
        for w in ws:
            depth = self.depth if self.depth > 0 else np.random.randint(1, 4)
            ops = np.random.choice(self.ops, depth, replace=True)
            img_aug = img_orig  # no ops are in-place, deep copy not necessary
            for op in ops:
                img_aug = op(img_aug)
            img = Image.blend(img, img_aug, w)
        return img

    def _apply_basic(self, img, mixing_weights, m):
        # This is a literal adaptation of the paper/official implementation without normalizations and
        # PIL <-> Numpy conversions between every op. It is still quite CPU compute heavy compared to the
        # typical augmentation transforms, could use a GPU / Kornia implementation.
        img_shape = img.size[0], img.size[1], len(img.getbands())
        mixed = np.zeros(img_shape, dtype=np.float32)
        for mw in mixing_weights:
            depth = self.depth if self.depth > 0 else np.random.randint(1, 4)
            ops = np.random.choice(self.ops, depth, replace=True)
            img_aug = img  # no ops are in-place, deep copy not necessary
            for op in ops:
                img_aug = op(img_aug)
            mixed += mw * np.asarray(img_aug, dtype=np.float32)
        np.clip(mixed, 0, 255., out=mixed)
        mixed = Image.fromarray(mixed.astype(np.uint8))
        return Image.blend(img, mixed, m)

    def __call__(self, img):
        mixing_weights = np.float32(np.random.dirichlet([self.alpha] * self.width))
        m = np.float32(np.random.beta(self.alpha, self.alpha))
        if self.blended:
            mixed = self._apply_blended(img, mixing_weights, m)
        else:
            mixed = self._apply_basic(img, mixing_weights, m)
        return mixed

    def __repr__(self):
        fs = self.__class__.__name__ + f'(alpha={self.alpha}, width={self.width}, depth={self.depth}, ops='
        for op in self.ops:
            fs += f'\n\t{op}'
        fs += ')'
        return fs


def augment_and_mix_transform(config_str, hparams):
    """ Create AugMix PyTorch transform
    :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by
    dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining
    sections, not order sepecific determine
        'm' - integer magnitude (severity) of augmentation mix (default: 3)
        'w' - integer width of augmentation chain (default: 3)
        'd' - integer depth of augmentation chain (-1 is random [1, 3], default: -1)
        'b' - integer (bool), blend each branch of chain into end result without a final blend, less CPU (default: 0)
        'mstd' -  float std deviation of magnitude noise applied (default: 0)
    Ex 'augmix-m5-w4-d2' results in AugMix with severity 5, chain width 4, chain depth 2
    :param hparams: Other hparams (kwargs) for the Augmentation transforms
    :return: A PyTorch compatible Transform
    """
    magnitude = 3
    width = 3
    depth = -1
    alpha = 1.
    blended = False
    config = config_str.split('-')
    assert config[0] == 'augmix'
    config = config[1:]
    for c in config:
        cs = re.split(r'(\d.*)', c)
        if len(cs) < 2:
            continue
        key, val = cs[:2]
        if key == 'mstd':
            # noise param injected via hparams for now
            hparams.setdefault('magnitude_std', float(val))
        elif key == 'm':
            magnitude = int(val)
        elif key == 'w':
            width = int(val)
        elif key == 'd':
            depth = int(val)
        elif key == 'a':
            alpha = float(val)
        elif key == 'b':
            blended = bool(val)
        else:
            assert False, 'Unknown AugMix config section'
    hparams.setdefault('magnitude_std', float('inf'))  # default to uniform sampling (if not set via mstd arg)
    ops = augmix_ops(magnitude=magnitude, hparams=hparams)
    return AugMixAugment(ops, alpha=alpha, width=width, depth=depth, blended=blended)



In [2]:
import random

import PIL, PIL.ImageOps, PIL.ImageEnhance, PIL.ImageDraw
import numpy as np
import torch
from torchvision.transforms.transforms import Compose

random_mirror = True


def ShearX(img, v):  # [-0.3, 0.3]
    assert -0.3 <= v <= 0.3
    if random_mirror and random.random() > 0.5:
        v = -v
    return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0))


def ShearY(img, v):  # [-0.3, 0.3]
    assert -0.3 <= v <= 0.3
    if random_mirror and random.random() > 0.5:
        v = -v
    return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0))


def TranslateX(img, v):  # [-150, 150] => percentage: [-0.45, 0.45]
    assert -0.45 <= v <= 0.45
    if random_mirror and random.random() > 0.5:
        v = -v
    v = v * img.size[0]
    return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))


def TranslateY(img, v):  # [-150, 150] => percentage: [-0.45, 0.45]
    assert -0.45 <= v <= 0.45
    if random_mirror and random.random() > 0.5:
        v = -v
    v = v * img.size[1]
    return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))


def TranslateXAbs(img, v):  # [-150, 150] => percentage: [-0.45, 0.45]
    assert 0 <= v <= 10
    if random.random() > 0.5:
        v = -v
    return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))


def TranslateYAbs(img, v):  # [-150, 150] => percentage: [-0.45, 0.45]
    assert 0 <= v <= 10
    if random.random() > 0.5:
        v = -v
    return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))


def Rotate(img, v):  # [-30, 30]
    assert -30 <= v <= 30
    if random_mirror and random.random() > 0.5:
        v = -v
    return img.rotate(v)


def AutoContrast(img, _):
    return PIL.ImageOps.autocontrast(img)


def Invert(img, _):
    return PIL.ImageOps.invert(img)


def Equalize(img, _):
    return PIL.ImageOps.equalize(img)


def Flip(img, _):  # not from the paper
    return PIL.ImageOps.mirror(img)


def Solarize(img, v):  # [0, 256]
    assert 0 <= v <= 256
    return PIL.ImageOps.solarize(img, v)


def Posterize(img, v):  # [4, 8]
    assert 4 <= v <= 8
    v = int(v)
    return PIL.ImageOps.posterize(img, v)


def Posterize2(img, v):  # [0, 4]
    assert 0 <= v <= 4
    v = int(v)
    return PIL.ImageOps.posterize(img, v)


def Contrast(img, v):  # [0.1,1.9]
    assert 0.1 <= v <= 1.9
    return PIL.ImageEnhance.Contrast(img).enhance(v)


def Color(img, v):  # [0.1,1.9]
    assert 0.1 <= v <= 1.9
    return PIL.ImageEnhance.Color(img).enhance(v)


def Brightness(img, v):  # [0.1,1.9]
    assert 0.1 <= v <= 1.9
    return PIL.ImageEnhance.Brightness(img).enhance(v)


def Sharpness(img, v):  # [0.1,1.9]
    assert 0.1 <= v <= 1.9
    return PIL.ImageEnhance.Sharpness(img).enhance(v)


def Cutout(img, v):  # [0, 60] => percentage: [0, 0.2]
    assert 0.0 <= v <= 0.2
    if v <= 0.:
        return img

    v = v * img.size[0]
    return CutoutAbs(img, v)


def CutoutAbs(img, v):  # [0, 60] => percentage: [0, 0.2]
    # assert 0 <= v <= 20
    if v < 0:
        return img
    w, h = img.size
    x0 = np.random.uniform(w)
    y0 = np.random.uniform(h)

    x0 = int(max(0, x0 - v / 2.))
    y0 = int(max(0, y0 - v / 2.))
    x1 = min(w, x0 + v)
    y1 = min(h, y0 + v)

    xy = (x0, y0, x1, y1)
    color = (125, 123, 114)
    # color = (0, 0, 0)
    img = img.copy()
    PIL.ImageDraw.Draw(img).rectangle(xy, color)
    return img


def SamplePairing(imgs):  # [0, 0.4]
    def f(img1, v):
        i = np.random.choice(len(imgs))
        img2 = PIL.Image.fromarray(imgs[i])
        return PIL.Image.blend(img1, img2, v)

    return f


def augment_list(for_autoaug=True):  # 16 oeprations and their ranges
    l = [
        (ShearX, -0.3, 0.3),  # 0
        (ShearY, -0.3, 0.3),  # 1
        (TranslateX, -0.45, 0.45),  # 2
        (TranslateY, -0.45, 0.45),  # 3
        (Rotate, -30, 30),  # 4
        (AutoContrast, 0, 1),  # 5
        (Invert, 0, 1),  # 6
        (Equalize, 0, 1),  # 7
        (Solarize, 0, 256),  # 8
        (Posterize, 4, 8),  # 9
        (Contrast, 0.1, 1.9),  # 10
        (Color, 0.1, 1.9),  # 11
        (Brightness, 0.1, 1.9),  # 12
        (Sharpness, 0.1, 1.9),  # 13
        (Cutout, 0, 0.2),  # 14
        # (SamplePairing(imgs), 0, 0.4),  # 15
    ]
    if for_autoaug:
        l += [
            (CutoutAbs, 0, 20),  # compatible with auto-augment
            (Posterize2, 0, 4),  # 9
            (TranslateXAbs, 0, 10),  # 9
            (TranslateYAbs, 0, 10),  # 9
        ]
    return l


augment_dict = {fn.__name__: (fn, v1, v2) for fn, v1, v2 in augment_list()}


def get_augment(name):
    return augment_dict[name]


def apply_augment(img, name, level):
    augment_fn, low, high = get_augment(name)
    return augment_fn(img.copy(), level * (high - low) + low)


class Lighting(object):
    """Lighting noise(AlexNet - style PCA - based noise)"""

    def __init__(self, alphastd, eigval, eigvec):
        self.alphastd = alphastd
        self.eigval = torch.Tensor(eigval)
        self.eigvec = torch.Tensor(eigvec)

    def __call__(self, img):
        if self.alphastd == 0:
            return img

        alpha = img.new().resize_(3).normal_(0, self.alphastd)
        rgb = self.eigvec.type_as(img).clone() \
            .mul(alpha.view(1, 3).expand(3, 3)) \
            .mul(self.eigval.view(1, 3).expand(3, 3)) \
            .sum(1).squeeze()

        return img.add(rgb.view(3, 1, 1).expand_as(img))

In [3]:
def fa_resnet50_rimagenet():
    p = [[["ShearY", 0.14143816458479197, 0.513124791615952], ["Sharpness", 0.9290316227291179, 0.9788406212603302]],
         [["Color", 0.21502874228385338, 0.3698477943880306], ["TranslateY", 0.49865058747734736, 0.4352676987103321]],
         [["Brightness", 0.6603452126485386, 0.6990174510500261], ["Cutout", 0.7742953773992511, 0.8362550883640804]],
         [["Posterize", 0.5188375788270497, 0.9863648925446865],
          ["TranslateY", 0.8365230108655313, 0.6000972236440252]],
         [["ShearY", 0.9714994964711299, 0.2563663552809896], ["Equalize", 0.8987567223581153, 0.1181761775609772]],
         [["Sharpness", 0.14346409304565366, 0.5342189791746006],
          ["Sharpness", 0.1219714162835897, 0.44746801278319975]],
         [["TranslateX", 0.08089260772173967, 0.028011721602479833],
          ["TranslateX", 0.34767877352421406, 0.45131294688688794]],
         [["Brightness", 0.9191164585327378, 0.5143232242627864], ["Color", 0.9235247849934283, 0.30604586249462173]],
         [["Contrast", 0.4584173187505879, 0.40314219914942756], ["Rotate", 0.550289356406774, 0.38419022293237126]],
         [["Posterize", 0.37046156420799325, 0.052693291117634544], ["Cutout", 0.7597581409366909, 0.7535799791937421]],
         [["Color", 0.42583964114658746, 0.6776641859552079], ["ShearY", 0.2864805671096011, 0.07580175477739545]],
         [["Brightness", 0.5065952125552232, 0.5508640233704984],
          ["Brightness", 0.4760021616081475, 0.3544313318097987]],
         [["Posterize", 0.5169630851995185, 0.9466018906715961], ["Posterize", 0.5390336503396841, 0.1171015788193209]],
         [["Posterize", 0.41153170909576176, 0.7213063942615204], ["Rotate", 0.6232230424824348, 0.7291984098675746]],
         [["Color", 0.06704687234714028, 0.5278429246040438], ["Sharpness", 0.9146652195810183, 0.4581415618941407]],
         [["ShearX", 0.22404644446773492, 0.6508620171913467],
          ["Brightness", 0.06421961538672451, 0.06859528721039095]],
         [["Rotate", 0.29864103693134797, 0.5244313199644495], ["Sharpness", 0.4006161706584276, 0.5203708477368657]],
         [["AutoContrast", 0.5748186910788027, 0.8185482599354216],
          ["Posterize", 0.9571441684265188, 0.1921474117448481]],
         [["ShearY", 0.5214786760436251, 0.8375629059785009], ["Invert", 0.6872393349333636, 0.9307694335024579]],
         [["Contrast", 0.47219838080793364, 0.8228524484275648],
          ["TranslateY", 0.7435518856840543, 0.5888865560614439]],
         [["Posterize", 0.10773482839638836, 0.6597021018893648], ["Contrast", 0.5218466423129691, 0.562985661685268]],
         [["Rotate", 0.4401753067886466, 0.055198255925702475], ["Rotate", 0.3702153509335602, 0.5821574425474759]],
         [["TranslateY", 0.6714729117832363, 0.7145542887432927],
          ["Equalize", 0.0023263758097700205, 0.25837341854887885]],
         [["Cutout", 0.3159707561240235, 0.19539664199170742], ["TranslateY", 0.8702824829864558, 0.5832348977243467]],
         [["AutoContrast", 0.24800812729140026, 0.08017301277245716],
          ["Brightness", 0.5775505849482201, 0.4905904775616114]],
         [["Color", 0.4143517886294533, 0.8445937742921498], ["ShearY", 0.28688910858536587, 0.17539366839474402]],
         [["Brightness", 0.6341134194059947, 0.43683815933640435],
          ["Brightness", 0.3362277685899835, 0.4612826163288225]],
         [["Sharpness", 0.4504035748829761, 0.6698294470467474],
          ["Posterize", 0.9610055612671645, 0.21070714173174876]],
         [["Posterize", 0.19490421920029832, 0.7235798208354267], ["Rotate", 0.8675551331308305, 0.46335565746433094]],
         [["Color", 0.35097958351003306, 0.42199181561523186], ["Invert", 0.914112788087429, 0.44775583211984815]],
         [["Cutout", 0.223575616055454, 0.6328591417299063], ["TranslateY", 0.09269465212259387, 0.5101073959070608]],
         [["Rotate", 0.3315734525975911, 0.9983593458299167], ["Sharpness", 0.12245416662856974, 0.6258689139914664]],
         [["ShearY", 0.696116760180471, 0.6317805202283014], ["Color", 0.847501151593963, 0.4440116609830195]],
         [["Solarize", 0.24945891607225948, 0.7651150206105561], ["Cutout", 0.7229677092930331, 0.12674657348602494]],
         [["TranslateX", 0.43461945065713675, 0.06476571036747841], ["Color", 0.6139316940180952, 0.7376264330632316]],
         [["Invert", 0.1933003530637138, 0.4497819016184308], ["Invert", 0.18391634069983653, 0.3199769100951113]],
         [["Color", 0.20418296626476137, 0.36785101882029814], ["Posterize", 0.624658293920083, 0.8390081535735991]],
         [["Sharpness", 0.5864963540530814, 0.586672446690273], ["Posterize", 0.1980280647652339, 0.222114611452575]],
         [["Invert", 0.3543654961628104, 0.5146369635250309], ["Equalize", 0.40751271919434434, 0.4325310837291978]],
         [["ShearY", 0.22602859359451877, 0.13137880879778158], ["Posterize", 0.7475029061591305, 0.803900538461099]],
         [["Sharpness", 0.12426276165599924, 0.5965912716602046], ["Invert", 0.22603903038966913, 0.4346802001255868]],
         [["TranslateY", 0.010307035630661765, 0.16577665156754046],
          ["Posterize", 0.4114319141395257, 0.829872913683949]],
         [["TranslateY", 0.9353069865746215, 0.5327821671247214], ["Color", 0.16990443486261103, 0.38794866007484197]],
         [["Cutout", 0.1028174322829021, 0.3955952903458266], ["ShearY", 0.4311995281335693, 0.48024695395374734]],
         [["Posterize", 0.1800334334284686, 0.0548749478418862],
          ["Brightness", 0.7545808536793187, 0.7699080551646432]],
         [["Color", 0.48695305373084197, 0.6674269768464615], ["ShearY", 0.4306032279086781, 0.06057690550239343]],
         [["Brightness", 0.4919399683825053, 0.677338905806407],
          ["Brightness", 0.24112708387760828, 0.42761103121157656]],
         [["Posterize", 0.4434818644882532, 0.9489450593207714],
          ["Posterize", 0.40957675116385955, 0.015664946759584186]],
         [["Posterize", 0.41307949855153797, 0.6843276552020272], ["Rotate", 0.8003545094091291, 0.7002300783416026]],
         [["Color", 0.7038570031770905, 0.4697612983649519], ["Sharpness", 0.9700016496081002, 0.25185103545948884]],
         [["AutoContrast", 0.714641656154856, 0.7962423001719023],
          ["Sharpness", 0.2410097684093468, 0.5919171048019731]],
         [["TranslateX", 0.8101567644494714, 0.7156447005337443], ["Solarize", 0.5634727831229329, 0.8875158446846]],
         [["Sharpness", 0.5335258857303261, 0.364743126378182], ["Color", 0.453280875871377, 0.5621962714743068]],
         [["Cutout", 0.7423678127672542, 0.7726370777867049], ["Invert", 0.2806161382641934, 0.6021111986900146]],
         [["TranslateY", 0.15190341320343761, 0.3860373175487939], ["Cutout", 0.9980805818665679, 0.05332384819400854]],
         [["Posterize", 0.36518675678786605, 0.2935819027397963],
          ["TranslateX", 0.26586180351840005, 0.303641300745208]],
         [["Brightness", 0.19994509744377761, 0.90813953707639], ["Equalize", 0.8447217761297836, 0.3449396603478335]],
         [["Sharpness", 0.9294773669936768, 0.999713346583839], ["Brightness", 0.1359744825665662, 0.1658489221872924]],
         [["TranslateX", 0.11456529257659381, 0.9063795878367734],
          ["Equalize", 0.017438134319894553, 0.15776887259743755]],
         [["ShearX", 0.9833726383270114, 0.5688194948373335], ["Equalize", 0.04975615490994345, 0.8078130016227757]],
         [["Brightness", 0.2654654830488695, 0.8989789725280538],
          ["TranslateX", 0.3681535065952329, 0.36433345713161036]],
         [["Rotate", 0.04956524209892327, 0.5371942433238247], ["ShearY", 0.0005527499145153714, 0.56082571605602]],
         [["Rotate", 0.7918337108932019, 0.5906896260060501], ["Posterize", 0.8223967034091191, 0.450216998388943]],
         [["Color", 0.43595106766978337, 0.5253013785221605], ["Sharpness", 0.9169421073531799, 0.8439997639348893]],
         [["TranslateY", 0.20052300197155504, 0.8202662448307549],
          ["Sharpness", 0.2875792108435686, 0.6997181624527842]],
         [["Color", 0.10568089980973616, 0.3349467065132249], ["Brightness", 0.13070947282207768, 0.5757725013960775]],
         [["AutoContrast", 0.3749999712869779, 0.6665578760607657],
          ["Brightness", 0.8101178402610292, 0.23271946112218125]],
         [["Color", 0.6473605933679651, 0.7903409763232029], ["ShearX", 0.588080941572581, 0.27223524148254086]],
         [["Cutout", 0.46293361616697304, 0.7107761001833921],
          ["AutoContrast", 0.3063766931658412, 0.8026114219854579]],
         [["Brightness", 0.7884854981520251, 0.5503669863113797],
          ["Brightness", 0.5832456158675261, 0.5840349298921661]],
         [["Solarize", 0.4157539625058916, 0.9161905834309929], ["Sharpness", 0.30628197221802017, 0.5386291658995193]],
         [["Sharpness", 0.03329610069672856, 0.17066672983670506], ["Invert", 0.9900547302690527, 0.6276238841220477]],
         [["Solarize", 0.551015648982762, 0.6937104775938737], ["Color", 0.8838491591064375, 0.31596634380795385]],
         [["AutoContrast", 0.16224182418148447, 0.6068227969351896],
          ["Sharpness", 0.9599468096118623, 0.4885289719905087]],
         [["TranslateY", 0.06576432526133724, 0.6899544605400214],
          ["Posterize", 0.2177096480169678, 0.9949164789616582]], [["Solarize", 0.529820544480292, 0.7576047224165541],
                                                                   ["Sharpness", 0.027047878909321643,
                                                                    0.45425231553970685]],
         [["Sharpness", 0.9102526010473146, 0.8311987141993857], ["Invert", 0.5191838751826638, 0.6906136644742229]],
         [["Solarize", 0.4762773516008588, 0.7703654263842423], ["Color", 0.8048437792602289, 0.4741523094238038]],
         [["Sharpness", 0.7095055508594206, 0.7047344238075169], ["Sharpness", 0.5059623654132546, 0.6127255499234886]],
         [["TranslateY", 0.02150725921966186, 0.3515764519224378],
          ["Posterize", 0.12482170119714735, 0.7829851754051393]],
         [["Color", 0.7983830079184816, 0.6964694521670339], ["Brightness", 0.3666527856286296, 0.16093151636495978]],
         [["AutoContrast", 0.6724982375829505, 0.536777706678488],
          ["Sharpness", 0.43091754837597646, 0.7363240924241439]],
         [["Brightness", 0.2889770401966227, 0.4556557902380539],
          ["Sharpness", 0.8805303296690755, 0.6262218017754902]],
         [["Sharpness", 0.5341939854581068, 0.6697109101429343], ["Rotate", 0.6806606655137529, 0.4896914517968317]],
         [["Sharpness", 0.5690509737059344, 0.32790632371915096],
          ["Posterize", 0.7951894258661069, 0.08377850335209162]],
         [["Color", 0.6124132978216081, 0.5756485920709012], ["Brightness", 0.33053544654445344, 0.23321841707002083]],
         [["TranslateX", 0.0654795026615917, 0.5227246924310244], ["ShearX", 0.2932320531132063, 0.6732066478183716]],
         [["Cutout", 0.6226071187083615, 0.01009274433736012], ["ShearX", 0.7176799968189801, 0.3758780240463811]],
         [["Rotate", 0.18172339508029314, 0.18099184896819184], ["ShearY", 0.7862658331645667, 0.295658135767252]],
         [["Contrast", 0.4156099177015862, 0.7015784500878446], ["Sharpness", 0.6454135310009, 0.32335858947955287]],
         [["Color", 0.6215885089922037, 0.6882673235388836], ["Brightness", 0.3539881732605379, 0.39486736455795496]],
         [["Invert", 0.8164816716866418, 0.7238192000817796], ["Sharpness", 0.3876355847343607, 0.9870077619731956]],
         [["Brightness", 0.1875628712629315, 0.5068115936257], ["Sharpness", 0.8732419122060423, 0.5028019258530066]],
         [["Sharpness", 0.6140734993408259, 0.6458239834366959], ["Rotate", 0.5250107862824867, 0.533419456933602]],
         [["Sharpness", 0.5710893143725344, 0.15551651073007305], ["ShearY", 0.6548487860151722, 0.021365083044319146]],
         [["Color", 0.7610250354649954, 0.9084452893074055], ["Brightness", 0.6934611792619156, 0.4108071412071374]],
         [["ShearY", 0.07512550098923898, 0.32923768385754293], ["ShearY", 0.2559588911696498, 0.7082337365398496]],
         [["Cutout", 0.5401319018926146, 0.004750568603408445], ["ShearX", 0.7473354415031975, 0.34472481968368773]],
         [["Rotate", 0.02284154583679092, 0.1353450082435801], ["ShearY", 0.8192458031684238, 0.2811653613473772]],
         [["Contrast", 0.21142896718139154, 0.7230739568811746],
          ["Sharpness", 0.6902690582665707, 0.13488436112901683]],
         [["Posterize", 0.21701219600958138, 0.5900695769640687], ["Rotate", 0.7541095031505971, 0.5341162375286219]],
         [["Posterize", 0.5772853064792737, 0.45808311743269936],
          ["Brightness", 0.14366050177823675, 0.4644871239446629]],
         [["Cutout", 0.8951718842805059, 0.4970074074310499], ["Equalize", 0.3863835903119882, 0.9986531042150006]],
         [["Equalize", 0.039411354473938925, 0.7475477254908457],
          ["Sharpness", 0.8741966378291861, 0.7304822679596362]],
         [["Solarize", 0.4908704265218634, 0.5160677350249471], ["Color", 0.24961813832742435, 0.09362352627360726]],
         [["Rotate", 7.870457075154214e-05, 0.8086950025500952],
          ["Solarize", 0.10200484521793163, 0.12312889222989265]],
         [["Contrast", 0.8052564975559727, 0.3403813036543645], ["Solarize", 0.7690158533600184, 0.8234626822018851]],
         [["AutoContrast", 0.680362728854513, 0.9415320040873628],
          ["TranslateY", 0.5305871824686941, 0.8030609611614028]],
         [["Cutout", 0.1748050257378294, 0.06565343731910589], ["TranslateX", 0.1812738872339903, 0.6254461448344308]],
         [["Brightness", 0.4230502644722749, 0.3346463682905031], ["ShearX", 0.19107198973659312, 0.6715789128604919]],
         [["ShearX", 0.1706528684548394, 0.7816570201200446], ["TranslateX", 0.494545185948171, 0.4710810058360291]],
         [["TranslateX", 0.42356251508933324, 0.23865307292867322],
          ["TranslateX", 0.24407503619326745, 0.6013778508137331]],
         [["AutoContrast", 0.7719512185744232, 0.3107905373009763],
          ["ShearY", 0.49448082925617176, 0.5777951230577671]],
         [["Cutout", 0.13026983827940525, 0.30120438757485657], ["Brightness", 0.8857896834516185, 0.7731541459513939]],
         [["AutoContrast", 0.6422800349197934, 0.38637401090264556],
          ["TranslateX", 0.25085431400995084, 0.3170642592664873]],
         [["Sharpness", 0.22336654455367122, 0.4137774852324138], ["ShearY", 0.22446851054920894, 0.518341735882535]],
         [["Color", 0.2597579403253848, 0.7289643913060193], ["Sharpness", 0.5227416670468619, 0.9239943674030637]],
         [["Cutout", 0.6835337711563527, 0.24777620448593812],
          ["AutoContrast", 0.37260245353051846, 0.4840361183247263]],
         [["Posterize", 0.32756602788628375, 0.21185124493743707],
          ["ShearX", 0.25431504951763967, 0.19585996561416225]],
         [["AutoContrast", 0.07930627591849979, 0.5719381348340309],
          ["AutoContrast", 0.335512380071304, 0.4208050118308541]],
         [["Rotate", 0.2924360268257798, 0.5317629242879337], ["Sharpness", 0.4531050021499891, 0.4102650087199528]],
         [["Equalize", 0.5908862210984079, 0.468742362277498], ["Brightness", 0.08571766548550425, 0.5629320703375056]],
         [["Cutout", 0.52751122383816, 0.7287774744737556], ["Equalize", 0.28721628275296274, 0.8075179887475786]],
         [["AutoContrast", 0.24208377391366226, 0.34616549409607644],
          ["TranslateX", 0.17454707403766834, 0.5278055700078459]],
         [["Brightness", 0.5511881924749478, 0.999638675514418], ["Equalize", 0.14076197797220913, 0.2573030693317552]],
         [["ShearX", 0.668731433926434, 0.7564253049646743], ["Color", 0.63235486543845, 0.43954436063340785]],
         [["ShearX", 0.40511960873276237, 0.5710419512142979], ["Contrast", 0.9256769948746423, 0.7461350716211649]],
         [["Cutout", 0.9995917204023061, 0.22908419326246265], ["TranslateX", 0.5440902956629469, 0.9965570051216295]],
         [["Color", 0.22552987172228894, 0.4514558960849747], ["Sharpness", 0.638058150559443, 0.9987829481002615]],
         [["Contrast", 0.5362775837534763, 0.7052133185951871], ["ShearY", 0.220369845547023, 0.7593922994775721]],
         [["ShearX", 0.0317785822935219, 0.775536785253455], ["TranslateX", 0.7939510227015061, 0.5355620618496535]],
         [["Cutout", 0.46027969917602196, 0.31561199122527517], ["Color", 0.06154066467629451, 0.5384660000729091]],
         [["Sharpness", 0.7205483743301113, 0.552222392539886], ["Posterize", 0.5146496404711752, 0.9224333144307473]],
         [["ShearX", 0.00014547730356910538, 0.3553954298642108], ["TranslateY", 0.9625736029090676, 0.57403418640424]],
         [["Posterize", 0.9199917903297341, 0.6690259107633706],
          ["Posterize", 0.0932558110217602, 0.22279303372106138]],
         [["Invert", 0.25401453476874863, 0.3354329544078385], ["Posterize", 0.1832673201325652, 0.4304718799821412]],
         [["TranslateY", 0.02084122674367607, 0.12826181437197323], ["ShearY", 0.655862534043703, 0.3838330909470975]],
         [["Contrast", 0.35231797644104523, 0.3379356652070079], ["Cutout", 0.19685599014304822, 0.1254328595280942]],
         [["Sharpness", 0.18795594984191433, 0.09488678946484895], ["ShearX", 0.33332876790679306, 0.633523782574133]],
         [["Cutout", 0.28267175940290246, 0.7901991550267817], ["Contrast", 0.021200195312951198, 0.4733128702798515]],
         [["ShearX", 0.966231043411256, 0.7700673327786812], ["TranslateX", 0.7102390777763321, 0.12161245817120675]],
         [["Cutout", 0.5183324259533826, 0.30766086003013055], ["Color", 0.48399078150128927, 0.4967477809069189]],
         [["Sharpness", 0.8160855187385873, 0.47937658961644], ["Posterize", 0.46360395447862535, 0.7685454058155061]],
         [["ShearX", 0.10173571421694395, 0.3987290690178754], ["TranslateY", 0.8939980277379345, 0.5669994143735713]],
         [["Posterize", 0.6768089584801844, 0.7113149244621721],
          ["Posterize", 0.054896856043358935, 0.3660837250743921]],
         [["AutoContrast", 0.5915576211896306, 0.33607718177676493],
          ["Contrast", 0.3809408206617828, 0.5712201773913784]],
         [["AutoContrast", 0.012321347472748323, 0.06379072432796573],
          ["Rotate", 0.0017964439160045656, 0.7598026295973337]],
         [["Contrast", 0.6007100085192627, 0.36171972473370206], ["Invert", 0.09553573684975913, 0.12218510774295901]],
         [["AutoContrast", 0.32848604643836266, 0.2619457656206414],
          ["Invert", 0.27082113532501784, 0.9967965642293485]],
         [["AutoContrast", 0.6156282120903395, 0.9422706516080884],
          ["Sharpness", 0.4215509247379262, 0.4063347716503587]],
         [["Solarize", 0.25059210436331264, 0.7215305521159305], ["Invert", 0.1654465185253614, 0.9605851884186778]],
         [["AutoContrast", 0.4464438610980994, 0.685334175815482], ["Cutout", 0.24358625461158645, 0.4699066834058694]],
         [["Rotate", 0.5931657741857909, 0.6813978655574067], ["AutoContrast", 0.9259100547738681, 0.4903201223870492]],
         [["Color", 0.8203976071280751, 0.9777824466585101], ["Posterize", 0.4620669369254169, 0.2738895968716055]],
         [["Contrast", 0.13754352055786848, 0.3369433962088463],
          ["Posterize", 0.48371187792441916, 0.025718004361451302]],
         [["Rotate", 0.5208233630704999, 0.1760188899913535], ["TranslateX", 0.49753461392937226, 0.4142935276250922]],
         [["Cutout", 0.5967418240931212, 0.8028675552639539], ["Cutout", 0.20021854152659121, 0.19426330549590076]],
         [["ShearY", 0.549583567386676, 0.6601326640171705], ["Cutout", 0.6111813470383047, 0.4141935587984994]],
         [["Brightness", 0.6354891977535064, 0.31591459747846745],
          ["AutoContrast", 0.7853952208711621, 0.6555861906702081]],
         [["AutoContrast", 0.7333725370546154, 0.9919410576081586], ["Cutout", 0.9984177877923588, 0.2938253683694291]],
         [["Color", 0.33219296307742263, 0.6378995578424113],
          ["AutoContrast", 0.15432820754183288, 0.7897899838932103]],
         [["Contrast", 0.5905289460222578, 0.8158577207653422], ["Cutout", 0.3980284381203051, 0.43030531250317217]],
         [["TranslateX", 0.452093693346745, 0.5251475931559115], ["Rotate", 0.991422504871258, 0.4556503729269001]],
         [["Color", 0.04560406292983776, 0.061574671308480766],
          ["Brightness", 0.05161079440128734, 0.6718398142425688]],
         [["Contrast", 0.02913302416506853, 0.14402056093217708], ["Rotate", 0.7306930378774588, 0.47088249057922094]],
         [["Solarize", 0.3283072384190169, 0.82680847744367], ["Invert", 0.21632614168418854, 0.8792241691482687]],
         [["Equalize", 0.4860808352478527, 0.9440534949023064], ["Cutout", 0.31395897639184694, 0.41805859306017523]],
         [["Rotate", 0.2816043232522335, 0.5451282807926706], ["Color", 0.7388520447173302, 0.7706503658143311]],
         [["Color", 0.9342776719536201, 0.9039981381514299], ["Rotate", 0.6646389177840164, 0.5147917008383647]],
         [["Cutout", 0.08929430082050335, 0.22416445996932374], ["Posterize", 0.454485751267457, 0.500958345348237]],
         [["TranslateX", 0.14674201106374488, 0.7018633472428202],
          ["Sharpness", 0.6128796723832848, 0.743535235614809]],
         [["TranslateX", 0.5189900164469432, 0.6491132403587601],
          ["Contrast", 0.26309555778227806, 0.5976857969656114]],
         [["Solarize", 0.23569808291972655, 0.3315781686591778], ["ShearY", 0.07292078937544964, 0.7460326987587573]],
         [["ShearY", 0.7090542757477153, 0.5246437008439621], ["Sharpness", 0.9666919148538443, 0.4841687888767071]],
         [["Solarize", 0.3486952615189488, 0.7012877201721799], ["Invert", 0.1933387967311534, 0.9535472742828175]],
         [["AutoContrast", 0.5393460721514914, 0.6924005011697713],
          ["Cutout", 0.16988156769247176, 0.3667207571712882]],
         [["Rotate", 0.5815329514554719, 0.5390406879316949], ["AutoContrast", 0.7370538341589625, 0.7708822194197815]],
         [["Color", 0.8463701017918459, 0.9893491045831084], ["Invert", 0.06537367901579016, 0.5238468509941635]],
         [["Contrast", 0.8099771812443645, 0.39371603893945184],
          ["Posterize", 0.38273629875646487, 0.46493786058573966]],
         [["Color", 0.11164686537114032, 0.6771450570033168], ["Posterize", 0.27921361289661406, 0.7214300893597819]],
         [["Contrast", 0.5958265906571906, 0.5963959447666958], ["Sharpness", 0.2640889223630885, 0.3365870842641453]],
         [["Color", 0.255634146724125, 0.5610029792926452], ["ShearY", 0.7476893976084721, 0.36613194760395557]],
         [["ShearX", 0.2167581882130063, 0.022978065071245002], ["TranslateX", 0.1686864409720319, 0.4919575435512007]],
         [["Solarize", 0.10702753776284957, 0.3954707963684698], ["Contrast", 0.7256100635368403, 0.48845259655719686]],
         [["Sharpness", 0.6165615058519549, 0.2624079463213861], ["ShearX", 0.3804820351860919, 0.4738994677544202]],
         [["TranslateX", 0.18066394808448177, 0.8174509422318228],
          ["Solarize", 0.07964569396290502, 0.45495935736800974]],
         [["Sharpness", 0.2741884021129658, 0.9311045302358317], ["Cutout", 0.0009101326429323388, 0.5932102256756948]],
         [["Rotate", 0.8501796375826188, 0.5092564038282137], ["Brightness", 0.6520146983999912, 0.724091283316938]],
         [["Brightness", 0.10079744898900078, 0.7644088017429471],
          ["AutoContrast", 0.33540215138213575, 0.1487538541758792]],
         [["ShearY", 0.10632545944757177, 0.9565164562996977], ["Rotate", 0.275833816849538, 0.6200731548023757]],
         [["Color", 0.6749819274397422, 0.41042188598168844],
          ["AutoContrast", 0.22396590966461932, 0.5048018491863738]],
         [["Equalize", 0.5044277111650255, 0.2649182381110667],
          ["Brightness", 0.35715133289571355, 0.8653260893016869]],
         [["Cutout", 0.49083594426355326, 0.5602781291093129], ["Posterize", 0.721795488514384, 0.5525847430754974]],
         [["Sharpness", 0.5081835448947317, 0.7453323423804428],
          ["TranslateX", 0.11511932212234266, 0.4337766796030984]],
         [["Solarize", 0.3817050641766593, 0.6879004573473403], ["Invert", 0.0015041436267447528, 0.9793134066888262]],
         [["AutoContrast", 0.5107410439697935, 0.8276720355454423],
          ["Cutout", 0.2786270701864015, 0.43993387208414564]],
         [["Rotate", 0.6711202569428987, 0.6342930903972932], ["Posterize", 0.802820231163559, 0.42770002619222053]],
         [["Color", 0.9426854321337312, 0.9055431782458764], ["AutoContrast", 0.3556422423506799, 0.2773922428787449]],
         [["Contrast", 0.10318991257659992, 0.30841372533347416],
          ["Posterize", 0.4202264962677853, 0.05060395018085634]],
         [["Invert", 0.549305630337048, 0.886056156681853], ["Cutout", 0.9314157033373055, 0.3485836940307909]],
         [["ShearX", 0.5642891775895684, 0.16427372934801418], ["Invert", 0.228741164726475, 0.5066345406806475]],
         [["ShearY", 0.5813123201003086, 0.33474363490586106], ["Equalize", 0.11803439432255824, 0.8583936440614798]],
         [["Sharpness", 0.1642809706111211, 0.6958675237301609], ["ShearY", 0.5989560762277414, 0.6194018060415276]],
         [["Rotate", 0.05092104774529638, 0.9358045394527796], ["Cutout", 0.6443254331615441, 0.28548414658857657]],
         [["Brightness", 0.6986036769232594, 0.9618046340942727],
          ["Sharpness", 0.5564490243465492, 0.6295231286085622]],
         [["Brightness", 0.42725649792574105, 0.17628028916784244],
          ["Equalize", 0.4425109360966546, 0.6392872650036018]],
         [["ShearY", 0.5758622795525444, 0.8773349286588288], ["ShearX", 0.038525646435423666, 0.8755366512394268]],
         [["Sharpness", 0.3704459924265827, 0.9236361456197351], ["Color", 0.6379842432311235, 0.4548767717224531]],
         [["Contrast", 0.1619523824549347, 0.4506528800882731],
          ["AutoContrast", 0.34513874426188385, 0.3580290330996726]],
         [["Contrast", 0.728699731513527, 0.6932238009822878], ["Brightness", 0.8602917375630352, 0.5341445123280423]],
         [["Equalize", 0.3574552353044203, 0.16814745124536548], ["Rotate", 0.24191717169379262, 0.3279497108179034]],
         [["ShearY", 0.8567478695576244, 0.37746117240238164], ["ShearX", 0.9654125389830487, 0.9283047610798827]],
         [["ShearY", 0.4339052480582405, 0.5394548246617406], ["Cutout", 0.5070570647967001, 0.7846286976687882]],
         [["AutoContrast", 0.021620100406875065, 0.44425839772845227],
          ["AutoContrast", 0.33978157614075183, 0.47716564815092244]],
         [["Contrast", 0.9727600659025666, 0.6651758819229426],
          ["Brightness", 0.9893133904996626, 0.39176397622636105]],
         [["Equalize", 0.283428620586305, 0.18727922861893637], ["Rotate", 0.3556063466797136, 0.3722839913107821]],
         [["ShearY", 0.7276172841941864, 0.4834188516302227], ["ShearX", 0.010783217950465884, 0.9756458772142235]],
         [["ShearY", 0.2901753295101581, 0.5684700238749064], ["Cutout", 0.655585564610337, 0.9490071307790201]],
         [["AutoContrast", 0.008507193981450278, 0.4881150103902877],
          ["AutoContrast", 0.6561989723231185, 0.3715071329838596]],
         [["Contrast", 0.7702505530948414, 0.6961371266519999], ["Brightness", 0.9953051630261895, 0.3861962467326121]],
         [["Equalize", 0.2805270012472756, 0.17715406116880994], ["Rotate", 0.3111256593947474, 0.15824352183820073]],
         [["Brightness", 0.9888680802094193, 0.4856236485253163], ["ShearX", 0.022370252047332284, 0.9284975906226682]],
         [["ShearY", 0.4065719044318099, 0.7468528006921563],
          ["AutoContrast", 0.19494427109708126, 0.8613186475174786]],
         [["AutoContrast", 0.023296727279367765, 0.9170949567425306],
          ["AutoContrast", 0.11663051100921168, 0.7908646792175343]],
         [["AutoContrast", 0.7335191671571732, 0.4958357308292425], ["Color", 0.7964964008349845, 0.4977687544324929]],
         [["ShearX", 0.19905221600021472, 0.3033081933150046], ["Equalize", 0.9383410219319321, 0.3224669877230161]],
         [["ShearX", 0.8265450331466404, 0.6509091423603757], ["Sharpness", 0.7134181178748723, 0.6472835976443643]],
         [["ShearY", 0.46962439525486044, 0.223433110541722], ["Rotate", 0.7749806946212373, 0.5337060376916906]],
         [["Posterize", 0.1652499695106796, 0.04860659068586126],
          ["Brightness", 0.6644577712782511, 0.4144528269429337]],
         [["TranslateY", 0.6220449565731829, 0.4917495676722932],
          ["Posterize", 0.6255000355409635, 0.8374266890984867]],
         [["AutoContrast", 0.4887160797052227, 0.7106426020530529],
          ["Sharpness", 0.7684218571497236, 0.43678474722954763]],
         [["Invert", 0.13178101535845366, 0.8301141976359813], ["Color", 0.002820877424219378, 0.49444413062487075]],
         [["TranslateX", 0.9920683666478188, 0.5862245842588877],
          ["Posterize", 0.5536357075855376, 0.5454300367281468]],
         [["Brightness", 0.8150181219663427, 0.1411060258870707], ["Sharpness", 0.8548823004164599, 0.77008691072314]],
         [["Brightness", 0.9580478020413399, 0.7198667636628974], ["ShearY", 0.8431585033377366, 0.38750016565010803]],
         [["Solarize", 0.2331505347152334, 0.25754361489084787], ["TranslateY", 0.447431373734262, 0.5782399531772253]],
         [["TranslateY", 0.8904927998691309, 0.25872872455072315],
          ["AutoContrast", 0.7129888139716263, 0.7161603231650524]],
         [["ShearY", 0.6336216800247362, 0.5247508616674911], ["Cutout", 0.9167315119726633, 0.2060557387978919]],
         [["ShearX", 0.001661782345968199, 0.3682225725445044], ["Solarize", 0.12303352043754572, 0.5014989548584458]],
         [["Brightness", 0.9723625105116246, 0.6555444729681099], ["Contrast", 0.5539208721135375, 0.7819973409318487]],
         [["Equalize", 0.3262607499912611, 0.0006745572802121513],
          ["Contrast", 0.35341551623767103, 0.36814689398886347]],
         [["ShearY", 0.7478539900243613, 0.37322078030129185], ["TranslateX", 0.41558847793529247, 0.7394615158544118]],
         [["Invert", 0.13735541232529067, 0.5536403864332143], ["Cutout", 0.5109718190377135, 0.0447509485253679]],
         [["AutoContrast", 0.09403602327274725, 0.5909250807862687], ["ShearY", 0.53234060616395, 0.5316981359469398]],
         [["ShearX", 0.5651922367876323, 0.6794110241313183], ["Posterize", 0.7431624856363638, 0.7896861463783287]],
         [["Brightness", 0.30949179379286806, 0.7650569096019195],
          ["Sharpness", 0.5461629122105034, 0.6814369444005866]],
         [["Sharpness", 0.28459340191768434, 0.7802208350806028], ["Rotate", 0.15097973114238117, 0.5259683294104645]],
         [["ShearX", 0.6430803693700531, 0.9333735880102375], ["Contrast", 0.7522209520030653, 0.18831747966185058]],
         [["Contrast", 0.4219455937915647, 0.29949769435499646], ["Color", 0.6925322933509542, 0.8095523885795443]],
         [["ShearX", 0.23553236193043048, 0.17966207900468323],
          ["AutoContrast", 0.9039700567886262, 0.21983629944639108]],
         [["ShearX", 0.19256223146671514, 0.31200739880443584], ["Sharpness", 0.31962196883294713, 0.6828107668550425]],
         [["Cutout", 0.5947690279080912, 0.21728220253899178], ["Rotate", 0.6757188879871141, 0.489460599679474]],
         [["ShearY", 0.18365897125470526, 0.3988571115918058], ["Brightness", 0.7727489489504, 0.4790369956329955]],
         [["Contrast", 0.7090301084131432, 0.5178303607560537], ["ShearX", 0.16749258277688506, 0.33061773301592356]],
         [["ShearX", 0.3706690885419934, 0.38510677124319415],
          ["AutoContrast", 0.8288356276501032, 0.16556487668770264]],
         [["TranslateY", 0.16758043046445614, 0.30127092823893986],
          ["Brightness", 0.5194636577132354, 0.6225165310621702]],
         [["Cutout", 0.6087289363049726, 0.10439287037803044], ["Rotate", 0.7503452083033819, 0.7425316019981433]],
         [["ShearY", 0.24347189588329932, 0.5554979486672325], ["Brightness", 0.9468115239174161, 0.6132449358023568]],
         [["Brightness", 0.7144508395807994, 0.4610594769966929], ["ShearX", 0.16466683833092968, 0.3382903812375781]],
         [["Sharpness", 0.27743648684265465, 0.17200038071656915], ["Color", 0.47404262107546236, 0.7868991675614725]],
         [["Sharpness", 0.8603993513633618, 0.324604728411791], ["TranslateX", 0.3331597130403763, 0.9369586812977804]],
         [["Color", 0.1535813630595832, 0.4700116846558207], ["Color", 0.5435647971896318, 0.7639291483525243]],
         [["Brightness", 0.21486188101947656, 0.039347277341450576],
          ["Cutout", 0.7069526940684954, 0.39273934115015696]],
         [["ShearY", 0.7267130888840517, 0.6310800726389485], ["AutoContrast", 0.662163190824139, 0.31948540372237766]],
         [["ShearX", 0.5123132117185981, 0.1981015909438834],
          ["AutoContrast", 0.9009347363863067, 0.26790399126924036]],
         [["Brightness", 0.24245061453231648, 0.2673478678291436], ["ShearX", 0.31707976089283946, 0.6800582845544948]],
         [["Cutout", 0.9257780138367764, 0.03972673526848819], ["Rotate", 0.6807858944518548, 0.46974332280612097]],
         [["ShearY", 0.1543443071262312, 0.6051682587030671], ["Brightness", 0.9758203119828304, 0.4941406868162414]],
         [["Contrast", 0.07578049236491124, 0.38953819133407647], ["ShearX", 0.20194918288164293, 0.4141510791947318]],
         [["Color", 0.27826402243792286, 0.43517491081531157],
          ["AutoContrast", 0.6159269026143263, 0.2021846783488046]],
         [["AutoContrast", 0.5039377966534692, 0.19241507605941105],
          ["Invert", 0.5563931144385394, 0.7069728937319112]],
         [["Sharpness", 0.19031632433810566, 0.26310171056096743], ["Color", 0.4724537593175573, 0.6715201448387876]],
         [["ShearY", 0.2280910467786642, 0.33340559088059313], ["ShearY", 0.8858560034869303, 0.2598627441471076]],
         [["ShearY", 0.07291814128021593, 0.5819462692986321], ["Cutout", 0.27605696060512147, 0.9693427371868695]],
         [["Posterize", 0.4249871586563321, 0.8256952014328607],
          ["Posterize", 0.005907466926447169, 0.8081353382152597]],
         [["Brightness", 0.9071305290601128, 0.4781196213717954],
          ["Posterize", 0.8996214311439275, 0.5540717376630279]],
         [["Brightness", 0.06560728936236392, 0.9920627849065685],
          ["TranslateX", 0.04530789794044952, 0.5318568944702607]],
         [["TranslateX", 0.6800263601084814, 0.4611536772507228], ["Rotate", 0.7245888375283157, 0.0914772551375381]],
         [["Sharpness", 0.879556061897963, 0.42272481462067535],
          ["TranslateX", 0.4600350422524085, 0.5742175429334919]],
         [["AutoContrast", 0.5005776243176145, 0.22597121331684505],
          ["Invert", 0.10763286370369299, 0.6841782704962373]], [["Sharpness", 0.7422908472000116, 0.6850324203882405],
                                                                 ["TranslateX", 0.3832914614128403,
                                                                  0.34798646673324896]],
         [["ShearY", 0.31939465302679326, 0.8792088167639516], ["Brightness", 0.4093604352811235, 0.21055483197261338]],
         [["AutoContrast", 0.7447595860998638, 0.19280222555998586],
          ["TranslateY", 0.317754779431227, 0.9983454520593591]],
         [["Equalize", 0.27706973689750847, 0.6447455020660622], ["Contrast", 0.5626579126863761, 0.7920049962776781]],
         [["Rotate", 0.13064369451773816, 0.1495367590684905], ["Sharpness", 0.24893941981801215, 0.6295943894521504]],
         [["ShearX", 0.6856269993063254, 0.5167938584189854], ["Sharpness", 0.24835352574609537, 0.9990550493102627]],
         [["AutoContrast", 0.461654115871693, 0.43097388896245004], ["Cutout", 0.366359682416437, 0.08011826474215511]],
         [["AutoContrast", 0.993892672935951, 0.2403608711236933], ["ShearX", 0.6620817870694181, 0.1744814077869482]],
         [["ShearY", 0.6396747719986443, 0.15031017143644265], ["Brightness", 0.9451954879495629, 0.26490678840264714]],
         [["Color", 0.19311480787397262, 0.15712300697448575], ["Posterize", 0.05391448762015258, 0.6943963643155474]],
         [["Sharpness", 0.6199669674684085, 0.5412492335319072], ["Invert", 0.14086213450149815, 0.2611850277919339]],
         [["Posterize", 0.5533129268803405, 0.5332478159319912], ["ShearX", 0.48956244029096635, 0.09223930853562916]],
         [["ShearY", 0.05871590849449765, 0.19549715278943228],
          ["TranslateY", 0.7208521362741379, 0.36414003004659434]],
         [["ShearY", 0.7316263417917531, 0.0629747985768501], ["Contrast", 0.036359793501448245, 0.48658745414898386]],
         [["Rotate", 0.3301497610942963, 0.5686622043085637], ["ShearX", 0.40581487555676843, 0.5866127743850192]],
         [["ShearX", 0.6679039628249283, 0.5292270693200821], ["Sharpness", 0.25901391739310703, 0.9778360586541461]],
         [["AutoContrast", 0.27373222012596854, 0.14456771405730712],
          ["Contrast", 0.3877220783523938, 0.7965158941894336]],
         [["Solarize", 0.29440905483979096, 0.06071633809388455],
          ["Equalize", 0.5246736285116214, 0.37575084834661976]],
         [["TranslateY", 0.2191269464520395, 0.7444942293988484],
          ["Posterize", 0.3840878524812771, 0.31812671711741247]],
         [["Solarize", 0.25159267140731356, 0.5833264622559661],
          ["Brightness", 0.07552262572348738, 0.33210648549288435]],
         [["AutoContrast", 0.9770099298399954, 0.46421915310428197],
          ["AutoContrast", 0.04707358934642503, 0.24922048012183493]],
         [["Cutout", 0.5379685806621965, 0.02038212605928355], ["Brightness", 0.5900728303717965, 0.28807872931416956]],
         [["Sharpness", 0.11596624872886108, 0.6086947716949325],
          ["AutoContrast", 0.34876470059667525, 0.22707897759730578]],
         [["Contrast", 0.276545513135698, 0.8822580384226156], ["Rotate", 0.04874027684061846, 0.6722214281612163]],
         [["ShearY", 0.595839851757025, 0.4389866852785822], ["Equalize", 0.5225492356128832, 0.2735290854063459]],
         [["Sharpness", 0.9918029636732927, 0.9919926583216121],
          ["Sharpness", 0.03672376137997366, 0.5563865980047012]],
         [["AutoContrast", 0.34169589759999847, 0.16419911552645738],
          ["Invert", 0.32995953043129234, 0.15073174739720568]],
         [["Posterize", 0.04600255098477292, 0.2632612790075844],
          ["TranslateY", 0.7852153329831825, 0.6990722310191976]],
         [["AutoContrast", 0.4414653815356372, 0.2657468780017082],
          ["Posterize", 0.30647061536763337, 0.3688222724948656]],
         [["Contrast", 0.4239361091421837, 0.6076562806342001], ["Cutout", 0.5780707784165284, 0.05361325256745192]],
         [["Sharpness", 0.7657895907855394, 0.9842407321667671], ["Sharpness", 0.5416352696151596, 0.6773681575200902]],
         [["AutoContrast", 0.13967381098331305, 0.10787258006315015],
          ["Posterize", 0.5019536507897069, 0.9881978222469807]],
         [["Brightness", 0.030528346448984903, 0.31562058762552847],
          ["TranslateY", 0.0843808140595676, 0.21019213305350526]],
         [["AutoContrast", 0.6934579165006736, 0.2530484168209199],
          ["Rotate", 0.0005751408130693636, 0.43790043943210005]],
         [["TranslateX", 0.611258547664328, 0.25465240215894935],
          ["Sharpness", 0.5001446909868196, 0.36102204109889413]],
         [["Contrast", 0.8995127327150193, 0.5493190695343996], ["Brightness", 0.242708780669213, 0.5461116653329015]],
         [["AutoContrast", 0.3751825351022747, 0.16845985803896962],
          ["Cutout", 0.25201103287363663, 0.0005893331783358435]],
         [["ShearX", 0.1518985779435941, 0.14768180777304504], ["Color", 0.85133530274324, 0.4006641163378305]],
         [["TranslateX", 0.5489668255504668, 0.4694591826554948], ["Rotate", 0.1917354490155893, 0.39993269385802177]],
         [["ShearY", 0.6689267479532809, 0.34304285013663577], ["Equalize", 0.24133154048883143, 0.279324043138247]],
         [["Contrast", 0.3412544002099494, 0.20217358823930232], ["Color", 0.8606984790510235, 0.14305503544676373]],
         [["Cutout", 0.21656155695311988, 0.5240101349572595], ["Brightness", 0.14109877717636352, 0.2016827341210295]],
         [["Sharpness", 0.24764371218833872, 0.19655480259925423],
          ["Posterize", 0.19460398862039913, 0.4975414350200679]],
         [["Brightness", 0.6071850094982323, 0.7270716448607151], ["Solarize", 0.111786402398499, 0.6325641684614275]],
         [["Contrast", 0.44772949532200856, 0.44267502710695955],
          ["AutoContrast", 0.360117506402693, 0.2623958228760273]],
         [["Sharpness", 0.8888131688583053, 0.936897400764746], ["Sharpness", 0.16080674198274894, 0.5681119841445879]],
         [["AutoContrast", 0.8004456226590612, 0.1788600469525269],
          ["Brightness", 0.24832285390647374, 0.02755350284841604]],
         [["ShearY", 0.06910320102646594, 0.26076407321544054], ["Contrast", 0.8633703022354964, 0.38968514704043056]],
         [["AutoContrast", 0.42306251382780613, 0.6883260271268138],
          ["Rotate", 0.3938724346852023, 0.16740881249086037]],
         [["Contrast", 0.2725343884286728, 0.6468194318074759], ["Sharpness", 0.32238942646494745, 0.6721149242783824]],
         [["AutoContrast", 0.942093919956842, 0.14675331481712853],
          ["Posterize", 0.5406276708262192, 0.683901182218153]],
         [["Cutout", 0.5386811894643584, 0.04498833938429728], ["Posterize", 0.17007257321724775, 0.45761177118620633]],
         [["Contrast", 0.13599408935104654, 0.53282738083886], ["Solarize", 0.26941667995081114, 0.20958261079465895]],
         [["Color", 0.6600788518606634, 0.9522228302165842], ["Invert", 0.0542722262516899, 0.5152431169321683]],
         [["Contrast", 0.5328934819727553, 0.2376220512388278], ["Posterize", 0.04890422575781711, 0.3182233123739474]],
         [["AutoContrast", 0.9289628064340965, 0.2976678437448435], ["Color", 0.20936893798507963, 0.9649612821434217]],
         [["Cutout", 0.9019423698575457, 0.24002036989728096],
          ["Brightness", 0.48734445615892974, 0.047660899809176316]],
         [["Sharpness", 0.09347824275711591, 0.01358686275590612],
          ["Posterize", 0.9248539660538934, 0.4064232632650468]],
         [["Brightness", 0.46575675383704634, 0.6280194775484345],
          ["Invert", 0.17276207634499413, 0.21263495428839635]],
         [["Brightness", 0.7238014711679732, 0.6178946027258592],
          ["Equalize", 0.3815496086340364, 0.07301281068847276]],
         [["Contrast", 0.754557393588416, 0.895332753570098], ["Color", 0.32709957750707447, 0.8425486003491515]],
         [["Rotate", 0.43406698081696576, 0.28628263254953723],
          ["TranslateY", 0.43949548709125374, 0.15927082198238685]],
         [["Brightness", 0.0015838339831640708, 0.09341692553352654],
          ["AutoContrast", 0.9113966907329718, 0.8345900469751112]],
         [["ShearY", 0.46698796308585017, 0.6150701348176804], ["Invert", 0.14894062704815722, 0.2778388046184728]],
         [["Color", 0.30360499169455957, 0.995713092016834], ["Contrast", 0.2597016288524961, 0.8654420870658932]],
         [["Brightness", 0.9661642031891435, 0.7322006407169436],
          ["TranslateY", 0.4393502786333408, 0.33934762664274265]],
         [["Color", 0.9323638351992302, 0.912776309755293], ["Brightness", 0.1618274755371618, 0.23485741708056307]],
         [["Color", 0.2216470771158821, 0.3359240197334976], ["Sharpness", 0.6328691811471494, 0.6298393874452548]],
         [["Solarize", 0.4772769142265505, 0.7073470698713035], ["ShearY", 0.2656114148206966, 0.31343097010487253]],
         [["Solarize", 0.3839017339304234, 0.5985505779429036],
          ["Equalize", 0.002412059429196589, 0.06637506181196245]],
         [["Contrast", 0.12751196553017863, 0.46980311434237976],
          ["Sharpness", 0.3467487455865491, 0.4054907610444406]],
         [["AutoContrast", 0.9321813669127206, 0.31328471589533274],
          ["Rotate", 0.05801738717432747, 0.36035756254444273]],
         [["TranslateX", 0.52092390458353, 0.5261722561643886], ["Contrast", 0.17836804476171306, 0.39354333443158535]],
         [["Posterize", 0.5458100909925713, 0.49447244994482603],
          ["Brightness", 0.7372536822363605, 0.5303409097463796]],
         [["Solarize", 0.1913974941725724, 0.5582966653986761], ["Equalize", 0.020733669175727026, 0.9377467166472878]],
         [["Equalize", 0.16265732137763889, 0.5206282340874929], ["Sharpness", 0.2421533133595281, 0.506389065871883]],
         [["AutoContrast", 0.9787324801448523, 0.24815051941486466],
          ["Rotate", 0.2423487151245957, 0.6456493129745148]], [["TranslateX", 0.6809867726670327, 0.6949687002397612],
                                                                ["Contrast", 0.16125673359747458, 0.7582679978218987]],
         [["Posterize", 0.8212000950994955, 0.5225012157831872],
          ["Brightness", 0.8824891856626245, 0.4499216779709508]],
         [["Solarize", 0.12061313332505218, 0.5319371283368052], ["Equalize", 0.04120865969945108, 0.8179402157299602]],
         [["Rotate", 0.11278256686005855, 0.4022686554165438], ["ShearX", 0.2983451019112792, 0.42782525461812604]],
         [["ShearY", 0.8847385513289983, 0.5429227024179573], ["Rotate", 0.21316428726607445, 0.6712120087528564]],
         [["TranslateX", 0.46448081241068717, 0.4746090648963252],
          ["Brightness", 0.19973580961271142, 0.49252862676553605]],
         [["Posterize", 0.49664100539481526, 0.4460713166484651],
          ["Brightness", 0.6629559985581529, 0.35192346529003693]],
         [["Color", 0.22710733249173676, 0.37943185764616194], ["ShearX", 0.015809774971472595, 0.8472080190835669]],
         [["Contrast", 0.4187366322381491, 0.21621979869256666],
          ["AutoContrast", 0.7631045030367304, 0.44965231251615134]],
         [["Sharpness", 0.47240637876720515, 0.8080091811749525], ["Cutout", 0.2853425420104144, 0.6669811510150936]],
         [["Posterize", 0.7830320527127324, 0.2727062685529881], ["Solarize", 0.527834000867504, 0.20098218845222998]],
         [["Contrast", 0.366380535288225, 0.39766001659663075], ["Cutout", 0.8708808878088891, 0.20669525734273086]],
         [["ShearX", 0.6815427281122932, 0.6146858582671569], ["AutoContrast", 0.28330622372053493, 0.931352024154997]],
         [["AutoContrast", 0.8668174463154519, 0.39961453880632863],
          ["AutoContrast", 0.5718557712359253, 0.6337062930797239]],
         [["ShearY", 0.8923152519411871, 0.02480062504737446], ["Cutout", 0.14954159341231515, 0.1422219808492364]],
         [["Rotate", 0.3733718175355636, 0.3861928572224287], ["Sharpness", 0.5651126520194574, 0.6091103847442831]],
         [["Posterize", 0.8891714191922857, 0.29600154265251016],
          ["TranslateY", 0.7865351723963945, 0.5664998548985523]],
         [["TranslateX", 0.9298214806998273, 0.729856565052017],
          ["AutoContrast", 0.26349082482341846, 0.9638882609038888]],
         [["Sharpness", 0.8387378377527128, 0.42146721129032494],
          ["AutoContrast", 0.9860522000876452, 0.4200699464169384]],
         [["ShearY", 0.019609159303115145, 0.37197835936879514], ["Cutout", 0.22199340461754258, 0.015932573201085848]],
         [["Rotate", 0.43871085583928443, 0.3283504258860078], ["Sharpness", 0.6077702068037776, 0.6830305349618742]],
         [["Contrast", 0.6160211756538094, 0.32029451083389626], ["Cutout", 0.8037631428427006, 0.4025688837399259]],
         [["TranslateY", 0.051637820936985435, 0.6908417834391846],
          ["Sharpness", 0.7602756948473368, 0.4927111506643095]],
         [["Rotate", 0.4973618638052235, 0.45931479729281227], ["TranslateY", 0.04701789716427618, 0.9408779705948676]],
         [["Rotate", 0.5214194592768602, 0.8371249272013652], ["Solarize", 0.17734812472813338, 0.045020798970228315]],
         [["ShearX", 0.7457999920079351, 0.19025612553075893], ["Sharpness", 0.5994846101703786, 0.5665094068864229]],
         [["Contrast", 0.6172655452900769, 0.7811432139704904], ["Cutout", 0.09915620454670282, 0.3963692287596121]],
         [["TranslateX", 0.2650112299235817, 0.7377261946165307],
          ["AutoContrast", 0.5019539734059677, 0.26905046992024506]],
         [["Contrast", 0.6646299821370135, 0.41667784809592945], ["Cutout", 0.9698457154992128, 0.15429001887703997]],
         [["Sharpness", 0.9467079029475773, 0.44906457469098204], ["Cutout", 0.30036908747917396, 0.4766149689663106]],
         [["Equalize", 0.6667517691051055, 0.5014839828447363], ["Solarize", 0.4127890336820831, 0.9578274770236529]],
         [["Cutout", 0.6447384874120834, 0.2868806107728985], ["Cutout", 0.4800990488106021, 0.4757538246206956]],
         [["Solarize", 0.12560195032363236, 0.5557473475801568],
          ["Equalize", 0.019957161871490228, 0.5556797187823773]],
         [["Contrast", 0.12607637375759484, 0.4300633627435161],
          ["Sharpness", 0.3437273670109087, 0.40493203127714417]],
         [["AutoContrast", 0.884353334807183, 0.5880138314357569], ["Rotate", 0.9846032404597116, 0.3591877296622974]],
         [["TranslateX", 0.6862295865975581, 0.5307482119690076],
          ["Contrast", 0.19439251187251982, 0.3999195825722808]],
         [["Posterize", 0.4187641835025246, 0.5008988942651585],
          ["Brightness", 0.6665805605402482, 0.3853288204214253]],
         [["Posterize", 0.4507470690013903, 0.4232437206624681],
          ["TranslateX", 0.6054107416317659, 0.38123828040922203]],
         [["AutoContrast", 0.29562338573283276, 0.35608605102687474],
          ["TranslateX", 0.909954785390274, 0.20098894888066549]],
         [["Contrast", 0.6015278411777212, 0.6049140992035096], ["Cutout", 0.47178713636517855, 0.5333747244651914]],
         [["TranslateX", 0.490851976691112, 0.3829593925141144], ["Sharpness", 0.2716675173824095, 0.5131696240367152]],
         [["Posterize", 0.4190558294646337, 0.39316689077269873], ["Rotate", 0.5018526072725914, 0.295712490156129]],
         [["AutoContrast", 0.29624715560691617, 0.10937329832409388],
          ["Posterize", 0.8770505275992637, 0.43117765012206943]],
         [["Rotate", 0.6649970092751698, 0.47767131373391974], ["ShearX", 0.6257923540490786, 0.6643337040198358]],
         [["Sharpness", 0.5553620705849509, 0.8467799429696928], ["Cutout", 0.9006185811918932, 0.3537270716262]],
         [["ShearY", 0.0007619678283789788, 0.9494591850536303], ["Invert", 0.24267733654007673, 0.7851608409575828]],
         [["Contrast", 0.9730916198112872, 0.404670123321921], ["Sharpness", 0.5923587793251186, 0.7405792404430281]],
         [["Cutout", 0.07393909593373034, 0.44569630026328344], ["TranslateX", 0.2460593252211425, 0.4817527814541055]],
         [["Brightness", 0.31058654119340867, 0.7043749950260936], ["ShearX", 0.7632161538947713, 0.8043681264908555]],
         [["AutoContrast", 0.4352334371415373, 0.6377550087204297],
          ["Rotate", 0.2892714673415678, 0.49521052050510556]],
         [["Equalize", 0.509071051375276, 0.7352913414974414], ["ShearX", 0.5099959429711828, 0.7071566714593619]],
         [["Posterize", 0.9540506532512889, 0.8498853304461906], ["ShearY", 0.28199061357155397, 0.3161715627214629]],
         [["Posterize", 0.6740855359097433, 0.684004694936616], ["Posterize", 0.6816720350737863, 0.9654766942980918]],
         [["Solarize", 0.7149344531717328, 0.42212789795181643], ["Brightness", 0.686601460864528, 0.4263050070610551]],
         [["Cutout", 0.49577164991501, 0.08394890892056037], ["Rotate", 0.5810369852730606, 0.3320732965776973]],
         [["TranslateY", 0.1793755480490623, 0.6006520265468684],
          ["Brightness", 0.3769016576438939, 0.7190746300828186]],
         [["TranslateX", 0.7226363597757153, 0.3847027238123509],
          ["Brightness", 0.7641713191794035, 0.36234003077512544]],
         [["TranslateY", 0.1211227055347106, 0.6693523474608023],
          ["Brightness", 0.13011180247738063, 0.5126647617294864]],
         [["Equalize", 0.1501070550869129, 0.0038548909451806557],
          ["Posterize", 0.8266535939653881, 0.5502199643499207]], [["Sharpness", 0.550624117428359, 0.2023044586648523],
                                                                   ["Brightness", 0.06291556314780017,
                                                                    0.7832635398703937]],
         [["Color", 0.3701578205508141, 0.9051537973590863], ["Contrast", 0.5763972727739397, 0.4905511239739898]],
         [["Rotate", 0.7678527224046323, 0.6723066265307555], ["Solarize", 0.31458533097383207, 0.38329324335154524]],
         [["Brightness", 0.292050127929522, 0.7047582807953063], ["ShearX", 0.040541891910333805, 0.06639328601282746]],
         [["TranslateY", 0.4293891393238555, 0.6608516902234284],
          ["Sharpness", 0.7794685477624004, 0.5168044063408147]],
         [["Color", 0.3682450402286552, 0.17274523597220048], ["ShearY", 0.3936056470397763, 0.5702597289866161]],
         [["Equalize", 0.43436990310624657, 0.9207072627823626], ["Contrast", 0.7608688260846083, 0.4759023148841439]],
         [["Brightness", 0.7926088966143935, 0.8270093925674497], ["ShearY", 0.4924174064969461, 0.47424347505831244]],
         [["Contrast", 0.043917555279430476, 0.15861903591675125], ["ShearX", 0.30439480405505853, 0.1682659341098064]],
         [["TranslateY", 0.5598255583454538, 0.721352536005039], ["Posterize", 0.9700921973303752, 0.6882015184440126]],
         [["AutoContrast", 0.3620887415037668, 0.5958176322317132],
          ["TranslateX", 0.14213781552733287, 0.6230799786459947]],
         [["Color", 0.490366889723972, 0.9863152892045195], ["Color", 0.817792262022319, 0.6755656429452775]],
         [["Brightness", 0.7030707021937771, 0.254633187122679], ["Color", 0.13977318232688843, 0.16378180123959793]],
         [["AutoContrast", 0.2933247831326118, 0.6283663376211102],
          ["Sharpness", 0.85430478154147, 0.9753613184208796]],
         [["Rotate", 0.6674299955457268, 0.48571208708018976], ["Contrast", 0.47491370175907016, 0.6401079552479657]],
         [["Sharpness", 0.37589579644127863, 0.8475131989077025],
          ["TranslateY", 0.9985149867598191, 0.057815729375099975]],
         [["Equalize", 0.0017194373841596389, 0.7888361311461602], ["Contrast", 0.6779293670669408, 0.796851411454113]],
         [["TranslateY", 0.3296782119072306, 0.39765117357271834],
          ["Sharpness", 0.5890554357001884, 0.6318339473765834]],
         [["Posterize", 0.25423810893163856, 0.5400430289894207],
          ["Sharpness", 0.9273643918988342, 0.6480913470982622]],
         [["Cutout", 0.850219975768305, 0.4169812455601289], ["Solarize", 0.5418755745870089, 0.5679666650495466]],
         [["Brightness", 0.008881361977310959, 0.9282562314720516],
          ["TranslateY", 0.7736066471553994, 0.20041167606029642]],
         [["Brightness", 0.05382537581401925, 0.6405265501035952],
          ["Contrast", 0.30484329473639593, 0.5449338155734242]],
         [["Color", 0.613257119787967, 0.4541503912724138], ["Brightness", 0.9061572524724674, 0.4030159294447347]],
         [["Brightness", 0.02739111568942537, 0.006028056532326534],
          ["ShearX", 0.17276751958646486, 0.05967365780621859]],
         [["TranslateY", 0.4376298213047888, 0.7691816164456199],
          ["Sharpness", 0.8162292718857824, 0.6054926462265117]],
         [["Color", 0.37963069679121214, 0.5946919433483344], ["Posterize", 0.08485417284005387, 0.5663580913231766]],
         [["Equalize", 0.49785780226818316, 0.9999137109183761], ["Sharpness", 0.7685879484682496, 0.6260846154212211]],
         [["AutoContrast", 0.4190931409670763, 0.2374852525139795],
          ["Posterize", 0.8797422264608563, 0.3184738541692057]],
         [["Rotate", 0.7307269024632872, 0.41523609600701106], ["ShearX", 0.6166685870692289, 0.647133807748274]],
         [["Sharpness", 0.5633713231039904, 0.8276694754755876], ["Cutout", 0.8329340776895764, 0.42656043027424073]],
         [["ShearY", 0.14934828370884312, 0.8622510773680372], ["Invert", 0.25925989086863277, 0.8813283584888576]],
         [["Contrast", 0.9457071292265932, 0.43228655518614034], ["Sharpness", 0.8485316947644338, 0.7590298998732413]],
         [["AutoContrast", 0.8386103589399184, 0.5859583131318076],
          ["Solarize", 0.466758711343543, 0.9956215363818983]],
         [["Rotate", 0.9387133710926467, 0.19180564509396503], ["Rotate", 0.5558247609706255, 0.04321698692007105]],
         [["ShearX", 0.3608716600695567, 0.15206159451532864], ["TranslateX", 0.47295292905710146, 0.5290760596129888]],
         [["TranslateX", 0.8357685981547495, 0.5991305115727084],
          ["Posterize", 0.5362929404188211, 0.34398525441943373]],
         [["ShearY", 0.6751984031632811, 0.6066293622133011], ["Contrast", 0.4122723990263818, 0.4062467515095566]],
         [["Color", 0.7515349936021702, 0.5122124665429213], ["Contrast", 0.03190514292904123, 0.22903520154660545]],
         [["Contrast", 0.5448962625054385, 0.38655673938910545],
          ["AutoContrast", 0.4867400684894492, 0.3433111101096984]],
         [["Rotate", 0.0008372434310827959, 0.28599951781141714],
          ["Equalize", 0.37113686925530087, 0.5243929348114981]],
         [["Color", 0.720054993488857, 0.2010177651701808], ["TranslateX", 0.23036196506059398, 0.11152764304368781]],
         [["Cutout", 0.859134208332423, 0.6727345740185254], ["ShearY", 0.02159833505865088, 0.46390076266538544]],
         [["Sharpness", 0.3428232157391428, 0.4067874527486514],
          ["Brightness", 0.5409415136577347, 0.3698432231874003]],
         [["Solarize", 0.27303978936454776, 0.9832186173589548], ["ShearY", 0.08831127213044043, 0.4681870331149774]],
         [["TranslateY", 0.2909309268736869, 0.4059460811623174],
          ["Sharpness", 0.6425125139803729, 0.20275737203293587]],
         [["Contrast", 0.32167626214661627, 0.28636162794046977], ["Invert", 0.4712405253509603, 0.7934644799163176]],
         [["Color", 0.867993060896951, 0.96574321666213], ["Color", 0.02233897320328512, 0.44478933557303063]],
         [["AutoContrast", 0.1841254751814967, 0.2779992148017741], ["Color", 0.3586283093530607, 0.3696246850445087]],
         [["Posterize", 0.2052935984046965, 0.16796913860308244], ["ShearX", 0.4807226832843722, 0.11296747254563266]],
         [["Cutout", 0.2016411266364791, 0.2765295444084803], ["Brightness", 0.3054112810424313, 0.695924264931216]],
         [["Rotate", 0.8405872184910479, 0.5434142541450815], ["Cutout", 0.4493615138203356, 0.893453735250007]],
         [["Contrast", 0.8433310507685494, 0.4915423577963278], ["ShearX", 0.22567799557913246, 0.20129892537008834]],
         [["Contrast", 0.045954277103674224, 0.5043900167190442], ["Cutout", 0.5552992473054611, 0.14436447810888237]],
         [["AutoContrast", 0.7719296115130478, 0.4440417544621306],
          ["Sharpness", 0.13992809206158283, 0.7988278670709781]],
         [["Color", 0.7838574233513952, 0.5971351401625151], ["TranslateY", 0.13562290583925385, 0.2253039635819158]],
         [["Cutout", 0.24870301109385806, 0.6937886690381568], ["TranslateY", 0.4033400068952813, 0.06253378991880915]],
         [["TranslateX", 0.0036059390486775644, 0.5234723884081843],
          ["Solarize", 0.42724862530733526, 0.8697702564187633]],
         [["Equalize", 0.5446026737834311, 0.9367992979112202], ["ShearY", 0.5943478903735789, 0.42345889214100046]],
         [["ShearX", 0.18611885697957506, 0.7320849092947314], ["ShearX", 0.3796416430900566, 0.03817761920009881]],
         [["Posterize", 0.37636778506979124, 0.26807924785236537],
          ["Brightness", 0.4317372554383255, 0.5473346211870932]],
         [["Brightness", 0.8100436240916665, 0.3817612088285007],
          ["Brightness", 0.4193974619003253, 0.9685902764026623]],
         [["Contrast", 0.701776402197012, 0.6612786008858009], ["Color", 0.19882787177960912, 0.17275597188875483]],
         [["Color", 0.9538303302832989, 0.48362384535228686], ["ShearY", 0.2179980837345602, 0.37027290936457313]],
         [["TranslateY", 0.6068028691503798, 0.3919346523454841], ["Cutout", 0.8228303342563138, 0.18372280287814613]],
         [["Equalize", 0.016416758802906828, 0.642838949194916], ["Cutout", 0.5761717838655257, 0.7600661153497648]],
         [["Color", 0.9417761826818639, 0.9916074035986558], ["Equalize", 0.2524209308597042, 0.6373703468715077]],
         [["Brightness", 0.75512589439513, 0.6155072321007569], ["Contrast", 0.32413476940254515, 0.4194739830159837]],
         [["Sharpness", 0.3339450765586968, 0.9973297539194967],
          ["AutoContrast", 0.6523930242124429, 0.1053482471037186]],
         [["ShearX", 0.2961391955838801, 0.9870036064904368], ["ShearY", 0.18705025965909403, 0.4550895821154484]],
         [["TranslateY", 0.36956447983807883, 0.36371471767143543],
          ["Sharpness", 0.6860051967688487, 0.2850190720087796]],
         [["Cutout", 0.13017742151902967, 0.47316674150067195], ["Invert", 0.28923829959551883, 0.9295585654924601]],
         [["Contrast", 0.7302368472279086, 0.7178974949876642],
          ["TranslateY", 0.12589674152030433, 0.7485392909494947]],
         [["Color", 0.6474693117772619, 0.5518269515590674], ["Contrast", 0.24643004970708016, 0.3435581358079418]],
         [["Contrast", 0.5650327855750835, 0.4843031798040887], ["Brightness", 0.3526684005761239, 0.3005305004600969]],
         [["Rotate", 0.09822284968122225, 0.13172798244520356], ["Equalize", 0.38135066977857157, 0.5135129123554154]],
         [["Contrast", 0.5902590645585712, 0.2196062383730596], ["ShearY", 0.14188379126120954, 0.1582612142182743]],
         [["Cutout", 0.8529913814417812, 0.89734031211874], ["Color", 0.07293767043078672, 0.32577659205278897]],
         [["Equalize", 0.21401668971453247, 0.040015259500028266], ["ShearY", 0.5126400895338797, 0.4726484828276388]],
         [["Brightness", 0.8269430025954498, 0.9678362841865166], ["ShearY", 0.17142069814830432, 0.4726727848289514]],
         [["Brightness", 0.699707089334018, 0.2795501395789335], ["ShearX", 0.5308818178242845, 0.10581814221896294]],
         [["Equalize", 0.32519644258946145, 0.15763390340309183],
          ["TranslateX", 0.6149090364414208, 0.7454832565718259]],
         [["AutoContrast", 0.5404508567155423, 0.7472387762067986],
          ["Equalize", 0.05649876539221024, 0.5628180219887216]]]
    return p


class FastAugmentation(object):
    def __init__(self, policies=fa_resnet50_rimagenet()):
        self.policies = policies

    def __call__(self, img):
        for _ in range(1):
            policy = random.choice(self.policies)
            for name, pr, level in policy:
                if random.random() > pr:
                    continue
                img = apply_augment(img, name, level)
        return img

In [4]:
def drop_path_f(x, drop_prob: float = 0., training: bool = False):
    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
    This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
    changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
    'survival rate' as the argument.
    """
    if drop_prob == 0. or not training:
        return x
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
    random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
    random_tensor.floor_()  # binarize
    output = x.div(keep_prob) * random_tensor
    return output


class DropPath(nn.Module):
    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
    """
    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        return drop_path_f(x, self.drop_prob, self.training)

def window_partition(x, window_size: int):
    """
    
    Args:
        x: (B, H, W, C)
        window_size (int): window size(M)
    Returns:
        windows: (num_windows*B, window_size, window_size, C)
    """
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    # permute: [B, H//Mh, Mh, W//Mw, Mw, C] -> [B, H//Mh, W//Mh, Mw, Mw, C]
    # view: [B, H//Mh, W//Mw, Mh, Mw, C] -> [B*num_windows, Mh, Mw, C]
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows

def window_reverse(windows, window_size: int, H: int, W: int):
    """
    
    Args:
        windows: (num_windows*B, window_size, window_size, C)
        window_size (int): Window size(M)
        H (int): Height of image
        W (int): Width of image
    Returns:
        x: (B, H, W, C)
    """
    B = int(windows.shape[0] / (H * W / window_size / window_size))
    # view: [B*num_windows, Mh, Mw, C] -> [B, H//Mh, W//Mw, Mh, Mw, C]
    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
    # permute: [B, H//Mh, W//Mw, Mh, Mw, C] -> [B, H//Mh, Mh, W//Mw, Mw, C]
    # view: [B, H//Mh, Mh, W//Mw, Mw, C] -> [B, H, W, C]
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
    return x

class Mlp(nn.Module):
    """ MLP as used in Vision Transformer, MLP-Mixer and related networks
    """
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features

        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.drop1 = nn.Dropout(drop)
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop2 = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop1(x)
        x = self.fc2(x)
        x = self.drop2(x)
        return x

class WindowAttention(nn.Module):
    """ Window based multi-head self attention (W-MSA) module with relative position bias.
    It supports both of shifted and non-shifted window.
    Args:
        dim (int): Number of input channels.
        window_size (tuple[int]): The height and width of the window.
        num_heads (int): Number of attention heads.
        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
    """

    def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.):

        super().__init__()
        self.dim = dim
        self.window_size = window_size  # [Mh, Mw]
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        # define a parameter table of relative position bias
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # [2*Mh-1 * 2*Mw-1, nH]

        # get pair-wise relative position index for each token inside the window
        coords_h = torch.arange(self.window_size[0])
        coords_w = torch.arange(self.window_size[1])
        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # [2, Mh, Mw]
        coords_flatten = torch.flatten(coords, 1)  # [2, Mh*Mw]
        # [2, Mh*Mw, 1] - [2, 1, Mh*Mw]
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # [2, Mh*Mw, Mh*Mw]
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # [Mh*Mw, Mh*Mw, 2]
        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
        relative_position_index = relative_coords.sum(-1)  # [Mh*Mw, Mh*Mw]
        self.register_buffer("relative_position_index", relative_position_index)

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        nn.init.trunc_normal_(self.relative_position_bias_table, std=.02)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, mask = None):
        """
        Args:
            x: input features with shape of (num_windows*B, Mh*Mw, C)
            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
        """
        # [batch_size*num_windows, Mh*Mw, total_embed_dim]
        B_, N, C = x.shape
        # qkv(): -> [batch_size*num_windows, Mh*Mw, 3 * total_embed_dim]
        # reshape: -> [batch_size*num_windows, Mh*Mw, 3, num_heads, embed_dim_per_head]
        # permute: -> [3, batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        # [batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]
        q, k, v = qkv.unbind(0)  # make torchscript happy (cannot use tensor as tuple)

        # transpose: -> [batch_size*num_windows, num_heads, embed_dim_per_head, Mh*Mw]
        # @: multiply -> [batch_size*num_windows, num_heads, Mh*Mw, Mh*Mw]
        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))

        # relative_position_bias_table.view: [Mh*Mw*Mh*Mw,nH] -> [Mh*Mw,Mh*Mw,nH]
        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # [nH, Mh*Mw, Mh*Mw]
        attn = attn + relative_position_bias.unsqueeze(0)

        if mask is not None:
            # mask: [nW, Mh*Mw, Mh*Mw]
            nW = mask.shape[0]  # num_windows
            # attn.view: [batch_size, num_windows, num_heads, Mh*Mw, Mh*Mw]
            # mask.unsqueeze: [1, nW, 1, Mh*Mw, Mh*Mw]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)
        else:
            attn = self.softmax(attn)

        attn = self.attn_drop(attn)

        # @: multiply -> [batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]
        # transpose: -> [batch_size*num_windows, Mh*Mw, num_heads, embed_dim_per_head]
        # reshape: -> [batch_size*num_windows, Mh*Mw, total_embed_dim]
        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

class SwinTransformerLayer(nn.Module):
    # Vision Transformer https://arxiv.org/abs/2010.11929
    def __init__(self, c, num_heads, window_size=7, shift_size=0, 
                mlp_ratio = 4, qkv_bias=False, drop=0., attn_drop=0., drop_path=0.,
                act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        if num_heads > 10:
            drop_path = 0.1
        self.window_size = window_size
        self.shift_size = shift_size
        self.mlp_ratio = mlp_ratio

        self.norm1 = norm_layer(c)
        self.attn = WindowAttention(
            c, window_size=(self.window_size, self.window_size), num_heads=num_heads, qkv_bias=qkv_bias,
            attn_drop=attn_drop, proj_drop=drop)

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(c)
        mlp_hidden_dim = int(c * mlp_ratio)
        self.mlp = Mlp(in_features=c, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
        
    def create_mask(self, x, H, W):
        # calculate attention mask for SW-MSA
        
        Hp = int(np.ceil(H / self.window_size)) * self.window_size
        Wp = int(np.ceil(W / self.window_size)) * self.window_size
        
        img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device)  # [1, Hp, Wp, 1]
        h_slices = ( (0, -self.window_size),
                    slice(-self.window_size, -self.shift_size),
                    slice(-self.shift_size, None))
        w_slices = (slice(0, -self.window_size),
                    slice(-self.window_size, -self.shift_size),
                    slice(-self.shift_size, None))
        cnt = 0
        for h in h_slices:
            for w in w_slices:
                img_mask[:, h, w, :] = cnt
                cnt += 1

        mask_windows = window_partition(img_mask, self.window_size)  # [nW, Mh, Mw, 1]
        mask_windows = mask_windows.view(-1, self.window_size * self.window_size)  # [nW, Mh*Mw]
        attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)  # [nW, 1, Mh*Mw] - [nW, Mh*Mw, 1]
        # [nW, Mh*Mw, Mh*Mw]
        attn_mask = attn_mask.masked_fill(attn_mask != 0, torch.tensor(-100.0)).masked_fill(attn_mask == 0, torch.tensor(0.0))
        return attn_mask

    def forward(self, x):
        b, c, w, h = x.shape
        x = x.permute(0, 3, 2, 1).contiguous() # [b,h,w,c]

        attn_mask = self.create_mask(x, h, w) # [nW, Mh*Mw, Mh*Mw]
        shortcut = x
        x = self.norm1(x)
        
        pad_l = pad_t = 0
        pad_r = (self.window_size - w % self.window_size) % self.window_size
        pad_b = (self.window_size - h % self.window_size) % self.window_size
        x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
        _, hp, wp, _ = x.shape

        if self.shift_size > 0:
            # print(f"shift size: {self.shift_size}")
            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
        else:
            shifted_x = x
            attn_mask = None
        
        x_windows = window_partition(shifted_x, self.window_size) # [nW*B, Mh, Mw, C]
        x_windows = x_windows.view(-1, self.window_size * self.window_size, c) # [nW*B, Mh*Mw, C]

        attn_windows = self.attn(x_windows, mask=attn_mask)  # [nW*B, Mh*Mw, C]

        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, c)  # [nW*B, Mh, Mw, C]
        shifted_x = window_reverse(attn_windows, self.window_size, hp, wp)  # [B, H', W', C]
        
        if self.shift_size > 0:
            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
        else:
            x = shifted_x
        
        if pad_r > 0 or pad_b > 0:
            x = x[:, :h, :w, :].contiguous()

        x = shortcut + self.drop_path(x)
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        
        x = x.permute(0, 3, 2, 1).contiguous()
        return x # (b, self.c2, w, h)

class SwinTransformerBlock(nn.Module):
    def __init__(self, c1, c2, num_heads, num_layers, window_size=8):
        super().__init__()
        self.conv = None
        if c1 != c2:
            self.conv = Conv(c1, c2)

        self.window_size = window_size
        self.shift_size = window_size // 2
        self.tr = nn.Sequential(*(SwinTransformerLayer(c2, num_heads=num_heads, window_size=window_size,  shift_size=0 if (i % 2 == 0) else self.shift_size ) for i in range(num_layers)))

    def forward(self, x):
        if self.conv is not None:
            x = self.conv(x)
        x = self.tr(x)
        return x

In [5]:
'''import torch
import numpy as np
import math
import torch.nn as nn
import torch.nn.functional as F

__all__ = ['mesnest_s','mesnest_m','mesnest_l','mesnest_xl']

class Mish_func(torch.autograd.Function):
    @staticmethod
    def forward(ctx, i):
        result = i * torch.tanh(F.softplus(i))
        ctx.save_for_backward(i)
        return result

    @staticmethod
    def backward(ctx, grad_output):
        i = ctx.saved_tensors[0]
  
        v = 1. + i.exp()
        h = v.log() 
        grad_gh = 1./h.cosh().pow_(2) 

        # Note that grad_hv * grad_vx = sigmoid(x)
        #grad_hv = 1./v  
        #grad_vx = i.exp()
        grad_hx = i.sigmoid()

        grad_gx = grad_gh *  grad_hx #grad_hv * grad_vx 
        
        grad_f =  torch.tanh(F.softplus(i)) + i * grad_gx 
        
        return grad_output * grad_f 


class Mish(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()
        pass
    def forward(self, input_tensor):
        return Mish_func.apply(input_tensor)

def _make_divisible(v, divisor, min_value=None):
    """
    This function is taken from the original tf repo.
    It ensures that all layers have a channel number that is divisible by 8
    It can be seen here:
    https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
    :param v:
    :param divisor:
    :param min_value:
    :return:
    """
    if min_value is None:
        min_value = divisor
    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
    # Make sure that round down does not go down by more than 10%.
    if new_v < 0.9 * v:
        new_v += divisor
    return new_v

def autopad(k, p=None):  # kernel, padding
    # Pad to 'same'
    if p is None:
        p = k // 2 if isinstance(k, int) else [x // 2 for x in k]  # auto-pad
    return p

class rSoftMax(nn.Module):
    def __init__(self, radix, cardinality):
        super().__init__()
        self.radix = radix
        self.cardinality = cardinality

    def forward(self, x):
        batch = x.size(0)
        if self.radix > 1:
            x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2)
            x = F.softmax(x, dim=1)
            x = x.reshape(batch, -1)
        else:
            x = torch.sigmoid(x)
        return x

def channel_shuffle(x, groups):
    batchsize, num_channels, height, width = x.data.size()

    channels_per_group = num_channels // groups
    
    # reshape
    x = x.view(batchsize, groups, 
        channels_per_group, height, width)

    x = torch.transpose(x, 1, 2).contiguous()

    # flatten
    x = x.view(batchsize, -1, height, width)

    return x

class SplitAttn(nn.Module):
    def __init__(self, c1, c2=None, kernel_size=3, stride=1, padding=None,
                 dilation=1, groups=1, bias=False, radix=2, rd_ratio=0.25, rd_channels=None, rd_divisor=8,
                 act_layer=Mish, norm_layer=None, drop_block=None, **kwargs):
        super(SplitAttn, self).__init__()
        c2 = c2 or c1
        self.radix = radix
        self.drop_block = drop_block
        mid_chs = c2 * radix
        if rd_channels is None:
            attn_chs = _make_divisible(c1 * radix * rd_ratio, min_value=32, divisor=rd_divisor)
        else:
            attn_chs = rd_channels * radix

        padding = kernel_size // 2 if padding is None else padding
        self.conv = nn.Conv2d(
           c1, mid_chs, kernel_size, stride, padding, dilation,
            groups=groups * radix, bias=bias, **kwargs)
        #self.conv = GhostConv(c1,mid_chs,k=kernel_size,s=stride,g=groups*radix)
        self.bn0 = norm_layer(mid_chs) if norm_layer else nn.Identity()
        self.act0 = act_layer()
        self.fc1 = nn.Conv2d(c2, attn_chs, 1, groups=groups)
        #self.fc1 = GhostConv(c2, attn_chs, 1, 1)
        self.bn1 = norm_layer(attn_chs) if norm_layer else nn.Identity()
        self.act1 = act_layer()
        self.fc2 = nn.Conv2d(attn_chs, mid_chs, 1, groups=groups)
        #self.fc2 = GhostConv(attn_chs, mid_chs, 1, 1, act=False)
        self.rsoftmax = rSoftMax(radix, groups)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn0(x)
        if self.drop_block is not None:
            x = self.drop_block(x)
        x = self.act0(x)

        B, RC, H, W = x.shape
        if self.radix > 1:
            x = x.reshape((B, self.radix, RC // self.radix, H, W))
            x_gap = x.sum(dim=1)
        else:
            x_gap = x
        x_gap = x_gap.mean((2, 3), keepdim=True)
        x_gap = self.fc1(x_gap)
        x_gap = self.bn1(x_gap)
        x_gap = self.act1(x_gap)
        x_attn = self.fc2(x_gap)

        x_attn = self.rsoftmax(x_attn).view(B, -1, 1, 1)
        if self.radix > 1:
            out = (x * x_attn.reshape((B, self.radix, RC // self.radix, 1, 1))).sum(dim=1)
        else:
            out = x * x_attn
        output = out.contiguous()
        return output

class SELayer(nn.Module):
    def __init__(self, inp, oup, reduction=4):
        super(SELayer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
                nn.Linear(oup, _make_divisible(inp // reduction, 8)),
                Mish(),
                nn.Linear(_make_divisible(inp // reduction, 8), oup),
                nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y

class LayerNorm(nn.Module):
    """ LayerNorm that supports two data formats: channels_last (default) or channels_first. 
    The ordering of the dimensions in the inputs. channels_last corresponds to inputs with 
    shape (batch_size, height, width, channels) while channels_first corresponds to inputs 
    with shape (batch_size, channels, height, width).
    """
    def __init__(self, normalized_shape, eps=1e-6, data_format="channels_first"):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.bias = nn.Parameter(torch.zeros(normalized_shape))
        self.eps = eps
        self.data_format = data_format
        if self.data_format not in ["channels_last", "channels_first"]:
            raise NotImplementedError 
        self.normalized_shape = (normalized_shape, )
    
    def forward(self, x):
        if self.data_format == "channels_last":
            return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
        elif self.data_format == "channels_first":
            u = x.mean(1, keepdim=True)
            s = (x - u).pow(2).mean(1, keepdim=True)
            x = (x - u) / torch.sqrt(s + self.eps)
            x = self.weight[:, None, None] * x + self.bias[:, None, None]
            return x

class MixConv2d(nn.Module):
    # Mixed Depth-wise Conv https://arxiv.org/abs/1907.09595
    def __init__(self, c1, c2, k=(1, 3), s=1, equal_ch=True):  # ch_in, ch_out, kernel, stride, ch_strategy
        super().__init__()
        n = len(k)  # number of convolutions
        if equal_ch:  # equal c_ per group
            i = torch.linspace(0, n - 1E-6, c2).floor()  # c2 indices
            c_ = [(i == g).sum() for g in range(n)]  # intermediate channels
        else:  # equal weight.numel() per group
            b = [c2] + [0] * n
            a = np.eye(n + 1, n, k=-1)
            a -= np.roll(a, 1, axis=1)
            a *= np.array(k) ** 2
            a[0] = 1
            c_ = np.linalg.lstsq(a, b, rcond=None)[0].round()  # solve for equal weight indices, ax = b

        self.m = nn.ModuleList(
            [nn.Conv2d(c1, int(c_), k, s, k // 2, groups=math.gcd(c1, int(c_)), bias=False) for k, c_ in zip(k, c_)])
        self.bn = nn.BatchNorm2d(c2)
        self.act = Mish()

    def forward(self, x):
        return self.act(self.bn(torch.cat([m(x) for m in self.m], 1)))

class Conv(nn.Module):
    # Standard convolution
    def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True):  # ch_in, ch_out, kernel, stride, padding, groups
        super().__init__()
        self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
        self.bn = nn.BatchNorm2d(c2)
        self.act = Mish() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())

    def forward(self, x):
        return self.act(self.bn(self.conv(x)))

    def forward_fuse(self, x):
        return self.act(self.conv(x))

class GhostConv(nn.Module):
    # Ghost Convolution https://github.com/huawei-noah/ghostnet
    def __init__(self, c1, c2, k=1, s=1, g=1, act=True):  # ch_in, ch_out, kernel, stride, groups
        super().__init__()
        c_ = c2 // 2  # hidden channels
        self.cv1 = Conv(c1, c_, k, s, None, g, act)
        self.cv2 = MixConv2d(c_,c_,k=(5,7,9),s=1,equal_ch=True)
        #self.cv2 = Conv(c_, c_, 5, 1, None, c_, act)

    def forward(self, x):
        y = self.cv1(x)
        return torch.cat([y, self.cv2(y)], 1)
    
class MSConv(nn.Module):
    def __init__(self, c1, c2, stride = 1, use_se = 0, expand_ratio=2):
        super(MSConv, self).__init__()
        assert stride in [1, 2]

        hidden_dim = round(c1 * expand_ratio)
        self.identity = stride == 1 and c1 == c2
        self.stride = stride
        if not self.identity:
            if stride > 1:
                self.down = nn.Sequential(nn.MaxPool2d(3,2,1),
                                          nn.Conv2d(c1,c2,1,1,0,bias=False))
            else:
                self.down = nn.Conv2d(c1,c2,1,1,0,bias=False)
            
        
        if use_se:
            self.conv = nn.Sequential(
                GhostConv(c1, hidden_dim, 1, 1),
                #mixed conv
                MixConv2d(hidden_dim,hidden_dim,k=(3,5,7,9),s=stride,equal_ch=True),
                #se
                SELayer(c1, hidden_dim),
                #pw-linear
                GhostConv(hidden_dim, c2, 1, 1, act=False),
            )
        else:
            self.conv = nn.Sequential(
                SPPF(c1,c1),
                #mixconv
                MixConv2d(c1,hidden_dim,k=(3,5,7,9),s=1,equal_ch=True),
                #split attention - radix = 2 - SK-Unit
                SplitAttn(c1=hidden_dim,c2=c2,stride=stride,groups=32,radix=2,
                          rd_ratio=0.25, act_layer=Mish, norm_layer=nn.BatchNorm2d),
            )


    def forward(self, x):
        if self.identity:
            out = x + self.conv(x)
        else:
            out = self.down(x) + self.conv(x)
        return channel_shuffle(out,2)

class MSBlock(nn.Module):
    def __init__(self, c1, c2, stride, use_se, radix=2, expand_ratio=1):
        super(MSBlock, self).__init__()
        assert stride in [1, 2]

        hidden_dim = round(c1 * expand_ratio)
        self.identity = stride == 1 and c1 == c2
        self.stride = stride
        if not self.identity:
            if stride > 1:
                self.down = nn.Sequential(nn.MaxPool2d(3,2,1),
                                          nn.Conv2d(c1,c2,1,1,0,bias=False))
            else:
                self.down = nn.Conv2d(c1,c2,1,1,0,bias=False)
            
        
        if use_se:
            #self.conv = C3MSGhost(c1,c2)
            self.conv = SwinTransformerBlock(c1,c2,c1//32,1)
        else:    
            self.conv = nn.Sequential(
                SPPF(c1,c1),
                #mixconv
                #MixConv2d(c1,hidden_dim,k=(3,5,7,9),s=1,equal_ch=True),
                #split attention - radix = 2 - SK-Unit
                SplitAttn(c1=c1,c2=c2,stride=stride,groups=32,radix=2,
                          rd_ratio=0.25, act_layer=Mish, norm_layer=LayerNorm),
            )
            

    def forward(self, x):
        if self.identity:
            out = x + self.conv(x)
        else:
            out = self.down(x) + self.conv(x)
        return channel_shuffle(out,2)
        #return out

class SPPF(nn.Module):
    # Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher
    def __init__(self, c1, c2, k=5):  # equivalent to SPP(k=(5, 9, 13))
        super().__init__()
        c_ = c1 // 2  # hidden channels
        self.cv1 = Conv(c1, c_, 1, 1)
        self.cv2 = Conv(c_ * 4, c2, 1, 1)
        self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)

    def forward(self, x):
        x = self.cv1(x)
        with warnings.catch_warnings():
            warnings.simplefilter('ignore')  # suppress torch 1.9.0 max_pool2d() warning
            y1 = self.m(x)
            y2 = self.m(y1)
            return self.cv2(torch.cat([x, y1, y2, self.m(y2)], 1))
        
class BottleneckMSGhost(nn.Module):
    # Standard bottleneck
    def __init__(self, c1, c2, shortcut=True, g=1, e=0.25):  # ch_in, ch_out, shortcut, groups, expansion
        super().__init__()
        c_ = int(c2 * e)  # hidden channels
        self.cv1 = Conv(c1, c_, 1, 1)
        #self.cv3 = MixConv2d(c_,c2,k=(3,5,7,9),s=1,equal_ch=True)
        self.cv2 = Conv(c_, c2, 3, 1, g=g)
        self.sppf = SPPF(c_,c_)
        #self.cv3 = Conv(c_, c_hidden, 1, 1)
        #self.msconv_sk = MSConv(c_hidden,c_hidden,1,0,1)
        self.msconv_se = MSConv(c2,c2,1,1,1)
        self.add = shortcut and c1 == c2

    def forward(self, x):
        return x + self.msconv_se(self.cv2(self.sppf(self.cv1(x)))) if self.add else self.msconv_se(self.cv2(self.sppf(self.cv1(x))))
        
class BottleneckMS(nn.Module):
    # Standard bottleneck
    def __init__(self, c1, c2, shortcut=True, g=1, e=0.25):  # ch_in, ch_out, shortcut, groups, expansion
        super().__init__()
        c_ = int(c2 * e)  # hidden channels
        self.cv1 = Conv(c1, c_, 1, 1)
        #self.cv3 = MixConv2d(c_,c2,k=(3,5,7,9),s=1,equal_ch=True)
        self.cv3 = Conv(c_, c2, 3, 1, g=g)
        self.sppf = SPPF(c_,c_)
        #self.cv3 = Conv(c_, c_hidden, 1, 1)
        #self.msconv_sk = MSConv(c_hidden,c_hidden,1,0,1)
        self.msconv_sk = MSConv(c2,c2,1,0,1)
        self.add = shortcut and c1 == c2

    def forward(self, x):
        return x + self.msconv_sk(self.cv3(self.sppf(self.cv1(x)))) if self.add else self.msconv_sk(self.cv3(self.sppf(self.cv1(x))))        
        
class C3MS(nn.Module):
    def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, splatt_groups=32):  # ch_in, ch_out, number, shortcut, groups, expansion
        super().__init__()
        c_ = int(c2 * e)  # hidden channels
        self.cv1 = Conv(c1, c_, 1, 1)
        self.cv2 = Conv(c1, c_, 1, 1)
        self.cv3 = Conv(2 * c_, c2, 1)  # act=FReLU(c2)
        self.swin = SwinTransformerBlock(c_, c_, c_//32, n)
        self.m = nn.Sequential(*(BottleneckMS(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
        # self.m = nn.Sequential(*[CrossConv(c_, c_, 3, 1, g, 1.0, shortcut) for _ in range(n)])

    def forward(self, x):
        return self.cv3(torch.cat((self.m(self.cv1(x)), self.swin(self.cv2(x))), dim=1))        
    
class C3MSGhost(nn.Module):
    def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):  # ch_in, ch_out, number, shortcut, groups, expansion
        super().__init__()
        c_ = int(c2 * e)  # hidden channels
        self.cv1 = Conv(c1, c_, 1, 1)
        self.cv2 = Conv(c1, c_, 1, 1)
        self.cv3 = Conv(2 * c_, c2, 1)  # act=FReLU(c2)
        self.swin = SwinTransformerBlock(c_, c_, c_//32, n)
        self.m = nn.Sequential(*(BottleneckMSGhost(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
        # self.m = nn.Sequential(*[CrossConv(c_, c_, 3, 1, g, 1.0, shortcut) for _ in range(n)])

    def forward(self, x):
        return self.cv3(torch.cat((self.m(self.cv1(x)), self.swin(self.cv2(x))), dim=1))   
        
        

def conv_3x3_bn(inp, oup, stride):
    return nn.Sequential(
        nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
        nn.BatchNorm2d(oup),
        Mish()
    )

def conv_1x1_bn(inp, oup):
    return nn.Sequential(
        nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
        nn.BatchNorm2d(oup),
        Mish()
    )

def get_n_params(model):
    pp=0
    for p in list(model.parameters()):
        nn=1
        for s in list(p.size()):
            nn = nn*s
        pp += nn
    return pp

class MESNeSt(nn.Module):
    def __init__(self, cfgs, num_classes=100, width_mult=1.):
        super(MESNeSt, self).__init__()
        self.cfgs = cfgs
        # building first layer
        input_channel = _make_divisible(64 * width_mult, 8)
        layers = [conv_3x3_bn(3, input_channel, 2)]
        # building inverted residual blocks
        block = MSBlock
        for t, c, n, s, use_se, radix in self.cfgs:
            output_channel = _make_divisible(c * width_mult, 8)
            if input_channel != output_channel:
                layers.append(MixConv2d(input_channel,output_channel,k=(3,5,7,9),s=2))
                input_channel = output_channel
            for i in range(n):
                layers.append(block(input_channel, output_channel, s if i == 0 else 1, use_se, radix, t))
                input_channel = output_channel
        self.features = nn.Sequential(*layers)
        # building last several layers
        output_channel = _make_divisible(1792 * width_mult, 8) if width_mult > 1.0 else 1792
        self.conv = conv_1x1_bn(input_channel, output_channel)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.classifier = nn.Linear(output_channel, num_classes)

        #self._initialize_weights()

    def forward(self, x):
        x = self.features(x)
        x = self.conv(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.weight.data.normal_(0, 0.001)
                m.bias.data.zero_()
                
def mesnest_n(**kwargs):
    """
    Constructs a MixConv-Swin-Attention-NeSt-N model
    """
    cfgs = [
        # t, c, n, s, SE,radix,splatt_groups
        [1, 64,   4, 1, 1, 0],
        [2, 128,  4, 1, 1, 0],
        [4, 256,  2, 1, 1, 2],
        [4, 256,  5, 1, 0, 2],
        [4, 512,  2, 1, 1, 2],
        [4, 512,  5, 1, 0, 2],
    ]
    return MESNeSt(cfgs, **kwargs)
            

def mesnest_s(**kwargs):
    """
    Constructs a MixConv-Swin-Attention-NeSt-S model
    """
    cfgs = [
        # t, c, n, s, Swin,radix,
        [1,  64,  4, 1, 1, 0],
        [2,  128,  4, 1, 1, 0],
        [4,  256,  8, 1, 0, 2],
        [4,  256,  4, 1, 1, 2],
        [4, 512,  8, 1, 0, 2],
        [4, 512,  3, 1, 1, 2],
        [6, 1024,  1, 1, 0, 2],
    ]
    return MESNeSt(cfgs, **kwargs)

def mesnest_m(**kwargs):
    """
    Constructs a MixConv-Swin-Attention-NeSt-S model
    """
    cfgs = [
        # t, c, n, s, SE,radix,splatt_groups
        [1,  64,  4, 1, 1, 0],
        [2,  128,  4, 1, 1, 0],
        [4,  256,  15, 1, 0, 2],
        [4,  256,  10, 1, 1, 2],
        [4, 512,  15, 1, 0, 2],
        [4, 512,  5, 1, 1, 2],
        [6, 1024,  5, 1, 0, 2],
    ]
    return MESNeSt(cfgs, **kwargs)

def mesnest_l(**kwargs):
    """
    Constructs a MixConv-Swin-Attention-NeSt-L model
    """
    cfgs = [
        # t, c, n, s, SE,radix,splatt_groups
        [1,  64,   2, 1, 1, 0],
        [2,  128,  2, 1, 1, 0],
        [4,  256,  25, 1, 0, 2],
        [4,  256,  15, 1, 1, 2],
        [4, 512,   30, 1, 0, 2],
        [4, 512,   15, 1, 1, 2],
        [6, 1024,  10, 1, 0, 2],
    ]
    return MESNeSt(cfgs, **kwargs)


def mesnest_x(**kwargs):
    """
    Constructs a MixConv-Swin-Attention-NeSt-X model
    """
    cfgs = [
        # t, c, n, s, SE,radix,splatt_groups
        [1,  64,   2,  1, 1, 0],
        [2,  128,  2,  1, 1, 0],
        [4,  256,  25, 1, 0, 2],
        [4,  256,  20, 1, 1, 2],
        [4, 512,   25, 1, 0, 2],
        [4, 512,   18, 1, 1, 2],
        [6, 1024,  10,  1, 0, 2],
        [4, 1024,  3,  1, 1, 2],
        [6, 2048,  4,  1, 0, 2],
    ]
    return MESNeSt(cfgs, **kwargs) '''



In [6]:
import torch
import numpy as np
import math
import torch.nn as nn
import torch.nn.functional as F

__all__ = ['mesnest_s','mesnest_m','mesnest_l','mesnest_xl']

class Mish_func(torch.autograd.Function):
    @staticmethod
    def forward(ctx, i):
        result = i * torch.tanh(F.softplus(i))
        ctx.save_for_backward(i)
        return result

    @staticmethod
    def backward(ctx, grad_output):
        i = ctx.saved_tensors[0]
  
        v = 1. + i.exp()
        h = v.log() 
        grad_gh = 1./h.cosh().pow_(2) 

        # Note that grad_hv * grad_vx = sigmoid(x)
        #grad_hv = 1./v  
        #grad_vx = i.exp()
        grad_hx = i.sigmoid()

        grad_gx = grad_gh *  grad_hx #grad_hv * grad_vx 
        
        grad_f =  torch.tanh(F.softplus(i)) + i * grad_gx 
        
        return grad_output * grad_f 


class Mish(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()
        pass
    def forward(self, input_tensor):
        return Mish_func.apply(input_tensor)

def _make_divisible(v, divisor, min_value=None):
    """
    This function is taken from the original tf repo.
    It ensures that all layers have a channel number that is divisible by 8
    It can be seen here:
    https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
    :param v:
    :param divisor:
    :param min_value:
    :return:
    """
    if min_value is None:
        min_value = divisor
    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
    # Make sure that round down does not go down by more than 10%.
    if new_v < 0.9 * v:
        new_v += divisor
    return new_v

def autopad(k, p=None):  # kernel, padding
    # Pad to 'same'
    if p is None:
        p = k // 2 if isinstance(k, int) else [x // 2 for x in k]  # auto-pad
    return p

class rSoftMax(nn.Module):
    def __init__(self, radix, cardinality):
        super().__init__()
        self.radix = radix
        self.cardinality = cardinality

    def forward(self, x):
        batch = x.size(0)
        if self.radix > 1:
            x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2)
            x = F.softmax(x, dim=1)
            x = x.reshape(batch, -1)
        else:
            x = torch.sigmoid(x)
        return x

def channel_shuffle(x, groups):
    batchsize, num_channels, height, width = x.data.size()

    channels_per_group = num_channels // groups
    
    # reshape
    x = x.view(batchsize, groups, 
        channels_per_group, height, width)

    x = torch.transpose(x, 1, 2).contiguous()

    # flatten
    x = x.view(batchsize, -1, height, width)

    return x

class SplitAttn(nn.Module):
    def __init__(self, c1, c2=None, kernel_size=3, stride=1, padding=None,
                 dilation=1, groups=1, bias=False, radix=2, rd_ratio=0.25, rd_channels=None, rd_divisor=8,
                 act_layer=Mish, norm_layer=None, drop_block=None, **kwargs):
        super(SplitAttn, self).__init__()
        c2 = c2 or c1
        self.radix = radix
        self.drop_block = drop_block
        mid_chs = c2 * radix
        if rd_channels is None:
            attn_chs = _make_divisible(c1 * radix * rd_ratio, min_value=32, divisor=rd_divisor)
        else:
            attn_chs = rd_channels * radix

        padding = kernel_size // 2 if padding is None else padding
        self.conv = nn.Conv2d(
           c1, mid_chs, kernel_size, stride, padding, dilation,
            groups=groups * radix, bias=bias, **kwargs)
        #self.conv = GhostConv(c1,mid_chs,k=kernel_size,s=stride,g=groups*radix)
        self.bn0 = norm_layer(mid_chs) if norm_layer else nn.Identity()
        self.act0 = act_layer()
        self.fc1 = nn.Conv2d(c2, attn_chs, 1, groups=groups)
        #self.fc1 = GhostConv(c2, attn_chs, 1, 1)
        self.bn1 = norm_layer(attn_chs) if norm_layer else nn.Identity()
        self.act1 = act_layer()
        self.fc2 = nn.Conv2d(attn_chs, mid_chs, 1, groups=groups)
        #self.fc2 = GhostConv(attn_chs, mid_chs, 1, 1, act=False)
        self.rsoftmax = rSoftMax(radix, groups)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn0(x)
        if self.drop_block is not None:
            x = self.drop_block(x)
        x = self.act0(x)

        B, RC, H, W = x.shape
        if self.radix > 1:
            x = x.reshape((B, self.radix, RC // self.radix, H, W))
            x_gap = x.sum(dim=1)
        else:
            x_gap = x
        x_gap = x_gap.mean((2, 3), keepdim=True)
        x_gap = self.fc1(x_gap)
        x_gap = self.bn1(x_gap)
        x_gap = self.act1(x_gap)
        x_attn = self.fc2(x_gap)

        x_attn = self.rsoftmax(x_attn).view(B, -1, 1, 1)
        if self.radix > 1:
            out = (x * x_attn.reshape((B, self.radix, RC // self.radix, 1, 1))).sum(dim=1)
        else:
            out = x * x_attn
        output = out.contiguous()
        return output

class MixConvSwin(nn.Module):
    def __init__(self, inp, oup, downsample=False, expansion=2):
        super(MixConvSwin, self).__init__()
        self.downsample = downsample
        stride = 1 if self.downsample == False else 2
        hidden_dim = int(inp * expansion)

        if self.downsample:
            self.pool = nn.MaxPool2d(3, 2, 1)
            self.proj = nn.Conv2d(inp, oup, 1, 1, 0, bias=False)
        
        self.conv = nn.Sequential(
                MixConv2d(inp,hidden_dim,k=(3,5,7,9),downsample=self.downsample),
                SwinTransformerBlock(hidden_dim,hidden_dim,hidden_dim//32,1),
                MixConv2d(hidden_dim,oup,k=(3,5,7,9),act=False)
            )
        

    def forward(self, x):
        if self.downsample:
            out = self.proj(self.pool(x)) + self.conv(x)  
        else:
            out = x + self.conv(x)
        return channel_shuffle(out,2)
    
class MixGhostConvSPPSPlatt(nn.Module):
    def __init__(self, inp, oup, downsample=False, expansion=4):
        super(MixGhostConvSPPSPlatt, self).__init__()
        self.downsample = downsample
        stride = 1 if self.downsample == False else 2
        hidden_dim = int(inp * expansion)

        if self.downsample:
            self.pool = nn.MaxPool2d(3, 2, 1)
            self.proj = nn.Conv2d(inp, oup, 1, 1, 0, bias=False)

        self.conv = nn.Sequential(
                #SPPF(inp,inp),
                MixConv2d(inp,hidden_dim,k=(3,5,7,9),downsample=self.downsample),
                SplitAttn(hidden_dim,hidden_dim,stride=1,groups=32,radix=2,
                          rd_ratio=0.25, act_layer=nn.Mish, norm_layer=LayerNorm),
                MixConv2d(hidden_dim,oup,k=(3,5,7,9),act=False)
            )

    def forward(self, x):
        if self.downsample:
            out = self.proj(self.pool(x)) + self.conv(x)  
        else:
            out = x + self.conv(x)
        return channel_shuffle(out,2)
    
class LayerNorm(nn.Module):
    """ LayerNorm that supports two data formats: channels_last (default) or channels_first. 
    The ordering of the dimensions in the inputs. channels_last corresponds to inputs with 
    shape (batch_size, height, width, channels) while channels_first corresponds to inputs 
    with shape (batch_size, channels, height, width).
    """
    def __init__(self, normalized_shape, eps=1e-6, data_format="channels_first"):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.bias = nn.Parameter(torch.zeros(normalized_shape))
        self.eps = eps
        self.data_format = data_format
        if self.data_format not in ["channels_last", "channels_first"]:
            raise NotImplementedError 
        self.normalized_shape = (normalized_shape, )
    
    def forward(self, x):
        if self.data_format == "channels_last":
            return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
        elif self.data_format == "channels_first":
            u = x.mean(1, keepdim=True)
            s = (x - u).pow(2).mean(1, keepdim=True)
            x = (x - u) / torch.sqrt(s + self.eps)
            x = self.weight[:, None, None] * x + self.bias[:, None, None]
            return x

class MixConv2d(nn.Module):
    # Mixed Depth-wise Conv https://arxiv.org/abs/1907.09595
    def __init__(self, c1, c2, k=(3,5,7,9), s=1, equal_ch=True, downsample=False, act=True):  # ch_in, ch_out, kernel, stride, ch_strategy
        super().__init__()
        n = len(k)  # number of convolutions
        if downsample:
            s = 2
        if equal_ch:  # equal c_ per group
            i = torch.linspace(0, n - 1E-6, c2).floor()  # c2 indices
            c_ = [(i == g).sum() for g in range(n)]  # intermediate channels
        else:  # equal weight.numel() per group
            b = [c2] + [0] * n
            a = np.eye(n + 1, n, k=-1)
            a -= np.roll(a, 1, axis=1)
            a *= np.array(k) ** 2
            a[0] = 1
            c_ = np.linalg.lstsq(a, b, rcond=None)[0].round()  # solve for equal weight indices, ax = b

        self.m = nn.ModuleList(
            [nn.Conv2d(c1, int(c_), k, s, k // 2, groups=math.gcd(c1, int(c_)), bias=False) for k, c_ in zip(k, c_)])
        self.bn = nn.BatchNorm2d(c2)
        self.act = nn.Mish(inplace=True) if act is True else nn.Identity()

    def forward(self, x):
        return self.act(self.bn(torch.cat([m(x) for m in self.m], 1)))

class Conv(nn.Module):
    # Standard convolution
    def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True):  # ch_in, ch_out, kernel, stride, padding, groups
        super().__init__()
        self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
        self.bn = nn.BatchNorm2d(c2)
        self.act = nn.Mish(inplace=True) if act is True else (act if isinstance(act, nn.Module) else nn.Identity())

    def forward(self, x):
        return self.act(self.bn(self.conv(x)))

    def forward_fuse(self, x):
        return self.act(self.conv(x))

class GhostConv(nn.Module):
    # Ghost Convolution https://github.com/huawei-noah/ghostnet
    def __init__(self, c1, c2, k=1, s=1, g=1, act=True):  # ch_in, ch_out, kernel, stride, groups
        super().__init__()
        c_ = c2 // 2  # hidden channels
        self.cv1 = Conv(c1, c_, k, s, None, g, act)
        self.cv2 = MixConv2d(c_,c_,k=(3,5,7,9),s=1,equal_ch=True)
        #self.cv2 = Conv(c_, c_, 5, 1, None, c_, act)

    def forward(self, x):
        y = self.cv1(x)
        return torch.cat([y, self.cv2(y)], 1)
    
class SPPF(nn.Module):
    # Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher
    def __init__(self, c1, c2, k=5):  # equivalent to SPP(k=(5, 9, 13))
        super().__init__()
        c_ = c1 // 2  # hidden channels
        self.cv1 = Conv(c1, c_, 1, 1)
        self.cv2 = Conv(c_ * 4, c2, 1, 1)
        self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)

    def forward(self, x):
        x = self.cv1(x)
        with warnings.catch_warnings():
            warnings.simplefilter('ignore')  # suppress torch 1.9.0 max_pool2d() warning
            y1 = self.m(x)
            y2 = self.m(y1)
            return self.cv2(torch.cat([x, y1, y2, self.m(y2)], 1))
        
def conv_3x3_bn(inp, oup, stride):
    return nn.Sequential(
        nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
        nn.BatchNorm2d(oup),
        Mish()
    )

def conv_1x1_bn(inp, oup):
    return nn.Sequential(
        nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
        nn.BatchNorm2d(oup),
        Mish()
    )

def get_n_params(model):
    pp=0
    for p in list(model.parameters()):
        nn=1
        for s in list(p.size()):
            nn = nn*s
        pp += nn
    return pp


class CoAtNet(nn.Module):
    def __init__(self, image_size, in_channels, num_blocks, channels, num_classes=1000, block_types=['C', 'T', 'T', 'T']):
        super().__init__()
        ih, iw = image_size
        block = {'C': MixGhostConvSPPSPlatt, 'T': MixConvSwin}

        self.s0 = self._make_layer(
            MixConv2d, in_channels, channels[0], num_blocks[0], (ih // 2, iw // 2))
        self.s1 = self._make_layer(
            block[block_types[0]], channels[0], channels[1], num_blocks[1], (ih // 4, iw // 4))
        self.s2 = self._make_layer(
            block[block_types[1]], channels[1], channels[2], num_blocks[2], (ih // 8, iw // 8))
        self.s3 = self._make_layer(
            block[block_types[2]], channels[2], channels[3], num_blocks[3], (ih // 16, iw // 16))
        self.s4 = self._make_layer(
            block[block_types[3]], channels[3], channels[4], num_blocks[4], (ih // 32, iw // 32))

        self.pool = nn.AvgPool2d(ih // 32, 1)
        self.fc = nn.Linear(channels[-1], num_classes, bias=False)

    def forward(self, x):
        x = self.s0(x)
        x = self.s1(x)
        x = self.s2(x)
        x = self.s3(x)
        x = self.s4(x)

        x = self.pool(x).view(-1, x.shape[1])
        x = self.fc(x)
        return x

    def _make_layer(self, block, inp, oup, depth, image_size):
        layers = nn.ModuleList([])
        for i in range(depth):
            if i == 0:
                layers.append(block(inp, oup, downsample=True))
            else:
                layers.append(block(oup, oup))
        return nn.Sequential(*layers)


def coatnest_0():
    num_blocks = [2, 2, 3, 5, 2]            # L
    channels = [64, 96, 192, 384, 768]      # D
    return CoAtNet((256, 256), 3, num_blocks, channels, num_classes=100)


def coatnest_1():
    num_blocks = [2, 2, 6, 14, 2]           # L
    channels = [64, 96, 192, 384, 768]      # D
    return CoAtNet((256, 256), 3, num_blocks, channels, num_classes=100)


def coatnest_2():
    num_blocks = [2, 2, 6, 14, 2]           # L
    channels = [128, 128, 256, 512, 1026]   # D
    return CoAtNet((256, 256), 3, num_blocks, channels, num_classes=100)


def coatnest_3():
    num_blocks = [2, 2, 6, 14, 2]           # L
    channels = [192, 192, 384, 768, 1536]   # D
    return CoAtNet((256, 256), 3, num_blocks, channels, num_classes=100)


def coatnest_4():
    num_blocks = [2, 2, 12, 28, 2]          # L
    channels = [192, 192, 384, 768, 1536]   # D
    return CoAtNet((256, 256), 3, num_blocks, channels, num_classes=100)


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [7]:
class MSBlock(nn.Module):
    def __init__(self, c1, c2, stride, use_se, radix=2, expand_ratio=1):
        super(MSBlock, self).__init__()
        assert stride in [1, 2]

        hidden_dim = round(c1 * expand_ratio)
        self.identity = stride == 1 and c1 == c2
        self.stride = stride
        if not self.identity:
            if stride > 1:
                self.down = nn.Sequential(nn.MaxPool2d(3,2,1),
                                          nn.Conv2d(c1,c2,1,1,0,bias=False))
            else:
                self.down = nn.Conv2d(c1,c2,1,1,0,bias=False)
            
        
        if use_se:
            #self.conv = C3MSGhost(c1,c2)
            self.conv = SwinTransformerBlock(c1,c2,c1//32,1)
        else:    
            self.conv = nn.Sequential(
                SPPF(c1,c1),
                SplitAttn(c1=c1,c2=c2,stride=stride,groups=32,radix=2,
                          rd_ratio=0.25, act_layer=Mish, norm_layer=LayerNorm),
            )
            
    def forward(self, x):
        if self.identity:
            out = x + self.conv(x)
        else:
            out = self.down(x) + self.conv(x)
        return channel_shuffle(out,2)

def conv_3x3_bn(inp, oup, stride):
    return nn.Sequential(
        nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
        nn.BatchNorm2d(oup),
        Mish()
    )

def conv_1x1_bn(inp, oup):
    return nn.Sequential(
        nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
        nn.BatchNorm2d(oup),
        Mish()
    )

def get_n_params(model):
    pp=0
    for p in list(model.parameters()):
        nn=1
        for s in list(p.size()):
            nn = nn*s
        pp += nn
    return pp

class MESNeSt(nn.Module):
    def __init__(self, cfgs, num_classes=100, width_mult=1.):
        super(MESNeSt, self).__init__()
        self.cfgs = cfgs
        # building first layer
        input_channel = _make_divisible(64 * width_mult, 8)
        layers = [conv_3x3_bn(3, input_channel, 2)]
        # building inverted residual blocks
        block = MSBlock
        for t, c, n, s, use_se, radix in self.cfgs:
            output_channel = _make_divisible(c * width_mult, 8)
            if input_channel != output_channel:
                layers.append(MixConv2d(input_channel,output_channel,k=(3,5,7,9),s=2))
                input_channel = output_channel
            for i in range(n):
                layers.append(block(input_channel, output_channel, s if i == 0 else 1, use_se, radix, t))
                input_channel = output_channel
        self.features = nn.Sequential(*layers)
        # building last several layers
        output_channel = _make_divisible(1792 * width_mult, 8) if width_mult > 1.0 else 1792
        self.conv = conv_1x1_bn(input_channel, output_channel)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.classifier = nn.Linear(output_channel, num_classes)
    
    def forward(self, x):
        x = self.features(x)
        x = self.conv(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x
    
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.weight.data.normal_(0, 0.001)
                m.bias.data.zero_()
                
def mesnest_n(**kwargs):
    """
    Constructs a MixConv-Swin-Attention-NeSt-N model
    """
    cfgs = [
        # t, c, n, s, SE,radix,splatt_groups
        [1, 64,   4, 1, 1, 0],
        [2, 128,  4, 1, 1, 0],
        [4, 256,  2, 1, 1, 2],
        [4, 256,  5, 1, 0, 2],
        [4, 512,  2, 1, 1, 2],
        [4, 512,  5, 1, 0, 2],
    ]
    return MESNeSt(cfgs, **kwargs)

def mesnest_s(**kwargs):
    """
    Constructs a MixConv-Swin-Attention-NeSt-S model
    """
    cfgs = [
        # t, c, n, s, Swin,radix,
        [1,  64,  4, 1, 1, 0],
        [2,  128,  4, 1, 1, 0],
        [4,  256,  8, 1, 0, 2],
        [4,  256,  6, 1, 0, 2],
        [4, 512,  4, 1, 1, 2],
        #[4, 512,  3, 1, 0, 2],
        [6, 768,  1, 1, 1, 2],
    ]
    return MESNeSt(cfgs, **kwargs)

def mesnest_m(**kwargs):
    """
    Constructs a MixConv-Swin-Attention-NeSt-S model
    """
    cfgs = [
        # t, c, n, s, SE,radix,splatt_groups
        [1,  64,  4, 1, 1, 0],
        [2,  128,  4, 1, 1, 0],
        [4,  256,  15, 1, 0, 2],
        [4,  256,  10, 1, 1, 2],
        [4, 512,  15, 1, 0, 2],
        [4, 512,  5, 1, 1, 2],
        [6, 1024,  5, 1, 0, 2],
    ]
    return MESNeSt(cfgs, **kwargs)

def mesnest_l(**kwargs):
    """
    Constructs a MixConv-Swin-Attention-NeSt-L model
    """
    cfgs = [
        # t, c, n, s, SE,radix,splatt_groups
        [1,  64,   2, 1, 1, 0],
        [2,  128,  2, 1, 1, 0],
        [4,  256,  25, 1, 0, 2],
        [4,  256,  15, 1, 1, 2],
        [4, 512,   30, 1, 0, 2],
        [4, 512,   15, 1, 1, 2],
        [6, 1024,  10, 1, 0, 2],
    ]
    return MESNeSt(cfgs, **kwargs)

def mesnest_x(**kwargs):
    """
    Constructs a MixConv-Swin-Attention-NeSt-X model
    """
    cfgs = [
        # t, c, n, s, SE,radix,splatt_groups
        [1,  64,   2,  1, 1, 0],
        [2,  128,  2,  1, 1, 0],
        [4,  256,  25, 1, 0, 2],
        [4,  256,  20, 1, 1, 2],
        [4, 512,   25, 1, 0, 2],
        [4, 512,   18, 1, 1, 2],
        [6, 1024,  10,  1, 0, 2],
        [4, 1024,  3,  1, 1, 2],
        [6, 2048,  4,  1, 0, 2],
    ]
    return MESNeSt(cfgs, **kwargs)

In [8]:
model = mesnest_s()
get_n_params(model)

25227028

In [9]:
'''import torch
import numpy as np
import math
import torch.nn as nn
import torch.nn.functional as F

__all__ = ['mesnest_s','mesnest_m','mesnest_l','mesnest_xl']

class Mish_func(torch.autograd.Function):
    @staticmethod
    def forward(ctx, i):
        result = i * torch.tanh(F.softplus(i))
        ctx.save_for_backward(i)
        return result

    @staticmethod
    def backward(ctx, grad_output):
        i = ctx.saved_tensors[0]
  
        v = 1. + i.exp()
        h = v.log() 
        grad_gh = 1./h.cosh().pow_(2) 

        # Note that grad_hv * grad_vx = sigmoid(x)
        #grad_hv = 1./v  
        #grad_vx = i.exp()
        grad_hx = i.sigmoid()

        grad_gx = grad_gh *  grad_hx #grad_hv * grad_vx 
        
        grad_f =  torch.tanh(F.softplus(i)) + i * grad_gx 
        
        return grad_output * grad_f 


class Mish(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()
        pass
    def forward(self, input_tensor):
        return Mish_func.apply(input_tensor)

def _make_divisible(v, divisor, min_value=None):
    """
    This function is taken from the original tf repo.
    It ensures that all layers have a channel number that is divisible by 8
    It can be seen here:
    https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
    :param v:
    :param divisor:
    :param min_value:
    :return:
    """
    if min_value is None:
        min_value = divisor
    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
    # Make sure that round down does not go down by more than 10%.
    if new_v < 0.9 * v:
        new_v += divisor
    return new_v

def autopad(k, p=None):  # kernel, padding
    # Pad to 'same'
    if p is None:
        p = k // 2 if isinstance(k, int) else [x // 2 for x in k]  # auto-pad
    return p

class rSoftMax(nn.Module):
    def __init__(self, radix, cardinality):
        super().__init__()
        self.radix = radix
        self.cardinality = cardinality

    def forward(self, x):
        batch = x.size(0)
        if self.radix > 1:
            x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2)
            x = F.softmax(x, dim=1)
            x = x.reshape(batch, -1)
        else:
            x = torch.sigmoid(x)
        return x

def channel_shuffle(x, groups):
    batchsize, num_channels, height, width = x.data.size()

    channels_per_group = num_channels // groups
    
    # reshape
    x = x.view(batchsize, groups, 
        channels_per_group, height, width)

    x = torch.transpose(x, 1, 2).contiguous()

    # flatten
    x = x.view(batchsize, -1, height, width)

    return x

class SplitAttn(nn.Module):
    def __init__(self, c1, c2=None, kernel_size=3, stride=1, padding=None,
                 dilation=1, groups=1, bias=False, radix=2, rd_ratio=0.25, rd_channels=None, rd_divisor=8,
                 act_layer=Mish, norm_layer=None, drop_block=None, **kwargs):
        super(SplitAttn, self).__init__()
        c2 = c2 or c1
        self.radix = radix
        self.drop_block = drop_block
        mid_chs = c2 * radix
        if rd_channels is None:
            attn_chs = _make_divisible(c1 * radix * rd_ratio, min_value=32, divisor=rd_divisor)
        else:
            attn_chs = rd_channels * radix

        padding = kernel_size // 2 if padding is None else padding
        self.conv = nn.Conv2d(
           c1, mid_chs, kernel_size, stride, padding, dilation,
            groups=groups * radix, bias=bias, **kwargs)
        #self.conv = GhostConv(c1,mid_chs,k=kernel_size,s=stride,g=groups*radix)
        self.bn0 = norm_layer(mid_chs) if norm_layer else nn.Identity()
        self.act0 = act_layer()
        self.fc1 = nn.Conv2d(c2, attn_chs, 1, groups=groups)
        #self.fc1 = GhostConv(c2, attn_chs, 1, 1)
        self.bn1 = norm_layer(attn_chs) if norm_layer else nn.Identity()
        self.act1 = act_layer()
        self.fc2 = nn.Conv2d(attn_chs, mid_chs, 1, groups=groups)
        #self.fc2 = GhostConv(attn_chs, mid_chs, 1, 1, act=False)
        self.rsoftmax = rSoftMax(radix, groups)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn0(x)
        if self.drop_block is not None:
            x = self.drop_block(x)
        x = self.act0(x)

        B, RC, H, W = x.shape
        if self.radix > 1:
            x = x.reshape((B, self.radix, RC // self.radix, H, W))
            x_gap = x.sum(dim=1)
        else:
            x_gap = x
        x_gap = x_gap.mean((2, 3), keepdim=True)
        x_gap = self.fc1(x_gap)
        x_gap = self.bn1(x_gap)
        x_gap = self.act1(x_gap)
        x_attn = self.fc2(x_gap)

        x_attn = self.rsoftmax(x_attn).view(B, -1, 1, 1)
        if self.radix > 1:
            out = (x * x_attn.reshape((B, self.radix, RC // self.radix, 1, 1))).sum(dim=1)
        else:
            out = x * x_attn
        output = out.contiguous()
        return output

class SELayer(nn.Module):
    def __init__(self, inp, oup, reduction=4):
        super(SELayer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
                nn.Linear(oup, _make_divisible(inp // reduction, 8)),
                Mish(),
                nn.Linear(_make_divisible(inp // reduction, 8), oup),
                nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y

class LayerNorm(nn.Module):
    """ LayerNorm that supports two data formats: channels_last (default) or channels_first. 
    The ordering of the dimensions in the inputs. channels_last corresponds to inputs with 
    shape (batch_size, height, width, channels) while channels_first corresponds to inputs 
    with shape (batch_size, channels, height, width).
    """
    def __init__(self, normalized_shape, eps=1e-6, data_format="channels_first"):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.bias = nn.Parameter(torch.zeros(normalized_shape))
        self.eps = eps
        self.data_format = data_format
        if self.data_format not in ["channels_last", "channels_first"]:
            raise NotImplementedError 
        self.normalized_shape = (normalized_shape, )
    
    def forward(self, x):
        if self.data_format == "channels_last":
            return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
        elif self.data_format == "channels_first":
            u = x.mean(1, keepdim=True)
            s = (x - u).pow(2).mean(1, keepdim=True)
            x = (x - u) / torch.sqrt(s + self.eps)
            x = self.weight[:, None, None] * x + self.bias[:, None, None]
            return x

class MixConv2d(nn.Module):
    # Mixed Depth-wise Conv https://arxiv.org/abs/1907.09595
    def __init__(self, c1, c2, k=(1, 3), s=1, equal_ch=True):  # ch_in, ch_out, kernel, stride, ch_strategy
        super().__init__()
        n = len(k)  # number of convolutions
        if equal_ch:  # equal c_ per group
            i = torch.linspace(0, n - 1E-6, c2).floor()  # c2 indices
            c_ = [(i == g).sum() for g in range(n)]  # intermediate channels
        else:  # equal weight.numel() per group
            b = [c2] + [0] * n
            a = np.eye(n + 1, n, k=-1)
            a -= np.roll(a, 1, axis=1)
            a *= np.array(k) ** 2
            a[0] = 1
            c_ = np.linalg.lstsq(a, b, rcond=None)[0].round()  # solve for equal weight indices, ax = b

        self.m = nn.ModuleList(
            [nn.Conv2d(c1, int(c_), k, s, k // 2, groups=math.gcd(c1, int(c_)), bias=False) for k, c_ in zip(k, c_)])
        self.bn = nn.BatchNorm2d(c2)
        self.act = Mish()

    def forward(self, x):
        return self.act(self.bn(torch.cat([m(x) for m in self.m], 1)))

class Conv(nn.Module):
    # Standard convolution
    def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True):  # ch_in, ch_out, kernel, stride, padding, groups
        super().__init__()
        self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
        self.bn = nn.BatchNorm2d(c2)
        self.act = Mish() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())

    def forward(self, x):
        return self.act(self.bn(self.conv(x)))

    def forward_fuse(self, x):
        return self.act(self.conv(x))

class GhostConv(nn.Module):
    # Ghost Convolution https://github.com/huawei-noah/ghostnet
    def __init__(self, c1, c2, k=1, s=1, g=1, act=True):  # ch_in, ch_out, kernel, stride, groups
        super().__init__()
        c_ = c2 // 2  # hidden channels
        self.cv1 = Conv(c1, c_, k, s, None, g, act)
        self.cv2 = MixConv2d(c_,c_,k=(3,5,7,9),s=1,equal_ch=True)
        #self.cv2 = Conv(c_, c_, 5, 1, None, c_, act)

    def forward(self, x):
        y = self.cv1(x)
        return torch.cat([y, self.cv2(y)], 1)

class MSBlock(nn.Module):
    def __init__(self, c1, c2, stride, use_se, radix=2, expand_ratio=1):
        super(MSBlock, self).__init__()
        assert stride in [1, 2]

        hidden_dim = round(c1 * expand_ratio)
        self.identity = stride == 1 and c1 == c2
        self.stride = stride
        if not self.identity:
            if stride > 1:
                self.down = nn.Sequential(nn.MaxPool2d(3,2,1),
                                          nn.Conv2d(c1,c2,1,1,0,bias=False))
            else:
                self.down = nn.Conv2d(c1,c2,1,1,0,bias=False)
            
        
        if use_se:
            self.conv = nn.Sequential(
                #pw
                GhostConv(c1, hidden_dim, 1, 1),
                #mixed conv
                MixConv2d(hidden_dim,hidden_dim,k=(3,5,7,9),s=stride,equal_ch=True),
                #se
                SELayer(c1, hidden_dim),
                #linear
                GhostConv(hidden_dim, c2, 1, 1, act=False)
            )
        else:
            self.conv = nn.Sequential(
                #mixconv
                MixConv2d(c1,hidden_dim,k=(3,5,7,9),s=1,equal_ch=True),
                #split attention - radix = 2 - SK-Unit,
                SplitAttn(c1=hidden_dim,c2=c2,stride=stride,groups=1,radix=radix,
                          rd_ratio=0.25, act_layer=Mish, norm_layer=nn.BatchNorm2d),
            )

    def forward(self, x):
        if self.identity:
            out = x + self.conv(x)
        else:
            out = self.down(x) + self.conv(x)
        #return channel_shuffle(out,2)
        return out

def conv_3x3_bn(inp, oup, stride):
    return nn.Sequential(
        nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
        nn.BatchNorm2d(oup),
        Mish()
    )

def conv_1x1_bn(inp, oup):
    return nn.Sequential(
        nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
        nn.BatchNorm2d(oup),
        Mish()
    )

def get_n_params(model):
    pp=0
    for p in list(model.parameters()):
        nn=1
        for s in list(p.size()):
            nn = nn*s
        pp += nn
    return pp

class MESNeSt(nn.Module):
    def __init__(self, cfgs, num_classes=100, width_mult=1.):
        super(MESNeSt, self).__init__()
        self.cfgs = cfgs
        # building first layer
        input_channel = _make_divisible(24 * width_mult, 8)
        layers = [conv_3x3_bn(3, input_channel, 2)]
        # building inverted residual blocks
        block = MSBlock
        for t, c, n, s, use_se, radix in self.cfgs:
            output_channel = _make_divisible(c * width_mult, 8)
            for i in range(n):
                layers.append(block(input_channel, output_channel, s if i == 0 else 1, use_se, radix, t))
                input_channel = output_channel
        self.features = nn.Sequential(*layers)
        # building last several layers
        output_channel = _make_divisible(1792 * width_mult, 8) if width_mult > 1.0 else 1792
        self.conv = conv_1x1_bn(input_channel, output_channel)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.classifier = nn.Linear(output_channel, num_classes)

        self._initialize_weights()

    def forward(self, x):
        x = self.features(x)
        x = self.conv(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.weight.data.normal_(0, 0.001)
                m.bias.data.zero_()

def mesnest_s(**kwargs):
    """
    Constructs a Mixconv-Efficient-Shuffle-NeSt-S model
    """
    cfgs = [
        # t, c, n, s, SE, radix
        [1,  24,  2, 1, 0, 2],
        [4,  48,  4, 2, 0, 2],
        [4,  64,  4, 2, 0, 2],
        [4, 128,  6, 2, 1, 1],
        [6, 160,  9, 1, 1, 1],
        [6, 256, 15, 2, 1, 1],
    ]
    return MESNeSt(cfgs, **kwargs)


def mesnest_m(**kwargs):
    """
    Constructs a Mixconv-Efficient-Shuffle-NeSt-M model
    """
    cfgs = [
        # t, c, n, s, SE, radix
        [1,  24,  3, 1, 0, 2],
        [4,  48,  5, 2, 0, 2],
        [4,  80,  5, 2, 0, 2],
        [4, 160,  7, 2, 1, 1],
        [6, 176, 14, 1, 1, 1],
        [6, 304, 18, 2, 1, 1],
        [6, 512,  5, 1, 1, 1],
    ]
    return MESNeSt(cfgs, **kwargs)


def mesnest_l(**kwargs):
    """
    Constructs a Mixconv-Efficient-Shuffle-NeSt-L model
    """
    cfgs = [
        # t, c, n, s, SE, radix
        [1,  32,  4, 1, 0, 2],
        [4,  64,  7, 2, 0, 2],
        [4,  96,  7, 2, 0, 2],
        [4, 192, 10, 2, 1, 1],
        [6, 224, 19, 1, 1, 1],
        [6, 384, 25, 2, 1, 1],
        [6, 640,  7, 1, 1, 1],
    ]
    return MESNeSt(cfgs, **kwargs)


def mesnest_xl(**kwargs):
    """
    Constructs a Mixconv-Efficient-Shuffle-NeSt-XL model
    """
    cfgs = [
        # t, c, n, s, SE, radix
        [1,  32,  4, 1, 0, 2],
        [4,  64,  8, 2, 0, 2],
        [4,  96,  8, 2, 0, 2],
        [4, 192, 16, 2, 1, 1],
        [6, 256, 24, 1, 1, 1],
        [6, 512, 32, 2, 1, 1],
        [6, 640,  8, 1, 1, 1],
    ]
    return MESNeSt(cfgs, **kwargs) '''

'import torch\nimport numpy as np\nimport math\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n__all__ = [\'mesnest_s\',\'mesnest_m\',\'mesnest_l\',\'mesnest_xl\']\n\nclass Mish_func(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, i):\n        result = i * torch.tanh(F.softplus(i))\n        ctx.save_for_backward(i)\n        return result\n\n    @staticmethod\n    def backward(ctx, grad_output):\n        i = ctx.saved_tensors[0]\n  \n        v = 1. + i.exp()\n        h = v.log() \n        grad_gh = 1./h.cosh().pow_(2) \n\n        # Note that grad_hv * grad_vx = sigmoid(x)\n        #grad_hv = 1./v  \n        #grad_vx = i.exp()\n        grad_hx = i.sigmoid()\n\n        grad_gx = grad_gh *  grad_hx #grad_hv * grad_vx \n        \n        grad_f =  torch.tanh(F.softplus(i)) + i * grad_gx \n        \n        return grad_output * grad_f \n\n\nclass Mish(nn.Module):\n    def __init__(self, **kwargs):\n        super().__init__()\n        pass\n    def forward(self, 

In [10]:
"""
Creates a EfficientNetV2 Model as defined in:
Mingxing Tan, Quoc V. Le. (2021). 
EfficientNetV2: Smaller Models and Faster Training
arXiv preprint arXiv:2104.00298.
import from https://github.com/d-li14/mobilenetv2.pytorch
"""

import torch
import torch.nn as nn
import math

#__all__ = ['effnetv2_s', 'effnetv2_m', 'effnetv2_l', 'effnetv2_xl']


def _make_divisible(v, divisor, min_value=None):
    """
    This function is taken from the original tf repo.
    It ensures that all layers have a channel number that is divisible by 8
    It can be seen here:
    https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
    :param v:
    :param divisor:
    :param min_value:
    :return:
    """
    if min_value is None:
        min_value = divisor
    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
    # Make sure that round down does not go down by more than 10%.
    if new_v < 0.9 * v:
        new_v += divisor
    return new_v


# SiLU (Swish) activation function
if hasattr(nn, 'SiLU'):
    SiLU = nn.SiLU
else:
    # For compatibility with old PyTorch versions
    class SiLU(nn.Module):
        def forward(self, x):
            return x * torch.sigmoid(x)

 
class SELayer(nn.Module):
    def __init__(self, inp, oup, reduction=4):
        super(SELayer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
                nn.Linear(oup, _make_divisible(inp // reduction, 8)),
                SiLU(),
                nn.Linear(_make_divisible(inp // reduction, 8), oup),
                nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y


def conv_3x3_bn(inp, oup, stride):
    return nn.Sequential(
        nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
        nn.BatchNorm2d(oup),
        SiLU()
    )


def conv_1x1_bn(inp, oup):
    return nn.Sequential(
        nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
        nn.BatchNorm2d(oup),
        SiLU()
    )


class MBConv(nn.Module):
    def __init__(self, inp, oup, stride, expand_ratio, use_se):
        super(MBConv, self).__init__()
        assert stride in [1, 2]

        hidden_dim = round(inp * expand_ratio)
        self.identity = stride == 1 and inp == oup
        if use_se:
            self.conv = nn.Sequential(
                # pw
                nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
                nn.BatchNorm2d(hidden_dim),
                SiLU(),
                # dw
                nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                SiLU(),
                SELayer(inp, hidden_dim),
                # pw-linear
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )
        else:
            self.conv = nn.Sequential(
                # fused
                nn.Conv2d(inp, hidden_dim, 3, stride, 1, bias=False),
                nn.BatchNorm2d(hidden_dim),
                SiLU(),
                # pw-linear
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )


    def forward(self, x):
        if self.identity:
            return x + self.conv(x)
        else:
            return self.conv(x)


class EffNetV2(nn.Module):
    def __init__(self, cfgs, num_classes=1000, width_mult=1.):
        super(EffNetV2, self).__init__()
        self.cfgs = cfgs

        # building first layer
        input_channel = _make_divisible(24 * width_mult, 8)
        layers = [conv_3x3_bn(3, input_channel, 2)]
        # building inverted residual blocks
        block = MBConv
        for t, c, n, s, use_se in self.cfgs:
            output_channel = _make_divisible(c * width_mult, 8)
            for i in range(n):
                layers.append(block(input_channel, output_channel, s if i == 0 else 1, t, use_se))
                input_channel = output_channel
        self.features = nn.Sequential(*layers)
        # building last several layers
        output_channel = _make_divisible(1792 * width_mult, 8) if width_mult > 1.0 else 1792
        self.conv = conv_1x1_bn(input_channel, output_channel)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.classifier = nn.Linear(output_channel, num_classes)

        self._initialize_weights()

    def forward(self, x):
        x = self.features(x)
        x = self.conv(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.weight.data.normal_(0, 0.001)
                m.bias.data.zero_()


def effnetv2_s(**kwargs):
    """
    Constructs a EfficientNetV2-S model
    """
    cfgs = [
        # t, c, n, s, SE
        [1,  24,  2, 1, 0],
        [4,  48,  4, 2, 0],
        [4,  64,  4, 2, 0],
        [4, 128,  6, 2, 1],
        [6, 160,  9, 1, 1],
        [6, 256, 15, 2, 1],
    ]
    return EffNetV2(cfgs, **kwargs)


def effnetv2_m(**kwargs):
    """
    Constructs a EfficientNetV2-M model
    """
    cfgs = [
        # t, c, n, s, SE
        [1,  24,  3, 1, 0],
        [4,  48,  5, 2, 0],
        [4,  80,  5, 2, 0],
        [4, 160,  7, 2, 1],
        [6, 176, 14, 1, 1],
        [6, 304, 18, 2, 1],
        [6, 512,  5, 1, 1],
    ]
    return EffNetV2(cfgs, **kwargs)


def effnetv2_l(**kwargs):
    """
    Constructs a EfficientNetV2-L model
    """
    cfgs = [
        # t, c, n, s, SE
        [1,  32,  4, 1, 0],
        [4,  64,  7, 2, 0],
        [4,  96,  7, 2, 0],
        [4, 192, 10, 2, 1],
        [6, 224, 19, 1, 1],
        [6, 384, 25, 2, 1],
        [6, 640,  7, 1, 1],
    ]
    return EffNetV2(cfgs, **kwargs)


def effnetv2_xl(**kwargs):
    """
    Constructs a EfficientNetV2-XL model
    """
    cfgs = [
        # t, c, n, s, SE
        [1,  32,  4, 1, 0],
        [4,  64,  8, 2, 0],
        [4,  96,  8, 2, 0],
        [4, 192, 16, 2, 1],
        [6, 256, 24, 1, 1],
        [6, 512, 32, 2, 1],
        [6, 640,  8, 1, 1],
    ]
    return EffNetV2(cfgs, **kwargs)

In [11]:
class CFG:
    cifar_100_traindir = '../input/cifar-100-128-train-val-test/cifar100-128/train'
    cifar_100_valdir = '../input/cifar-100-128-train-val-test/cifar100-128/val'
    cifar_100_testdir = '../input/cifar-100-128-train-val-test/cifar100-128/test'
    seed = 42
    img_size = 256
    scheduler = 'CosineAnnealingLR'
    cycle_len = 15
    T_max = 10
    lr = 1e-4
    min_lr = 1e-6
    batch_size = 15
    weight_decay = 1e-6
    num_epochs = 11
    num_classes = 100
    n_accumulate = 4
    temperature = 0.1
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
def set_seed(seed = 42):
    '''Sets the seed of the entire notebook so results are the same every time we run.
    This is for REPRODUCIBILITY.'''
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    
    # Set a fixed value for the hash seed
    os.environ['PYTHONHASHSEED'] = str(seed)
    
set_seed(CFG.seed)

In [12]:
from collections import OrderedDict
from tqdm import tqdm
import numpy as np

normalize = transforms.Normalize(mean=[0.4914 , 0.48216, 0.44653], std=[0.24703, 0.24349, 0.26159])
        
transform_train = transforms.Compose([
            transforms.Resize(256,interpolation=3),
            transforms.CenterCrop(256),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(.2,.2,.2),
            #auto_augment_transform('original-mstd0.5',{}),
            rand_augment_transform('rand-m15-n3-mstd0.5',{}),
            transforms.ToTensor(),
            normalize,
])    

transform_train.transforms.insert(0, FastAugmentation())
    
train_dataset = datasets.ImageFolder(
        CFG.cifar_100_traindir,
        transform_train
    )

train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=CFG.batch_size, shuffle=True,
        num_workers=2, pin_memory=True, sampler=None)

val_loader = torch.utils.data.DataLoader(
        datasets.ImageFolder(CFG.cifar_100_valdir, transforms.Compose([
            transforms.Resize(256, interpolation=3),
            transforms.ToTensor(),
            normalize,
        ])),
        batch_size=CFG.batch_size, shuffle=False, #Always False
        num_workers=2, pin_memory=True)

test_loader = torch.utils.data.DataLoader(
        datasets.ImageFolder(CFG.cifar_100_testdir, transforms.Compose([
            transforms.Resize(256, interpolation=3),
            transforms.ToTensor(),
            normalize,
        ])),
        batch_size=1, shuffle=False, #Always False
        num_workers=2, pin_memory=True)

def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res
    
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)
    
class ProgressMeter(object):
    def __init__(self, num_batches, meters, prefix=""):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = prefix

    def display(self, batch):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        print('\t'.join(entries))

    def _get_batch_fmtstr(self, num_batches):
        num_digits = len(str(num_batches // 1))
        fmt = '{:' + str(num_digits) + 'd}'
        return '[' + fmt + '/' + fmt.format(num_batches) + ']'

def train(train_loader, model, criterion, optimizer, scheduler, epoch, step):
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    progress = ProgressMeter(
        len(train_loader),
        [batch_time, data_time, losses, top1, top5],
        prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()
    for i, (images, target) in enumerate(tqdm(train_loader)):
        # measure data loading time
        data_time.update(time.time() - end)

    
        images = images.to(CFG.device)
        target = target.to(CFG.device)

        # compute output
        output = model(images)
        loss = criterion(output, target)

        # measure accuracy and record loss
        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        losses.update(loss.item(), images.size(0))
        top1.update(acc1[0], images.size(0))
        top5.update(acc5[0], images.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if (step + 1) % CFG.n_accumulate == 0:
            scheduler.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        
    progress.display(len(train_loader))
    return top1.avg, losses.avg
    
def validate(val_loader, model, criterion):
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    
    model.eval()

    with torch.no_grad():
        end = time.time()
        for i, (images, target) in enumerate(val_loader):
            images = images.to(CFG.device)
            target = target.to(CFG.device)

            # compute output
            output = model(images)
            loss = criterion(output, target)
            # measure accuracy and record loss
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            losses.update(loss.item(), images.size(0))
            top1.update(acc1[0], images.size(0))
            top5.update(acc5[0], images.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            #if i % args.print_freq == 0:
                #progress.display(i)
        # TODO: this should also be done with the ProgressMeter
        print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f} Val-Loss {losses.avg:.3f}'
              .format(top1=top1, top5=top5, losses=losses))
    return top1.avg , losses.avg
        
def train_model(model,train_loader,val_loader,criterion,optimizer, scheduler, num_epochs):
    best_loss = np.Inf
    best_acc = 0

    for step, epoch in enumerate(range(1,num_epochs+1)):
        top1,loss = train(train_loader,model,criterion,optimizer,scheduler,epoch,step)
        val_acc_top1, val_loss = validate(val_loader,model,criterion)
        PATH = f'coatnest_0_epoch_{epoch}_loss_{loss}_acc_{top1}.pt'
        if (val_acc_top1 > best_acc) and (val_loss < best_loss):
            best_loss = val_loss
            best_acc = val_acc_top1
            PATH = f'coatnest_0_epoch_{epoch}_val_loss_{best_loss}_val_acc_{best_acc}.pt'
        
        torch.save(model.state_dict(),PATH)

  "Argument interpolation should be of type InterpolationMode instead of int. "


In [13]:
model = mesnest_s()
model = model.to(CFG.device)
#model.load_state_dict(torch.load('../input/mesnest-cifar-weights/cifar_10_mesnest_m_epoch_3_val_loss_0.158_val_acc_94.68.pt'))

In [14]:
#get_n_params(model)

In [15]:
!pip install timm
import timm
from timm.loss import LabelSmoothingCrossEntropy



In [16]:
criterion = LabelSmoothingCrossEntropy()
optimizer = optim.AdamW(model.parameters(), lr=CFG.lr, weight_decay=CFG.weight_decay)
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=CFG.T_max, eta_min=CFG.min_lr)

In [17]:
train_model(model,train_loader,val_loader,criterion,optimizer,scheduler,CFG.num_epochs)

100%|██████████| 3164/3164 [53:59<00:00,  1.02s/it]

Epoch: [1][3164/3164]	Time  2.170 ( 1.024)	Data  0.001 ( 0.003)	Loss 3.8178e+00 (4.1508e+00)	Acc@1  16.67 (  8.73)	Acc@5  50.00 ( 26.66)





 * Acc@1 25.240 Acc@5 57.800 Val-Loss 3.247


100%|██████████| 3164/3164 [53:42<00:00,  1.02s/it]

Epoch: [2][3164/3164]	Time  0.465 ( 1.019)	Data  0.001 ( 0.002)	Loss 3.9385e+00 (3.5417e+00)	Acc@1   0.00 ( 20.79)	Acc@5  50.00 ( 48.12)





 * Acc@1 38.040 Acc@5 71.480 Val-Loss 2.804


100%|██████████| 3164/3164 [53:51<00:00,  1.02s/it]

Epoch: [3][3164/3164]	Time  0.464 ( 1.021)	Data  0.004 ( 0.002)	Loss 3.1484e+00 (3.1893e+00)	Acc@1  16.67 ( 29.02)	Acc@5  66.67 ( 59.27)





 * Acc@1 45.120 Acc@5 79.200 Val-Loss 2.521


100%|██████████| 3164/3164 [53:56<00:00,  1.02s/it]

Epoch: [4][3164/3164]	Time  0.465 ( 1.023)	Data  0.001 ( 0.002)	Loss 3.0808e+00 (2.8725e+00)	Acc@1  50.00 ( 37.52)	Acc@5  50.00 ( 68.22)





 * Acc@1 54.040 Acc@5 84.000 Val-Loss 2.285


100%|██████████| 3164/3164 [53:55<00:00,  1.02s/it]

Epoch: [5][3164/3164]	Time  0.490 ( 1.022)	Data  0.002 ( 0.002)	Loss 1.8695e+00 (2.7606e+00)	Acc@1  66.67 ( 40.57)	Acc@5 100.00 ( 71.11)





 * Acc@1 54.040 Acc@5 85.400 Val-Loss 2.238


100%|██████████| 3164/3164 [53:51<00:00,  1.02s/it]

Epoch: [6][3164/3164]	Time  0.460 ( 1.021)	Data  0.001 ( 0.002)	Loss 2.1866e+00 (2.6538e+00)	Acc@1  50.00 ( 43.41)	Acc@5 100.00 ( 73.82)





 * Acc@1 57.640 Acc@5 86.840 Val-Loss 2.126


100%|██████████| 3164/3164 [53:53<00:00,  1.02s/it]

Epoch: [7][3164/3164]	Time  0.461 ( 1.022)	Data  0.001 ( 0.002)	Loss 2.9470e+00 (2.5552e+00)	Acc@1  50.00 ( 46.35)	Acc@5  66.67 ( 75.73)





 * Acc@1 61.800 Acc@5 88.360 Val-Loss 2.039


100%|██████████| 3164/3164 [54:00<00:00,  1.02s/it]

Epoch: [8][3164/3164]	Time  0.460 ( 1.024)	Data  0.001 ( 0.002)	Loss 2.4079e+00 (2.4265e+00)	Acc@1  33.33 ( 50.33)	Acc@5  83.33 ( 78.76)





 * Acc@1 62.720 Acc@5 89.200 Val-Loss 1.992


100%|██████████| 3164/3164 [53:50<00:00,  1.02s/it]

Epoch: [9][3164/3164]	Time  0.461 ( 1.021)	Data  0.001 ( 0.002)	Loss 2.0248e+00 (2.2723e+00)	Acc@1  33.33 ( 54.82)	Acc@5  83.33 ( 82.19)





 * Acc@1 66.960 Acc@5 91.240 Val-Loss 1.865


100%|██████████| 3164/3164 [53:51<00:00,  1.02s/it]

Epoch: [10][3164/3164]	Time  0.490 ( 1.021)	Data  0.002 ( 0.002)	Loss 2.9206e+00 (2.2196e+00)	Acc@1  33.33 ( 56.49)	Acc@5  66.67 ( 83.31)





 * Acc@1 67.240 Acc@5 91.720 Val-Loss 1.837


100%|██████████| 3164/3164 [53:51<00:00,  1.02s/it]

Epoch: [11][3164/3164]	Time  0.459 ( 1.021)	Data  0.001 ( 0.002)	Loss 2.4892e+00 (2.1920e+00)	Acc@1  33.33 ( 57.19)	Acc@5 100.00 ( 83.92)





 * Acc@1 67.680 Acc@5 91.160 Val-Loss 1.825


In [18]:
#model.load_state_dict(torch.load('./mesnest_s_epoch_1_loss_0.19724411217171792_acc_94.71281433105469.pt'))

In [19]:
def testing(test_loader,model):
    model.eval()
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    
    for i,(images,labels) in enumerate(test_loader):
        images = images.to(CFG.device)
        labels = labels.to(CFG.device)
        
        outputs = model(images)
        acc1, acc5 = accuracy(outputs, labels, topk=(1, 5))
        top1.update(acc1[0], images.size(0))
        top5.update(acc5[0], images.size(0))
        
    print(f'Top 1 Test Accuracy - {top1.avg}')
    print(f'Top 5 Test Accuracy - {top5.avg}')
    
testing(test_loader,model)

Top 1 Test Accuracy - 67.38999938964844
Top 5 Test Accuracy - 90.77999877929688
