## Imports

In [1]:
import sys
import os
import argparse
import time
import random
import math
import numpy as np
from scipy.special import logsumexp

import torch
import torchvision
from PIL import Image
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as tfs

from tensorboardX import SummaryWriter


from utils import kNN, AverageMeter, py_softmax

## Training parameters

In [2]:
#data
datadir = "CIFAR-10"

# optimization
lamb = 10      # SK lambda-parameter
nopts = 400    # number of SK-optimizations
epochs = 400   # numbers of epochs
momentum = 0.9 # sgd momentum
exp = './cifar_exp' # experiments results dir



# other
devc='0'  # cuda device
batch_size = 128
lr=0.03     #learning rate
alr=0.03    #starting learning rate

knn_dim = 4096
best_acc = 0  # best test accuracy
start_epoch = 0  # start from epoch 0 or last checkpoint epoch


In [3]:
device = torch.device('cuda:' + devc) if torch.cuda.is_available() else torch.device('cpu')
print(f"GPU device: {torch.cuda.current_device()}")

GPU device: 0


## Model parameters (AlexNet in that case)

In [4]:
hc=10       # number of heads
ncl=128     # number of clusters

numc = [ncl] * hc
# (number of filters, kernel size, stride, pad) for AlexNet, two vesions
CFG = {
    'big': [(96, 11, 4, 2), 'M', (256, 5, 1, 2), 'M', (384, 3, 1, 1), (384, 3, 1, 1), (256, 3, 1, 1), 'M'],
    'small': [(64, 11, 4, 2), 'M', (192, 5, 1, 2), 'M', (384, 3, 1, 1), (256, 3, 1, 1), (256, 3, 1, 1), 'M']
}

## Data Preparation

In [5]:
print (os.listdir(datadir))

['cifar-10-python.tar.gz', 'test', 'train', '.ipynb_checkpoints', 'cifar-10-batches-py']


In [6]:
class CIFAR10Instance(torchvision.datasets.CIFAR10):
    """CIFAR10Instance Dataset.
    """
    def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
        super(CIFAR10Instance, self).__init__(root=root,
                                                           train=train,
                                                           transform=transform,
                                                           target_transform=target_transform)

    def __getitem__(self, index):
        image, target = self.data[index], self.targets[index]

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        image = Image.fromarray(image)

        if self.transform is not None:
            img = self.transform(image)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target, index

In [7]:
transform_train = tfs.Compose([
    tfs.Resize(256),
    tfs.RandomResizedCrop(size=224, scale=(0.2, 1.)),
    tfs.ColorJitter(0.4, 0.4, 0.4, 0.4),
    tfs.RandomGrayscale(p=0.2),
    tfs.RandomHorizontalFlip(),
    tfs.ToTensor(),
    tfs.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = tfs.Compose([
        tfs.Resize(256),
        tfs.CenterCrop(224),
    tfs.ToTensor(),
    tfs.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

In [8]:
trainset = CIFAR10Instance(root=datadir, train=True, download=False,
                               transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=8)


testset = CIFAR10Instance(root=datadir, train=False, download=False,
                              transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)

In [9]:
N = len(trainloader.dataset)
print(N)

50000


## Model, AlexNet

In [10]:
class AlexNet(nn.Module):
    def __init__(self, features, num_classes, init=True):
        super(AlexNet, self).__init__()
        self.features = features
        self.classifier = nn.Sequential(nn.Dropout(0.5),
                            nn.Linear(256 * 6 * 6, 4096),
                            nn.ReLU(inplace=True),
                            nn.Dropout(0.5),
                            nn.Linear(4096, 4096),
                            nn.ReLU(inplace=True))
        self.headcount = len(num_classes)
        self.return_features = False
        if len(num_classes) == 1:
            self.top_layer = nn.Linear(4096, num_classes[0])
        else:
            for a,i in enumerate(num_classes):
                setattr(self, "top_layer%d" % a, nn.Linear(4096, i))
            self.top_layer = None  # this way headcount can act as switch.
        if init:
            self._initialize_weights()

    def forward(self, x):
        #print(self.features)
        x = self.features(x)
        x = x.view(x.size(0), 256 * 6 * 6)
        if self.return_features:
            return x
        x = self.classifier(x)
        if self.headcount == 1:
            if self.top_layer: # this way headcount can act as switch.
                x = self.top_layer(x)
            return x
        else:
            outp = []
            for i in range(self.headcount):
                outp.append(getattr(self, "top_layer%d" % i)(x))
            return outp

    def _initialize_weights(self):
        for y, m in enumerate(self.modules()):
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                for i in range(m.out_channels):
                    m.weight.data[i].normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.weight.data.normal_(0, 0.01)
                m.bias.data.zero_()


def make_layers_features(cfg, input_dim, bn):
    layers = []
    in_channels = input_dim
    for v in cfg:
        if v == 'M':
            layers += [nn.MaxPool2d(kernel_size=3, stride=2)]
        else:
            conv2d = nn.Conv2d(in_channels, v[0], kernel_size=v[1], stride=v[2], padding=v[3])#,bias=False)
            if bn:
                layers += [conv2d, nn.BatchNorm2d(v[0]), nn.ReLU(inplace=True)]
            else:
                layers += [conv2d, nn.ReLU(inplace=True)]
            in_channels = v[0]
    return nn.Sequential(*layers)


def alexnet(bn=True, num_classes=[1000], init=True, size='big'):
    dim = 3
    model = AlexNet(make_layers_features(CFG[size], dim, bn=bn), num_classes, init)
    return model

## Sinkhorn-Knopp optimization

In [11]:
def optimize_L_sk(PS):
    N, K = PS.shape
    tt = time.time()
    PS = PS.T  # now it is K x N
    r = np.ones((K, 1)) / K
    c = np.ones((N, 1)) / N
    PS **= lamb  # K x N
    inv_K = 1. / K
    inv_N = 1. / N
    err = 1e3
    _counter = 0
    while err > 1e-2:
        r = inv_K / (PS @ c)  # (KxN)@(N,1) = K x 1
        c_new = inv_N / (r.T @ PS).T  # ((1,K)@(KxN)).t() = N x 1
        if _counter % 10 == 0:
            err = np.nansum(np.abs(c / c_new - 1))
        c = c_new
        _counter += 1
    print("error: ", err, 'step ', _counter, flush=True)  # " nonneg: ", sum(I), flush=True)
    # inplace calculations.
    PS *= np.squeeze(c)
    PS = PS.T
    PS *= np.squeeze(r)
    PS = PS.T
    argmaxes = np.nanargmax(PS, 0)  # size N
    newL = torch.LongTensor(argmaxes)
    selflabels = newL.to(device)
    PS = PS.T
    PS /= np.squeeze(r)
    PS = PS.T
    PS /= np.squeeze(c)
    sol = PS[argmaxes, np.arange(N)]
    np.log(sol, sol)
    cost = -(1. / lamb) * np.nansum(sol) / N
    print('cost: ', cost, flush=True)
    print('opt took {0:.2f}min, {1:4d}iters'.format(((time.time() - tt) / 60.), _counter), flush=True)
    return cost, selflabels

def opt_sk(model, selflabels_in, epoch):
    if hc == 1:
        PS = np.zeros((len(trainloader.dataset), args.ncl))
    else:
        PS_pre = np.zeros((len(trainloader.dataset), knn_dim))
    for batch_idx, (data, _, _selected) in enumerate(trainloader):
        data = data.to(device)#cuda()
        if hc == 1:
            p = nn.functional.softmax(model(data), 1)
            PS[_selected, :] = p.detach().cpu().numpy()
        else:
            p = model(data)
            PS_pre[_selected, :] = p.detach().cpu().numpy()
    if hc == 1:
        cost, selflabels = optimize_L_sk(PS)
        _costs = [cost]
    else:
        _nmis = np.zeros(hc)
        _costs = np.zeros(hc)
        nh = epoch % hc  # np.random.randint(args.hc)
        print("computing head %s " % nh, end="\r", flush=True)
        tl = getattr(model, "top_layer%d" % nh)
        # do the forward pass:
        PS = (PS_pre @ tl.weight.cpu().numpy().T
                   + tl.bias.cpu().numpy())
        PS = py_softmax(PS, 1)
        c, selflabels_ = optimize_L_sk(PS)
        _costs[nh] = c
        selflabels_in[nh] = selflabels_
        selflabels = selflabels_in
    return selflabels


## Training utils

In [12]:
def adjust_learning_rate(optimizer, epoch):
    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
    lr = alr
    if epochs == 200:
        if epoch >= 80:
            lr = alr * (0.1 ** ((epoch - 80) // 40))  # i.e. 120, 160
            print(lr)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
    elif epochs == 400:
        if epoch >= 160:
            lr = alr * (0.1 ** ((epoch - 160) // 80))  # i.e. 240,320
            print(lr)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
    elif epochs == 800:
        if epoch >= 320:
            lr = alr * (0.1 ** ((epoch - 320) // 160))  # i.e. 480, 640
            print(lr)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
    elif epochs == 1600:
        if epoch >= 640:
            lr = alr * (0.1 ** ((epoch - 640) // 320))
            print(lr)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

In [13]:
def feature_return_switch(model, bool=True):
    """
    switch between network output or conv5features
        if True: changes switch s.t. forward pass returns post-conv5 features
        if False: changes switch s.t. forward will give full network output
    """
    if bool:
        model.headcount = 1
    else:
        model.headcount = hc
    model.return_feature = bool

In [14]:
def train(epoch, selflabels):
    print('\nEpoch: %d' % epoch)
    print(name)
    adjust_learning_rate(optimizer, epoch)
    train_loss = AverageMeter()
    data_time = AverageMeter()
    batch_time = AverageMeter()

    # switch to train mode
    model.train()

    end = time.time()

    for batch_idx, (inputs, targets, indexes) in enumerate(trainloader):
        niter = epoch * len(trainloader) + batch_idx
        if niter * trainloader.batch_size >= optimize_times[-1]:
            with torch.no_grad():
                _ = optimize_times.pop()
                if hc >1:
                    feature_return_switch(model, True)
                selflabels = opt_sk(model, selflabels, epoch)
                if hc >1:
                    feature_return_switch(model, False)
        data_time.update(time.time() - end)
        inputs, targets, indexes = inputs.to(device), targets.to(device), indexes.to(device)
        optimizer.zero_grad()

        outputs = model(inputs)
        if hc == 1:
            loss = criterion(outputs, selflabels[indexes])
        else:
            loss = torch.mean(torch.stack([criterion(outputs[h],
                                                     selflabels[h, indexes]) for h in range(hc)]))

        loss.backward()
        optimizer.step()

        train_loss.update(loss.item(), inputs.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        if batch_idx % 10 == 0:
            print('Epoch: [{}][{}/{}]'
                  'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) '
                  'Data: {data_time.val:.3f} ({data_time.avg:.3f}) '
                  'Loss: {train_loss.val:.4f} ({train_loss.avg:.4f})'.format(
                epoch, batch_idx, len(trainloader), batch_time=batch_time, data_time=data_time, train_loss=train_loss))
            writer.add_scalar("loss", loss.item(), batch_idx*512 +epoch*len(trainloader.dataset))
    return selflabels

## Model initialization

In [15]:
model = alexnet(num_classes=numc)

In [16]:
optimize_times = ((epochs + 1.0001)*N*(np.linspace(0, 1, nopts))[::-1]).tolist()
optimize_times = [(epochs +10)*N] + optimize_times
print('We will optimize L at epochs:', [np.round(1.0*t/N, 2) for t in optimize_times], flush=True)

We will optimize L at epochs: [410.0, 401.0, 400.0, 398.99, 397.99, 396.98, 395.98, 394.97, 393.97, 392.96, 391.95, 390.95, 389.94, 388.94, 387.93, 386.93, 385.92, 384.92, 383.91, 382.91, 381.9, 380.9, 379.89, 378.89, 377.88, 376.88, 375.87, 374.87, 373.86, 372.86, 371.85, 370.85, 369.84, 368.84, 367.83, 366.83, 365.82, 364.82, 363.81, 362.81, 361.8, 360.8, 359.79, 358.79, 357.78, 356.78, 355.77, 354.77, 353.76, 352.76, 351.75, 350.75, 349.74, 348.74, 347.73, 346.73, 345.72, 344.72, 343.71, 342.71, 341.7, 340.7, 339.69, 338.69, 337.68, 336.68, 335.67, 334.67, 333.66, 332.66, 331.65, 330.65, 329.64, 328.64, 327.63, 326.63, 325.62, 324.62, 323.61, 322.61, 321.6, 320.6, 319.59, 318.59, 317.58, 316.58, 315.57, 314.57, 313.56, 312.56, 311.55, 310.55, 309.54, 308.54, 307.53, 306.53, 305.52, 304.52, 303.51, 302.51, 301.5, 300.5, 299.49, 298.49, 297.48, 296.48, 295.47, 294.47, 293.46, 292.46, 291.45, 290.45, 289.44, 288.44, 287.43, 286.43, 285.42, 284.42, 283.41, 282.41, 281.4, 280.4, 279.39, 

In [17]:
# init selflabels randomly
if hc == 1:
    selflabels = np.zeros(N, dtype=np.int32)
    for qq in range(N):
        selflabels[qq] = qq % ncl
    selflabels = np.random.permutation(selflabels)
    selflabels = torch.LongTensor(selflabels).to(device)
else:
    selflabels = np.zeros((hc, N), dtype=np.int32)
    for nh in range(hc):
        for _i in range(N):
            selflabels[nh, _i] = _i % numc[nh]
        selflabels[nh] = np.random.permutation(selflabels[nh])
    selflabels = torch.LongTensor(selflabels).to(device)

In [18]:
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=5e-4)
model.to(device)
criterion = nn.CrossEntropyLoss()

In [19]:
name = "AlexNetfromthepaper"
writer = SummaryWriter(f'./runs/cifar10/{name}')

## Training! 
Takes a couple of minutes per epoch

In [None]:
for epoch in range(start_epoch, start_epoch + epochs):
    selflabels = train(epoch, selflabels)
    feature_return_switch(model, True)
    acc = kNN(model, trainloader, testloader, K=10, sigma=0.1, dim=knn_dim)
    feature_return_switch(model, False)
    writer.add_scalar("accuracy kNN", acc, epoch)
    if acc > best_acc:
        print('Saving..')
        state = {
            'net': model.state_dict(),
            'acc': acc,
            'epoch': epoch,
            'opt': optimizer.state_dict(),
            'L': selflabels,
        }
        if not os.path.isdir(exp):
            os.mkdir(exp)
        torch.save(state, '%s/best_ckpt.t7' % (exp))
        best_acc = acc
    if epoch % 100 == 0:
        print('Saving..')
        state = {
            'net': model.state_dict(),
            'opt': optimizer.state_dict(),
            'acc': acc,
            'epoch': epoch,
            'L': selflabels,
        }
        if not os.path.isdir(exp):
            os.mkdir(exp)
        torch.save(state, '%s/ep%s.t7' % (exp, epoch))
    if epoch % 50 == 0:
        feature_return_switch(model, True)
        acc = kNN(model, trainloader, testloader, K=[50, 10],
                  sigma=[0.1, 0.5], dim=knn_dim, use_pca=True)
        i = 0
        for num_nn in [50, 10]:
            for sig in [0.1, 0.5]:
                writer.add_scalar('knn%s-%s' % (num_nn, sig), acc[i], epoch)
                i += 1
        feature_return_switch(model, False)
    print('best accuracy: {:.2f}'.format(best_acc * 100))

checkpoint = torch.load('%s'%exp+'/best_ckpt.t7' )
model.load_state_dict(checkpoint['net'])
feature_return_switch(model, True)
acc = kNN(model, trainloader, testloader, K=10, sigma=0.1, dim=knn_dim, use_pca=True)


Epoch: 0
AlexNetfromthepaper
error:  0.003979012589635067 step  31
cost:  4.0347929817170884
opt took 0.01min,   31iters
Epoch: [0][0/391]Time: 63.655 (63.655) Data: 63.556 (63.556) Loss: 4.9827 (4.9827)
Epoch: [0][10/391]Time: 0.122 (5.895) Data: 0.017 (5.792) Loss: 4.9202 (4.9415)
Epoch: [0][20/391]Time: 0.123 (3.146) Data: 0.018 (3.043) Loss: 4.9069 (4.9248)
Epoch: [0][30/391]Time: 0.125 (2.177) Data: 0.021 (2.073) Loss: 4.8777 (4.9112)
Epoch: [0][40/391]Time: 0.124 (1.687) Data: 0.019 (1.583) Loss: 4.8658 (4.9020)
Epoch: [0][50/391]Time: 0.289 (1.393) Data: 0.186 (1.289) Loss: 4.8722 (4.8953)
Epoch: [0][60/391]Time: 0.121 (1.193) Data: 0.016 (1.088) Loss: 4.8640 (4.8903)
Epoch: [0][70/391]Time: 0.120 (1.052) Data: 0.016 (0.948) Loss: 4.8634 (4.8867)
Epoch: [0][80/391]Time: 0.120 (0.942) Data: 0.018 (0.838) Loss: 4.8623 (4.8838)
Epoch: [0][90/391]Time: 0.332 (0.860) Data: 0.230 (0.756) Loss: 4.8681 (4.8813)
Epoch: [0][100/391]Time: 0.119 (0.789) Data: 0.016 (0.686) Loss: 4.8623 (4.

Epoch: [2][140/391]Time: 0.111 (0.562) Data: 0.010 (0.459) Loss: 4.8326 (4.8483)
Epoch: [2][150/391]Time: 0.307 (0.534) Data: 0.201 (0.432) Loss: 4.8446 (4.8482)
Epoch: [2][160/391]Time: 0.112 (0.511) Data: 0.011 (0.409) Loss: 4.8386 (4.8480)
Epoch: [2][170/391]Time: 0.113 (0.490) Data: 0.011 (0.388) Loss: 4.8491 (4.8479)
Epoch: [2][180/391]Time: 0.116 (0.472) Data: 0.014 (0.369) Loss: 4.8436 (4.8479)
Epoch: [2][190/391]Time: 0.202 (0.454) Data: 0.098 (0.351) Loss: 4.8447 (4.8479)
Epoch: [2][200/391]Time: 0.113 (0.437) Data: 0.011 (0.335) Loss: 4.8428 (4.8478)
Epoch: [2][210/391]Time: 0.113 (0.422) Data: 0.011 (0.320) Loss: 4.8465 (4.8477)
Epoch: [2][220/391]Time: 0.115 (0.409) Data: 0.012 (0.307) Loss: 4.8430 (4.8477)
Epoch: [2][230/391]Time: 0.313 (0.400) Data: 0.210 (0.297) Loss: 4.8456 (4.8476)
Epoch: [2][240/391]Time: 0.122 (0.389) Data: 0.017 (0.287) Loss: 4.8499 (4.8475)
Epoch: [2][250/391]Time: 0.112 (0.380) Data: 0.011 (0.277) Loss: 4.8453 (4.8475)
Epoch: [2][260/391]Time: 0.1