In [None]:
import torch
import numpy as np
from torch import nn
from os.path import join
from torchvision import models, transforms
from PIL import Image

In [None]:
resnet18 = models.resnet18(num_classes=100, pretrained=False) 

In [None]:
def print_type(m):
    if : 
        print(m)

resnet18.apply(print_type)

torch.nn.Sequential(*list(models.resnet18().children()) + [torch.nn.Flatten()])

In [None]:
mean = [0.485,0.456,0.406] #dataLoader中设置的mean参数
std = [0.229,0.224,0.225]  #dataLoader中设置的std参数

def tensor2img(input_image, imtype=np.uint8, mean=None, std=None):
    """"将tensor的数据类型转成numpy类型，并反归一化.

    Parameters:
        input_image (tensor) --  输入的图像tensor数组
        imtype (type)        --  转换后的numpy的数据类型
    """
    if not isinstance(input_image, np.ndarray):
        if isinstance(input_image, torch.Tensor): #如果传入的图片类型为torch.Tensor，则读取其数据进行下面的处理
            image_tensor = input_image.data
        else:
            return input_image
        image_numpy = image_tensor.cpu().float().numpy()  # convert it into a numpy array
        if image_numpy.shape[0] == 1:  # grayscale to RGB
            image_numpy = np.tile(image_numpy, (3, 1, 1))
        for i in range(len(mean)): #反标准化
            image_numpy[i] = image_numpy[i] * std[i] + mean[i]
        image_numpy = image_numpy * 255 #反ToTensor(),从[0,1]转为[0,255]
        image_numpy = np.transpose(image_numpy, (1, 2, 0))  # 从(channels, height, width)变为(height, width, channels)
    else:  # 如果传入的是numpy数组,则不做处理
        image_numpy = input_image
    return image_numpy.astype(imtype)

In [None]:
# code in this file is adpated from
# https://github.com/ildoonet/pytorch-randaugment/blob/master/RandAugment/augmentations.py
# https://github.com/google-research/fixmatch/blob/master/third_party/auto_augment/augmentations.py
# https://github.com/google-research/fixmatch/blob/master/libml/ctaugment.py
import logging
import random

import numpy as np
import PIL
import PIL.ImageOps
import PIL.ImageEnhance
import PIL.ImageDraw
from PIL import Image

logger = logging.getLogger(__name__)

PARAMETER_MAX = 10


def AutoContrast(img, **kwarg):
    return PIL.ImageOps.autocontrast(img)


def Brightness(img, v, max_v, bias=0):
    v = _float_parameter(v, max_v) + bias
    return PIL.ImageEnhance.Brightness(img).enhance(v)


def Color(img, v, max_v, bias=0):
    v = _float_parameter(v, max_v) + bias
    return PIL.ImageEnhance.Color(img).enhance(v)


def Contrast(img, v, max_v, bias=0):
    v = _float_parameter(v, max_v) + bias
    return PIL.ImageEnhance.Contrast(img).enhance(v)


def Cutout(img, v, max_v, bias=0):
    if v == 0:
        return img
    v = _float_parameter(v, max_v) + bias
    v = int(v * min(img.size))
    return CutoutAbs(img, v)


def CutoutAbs(img, v=40, **kwarg):
    w, h = img.size
    x0 = np.random.uniform(0, w)
    y0 = np.random.uniform(0, h)
    x0 = int(max(0, x0 - v / 2.))
    y0 = int(max(0, y0 - v / 2.))
    x1 = int(min(w, x0 + v))
    y1 = int(min(h, y0 + v))
    xy = (x0, y0, x1, y1)
    # gray
    color = (127, 127, 127)
    img = img.copy()
    PIL.ImageDraw.Draw(img).rectangle(xy, color)
    return img


def Equalize(img, **kwarg):
    return PIL.ImageOps.equalize(img)


def Identity(img, **kwarg):
    return img


def Invert(img, **kwarg):
    return PIL.ImageOps.invert(img)


def Posterize(img, v, max_v, bias=0):
    v = _int_parameter(v, max_v) + bias
    return PIL.ImageOps.posterize(img, v)


def Rotate(img, v, max_v, bias=0):
    v = _int_parameter(v, max_v) + bias
    if random.random() < 0.5:
        v = -v
    return img.rotate(v)


def Sharpness(img, v, max_v, bias=0):
    v = _float_parameter(v, max_v) + bias
    return PIL.ImageEnhance.Sharpness(img).enhance(v)


def ShearX(img, v, max_v, bias=0):
    v = _float_parameter(v, max_v) + bias
    if random.random() < 0.5:
        v = -v
    return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0))


def ShearY(img, v, max_v, bias=0):
    v = _float_parameter(v, max_v) + bias
    if random.random() < 0.5:
        v = -v
    return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0))


def Solarize(img, v, max_v, bias=0):
    v = _int_parameter(v, max_v) + bias
    return PIL.ImageOps.solarize(img, 256 - v)


def SolarizeAdd(img, v, max_v, bias=0, threshold=128):
    v = _int_parameter(v, max_v) + bias
    if random.random() < 0.5:
        v = -v
    img_np = np.array(img).astype(np.int)
    img_np = img_np + v
    img_np = np.clip(img_np, 0, 255)
    img_np = img_np.astype(np.uint8)
    img = Image.fromarray(img_np)
    return PIL.ImageOps.solarize(img, threshold)


def TranslateX(img, v, max_v, bias=0):
    v = _float_parameter(v, max_v) + bias
    if random.random() < 0.5:
        v = -v
    v = int(v * img.size[0])
    return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))


def TranslateY(img, v, max_v, bias=0):
    v = _float_parameter(v, max_v) + bias
    if random.random() < 0.5:
        v = -v
    v = int(v * img.size[1])
    return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))


def _float_parameter(v, max_v):
    return float(v) * max_v / PARAMETER_MAX


def _int_parameter(v, max_v):
    return int(v * max_v / PARAMETER_MAX)


def fixmatch_augment_pool():
    # FixMatch paper
    augs = [(AutoContrast, None, None),
            (Brightness, 0.9, 0.05),
            (Color, 0.9, 0.05),
            (Contrast, 0.9, 0.05),
            (Equalize, None, None),
            (Identity, None, None),
            (Posterize, 4, 4),
            (Rotate, 30, 0),
            (Sharpness, 0.9, 0.05),
            (ShearX, 0.3, 0),
            (ShearY, 0.3, 0),
            (Solarize, 256, 0),
            (TranslateX, 0.3, 0),
            (TranslateY, 0.3, 0)]
    return augs


def my_augment_pool():
    # Test
    augs = [(AutoContrast, None, None),
            (Brightness, 1.8, 0.1),
            (Color, 1.8, 0.1),
            (Contrast, 1.8, 0.1),
            (Cutout, 0.2, 0),
            (Equalize, None, None),
            (Invert, None, None),
            (Posterize, 4, 4),
            (Rotate, 30, 0),
            (Sharpness, 1.8, 0.1),
            (ShearX, 0.3, 0),
            (ShearY, 0.3, 0),
            (Solarize, 256, 0),
            (SolarizeAdd, 110, 0),
            (TranslateX, 0.45, 0),
            (TranslateY, 0.45, 0)]
    return augs


class RandAugmentPC(object):
    def __init__(self, n, m):
        assert n >= 1
        assert 1 <= m <= 10
        self.n = n
        self.m = m
        self.augment_pool = my_augment_pool()

    def __call__(self, img):
        ops = random.choices(self.augment_pool, k=self.n)
        for op, max_v, bias in ops:
            prob = np.random.uniform(0.2, 0.8)
            if random.random() + prob >= 1:
                img = op(img, v=self.m, max_v=max_v, bias=bias)
        img = CutoutAbs(img, int(32*0.5))
        return img


class RandAugmentMC(object):
    def __init__(self, n=2, m=10):
        assert n >= 1
        assert 1 <= m <= 10
        self.n = n
        self.m = m
        self.augment_pool = fixmatch_augment_pool()

    def __call__(self, img):
        ops = random.choices(self.augment_pool, k=self.n)
        img = CutoutAbs(img, int(224*0.5))
        out = []
        for op, max_v, bias in ops:
            v = np.random.randint(1, self.m)
            out.append([op, v, max_v, bias])
            if random.random() < 0.5:
                img = op(img, v=v, max_v=max_v, bias=bias)
                
        # print(out)
        return img

In [None]:
transform_train = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomCrop((224, 224)),
    CutoutAbs,
    transforms.RandomHorizontalFlip(0.5),
    transforms.RandomRotation(25),
])

def show_imgs(img_path, transform):
    img_ori = Image.open(img_path).convert('RGB')
    img_trans = transform(img_ori)
    # img_ori.show()
    img_trans.show()
    return img_trans

data_root = '/home/waa/Data/miniImageNet/'
img_path1 = join(data_root, 'train/Americanegret/n02009912_10212.JPEG')
img_path2 = join(data_root, 'train/Americanegret/n02009912_9056.JPEG')
img_path3 = join(data_root, 'train/artichoke/n07718747_1554.JPEG')
img_path4 = join(data_root, 'train/badger/n02447366_2475.JPEG')
img_path5 = join(data_root, 'train/birdhouse/n02843684_2113.JPEG')
img_path6 = join(data_root, 'train/bottlecap/n02877765_5096.JPEG')
img_path7 = join(data_root, 'train/crocodile/n01697457_9993.JPEG')
for i in range(32):
    img_trans = show_imgs(img_path7, transform=transform_train)

In [None]:
import torch
import torch.nn.functional as F
inputs = torch.tensor([[0.1474, 0.6745, 0.8948],
		               [0.8524, 0.2278, 0.6476]])
# targets = torch.tensor([[1., 0., 0.], [0., 1., 0.]])
targets = torch.tensor([0, 1])
weight = torch.tensor([1.0, 3.0, 7.0])
inputs = inputs.cuda()
targets = targets.cuda()
weight = weight.cuda()

ori_a = F.cross_entropy(inputs, targets, reduction='none') 
print(weight.device)
N_weight = weight[targets]
print(N_weight.device)
print(ori_a)
print(ori_a * N_weight)
print(torch.exp(ori_a))

In [None]:
resnet18 = models.resnet18(num_classes=100, pretrained=False) 

In [None]:
def freeze(model, unfreeze_keys):
    for k, v in resnet18.named_parameters():
        v.requires_grad = True if any(key in k for key in unfreeze_keys) else False
    return model

In [None]:
model = freeze(resnet18, unfreeze_keys=['fc'])

In [None]:
# list(model.named_parameters())
list(resnet18.named_parameters())

In [None]:
for named_key, param in resnet18.named_parameters():
    param.requires_grad = True 
    print(named_key)

In [None]:
list(resnet18.parameters())

In [None]:
for m in resnet18.layer4[-1].modules():
    print(m.__class__.__name__)

In [None]:
resnet18.layer4[-1]

In [None]:
nn.Sequential(resnet18.layer4[-1], nn.Sequential(*list(resnet18.children())[-2:]))

In [None]:
res_pool = nn.Sequential(*list(resnet18.children())[-1:])

In [None]:
x = torch.randn(1, 3, 224, 224)

In [None]:
res_pool[0] = nn.Linear(128, 64)

In [None]:
res_pool

In [None]:
resnet18