In [None]:
import copy
import time
import torch as th
from torch.nn.modules.loss import CrossEntropyLoss
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.data import DataLoader, TensorDataset
import my
import lenet

In [None]:
class Args:
    pass
args = Args()

args.master_gpu = 0
args.n_epochs = 25

In [None]:
cuda = args.master_gpu >= 0
if cuda:
    th.cuda.set_device(args.master_gpu)

labelling = {}
# labelling = {(0, 1) : 1, (1, 10) : 0}
# labelling = {(0, 1) : 0, (1, 4) : 1, (4, 7) : 2, (7, 10) : 3}
train_x, train_y, test_x, test_y = my.load_mnist(labelling, rbg=True)
# train_x, train_y, test_x, test_y = my.load_mnist(labelling, rbg=False)

train_loader = DataLoader(TensorDataset(train_x, train_y), 64, shuffle=True, drop_last=True)
test_loader = DataLoader(TensorDataset(test_x, test_y), 4096, drop_last=True)

n_classes = int(train_y.max() - train_y.min() + 1)

In [None]:
nd_f_beta = lambda y_bar, y: my.nd_f_beta(y_bar, y, n_classes)
nd_precision = lambda y_bar, y: my.nd_precision(y_bar, y, n_classes)
nd_recall = lambda y_bar, y: my.nd_recall(y_bar, y, n_classes)
stats = (my.accuracy, nd_f_beta, nd_precision, nd_recall)

In [None]:
# c = my.MLP((784, n_classes), None)
c = lenet.LeNet(1, n_classes, 28)
if cuda:
    c.cuda()
optim = Adam(c.parameters(), lr=0.001)
for i in range(args.n_epochs):
    for x, y in train_loader:
        if cuda:
            x, y = x.cuda(), y.cuda()
        ce = CrossEntropyLoss()(c(x), y)
        optim.zero_grad()
        ce.backward()
        optim.step()
    accuracy, precision, recall, f1 = my.global_stats(c, test_loader, stats)
    print('[epoch %d]accuracy: %f; precision: %f; recall: %f; f1: %f' % (i + 1, accuracy, precision, recall, f1))

# Algorithm

Let $c$ be a classifier and $D=\{(X_1, y_1),...,(X_N, y_N)\}$ be the set of training data. In order to minimize $L(c, D)$, where $L$ is a non-decomposable loss function, we introduce $L_\theta$, a parameterized approximation of $L(c, D)$, and update $c$ as follows:

1. Compute $\delta = L(c, D)-L(\bar{c},D)$, where $\bar{c}$ is obtained by stochastically perturbing the parameters of $c$

2. Randomly sample $K$ subsets, $D_1, ..., D_K$, of $D$ (these subsets may vary in cardinality)

3. Minimize $(\delta - \frac1K \sum_{i = 1}^K \delta_i)^2$ with respect to $\theta$, where $\delta_i = L_\theta(c, D_i) - L_\theta(\bar{c}, D_i)$

4. Repeat 1, 2, and 3 several times until $L_\theta$ becomes a satisfactory approximation of $L$ near $c$

5. Randomly sample $K'$ subsets, $D_1, ..., D_K'$, of $D$ and let $c \leftarrow c - \alpha \sum_{i = 1}^K \frac{\partial L_\theta}{\partial c} (c, D_i)$, where $\alpha$ is a positive learning rate

In [None]:
args = Args()

args.actor_iterations = 25
args.c_batch_size = 8
args.critic_batch_size = 8
args.critic_iterations = 25
args.n_iterations = 1000
args.n_perturbations = 100
args.sample_size = 50
args.std = 1e-1
args.tau = 1e-1

verbose = None

In [None]:
def forward(c, xy):
    x, y = xy
    y = my.onehot(y, n_classes)
    y_bar = F.softmax(c(x), 1)
    return th.cat((y, y_bar), 1).view(1, -1)

def L_global(c, loader):
    return my.global_stats(c, loader, lambda y_bar, y: my.nd_f_beta(y_bar, y, n_classes))

In [None]:
train_set = TensorDataset(train_x, train_y)
train_loader = DataLoader(TensorDataset(train_x, train_y), 8192)

# c = my.MLP((784, n_classes), None)
c = lenet.LeNet(1, n_classes, 28)

critic = my.RN(args.sample_size, 2 * n_classes, tuple(), (4 * n_classes, 64, 256), (256, 64) + (1,), F.relu, triu=True)

if cuda:
    c.cuda()
    critic.cuda()

c_optim = Adam(c.parameters(), eps=1e-3)
critic_optim = Adam(critic.parameters())

print('initial f1: %f' % L_global(c, test_loader))

In [None]:
hist = []
for i in range(args.n_iterations):
    hist.append({})
#     hist[-1]['critic_state_dict'] = copy.deepcopy(my.state_dict_gpu2cpu(critic.state_dict()))        
#     hist[-1]['critic_optim_state_dict'] = my.optim_state_dict_gpu2cpu(critic_optim.state_dict())
#     hist[-1]['c_state_dict'] = copy.deepcopy(my.state_dict_gpu2cpu(c.state_dict()))
#     hist[-1]['c_optim_state_dict'] = my.optim_state_dict_gpu2cpu(c_optim.state_dict())

    if verbose == 0:
        t0 = time.time()

    L_c = L_global(c, train_loader)
    c_bar_list = []
    L_bar_list = []
    t_list = []
    for j in range(args.n_perturbations):
        c_bar_list.append(my.perturb(c, args.std))
        L_bar = L_global(c_bar_list[-1], train_loader)
        L_bar_list.append(L_bar)
        t_list.append(L_c - L_bar)
    w_list = [th.exp(t ** 2 / args.tau) for t in t_list]
    z = sum(w_list)
    w_list = [w / z for w in w_list]

    hist[-1]['L_bar_list'] = L_bar_list
    hist[-1]['w_list'] = w_list

    if verbose == 0:
        t1 = time.time()
        print('[iteration %d]t1 - t0: %f' % (i + 1, t1 - t0))
    
    s_critic = my.sample(train_set, args.sample_size, args.critic_batch_size, cuda)
    y = th.cat([forward(c, xy) for xy in s_critic], 0).detach()
    y_bar_list = [th.cat([forward(c_bar, xy) for xy in s_critic], 0).detach() for c_bar in c_bar_list]
    for j in range(args.critic_iterations):
        for y_bar, t, w in zip(y_bar_list, t_list, w_list):
            delta = critic(y) - critic(y_bar)
            mse = w * th.sum((t - delta) ** 2)
            critic_optim.zero_grad()
            mse.backward()
            critic_optim.step()
#     assert not my.module_isnan(critic)

    if verbose == 0:
        t2 = time.time()
        print('[iteration %d]t2 - t1: %f' % (i + 1, t2 - t1))

    c_param = copy.deepcopy(tuple(c.parameters()))
    for j in range(args.actor_iterations):
        s = s_critic + my.sample(train_set, args.sample_size, args.c_batch_size, cuda)
        y_bar = th.cat([forward(c, xy) for xy in s], 0)
        objective = -th.mean(critic(y_bar))
        c_optim.zero_grad()
        objective.backward()
        c_optim.step()
        if any(float(th.max(th.abs(p - q))) > args.std for p, q in zip(c_param, c.parameters())):
            break
#     assert not my.module_isnan(c)

    if verbose == 0:
        t3 = time.time()
        print('[iteration %d]t3 - t2: %f' % (i + 1, t3 - t2))

    f1 = L_global(c, test_loader)
    hist[-1]['f1'] = float(f1)
    
    if (i + 1) % 1 == 0:
        print('[iteration %d]f1: %f' % (i + 1, f1))