In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from researchlib.single_import import *

In [3]:
train_dataset = MNIST(True)
test_dataset = MNIST(False)

In [4]:
train_x, train_y = train_dataset.data.numpy(), train_dataset.targets
test_x, test_y = test_dataset.data.numpy(), test_dataset.targets
train_y2 = to_one_hot(train_y, 10).numpy()
test_y2 = to_one_hot(test_y, 10).numpy()
train_y = train_y.numpy()
test_y = test_y.numpy()

In [5]:
train_data = FromNumpy(train_x[:, None, :, :], [train_y2, train_y], batch_size=256)
test_data = FromNumpy(test_x[:, None, :, :], [test_y2, test_y], batch_size=256)

In [6]:
class Auxiliary(nn.Module):
    def __init__(self, f):
        super().__init__()
        self.f = f
        self.store = None
    
    def forward(self, x):
        # Store the result
        # Return the origin value
        self.store = self.f(x)
        return x 

In [7]:
class Reg(nn.Module):
    def __init__(self, f, group, get='weight'):
        super().__init__()
        self.f = f
        self.get = get
        self.reg_store = None
        self.reg_group = group
        
    def forward(self, x):
        out = self.f(x)
        if self.get == 'weight':
            self.reg_store = self.f.weight
        elif self.get == 'out':
            self.reg_store = out
        elif self.get == 'in':
            self.reg_stroe = x
        return out
    

In [8]:
class Orthogonal(nn.Module):
    def __init__(self, alpha=0.1):
        super().__init__()
        self.alpha = alpha
    
    def forward(self, w1, w2):
        return self.alpha * torch.matmul(w1, w2.t()).abs().mean()

In [9]:
model = builder([
    nn.Conv2d(1, 20, 5, 1),
    nn.ReLU(),
    nn.MaxPool2d(2, 2),
    nn.Conv2d(20, 50, 5, 1),
    nn.ReLU(),
    nn.MaxPool2d(2, 2),
    Flatten(),
    Auxiliary(builder([
        Reg(nn.Linear(4*4*50, 128), 'orth'),
        nn.ReLU(),
        Reg(nn.Linear(128, 10), 'orth2'),
        nn.Sigmoid()
    ])),
    Reg(nn.Linear(4*4*50, 128), 'orth'),
    nn.ReLU(),
    Reg(nn.Linear(128, 10), 'orth2'),
    nn.LogSoftmax(1)
])

In [10]:
runner = Runner(model, train_data, test_data, 'adam', ['bce', 'nll'], reg_fn={'orth': Orthogonal(), 'orth2': Orthogonal()}, fp16=False)

KeyboardInterrupt: 

In [None]:
runner.fit(3, cycle='sc')

In [None]:
runner.fit(6, cycle='cycle')

In [None]:
model = runner.model

In [None]:
data = next(iter(test_data))
x = data[0][0]
x = x[None, :, :, :]
print(data[2][0])

x = x.permute(0,1,3,2)

out1 = model(x.cuda())
out2 = get_aux_out(model)[0]
print(out1)
print(out2)
print((out1*out2))

import matplotlib.pyplot as plt
fig, arr = plt.subplots(1, 4, figsize=(14, 5))
arr[0].imshow(x.numpy()[0][0], cmap='gray')
arr[1].set_title('Softmax')
arr[1].set_xticks(range(10))
arr[1].set_yticks([i/10 for i in range(10)])
arr[1].set_ylim(0, 1)
arr[1].bar(range(10), out1.detach().cpu().numpy()[0])
arr[2].set_title('Sigmoid')
arr[2].set_xticks(range(10))
arr[2].set_yticks([i/10 for i in range(10)])
arr[2].set_ylim(0, 1)
arr[2].bar(range(10), out2.detach().cpu().numpy()[0])
arr[3].set_title('Softmax * Sigmoid')
arr[3].set_xticks(range(10))
arr[3].set_yticks([i/10 for i in range(10)])
arr[3].set_ylim(0, 1)
arr[3].bar(range(10), ((out1*out2)).detach().cpu().numpy()[0])
plt.show()
