## 1 import

In [34]:
import os
import logging
import argparse

from tqdm import tqdm
import torch
import torch.optim as optim
import numpy as np



### 1.1 log

In [35]:
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s: %(filename)s[%(lineno)d]: %(message)s",
    datefmt="%m-%d %H:%M:%S")
logger = logging.getLogger()
logger.setLevel(logging.INFO)
wlog = logger.info

## 2 hyper-parameter

In [36]:
parser = argparse.ArgumentParser("Diffusion Recovery Likelihood")
parser.add_argument("--dataset", type=str, default="cifar10", choices=["cifar10", "svhn", "cifar100"])
parser.add_argument("--data_root", type=str, default="./data")
# optimization
parser.add_argument("--lr", type=float, default=0.0001)
parser.add_argument("--batch_size", type=int, default=128)
parser.add_argument("--n_epochs", type=int, default=250)
parser.add_argument("--warmup_iters", type=int, default=1000, help="number of iters to linearly increase learning rate, if -1 then no warmmup")
parser.add_argument("--sigma", type=float, default=3e-2, help="stddev of gaussian noise to add to input, .03 works but .1 is more stable")
parser.add_argument("--weight_decay", type=float, default=5e-4)

parser.add_argument("--width", type=int, default=10, help="WRN width parameter")
parser.add_argument("--depth", type=int, default=28, help="WRN depth parameter")
parser.add_argument("--model", type=str, default='wrnte', help='wrnte, wrntesn')
parser.add_argument("--norm", type=str, default=None, choices=[None, "none", "batch", "instance", "layer", "act", 'td'], help="norm to add to weights, none works fine")

parser.add_argument("--n_valid", type=int, default=0)
parser.add_argument("--log_dir", type=str, default='./runs/DRL')
parser.add_argument("--resume", type=str, default=None)

parser.add_argument("--novis", action="store_true", help="")
parser.add_argument("--debug", action="store_true", help="")
parser.add_argument("--exp_name", type=str, default="DIS", help="exp name, for description")
parser.add_argument("--gpu-id", type=str, default="0")
# args = parser.parse_args()
args, unparsed = parser.parse_known_args()

In [37]:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id

In [38]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
args.seed = 1
args.device = device

In [39]:
# hyper-parameters of diffusion recovery likelihood
args.img_sz = 32
args.num_diffusion_timesteps = 6
args.num_timesteps = 6
args.opt = 'sgd'
args.ma_decay = 0.999
args.noise_scale = 1.0
args.mcmc_num_steps = 30
args.mcmc_step_size_b_square = 2e-4
args.debug = False

args.pid = os.getpid()

In [40]:
logfile = os.path.join('runs', 'DRL', "%d_%s.log" % (args.pid, args.model))
args.log_dir = 'runs/DRL/%d' % args.pid
print("log dir is %s" % args.log_dir)
os.makedirs(args.log_dir, exist_ok=True)
fh = logging.FileHandler(logfile, mode='w')
fh.setLevel(logging.INFO)
formatter = logging.Formatter("%(asctime)s - %(filename)s[line:%(lineno)d]: %(message)s")
fh.setFormatter(formatter)
logger.addHandler(fh)

log dir is runs/DRL/777


## 3 util

### 3.1 数据获取

In [41]:
import os
import torch
import torch as t
from torch.utils.data import DataLoader, Dataset
import numpy as np

class DataSubset(Dataset):
    def __init__(self, base_dataset, inds=None, size=-1):
        self.base_dataset = base_dataset
        if inds is None:
            inds = np.random.choice(list(range(len(base_dataset))), size, replace=False)
        self.inds = inds

    def __getitem__(self, index):
        base_ind = self.inds[index]
        return self.base_dataset[base_ind]

    def __len__(self):
        return len(self.inds)

In [42]:
def sqrt(x):
    return int(t.sqrt(t.Tensor([x])))


def plot(p, x):
    return tv.utils.save_image(t.clamp(x, -1, 1), p, normalize=True, nrow=sqrt(x.size(0)))


def makedirs(dirname):
    if not os.path.exists(dirname):
        os.makedirs(dirname)


def save_checkpoint(state, save, epoch):
    if not os.path.exists(save):
        os.makedirs(save)
    filename = os.path.join(save, 'checkpt-%04d.pth' % epoch)
    torch.save(state, filename)

In [43]:
def cycle(loader):
    while True:
        for data in loader:
            yield data

In [44]:
import torchvision.transforms as tr
import torchvision as tv

def get_data(args):
    mean, std = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)
    if args.dataset == "svhn":
        transform_train = tr.Compose(
            [tr.Pad(4, padding_mode="reflect"),
             tr.RandomCrop(32),
             tr.ToTensor(),
             tr.Normalize(mean, std),
             # lambda x: x + args.sigma * t.randn_like(x)
             ]
        )
        transform_px = tr.Compose(
            [tr.ToTensor(),
             tr.Normalize(mean, std),
             ]
        )
    else:
        transform_train = tr.Compose(
            [tr.Pad(4, padding_mode="reflect"),
             tr.RandomCrop(32),
             tr.RandomHorizontalFlip(),
             tr.ToTensor(),
             tr.Normalize(mean, std),
             # lambda x: x + args.sigma * t.randn_like(x)
             ]
        )
        transform_px = tr.Compose(
            [tr.RandomHorizontalFlip(),
             tr.ToTensor(),
             tr.Normalize(mean, std),
             ]
        )
    transform_test = tr.Compose(
        [tr.ToTensor(),
         tr.Normalize(mean, std),
         ]
    )
    def dataset_fn(train, transform):
        if args.dataset == "cifar10":
            args.n_classes = 10
            return tv.datasets.CIFAR10(root=args.data_root, transform=transform, download=True, train=train)
        elif args.dataset == "cifar100":
            args.n_classes = 100
            return tv.datasets.CIFAR100(root=args.data_root, transform=transform, download=True, train=train)
        else:
            args.n_classes = 10
            return tv.datasets.SVHN(root=args.data_root, transform=transform, download=True, split="train" if train else "test")

    # get all training inds
    full_train = dataset_fn(True, transform_train)
    all_inds = list(range(len(full_train)))
    # set seed
    np.random.seed(args.seed)
    # shuffle
    np.random.shuffle(all_inds)

    # 缩减训练集大小，保留10%的数据
    reduction_ratio = 0.2
    train_size = int(len(all_inds) * reduction_ratio)
    reduced_train_inds = all_inds[:train_size]

    # 验证集的划分
    if args.n_valid > args.n_classes:
        valid_inds, train_inds = reduced_train_inds[:args.n_valid], reduced_train_inds[args.n_valid:]
    else:
        valid_inds, train_inds = [], reduced_train_inds


    # # seperate out validation set
    # if args.n_valid > args.n_classes:
    #     valid_inds, train_inds = all_inds[:args.n_valid], all_inds[args.n_valid:]
    # else:
    #     valid_inds, train_inds = [], all_inds

    train_inds = np.array(train_inds)
    train_labeled_inds = train_inds

    dset_train = DataSubset(dataset_fn(True, transform_px), inds=train_inds)
    dset_train_labeled = DataSubset(dataset_fn(True, transform_train), inds=train_labeled_inds)
    dset_valid = DataSubset(dataset_fn(True, transform_test), inds=valid_inds)

    num_workers = 0 if args.debug else 4
    dload_train = DataLoader(dset_train, batch_size=args.batch_size, shuffle=True, num_workers=num_workers, drop_last=True)
    label_bs = 128
    dload_train_labeled = DataLoader(dset_train_labeled, batch_size=label_bs, shuffle=True, num_workers=num_workers, drop_last=True)
    dload_train = cycle(dload_train)
    dset_test = dataset_fn(False, transform_test)
    dload_valid = DataLoader(dset_valid, batch_size=100, shuffle=False, num_workers=num_workers, drop_last=False)
    dload_test = DataLoader(dset_test, batch_size=100, shuffle=False, num_workers=num_workers, drop_last=False)
    return dload_train, dload_train_labeled, dload_valid, dload_test

In [45]:
class EMA:
    def __init__(self, mu):
        self.mu = mu
        self.shadow = {}

    def register(self, name, val):
        self.shadow[name] = val.clone()

    def __call__(self, name, x):
        assert name in self.shadow
        new_average = self.mu * x + (1.0 - self.mu) * self.shadow[name]
        self.shadow[name] = new_average.clone()
        return new_average

## 3 模型

In [46]:
# from models.wideresnet_te import Wide_ResNet as WResNet
# from models.wrn_te_sn import Wide_ResNet as WResNetSN

In [47]:
#norm
import torch
import torch.nn as nn
import torch.nn.init as init

class ConditionalInstanceNorm2dPlus(nn.Module):
    def __init__(self, num_features, num_classes, bias=True):
        super().__init__()
        self.num_features = num_features
        self.bias = bias
        self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False)
        if bias:
            self.embed = nn.Embedding(num_classes, num_features * 3)
            self.embed.weight.data[:, :2 * num_features].normal_(1, 0.02)  # Initialise scale at N(1, 0.02)
            self.embed.weight.data[:, 2 * num_features:].zero_()  # Initialise bias at 0
        else:
            self.embed = nn.Embedding(num_classes, 2 * num_features)
            self.embed.weight.data.normal_(1, 0.02)

    def forward(self, x, y):
        means = torch.mean(x, dim=(2, 3))
        m = torch.mean(means, dim=-1, keepdim=True)
        v = torch.var(means, dim=-1, keepdim=True)
        means = (means - m) / (torch.sqrt(v + 1e-5))
        h = self.instance_norm(x)

        if self.bias:
            gamma, alpha, beta = self.embed(y).chunk(3, dim=-1)
            h = h + means[..., None, None] * alpha[..., None, None]
            out = gamma.view(-1, self.num_features, 1, 1) * h + beta.view(-1, self.num_features, 1, 1)
        else:
            gamma, alpha = self.embed(y).chunk(2, dim=-1)
            h = h + means[..., None, None] * alpha[..., None, None]
            out = gamma.view(-1, self.num_features, 1, 1) * h
        return out


class ConditionalActNorm(nn.Module):
    def __init__(self, num_features, num_classes):
        super().__init__()
        self.num_features = num_features
        self.num_classes = num_classes
        self.embed = nn.Embedding(num_classes, num_features * 2)
        self.embed.weight.data.zero_()
        self.init = False

    def forward(self, x, y):
        if self.init:
            scale, bias = self.embed(y).chunk(2, dim=-1)
            return x * scale[:, :, None, None] + bias[:, :, None, None]
        else:
            m, v = torch.mean(x, dim=(0, 2, 3)), torch.var(x, dim=(0, 2, 3))
            std = torch.sqrt(v + 1e-5)
            scale_init = 1. / std
            bias_init = -1. * m / std
            self.embed.weight.data[:, :self.num_features] = scale_init[None].repeat(self.num_classes, 1)
            self.embed.weight.data[:, self.num_features:] = bias_init[None].repeat(self.num_classes, 1)
            self.init = True
            return self(x, y)


logabs = lambda x: torch.log(torch.abs(x))


class ActNorm(nn.Module):
    def __init__(self, in_channel, logdet=True):
        super().__init__()

        self.loc = nn.Parameter(torch.zeros(1, in_channel, 1, 1))
        self.scale = nn.Parameter(torch.ones(1, in_channel, 1, 1))

        self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
        self.logdet = logdet

    def initialize(self, input):
        with torch.no_grad():
            flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
            mean = (
                flatten.mean(1)
                .unsqueeze(1)
                .unsqueeze(2)
                .unsqueeze(3)
                .permute(1, 0, 2, 3)
            )
            std = (
                flatten.std(1)
                .unsqueeze(1)
                .unsqueeze(2)
                .unsqueeze(3)
                .permute(1, 0, 2, 3)
            )

            self.loc.data.copy_(-mean)
            self.scale.data.copy_(1 / (std + 1e-6))

    def forward(self, input):
        _, _, height, width = input.shape

        if self.initialized.item() == 0:
            self.initialize(input)
            self.initialized.fill_(1)

        log_abs = logabs(self.scale)

        logdet = height * width * torch.sum(log_abs)

        if self.logdet:
            return self.scale * (input + self.loc), logdet

        else:
            return self.scale * (input + self.loc)

    def reverse(self, output):
        return output / self.scale - self.loc


class ContinuousConditionalActNorm(nn.Module):
    def __init__(self, num_features, num_classes):
        super().__init__()
        del num_classes
        self.num_features = num_features
        self.embed = nn.Sequential(nn.Linear(1, 256),
                                   nn.ELU(inplace=True),
                                   nn.Linear(256, 256),
                                   nn.ELU(inplace=True),
                                   nn.Linear(256, self.num_features*2),
                                   )

    def forward(self, x, y):
        scale, bias = self.embed(y.unsqueeze(-1)).chunk(2, dim=-1)
        return x * scale[:, :, None, None] + bias[:, :, None, None]


class Identity(nn.Module):
    def __init__(self):
        super(Identity, self).__init__()

    def forward(self, x):
        return x


def get_norm(n_filters, norm, T=6):
    if norm is None or norm.lower() == 'none':
        return Identity()
    elif norm == "batch":
        return nn.BatchNorm2d(n_filters, momentum=0.9)
    elif norm == "instance":
        return nn.InstanceNorm2d(n_filters, affine=True)
    elif norm == "layer":
        return nn.GroupNorm(1, n_filters)
    elif norm == "act":
        return ActNorm(n_filters, False)
    else:
        return Identity()

### basic wideresnet

In [48]:
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
import numpy as np

def conv3x3(in_planes, out_planes, stride=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True)

"""
def conv_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        init.xavier_uniform(m.weight, gain=np.sqrt(2))
        init.constant(m.bias, 0)
    elif classname.find('BatchNorm') != -1:
        init.constant(m.weight, 1)
        init.constant(m.bias, 0)


class wide_basic(nn.Module):
    def __init__(self, in_planes, planes, dropout_rate, stride=1, norm=None, leak=.2):
        super(wide_basic, self).__init__()
        self.norm = norm
        self.lrelu = nn.LeakyReLU(leak)
        self.bn1 = get_norm(in_planes, norm)
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, bias=True)
        self.dropout = Identity() if dropout_rate == 0.0 else nn.Dropout(p=dropout_rate)
        self.bn2 = get_norm(planes, norm)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=True),
            )

    def forward(self, x):
        out = self.bn1(x)
        out = self.dropout(self.conv1(self.lrelu(out)))
        out = self.bn2(out)
        out = self.conv2(self.lrelu(out))
        out += self.shortcut(x)

        return out


class Wide_ResNet(nn.Module):
    def __init__(self, depth, widen_factor, num_classes=10, input_channels=3,
                 sum_pool=False, norm=None, leak=.2, dropout_rate=0.0):
        super(Wide_ResNet, self).__init__()
        self.leak = leak
        self.in_planes = 16
        self.sum_pool = sum_pool
        self.norm = norm
        self.lrelu = nn.LeakyReLU(leak)
        self.n_classes = num_classes

        assert ((depth - 4) % 6 == 0), 'Wide-resnet depth should be 6n+4'
        n = (depth - 4) // 6
        k = widen_factor

        print('| Wide-Resnet %dx%d' % (depth, k))
        nStages = [16, 16 * k, 32 * k, 64 * k]

        self.conv1 = conv3x3(input_channels, nStages[0])
        self.layer1 = self._wide_layer(wide_basic, nStages[1], n, dropout_rate, stride=1, leak=leak)
        self.layer2 = self._wide_layer(wide_basic, nStages[2], n, dropout_rate, stride=2, leak=leak)
        self.layer3 = self._wide_layer(wide_basic, nStages[3], n, dropout_rate, stride=2, leak=leak)
        self.bn1 = get_norm(nStages[3], self.norm)
        self.last_dim = nStages[3]
        self.linear = nn.Linear(nStages[3], num_classes)

    def _wide_layer(self, block, planes, num_blocks, dropout_rate, stride, leak=0.2):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []

        for stride in strides:
            layers.append(block(self.in_planes, planes, dropout_rate, stride, leak=leak, norm=self.norm))
            self.in_planes = planes

        return nn.Sequential(*layers)

    def forward(self, x, logits=False, feature=True):
        out = self.conv1(x)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.lrelu(self.bn1(out))
        if self.sum_pool:
            out = out.view(out.size(0), out.size(1), -1).sum(2)
        else:
            if self.n_classes > 100:
                out = F.adaptive_avg_pool2d(out, 1)
            else:
                out = F.avg_pool2d(out, 8)
        out = out.view(out.size(0), -1)
        if logits:
            out = self.linear(out)
        return out
"""

"\ndef conv_init(m):\n    classname = m.__class__.__name__\n    if classname.find('Conv') != -1:\n        init.xavier_uniform(m.weight, gain=np.sqrt(2))\n        init.constant(m.bias, 0)\n    elif classname.find('BatchNorm') != -1:\n        init.constant(m.weight, 1)\n        init.constant(m.bias, 0)\n\n\nclass wide_basic(nn.Module):\n    def __init__(self, in_planes, planes, dropout_rate, stride=1, norm=None, leak=.2):\n        super(wide_basic, self).__init__()\n        self.norm = norm\n        self.lrelu = nn.LeakyReLU(leak)\n        self.bn1 = get_norm(in_planes, norm)\n        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, bias=True)\n        self.dropout = Identity() if dropout_rate == 0.0 else nn.Dropout(p=dropout_rate)\n        self.bn2 = get_norm(planes, norm)\n        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True)\n\n        self.shortcut = nn.Sequential()\n        if stride != 1 or in_planes != planes:\n      

### te utils

In [49]:
import math
import torch
import torch.nn.functional as F


def get_timestep_embedding(timesteps, embedding_dim: int):
    """
    From Fairseq.
    Build sinusoidal embeddings.
    This matches the implementation in tensor2tensor, but differs slightly
    from the description in Section 3.5 of "Attention Is All You Need".
    """
    assert len(timesteps.shape) == 1  # and timesteps.dtype == torch.int32

    half_dim = embedding_dim // 2
    emb = math.log(10000) / (half_dim - 1)
    emb = torch.exp(torch.arange(0, half_dim) * -emb).to(timesteps.device)
    emb = torch.matmul(1.0 * timesteps.reshape(-1, 1), emb.reshape(1, -1))
    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
    if embedding_dim % 2 == 1:  # zero pad
        # emb = torch.cat([emb, torch.zeros([num_embeddings, 1])], axis=1)
        emb = F.pad(emb, [0, 1, 0, 0])
    assert list(emb.shape) == [timesteps.shape[0], embedding_dim]
    return emb

### model te

In [50]:
import torch
import torch.nn as nn
import torch.nn.functional as F
# from .wideresnet import get_norm, conv3x3, Identity
# from .te_utils import get_timestep_embedding


class wide_basic(nn.Module):
    def __init__(self, in_planes, planes, dropout_rate, stride=1, norm=None, leak=.2):
        super(wide_basic, self).__init__()
        self.norm = norm
        self.lrelu = nn.LeakyReLU(leak)
        self.bn1 = get_norm(in_planes, norm)
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, bias=True)
        self.dropout = Identity() if dropout_rate == 0.0 else nn.Dropout(p=dropout_rate)
        self.bn2 = get_norm(planes, norm)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True)

        self.temb_dense = nn.Linear(512, planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=True),
            )

    def forward(self, x):
        x, temb = x
        out = self.bn1(x)
        out = self.conv1(self.lrelu(out))
        if temb is not None:
            # add in timestep embedding
            temp_o = self.lrelu(self.temb_dense(temb))
            b, l = temp_o.shape
            out += temp_o.view(b, l, 1, 1)

        out = self.dropout(out)
        out = self.bn2(out)
        out = self.conv2(self.lrelu(out))
        out += self.shortcut(x)

        return out, temb


class Wide_ResNet(nn.Module):
    def __init__(self, depth, widen_factor, num_classes=10, input_channels=3, sum_pool=False, norm=None, leak=.2, dropout_rate=0.0):
        super(Wide_ResNet, self).__init__()
        self.leak = leak
        self.in_planes = 16
        self.sum_pool = sum_pool
        self.norm = norm
        self.lrelu = nn.LeakyReLU(leak)
        self.n_classes = num_classes

        assert ((depth - 4) % 6 == 0), 'Wide-resnet depth should be 6n+4'
        n = (depth - 4) // 6
        k = widen_factor

        print('| Wide-Resnet %dx%d, time embedding' % (depth, k))
        nStages = [16, 16 * k, 32 * k, 64 * k]

        self.layer_one_out = None
        self.conv1 = conv3x3(input_channels, nStages[0])
        # self.layer_one = self.conv1
        self.layer1 = self._wide_layer(wide_basic, nStages[1], n, dropout_rate, stride=1, leak=leak)
        self.layer2 = self._wide_layer(wide_basic, nStages[2], n, dropout_rate, stride=2, leak=leak)
        self.layer3 = self._wide_layer(wide_basic, nStages[3], n, dropout_rate, stride=2, leak=leak)
        self.bn1 = get_norm(nStages[3], self.norm)
        self.last_dim = nStages[3]
        self.linear = nn.Linear(nStages[3], num_classes)
        self.temb_dense_0 = nn.Linear(128, 512)
        self.temb_dense_1 = nn.Linear(512, 512)
        self.temb_dense_2 = nn.Linear(512, nStages[3])

    def _wide_layer(self, block, planes, num_blocks, dropout_rate, stride, leak=0.2):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []

        for stride in strides:
            layers.append(block(self.in_planes, planes, dropout_rate, stride, leak=leak, norm=self.norm))
            self.in_planes = planes

        return nn.Sequential(*layers)

    def forward(self, x, t, logits=False, feature=True):
        out = self.conv1(x)
        assert x.dtype == torch.float32
        if isinstance(t, int) or len(t.shape) == 0:
            t = torch.ones(x.shape[0], dtype=torch.int64, device=x.device) * t
        temb = get_timestep_embedding(t, 128)
        temb = self.temb_dense_0(temb)
        temb = self.temb_dense_1(self.lrelu(temb))

        out, _ = self.layer1([out, temb])
        out, _ = self.layer2([out, temb])
        out, _ = self.layer3([out, temb])
        out = self.lrelu(self.bn1(out))
        if self.sum_pool:
            out = out.view(out.size(0), out.size(1), -1).sum(2)
        else:
            if self.n_classes > 100:
                out = F.adaptive_avg_pool2d(out, 1)
            else:
                out = F.avg_pool2d(out, 8)
        out = out.view(out.size(0), -1)
        temb = self.temb_dense_2(self.lrelu(temb))
        out *= temb
        if logits:
            out = self.linear(out)
        return out

#### test

In [51]:
# ch_mult = (1, 2, 2, 2)
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# net = Wide_ResNet(28, 10, norm='batch', dropout_rate=0).to(device)
# x = torch.randn([64, 3, 32, 32]).to(device)
# t = torch.randint(size=[64], high=6).to(device)
# output = net(x, t)
# print(output.shape)

## 5  recovery likelihood

In [52]:
def get_sigma_schedule(*, beta_start, beta_end, num_diffusion_timesteps):
    """
    Get the noise level schedule
    :param beta_start: begin noise level
    :param beta_end: end noise level
    :param num_diffusion_timesteps: number of timesteps
    :return:
    -- sigmas: sigma_{t+1}, scaling parameter of epsilon_{t+1}
    -- a_s: sqrt(1 - sigma_{t+1}^2), scaling parameter of x_t
    """
    betas = np.linspace(beta_start, beta_end, 1000, dtype=np.float64)
    betas = np.append(betas, 1.)
    assert isinstance(betas, np.ndarray)
    betas = betas.astype(np.float64)
    assert (betas > 0).all() and (betas <= 1).all()
    sqrt_alphas = np.sqrt(1. - betas)
    temp = np.concatenate([np.arange(num_diffusion_timesteps) * (1000 // ((num_diffusion_timesteps - 1) * 2)), [999]])
    idx = temp.astype(np.int32)
    a_s = np.concatenate(
        [[np.prod(sqrt_alphas[: idx[0] + 1])],
         np.asarray([np.prod(sqrt_alphas[idx[i - 1] + 1: idx[i] + 1]) for i in np.arange(1, len(idx))])])
    sigma = np.sqrt(1 - a_s ** 2)

    return sigma, a_s

In [53]:
def unsorted_segment_mean(values, index, num_segments):
    ones = torch.ones_like(values)
    sums = torch.zeros(num_segments, device=values.device).scatter_add_(0, index, values)
    counts = torch.zeros(num_segments, device=values.device).scatter_add_(0, index, ones)
    return sums / counts


In [54]:
class RecoveryLikelihood(nn.Module):
    def __init__(self, model, args):
        super(RecoveryLikelihood, self).__init__()
        self.args = args
        self.num_timesteps = args.num_diffusion_timesteps

        sigmas, a_s = get_sigma_schedule(beta_start=0.0001, beta_end=0.02, num_diffusion_timesteps=self.num_timesteps)
        self.sigmas = torch.FloatTensor(sigmas).to(args.device)
        self.a_s = torch.FloatTensor(a_s).to(args.device)

        self.a_s_cum = torch.FloatTensor(np.cumprod(a_s)).to(args.device)
        self.sigmas_cum = torch.sqrt(1 - self.a_s_cum ** 2)
        self.a_s_prev = self.a_s.clone()
        self.a_s_prev[-1] = 1
        self.is_recovery = torch.ones(self.num_timesteps + 1).to(args.device)
        self.is_recovery[-1] = 0
        self.device = args.device

        # self.net = net_res_temb2(name='net', ch=128, ch_mult=ch_mult, num_res_blocks=self.args.num_res_blocks, attn_resolutions=(16,))
        # self.net = Wide_ResNet(28, 10, norm='batch', dropout_rate=0).to(args.device)
        self.net = model

    @staticmethod
    def _extract(a, t, x_shape, device):
        """
        Extract some coefficients at specified timesteps,
        then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
        """
        if isinstance(t, int) or len(t.shape) == 0:
            t = torch.ones(x_shape[0], dtype=torch.int64, device=device) * t
        bs, = t.shape
        assert x_shape[0] == bs
        out = a[t]
        # out = tf.gather(tf.convert_to_tensor(a, dtype=tf.float32), t)
        # print(out.shape, t.shape, bs)
        assert list(out.shape) == [bs]
        return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1]))

    def q_sample(self, x_start, t, noise=None):
        """
        Diffuse the data (t == 0 means diffused for 1 step)
        """
        if noise is None:
            noise = torch.randn_like(x_start)
        assert noise.shape == x_start.shape
        x_t = self._extract(self.a_s_cum, t, x_start.shape, self.args.device) * x_start + \
              self._extract(self.sigmas_cum, t, x_start.shape, self.args.device) * noise

        return x_t

    def q_sample_pairs(self, x_start, t):
        """
        Generate a pair of disturbed images for training
        :param x_start: x_0
        :param t: time step t
        :return: x_t, x_{t+1}
        """
        noise = torch.randn_like(x_start)
        x_t = self.q_sample(x_start, t)
        x_t_plus_one = self._extract(self.a_s, t + 1, x_start.shape, self.args.device) * x_t + \
                       self._extract(self.sigmas, t + 1, x_start.shape, self.args.device) * noise

        return x_t, x_t_plus_one

    def q_sample_progressive(self, x_0):
        """
        Generate a full sequence of disturbed images
        """
        x_preds = []
        for t in range(self.num_timesteps + 1):
            t_now = torch.ones([x_0.shape[0]], dtype=torch.int32, device=self.args.device) * t
            x = self.q_sample(x_0, t_now)
            x_preds.append(x)
        x_preds = torch.stack(x_preds, axis=0)

        return x_preds

    # === Training loss ===
    def training_losses(self, x_pos, x_neg, t):
        """
        Training loss calculation
        """
        a_s = self._extract(self.a_s_prev, t + 1, x_pos.shape, self.args.device)
        y_pos = a_s * x_pos
        y_neg = a_s * x_neg
        pos_f = self.net(y_pos, t).sum(dim=1)
        neg_f = self.net(y_neg, t).sum(dim=1)
        loss = - (pos_f - neg_f)

        loss_scale = 1.0 / (self.sigmas[t + 1] / self.sigmas[1])
        loss = loss_scale * loss

        # loss_ts = torch.math.unsorted_segment_mean(torch.abs(loss), t, self.num_timesteps)
        loss_ts = unsorted_segment_mean(torch.abs(loss), t, self.num_timesteps).detach()
        f_ts = unsorted_segment_mean(pos_f, t, self.num_timesteps).detach()

        return loss.mean(), loss_ts, f_ts

    def log_prob(self, y, t, tilde_x, b0, sigma, is_recovery):
        logits = self.net(y, t)

        return logits.sum(dim=1) / torch.reshape(b0, [-1]) - torch.sum((y - tilde_x) ** 2 / 2 / sigma ** 2 * is_recovery, dim=[1, 2, 3])

    def grad_f(self, y, t, tilde_x, b0, sigma, is_recovery):
        log_p_y = self.log_prob(y, t, tilde_x, b0, sigma, is_recovery)
        grad_y = torch.autograd.grad(log_p_y.sum(), [y], retain_graph=True)[0]
        # grad_y = torch.clamp(grad_y, -1, 1)
        return grad_y, log_p_y

    # === Sampling ===
    def p_sample_langevin(self, tilde_x, t):
        """
        Langevin sampling function
        """
        sigma = self._extract(self.sigmas, t + 1, tilde_x.shape, self.args.device)
        sigma_cum = self._extract(self.sigmas_cum, t, tilde_x.shape, self.args.device)
        is_recovery = self._extract(self.is_recovery, t + 1, tilde_x.shape, self.args.device)
        a_s = self._extract(self.a_s_prev, t + 1, tilde_x.shape, self.args.device)

        c_t_square = sigma_cum / self.sigmas_cum[0]
        step_size_square = c_t_square * self.args.mcmc_step_size_b_square * sigma ** 2

        # y = torch.identity(tilde_x)
        y = torch.autograd.Variable(tilde_x, requires_grad=True).to(self.args.device)
        is_accepted_summary = torch.zeros(y.shape[0], dtype=torch.float32, device=self.args.device)

        grad_y, log_p_y = self.grad_f(y, t, tilde_x, step_size_square, sigma, is_recovery)

        for _ in range(self.args.mcmc_num_steps):
            noise = torch.randn_like(y)
            y_new = y + 0.5 * step_size_square * grad_y + torch.sqrt(step_size_square) * noise * self.args.noise_scale

            grad_y_new, log_p_y_new = self.grad_f(y_new, t, tilde_x, step_size_square, sigma, is_recovery)
            y, grad_y, log_p_y = y_new, grad_y_new, log_p_y_new

        is_accepted_summary = 1.0 * is_accepted_summary / self.args.mcmc_num_steps
        is_accepted_summary = torch.mean(is_accepted_summary)

        x = y / a_s

        values = torch.norm(torch.reshape(x, [x.shape[0], -1]) - torch.reshape(tilde_x, [tilde_x.shape[0], -1]), dim=1)
        disp = unsorted_segment_mean(values, t, self.num_timesteps)
        return x, disp, is_accepted_summary

    def p_sample_progressive(self, noise):
        """
        Sample a sequence of images with the sequence of noise levels
        """
        num = noise.shape[0]
        x_neg_t = noise
        x_neg = torch.zeros([self.args.num_diffusion_timesteps, num, 3, self.args.img_sz, self.args.img_sz], device=self.device)
        x_neg = torch.cat([x_neg, torch.unsqueeze(noise, axis=0)], dim=0)
        is_accepted_summary = 0.

        for t in range(self.args.num_diffusion_timesteps - 1, -1, -1):
            t_v = torch.tensor(t).to(self.device)

            x_neg_t, _, is_accepted = self.p_sample_langevin(x_neg_t, t_v)
            is_accepted_summary = is_accepted_summary + is_accepted
            x_neg_t = torch.reshape(x_neg_t, [num, 3, self.args.img_sz, self.args.img_sz])
            insert_mask = t == torch.arange(self.args.num_diffusion_timesteps + 1, device=self.device)
            insert_mask = torch.reshape(insert_mask, [-1, *([1] * len(noise.shape))])
            x_neg = insert_mask * torch.unsqueeze(x_neg_t, axis=0) + (~ insert_mask) * x_neg
        is_accepted_summary = is_accepted_summary / self.args.num_diffusion_timesteps * 1.0
        return x_neg, is_accepted_summary



## 6 buffer

## **7** 训练类

In [55]:
def generate_from_scratch(diff_model, num_images, epoch, arg):
    noise = torch.randn(size=[25, 3, 32, 32]).to(arg.device)
    num_batch = int(np.ceil(num_images / 25))

    buffer = []
    for k in tqdm(range(num_batch)):
        x_neg, _ = diff_model.p_sample_progressive(noise)
        x_neg = x_neg[0].detach()
        buffer.append(x_neg)
        if k == 0 or k == 10:
            print(x_neg)
            plot('{}/img_{}_{:>06d}.png'.format(arg.log_dir, epoch, k), x_neg)
    buffer = torch.cat(buffer)
    return buffer

In [56]:
from Task.eval_buffer import eval_is_fid

class DRLTrainer:
    def __init__(self, arg):
        self.args = arg
        self.epoch_loss = 0
        self.diffusion = None
        self.diffusion_ema = None
        self.ema = None
        self.f_p = None

    def get_pred_by_freq(self, x, last=False):
        include_xpred_freq = max(1, self.args.num_diffusion_timesteps // 10)
        idx = torch.LongTensor(np.arange(self.args.num_diffusion_timesteps // include_xpred_freq + 1) * include_xpred_freq)
        if last:
            idx[-1] = idx[-1] - 1
        return x[idx]

    def train_a_batch(self, x, optimizer, epoch, i):
        t = torch.randint(size=[x.shape[0]], high=self.args.num_timesteps, device=device)

        x_pos, x_neg = self.diffusion.q_sample_pairs(x, t)
        x_neg, disp, is_accepted = self.diffusion.p_sample_langevin(x_neg, t)
        loss, _, f_p = self.diffusion.training_losses(x_pos, x_neg, t)
        self.epoch_loss += loss.item()
        self.f_p += f_p
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # self.ema.apply(self.diffusion)
        if i % 10 == 0:
            wlog('%d, [epoch:%d, iter:%d] Loss: %.03f' % (self.args.pid, epoch, i, self.epoch_loss / (i + 1)))
            f_ts = self.get_pred_by_freq(self.f_p, last=True)
            f_ts = ", ".join(["".join(str(np.around(aa.cpu().numpy(), 3))) for aa in f_ts])
            wlog('Positive Energy T ={:s}'.format(f_ts))
        return loss.item()

    def train(self):
        arg = self.args

        arg.n_classes = 10
        # datasets
        label_loader, train_loader, valid_loader, test_loader = get_data(arg)

        # model
        if arg.model == "wrnte":
            net =Wide_ResNet(arg.depth, arg.width, num_classes=arg.n_classes, norm=None)
            net_ema = Wide_ResNet(arg.depth, arg.width, num_classes=arg.n_classes, norm=None)
        else:
            net = Wide_ResNet(arg.depth, arg.width, num_classes=arg.n_classes, norm=None)
            net_ema = Wide_ResNet(arg.depth, arg.width, num_classes=arg.n_classes, norm=None)


        print('Model parameters: {:.2f}M'.format(sum(p.numel() for p in net.parameters()) / 1e6))
        if arg.resume is not None:
            pt = torch.load(args.resume)
            net.load_state_dict(pt)

        net = net.to(args.device)

        if args.lr > 0.01:
            optimizer = optim.SGD(net.parameters(), lr=args.lr, weight_decay=5e-4, momentum=0.9)
        else:
            optimizer = optim.Adam(net.parameters(), lr=args.lr, betas=[.9, .999], weight_decay=5e-4)


        self.diffusion = RecoveryLikelihood(net, arg)

        self.ema = EMA(mu=arg.ma_decay)

        best_acc = 0
        cur_iter = 0
        cur_lr = arg.lr

        for epoch in range(args.n_epochs):
            self.epoch_loss = 0
            self.f_p = 0
            if epoch in [120, 160, 200, 230]:
                cur_lr *= 0.2
                for param_group in optimizer.param_groups:
                    param_group['lr'] *= 0.2
                wlog("Learning rate decay %f" % cur_lr)

            net.train()
            for i, data in tqdm(enumerate(train_loader)):
                if cur_iter <= args.warmup_iters:
                    lr = args.lr * cur_iter / float(args.warmup_iters)
                    for param_group in optimizer.param_groups:
                        param_group['lr'] = lr

                x_p_d, _ = label_loader.__next__()
                x_p_d = x_p_d.to(device)

                loss = self.train_a_batch(x_p_d, optimizer, epoch, i)
                cur_iter += 1
                if abs(loss) > 1000:
                    print("diverge Epoch {} iter {}".format(epoch, i))
                    return

            metrics = {}
            i += 1  # in case of debugging
            if epoch % 5 == 0:
                buffer = generate_from_scratch(self.diffusion, 1000, epoch=epoch, arg=arg)
                inc_score, std, fid = eval_is_fid(buffer, arg, eval='all')
                wlog("Inception score of {} with std of {}".format(inc_score, std))
                wlog("FID of score {}".format(fid))
                metrics['Gen/IS'] = inc_score
                metrics['Gen/FID'] = fid
                metrics['Loss/EBM'] = self.epoch_loss / i
                metrics['Loss/T0'] = self.f_p[0] / i
                metrics['Loss/T2'] = self.f_p[2] / i
                metrics['Loss/T4'] = self.f_p[4] / i
                metrics['Loss/T6'] = self.f_p[5] / i

            torch.save(net.state_dict(), "./runs/DRL/%s/%s_last.pth" % (str(self.args.pid), str(arg.model)))
            if epoch % 10 == 0:
                torch.save(net.state_dict(), "./runs/DRL/%s/%s_%d.pth" % (str(self.args.pid), str(arg.model), epoch))


In [57]:
trainer = DRLTrainer(args)
trainer.train()

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
| Wide-Resnet 28x10, time embedding
| Wide-Resnet 28x10, time embedding
Model parameters: 39.43M


0it [00:00, ?it/s]INFO:root:777, [epoch:0, iter:0] Loss: -0.016
INFO:root:Positive Energy T =-0.216, -0.178, -0.151, -0.113, -0.126, -0.215, -0.215
10it [03:28, 20.85s/it]


KeyboardInterrupt: 

In [37]:
args.lr

0.0001