In [1]:
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
import pdb
import torchvision.utils as vutils
from pprint import pprint as pp
import numpy as np
from scipy import linalg

print('#GPU: ', torch.cuda.device_count())
print('PyTorch Version:', torch.__version__)
np.set_printoptions(precision=2)

#GPU:  8
PyTorch Version: 1.6.0+cu92


In [2]:
def toeplitz_1_ch(kernel, input_size):
    # shapes
    k_h, k_w = kernel.shape
    i_h, i_w = input_size
    o_h, o_w = i_h-k_h+1, i_w-k_w+1

    # construct 1d conv toeplitz matrices for each row of the kernel
    toeplitz = []
    for r in range(k_h):
        toeplitz.append(linalg.toeplitz(c=(kernel[r,0], *np.zeros(i_w-k_w)), r=(*kernel[r], *np.zeros(i_w-k_w))) ) 

    # construct toeplitz matrix of toeplitz matrices (just for padding=0)
    h_blocks, w_blocks = o_h, i_h
    h_block, w_block = toeplitz[0].shape

    W_conv = np.zeros((h_blocks, h_block, w_blocks, w_block))

    for i, B in enumerate(toeplitz):
        for j in range(o_h):
            W_conv[j, :, i+j, :] = B

    W_conv.shape = (h_blocks*h_block, w_blocks*w_block)

    return W_conv

def toeplitz_mult_ch(kernel, input_size):
    """Compute toeplitz matrix for 2d conv with multiple in and out channels.
    Args:
        kernel: shape=(n_out, n_in, H_k, W_k)
        input_size: (n_in, H_i, W_i)"""

    kernel_size = kernel.shape
    output_size = (kernel_size[0], input_size[1] - (kernel_size[2]-1), input_size[2] - (kernel_size[3]-1))
    print('==> output_size', output_size)
    T = np.zeros((output_size[0], int(np.prod(output_size[1:])), input_size[0], int(np.prod(input_size[1:]))))
    print('==> T', T.shape)

    for i,ks in enumerate(kernel):  # loop over output channel
        for j,k in enumerate(ks):  # loop over input channel
            T_k = toeplitz_1_ch(k, input_size[1:])
            T[i, :, j, :] = T_k

    T.shape = (np.prod(output_size), np.prod(input_size))

    return T

k = np.random.randn(1*1*2*2).reshape((1,1,2,2))
i = np.random.randn(1,1,4,4)

ref = F.conv2d(torch.tensor(i), torch.tensor(k), padding=0)
print('Ref Shape:', ref.shape)

i_pad = F.pad(torch.tensor(i), pad=[0,0,0,0])

print('--->', list(i_pad.size())[1:])
T = toeplitz_mult_ch(k, list(i_pad.size())[1:])
pp(T)
print('Weight Matrix, T Shape:', T.shape)
# ### Do conv as matrix (weight, T) * vector (input)
# out = T.dot(i_pad.numpy().flatten())
# print('Out Shape:', out.shape)


# out_reshape = out.reshape(1,4,7,7)
# print('Out_reshape Shape:', out_reshape.shape)


# print(np.sum((out_reshape - ref)**2))

out = torch.matmul(torch.tensor(T)[None,:,:], i_pad.reshape(i_pad.size(0), -1)[:,:,None])
print('Out Shape:', out.size())

out_reshape = out.reshape(1,1,3,3)
print('Out_reshape Shape:', out_reshape.size())

print((ref-out_reshape).abs().max())

Ref Shape: torch.Size([1, 1, 3, 3])
---> [1, 4, 4]
==> output_size (1, 3, 3)
==> T (1, 9, 1, 16)
array([[ 1.73,  0.11,  0.  ,  0.  , -0.24, -0.2 ,  0.  ,  0.  ,  0.  ,
         0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ],
       [ 0.  ,  1.73,  0.11,  0.  ,  0.  , -0.24, -0.2 ,  0.  ,  0.  ,
         0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ],
       [ 0.  ,  0.  ,  1.73,  0.11,  0.  ,  0.  , -0.24, -0.2 ,  0.  ,
         0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ],
       [ 0.  ,  0.  ,  0.  ,  0.  ,  1.73,  0.11,  0.  ,  0.  , -0.24,
        -0.2 ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ],
       [ 0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  1.73,  0.11,  0.  ,  0.  ,
        -0.24, -0.2 ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ],
       [ 0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  1.73,  0.11,  0.  ,
         0.  , -0.24, -0.2 ,  0.  ,  0.  ,  0.  ,  0.  ],
       [ 0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  1.73,
         0.11,  0.  ,  0.  , -0.24, -0.2 ,  0.  ,  0.  ],


In [3]:
class InvConv(nn.Module):
    def __init__(self):
        super(InvConv, self).__init__()
        self.if_init = False

    def init_layer(self, old_layer):
        self.old_weight = old_layer.weight.detach().cpu().clone().numpy()
        self.inp_padding = old_layer.padding
        
        self.inp_size = list(old_layer.inp_size)[1:] # inp_size was obtained by hook
        
        
        if old_layer.padding[-1]>0:
            self.inp_size[-1] += 2*old_layer.padding[-1]
        if old_layer.padding[-2]>0:
            self.inp_size[-2] += 2*old_layer.padding[-2]
            
        print('inp_size:', self.inp_size)
        print('old weight', self.old_weight.shape)
        self.old_weight_matrix = toeplitz_mult_ch(
            self.old_weight, self.inp_size)
        self.old_weight_matrix = torch.tensor(self.old_weight_matrix).float()
        
        
        self.if_init = True
        
    def forward(self, y, if_inv=True):
        '''
        Problem: 
            1. the converted weight matrix is super large. This matrix may consum ~150G memory. Please do NOT use cuda. 
        '''
        
        assert self.if_init

        # remember to transpose the old_weight_matrix
        self.old_weight_matrix = self.old_weight_matrix.to(y.device)
            
        if if_inv:
            '''
            inversion of convolution
            '''
            
            out = torch.matmul(self.old_weight_matrix.t()[None,:,:], y.view(y.size(0), -1)[:,:,None])

            # reshape the output
            out = out.view([out.size(0)] + self.inp_size)

            # un-padding
            if self.inp_padding[0] == 0 and self.inp_padding[1]==0:
                out_unpadding = out
            else:
                out_unpadding = out[:,:,self.inp_padding[0]:-self.inp_padding[0],self.inp_padding[1]:-self.inp_padding[1]]

            return out_unpadding
        else:
            '''
            standard convolution. However, we implement the convolution with pure matrix(weights)-vector(input) multiplication.
            
            '''
            x = y
            x_pad = F.pad(x, pad=[self.inp_padding[0],self.inp_padding[0],self.inp_padding[1],self.inp_padding[1]])
            print('### x_pad', x_pad.view(x_pad.size(0), -1)[:,:,None].size())
            out = torch.matmul(self.old_weight_matrix[None,:,:], x_pad.view(x_pad.size(0), -1)[:,:,None])
            # https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html?highlight=conv#torch.nn.Conv2d
            out_size = int( (self.inp_size[-1] - 1 * (self.old_weight.shape[-1] -1) - 1)  / 1. + 1)
            out = out.view([out.size(0), self.old_weight.shape[0]]+[out_size]*2)
            return out

    
if 1:
    # verify correction of our implementation
    inp = torch.randn(2,3,4,4)
    conv = nn.Conv2d(3, 5, 3, stride=1, padding=1, bias=False)
    conv.inp_size = [2,3,4,4]
    inv_conv = InvConv()
    inv_conv.init_layer(conv)
    
    nn_ref = conv(inp)
    print('nn Ref shape:', ref.shape)
    f_ref = F.conv2d(inp, conv.weight, stride=1, padding=1)
    out = inv_conv(inp, if_inv=False)
    
    print( (nn_ref-out).abs().max())
    print( (f_ref-out).abs().max())
    print(nn_ref)
    print(out)

inp_size: [3, 6, 6]
old weight (5, 3, 3, 3)
==> output_size (5, 4, 4)
==> T (5, 16, 3, 36)
nn Ref shape: torch.Size([1, 1, 3, 3])
### x_pad torch.Size([2, 108, 1])
tensor(3.5763e-07, grad_fn=<MaxBackward1>)
tensor(3.5763e-07, grad_fn=<MaxBackward1>)
tensor([[[[ 0.6956, -0.1893,  0.5630,  0.2856],
          [ 0.1645,  0.5064, -0.2672, -0.2958],
          [ 0.1333,  0.6983,  0.2453,  0.0849],
          [ 0.7977,  0.0585,  0.8783, -0.2751]],

         [[-0.0515, -0.2804, -0.3157,  0.5948],
          [-0.5185,  0.1907,  1.2124, -0.3263],
          [-0.3546, -0.2959,  0.1206,  0.1996],
          [ 0.2462,  0.6708, -0.7042,  0.5976]],

         [[ 0.4417,  0.5273,  0.4430,  0.5128],
          [ 0.2436, -0.8946, -0.0797,  0.1172],
          [ 0.5668,  0.3461,  0.1307,  0.2035],
          [ 0.1604,  0.5501,  0.5731,  0.0167]],

         [[-0.5004,  0.0868,  0.6181,  0.0332],
          [-0.5029, -0.3809,  0.5123,  0.0351],
          [-0.4592, -0.1736, -0.5052,  1.1377],
          [-0.0333, -0.5

In [4]:
# input = torch.arange(16).reshape(1, 1, 4, 4).float()
# print('\ninput', input)

# downsample = nn.Conv2d(1, 1, 2, stride=2, padding=0, bias=False)
# downsample.weight.data.copy_(torch.tensor([[[[0.,1.],[3.,2.]]]]))

# print('\nkernel:', downsample.weight.data)

# h = downsample(input)
# print('\nh:', h.data)

# upsample = nn.ConvTranspose2d(1, 1, 2, stride=2, padding=0, bias=False)
# upsample.weight.data.copy_(torch.tensor([[[[1.,1.],[1.,1.]]]])/6.)

# output = upsample(h)
# print('\noutput:', output.data)

In [5]:
# input = torch.arange(32).reshape(1, 2, 4, 4).float()
# print('\ninput', input)

# downsample = nn.Conv2d(2, 3, 2, stride=2, padding=0, bias=False)
# downsample.weight.data.copy_(torch.arange(24).float().reshape(3,2,2,2))

# print('\nweight:', downsample.weight.data)
# print('\nweight size:', downsample.weight.data.size())

# h = downsample(input)
# print('\nh:', h.data)

# upsample = nn.ConvTranspose2d(3, 2, 2, stride=2, padding=0, bias=False)
# print(upsample.weight.data.size())
# upsample.weight.data.copy_(torch.ones(3,2,2,2))

# output = upsample(h)
# print('\noutput:', output.data)

In [6]:
def loss_fn_kd(outputs, teacher_outputs):
    """
    Compute the knowledge-distillation (KD) loss given outputs
    """
    T = 3.0
    kld_loss = nn.KLDivLoss()(F.log_softmax(outputs / T, dim=1), F.softmax(teacher_outputs / T, dim=1))
    KD_loss = kld_loss * T * T
    return KD_loss



class InvertNet(nn.Module):
    def __init__(self, n_input_ch=1):
        super(InvertNet, self).__init__()
        # self.Tconv1 = nn.ConvTranspose2d(32, n_input_ch, 3, 1, bias=False)
        # self.Tconv2 = nn.ConvTranspose2d(64, 32, 3, 1, bias=False)
        self.Tconv1 = InvConv()
        self.Tconv2 = InvConv()
        self.Tfc1 = nn.Linear(128, int(64*7*7), bias=False)
        self.Tfc2 = nn.Linear(10, 128, bias=False)
        self.unpool = nn.MaxUnpool2d(2, stride=2)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear') #bilinear, nearest

    def forward(self, output, ind1, ind2):
        # going backward in the model
        # assume we don't synthesize anything that was lost by nonlinarities, like relu.
        # later we should add some noise like negative values to pre relu forexample
        # we cheat here by using ind input, it can be randomly assigned or use a predefined pattern
        # pdb.set_trace()
        x = self.Tfc2(output)
        # torch.Size([64, 128])
        x = self.Tfc1(x)
        # torch.Size([64, 9216])
        #unflatten
        x = x.reshape([64, 64, 7, 7])
        if 1:
            x = self.unpool(x, indices=ind2, output_size=torch.Size([64, 64, 14, 14]))
        else:
            x = self.upsample(x)
        # torch.Size([64, 64, 24, 24])
        x = self.Tconv2(x)
        # torch.Size([64, 32, 26, 26])
        x = self.unpool(x, indices=ind1, output_size=torch.Size([64, 32, 28, 28]))
        x = self.Tconv1(x)
        # torch.Size([64, 1, 28, 28])
        return x

class Net(nn.Module):
    def __init__(self, n_input_ch=3):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(n_input_ch, 32, 3, 1, padding=1, bias=False)
        self.conv2 = nn.Conv2d(32, 64, 3, 1, padding=1, bias=False)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(int(64*7*7), 128, bias=False)
        self.fc2 = nn.Linear(128, 10, bias=False)
        self.pool = nn.MaxPool2d(2, stride=2, return_indices=True)

    def forward(self, x):
        NO_DROP = True
        # pdb.set_trace()
        # x.shape : torch.Size([64, 1, 28, 28])
        x = self.conv1(x)
        # x.shape : torch.Size([64, 32, 28, 28])
        x = F.relu(x)
        
        x, ind1 = self.pool(x)
        
        x = self.conv2(x)
        # x.shape : torch.Size([64, 64, 14, 14])
        x = F.relu(x)
        # x, ind = F.max_pool2d(x, 2)
        x, ind2 = self.pool(x)
        # x.shape : torch.Size([64, 64, 7, 7])
        # ind.indeces : torch.Size([64, 64, 7, 7])
        # 0..22:2, 48..70:2 etc
        if not NO_DROP:
            x = self.dropout1(x)
        x = torch.flatten(x, 1)
        # x.shape torch.Size([64, 9216])
        x = self.fc1(x)
        # x.shape torch.Size([64, 128])
        x = F.relu(x)
        if not NO_DROP:
            x = self.dropout2(x)
        x = self.fc2(x)
        # x.shape torch.Size([64, 10])
        features = x
        output = F.log_softmax(x, dim=1)
        return output, (features, ind1, ind2)


def train(args, model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output, _ = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % args.log_interval == 0:
            print('1_Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
            if args.dry_run:
                break


def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output, _ = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\n1_Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))


# Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
                    help='input batch size for training (default: 64)')
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
                    help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs', type=int, default=1, metavar='N',
                    help='number of epochs to train (default: 14)')
parser.add_argument('--lr', type=float, default=1.0, metavar='LR',
                    help='learning rate (default: 1.0)')
parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
                    help='Learning rate step gamma (default: 0.7)')
parser.add_argument('--no-cuda', action='store_true', default=True,
                    help='disables CUDA training')
parser.add_argument('--dry-run', action='store_true', default=False,
                    help='quickly check a single pass')
parser.add_argument('--seed', type=int, default=1, metavar='S',
                    help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                    help='how many batches to wait before logging training status')
parser.add_argument('--save-model', action='store_true', default=False,
                    help='For Saving the current Model')
parser.add_argument('--input_noise', action='store_true', default=False,
                    help='learn distillation model from noise and not inverted images')
parser.add_argument('--cifar', action='store_true', default=False,
                    help='Use CIFAR10 dataset, if not set then use MNIST')
parser.add_argument('--train_inv_model', action='store_true', default=False,
                    help='Train inv_model so that output of it is correctly classified with model')
args = parser.parse_args('')
use_cuda = not args.no_cuda and torch.cuda.is_available()

torch.manual_seed(args.seed)

device = torch.device("cuda" if use_cuda else "cpu")

kwargs = {'batch_size': args.batch_size}
if use_cuda:
    kwargs.update({'num_workers': 1,
                    'pin_memory': True,
                    'shuffle': True},
                    )

CIFAR = args.cifar
if CIFAR:
    n_input_ch = 3
else:
    n_input_ch = 1

if not args.cifar:
    #MNIST
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    dataset1 = datasets.MNIST('../data', train=True, download=True,
                        transform=transform)
    dataset2 = datasets.MNIST('../data', train=False,
                        transform=transform)
else:
    #CIFAR10
    transform = transforms.Compose(
        [transforms.RandomCrop(28),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            ])
    dataset1 = datasets.CIFAR10('../data', train=True, download=True,
                                transform=transform)
    dataset2 = datasets.CIFAR10('../data', train=False,
                                transform=transform)
train_loader = torch.utils.data.DataLoader(dataset1,**kwargs)
test_loader = torch.utils.data.DataLoader(dataset2, **kwargs)

model = Net(n_input_ch=n_input_ch).to(device)

# torch.nn.init.normal_(model.conv1.weight) 
# torch.nn.init.normal_(model.conv2.weight) 
# torch.nn.init.normal_(model.fc1.weight) 
# torch.nn.init.normal_(model.fc2.weight) 


optimizer = optim.Adadelta(model.parameters(), lr=args.lr)

scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
for epoch in range(1, args.epochs + 1):
    train(args, model, device, train_loader, optimizer, epoch)
    test(model, device, test_loader)
    scheduler.step()

# if args.save_model:
#     torch.save(model.state_dict(), "mnist_cnn.pt")
inv_model = InvertNet(n_input_ch=n_input_ch).to(device)
# pdb.set_trace()

model.conv1.inp_size = [64, 1, 28, 28]
model.conv2.inp_size = [64, 32, 14, 14]

# inv_model.Tconv1.weight.data = torch.ones_like(model.conv1.weight.data) / model.conv1.weight.data.sum(dim=[2,3], keepdim=True)
# inv_model.Tconv2.weight.data = torch.ones_like(model.conv2.weight.data) / model.conv2.weight.data.sum(dim=[2,3], keepdim=True) 
inv_model.Tconv1.init_layer(model.conv1)
inv_model.Tconv2.init_layer(model.conv2)

# inv_model.Tfc1.weight.data = model.fc1.weight.data.T.contiguous().clone()
# inv_model.Tfc2.weight.data = model.fc2.weight.data.T.contiguous().clone()
inv_model.Tfc1.weight.data = torch.pinverse(model.fc1.weight.data.T).T
inv_model.Tfc2.weight.data = torch.pinverse(model.fc2.weight.data.T).T
print("--------------------------------------")
print("runnning one batch inversion with direct inversion")


TRAIN_ON_INVERTED = False
# NOISE_INPUT = True
NOISE_INPUT = args.input_noise
if TRAIN_ON_INVERTED:
    print("printing with target class labels for inverted images")
    print("will be tested on INVERTED images")
else:
    print("training with KD loss")
    print("will be tested on ORIGINAL images")

print("NOISE_INPUT: ", NOISE_INPUT)


print("**************************************")
print('training a inv_model to correctly excite the model')
if args.train_inv_model:
    # optimizer_inv_model = optim.Adam(inv_model.parameters(), lr=1e-4)
    optimizer_inv_model = optim.Adadelta(inv_model.parameters(), lr=args.lr)
    scheduler = StepLR(optimizer_inv_model, step_size=1, gamma=args.gamma)
    for epoch in range(1, args.epochs + 1):
        ###### training on inverted images, KD will work better
        inv_model.train()
        model.train()
        correct_forward_inverse= 0.0
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer_inv_model.zero_grad()
            model.zero_grad()

            if data.shape[0] != 64:
                print("hardcoded batch size 64 only")
                continue

            output_teacher, (features_teacher, ind) = model(data)
            #generate input
            inv_input = inv_model(features_teacher.detach(), ind)

            output, (features, _) = model(inv_input)
            # loss = F.nll_loss(output, target)
            loss = loss_fn_kd(features, features_teacher)

            loss.backward()
            optimizer_inv_model.step()

            correct_forward_inverse += (output.argmax(dim=1) == target).sum() / float(target.numel())

            if batch_idx % args.log_interval == 0:
                print('2a_Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tCorrect forward of inverted {:.1f}'.format(
                    epoch, batch_idx * len(data), len(train_loader.dataset),
                            100. * batch_idx / len(train_loader), loss.item(), 100. * correct_forward_inverse / float(batch_idx+1) ))
        ### testing:
for batch_idx, (data, target) in enumerate(train_loader):
    data, target = data.to(device), target.to(device)
    output, (features, ind1, ind2) = model(data)

    inv_input = inv_model(features, ind1, ind2)
    if NOISE_INPUT:
        # generate random input by perturbing all pixels
        idx = torch.randperm(data.nelement())
        inv_input = data.view(-1)[idx].view(data.size())
    # pdb.set_trace()
    # import numpy as np
    # data_display = np.concatenate((data.data.cpu().numpy(),  inv_input.data.cpu().numpy()), 3)
    data_display = torch.cat((data,  inv_input), 0)
    print("displaying original input-inverted input pairs")
    vutils.save_image(data_display,'orig_inverted_ones.png', normalize=True, scale_each=True, nrow=int(8))

    break

assert 1==0

print("--------------------------------------")
print("training calssifier on INVERTED images")


model_on_inv = Net(n_input_ch=n_input_ch).to(device)
optimizer = optim.Adadelta(model_on_inv.parameters(), lr=args.lr)

scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
for epoch in range(1, args.epochs + 1):
    ###### training on inverted images, KD will work better
    model_on_inv.train()
    model.train()
    correct_forward_inverse= 0.0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        model.zero_grad()
        inv_model.zero_grad()
        model_on_inv.zero_grad()

        if data.shape[0] != 64:
            print("hardcoded batch size 64 only")
            continue

        output_teacher, (features_teacher, ind) = model(data)
        #generate input
        inv_input = inv_model(features_teacher, ind)

        if NOISE_INPUT:
            #generate random input by perturbing all pixels
            idx = torch.randperm(data.nelement())
            inv_input = data.view(-1)[idx].view(data.size())

        output, (features, _) = model_on_inv(inv_input)
        if TRAIN_ON_INVERTED:
            loss = F.nll_loss(output, target)
        else:
            output_teacher, (features_teacher, ind) = model(inv_input)
            # Correctly classified images: (output_teacher.argmax(dim=1) == target).sum() / float(target.numel())
            correct_forward_inverse += (output_teacher.argmax(dim=1) == target).sum() / float(target.numel())
            loss = loss_fn_kd(features, features_teacher)

        loss.backward()
        optimizer.step()
        if batch_idx % args.log_interval == 0:
            print('3_Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tCorrect forward of inverted {:.1f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                        100. * batch_idx / len(train_loader), loss.item(), 100. * correct_forward_inverse / float(batch_idx+1) ))
    ### testing:
    model_on_inv.eval()
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)

            if data.shape[0] != 64:
                print("hardcoded batch size 64 only")
                continue

            if TRAIN_ON_INVERTED:
                output_teacher, (features, ind) = model(data)
                inv_input = inv_model(features, ind)
                output, _ = model_on_inv(inv_input)
                test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            else:
                output_teacher, (features_teacher, _) = model(data)
                output, (features, _) = model_on_inv(data)
                test_loss += loss_fn_kd(features, features_teacher)

            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\n3_Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

    # %98 accuracy



    scheduler.step()



1_Test set: Average loss: 0.0467, Accuracy: 9840/10000 (98%)

inp_size: [1, 30, 30]
old weight (32, 1, 3, 3)
==> output_size (32, 28, 28)
==> T (32, 784, 1, 900)
inp_size: [32, 16, 16]
old weight (64, 32, 3, 3)
==> output_size (64, 14, 14)
==> T (64, 196, 32, 256)
--------------------------------------
runnning one batch inversion with direct inversion
training with KD loss
will be tested on ORIGINAL images
NOISE_INPUT:  False
**************************************
training a inv_model to correctly excite the model
displaying original input-inverted input pairs


AssertionError: 