In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from researchlib.single_import import *

In [3]:
train_dataset = MNIST(True)
test_dataset = MNIST(False)

In [4]:
train_x, train_y = train_dataset.data.numpy(), train_dataset.targets
test_x, test_y = test_dataset.data.numpy(), test_dataset.targets
train_y2 = to_one_hot(train_y, 10).numpy()
test_y2 = to_one_hot(test_y, 10).numpy()
train_y = train_y.numpy()
test_y = test_y.numpy()

In [5]:
train_data = FromNumpy(train_x[:, None, :, :], [train_y2, train_y], batch_size=256)
test_data = FromNumpy(test_x[:, None, :, :], [test_y2, test_y], batch_size=256)

In [6]:
class Auxiliary(nn.Module):
    def __init__(self, f):
        super().__init__()
        self.f = f
        self.store = None
    
    def forward(self, x):
        # Store the result
        # Return the origin value
        self.store = self.f(x)
        return x 

In [7]:
class Reg(nn.Module):
    def __init__(self, f, group, get='weight'):
        super().__init__()
        self.f = f
        self.get = get
        self.reg_store = None
        self.reg_group = group
        
    def forward(self, x):
        out = self.f(x)
        if self.get == 'weight':
            self.reg_store = self.f.weight
        elif self.get == 'out':
            self.reg_store = out
        elif self.get == 'in':
            self.reg_stroe = x
        return out
    

In [8]:
class Orthogonal(nn.Module):
    def __init__(self, alpha=0.1):
        super().__init__()
        self.alpha = alpha
    
    def forward(self, w1, w2):
        return self.alpha * torch.matmul(w1, w2.t()).abs().mean()

In [9]:
model = builder([
    nn.Conv2d(1, 20, 5, 1),
    nn.ReLU(),
    nn.MaxPool2d(2, 2),
    nn.Conv2d(20, 50, 5, 1),
    nn.ReLU(),
    nn.MaxPool2d(2, 2),
    Flatten(),
    Auxiliary(builder([
        Reg(nn.Linear(4*4*50, 128), 'orth'),
        nn.ReLU(),
        Reg(nn.Linear(128, 10), 'orth2'),
        nn.Sigmoid()
    ])),
    Reg(nn.Linear(4*4*50, 128), 'orth'),
    nn.ReLU(),
    Reg(nn.Linear(128, 10), 'orth2'),
    nn.LogSoftmax(1)
])

In [10]:
runner = Runner(model, train_data, test_data, 'adam', ['bce', 'nll'], reg_fn={'orth': Orthogonal(), 'orth2': Orthogonal()}, fp16=False)

In [11]:
runner.fit(3, cycle='sc')

HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

HBox(children=(IntProgress(value=0, max=235), HTML(value='')))

  Epochs    train_loss    train_acc      val_loss      val_acc    
    1*        0.6079        0.9048        0.0537        0.9860    


HBox(children=(IntProgress(value=0, max=235), HTML(value='')))

    2         0.0445        0.9894        0.0475        0.9876    


HBox(children=(IntProgress(value=0, max=235), HTML(value='')))

    3         0.0332        0.9925        0.0442        0.9883    



In [12]:
runner.fit(6, cycle='cycle')

HBox(children=(IntProgress(value=0, max=21), HTML(value='')))

HBox(children=(IntProgress(value=0, max=235), HTML(value='')))

  Epochs    train_loss    train_acc      val_loss      val_acc    
    1         0.0603        0.9854        0.0363        0.9904    


HBox(children=(IntProgress(value=0, max=235), HTML(value='')))

    2         0.0486        0.9877        0.0390        0.9900    


HBox(children=(IntProgress(value=0, max=235), HTML(value='')))

    3         0.0153        0.9962        0.0312        0.9920    


HBox(children=(IntProgress(value=0, max=235), HTML(value='')))

    4         0.0363        0.9908        0.0619        0.9879    


HBox(children=(IntProgress(value=0, max=235), HTML(value='')))

    5         0.0157        0.9958        0.0307        0.9910    


HBox(children=(IntProgress(value=0, max=235), HTML(value='')))

    6         0.0038        0.9992        0.0281        0.9922    


HBox(children=(IntProgress(value=0, max=235), HTML(value='')))

    7         0.0269        0.9931        0.0395        0.9889    


HBox(children=(IntProgress(value=0, max=235), HTML(value='')))

    8         0.0123        0.9965        0.0362        0.9913    


HBox(children=(IntProgress(value=0, max=235), HTML(value='')))

    9         0.0033        0.9991        0.0330        0.9929    


HBox(children=(IntProgress(value=0, max=235), HTML(value='')))

    10        0.0006        0.9999        0.0326        0.9937    


HBox(children=(IntProgress(value=0, max=235), HTML(value='')))

    11        0.0247        0.9935        0.0524        0.9899    


HBox(children=(IntProgress(value=0, max=235), HTML(value='')))

    12        0.0107        0.9968        0.0456        0.9910    


HBox(children=(IntProgress(value=0, max=235), HTML(value='')))

    13        0.0046        0.9988        0.0369        0.9926    


HBox(children=(IntProgress(value=0, max=235), HTML(value='')))

    14        0.0009        0.9998        0.0355        0.9929    


HBox(children=(IntProgress(value=0, max=235), HTML(value='')))

    15        0.0001        1.0000        0.0359        0.9931    


HBox(children=(IntProgress(value=0, max=235), HTML(value='')))

    16        0.0185        0.9958        0.0567        0.9892    


HBox(children=(IntProgress(value=0, max=235), HTML(value='')))

    17        0.0070        0.9981        0.0510        0.9918    


HBox(children=(IntProgress(value=0, max=235), HTML(value='')))

    18        0.0052        0.9986        0.0485        0.9928    


HBox(children=(IntProgress(value=0, max=235), HTML(value='')))

KeyboardInterrupt: 

In [None]:
model = runner.model

In [None]:
data = next(iter(test_data))
x = data[0][0]
x = x[None, :, :, :]
print(data[2][0])

x = x.permute(0,1,3,2)

out1 = model(x.cuda())
out2 = get_aux_out(model)[0]
print(out1)
print(out2)
print((out1*out2))

import matplotlib.pyplot as plt
fig, arr = plt.subplots(1, 4, figsize=(14, 5))
arr[0].imshow(x.numpy()[0][0], cmap='gray')
arr[1].set_title('Softmax')
arr[1].set_xticks(range(10))
arr[1].set_yticks([i/10 for i in range(10)])
arr[1].set_ylim(0, 1)
arr[1].bar(range(10), out1.detach().cpu().numpy()[0])
arr[2].set_title('Sigmoid')
arr[2].set_xticks(range(10))
arr[2].set_yticks([i/10 for i in range(10)])
arr[2].set_ylim(0, 1)
arr[2].bar(range(10), out2.detach().cpu().numpy()[0])
arr[3].set_title('Softmax * Sigmoid')
arr[3].set_xticks(range(10))
arr[3].set_yticks([i/10 for i in range(10)])
arr[3].set_ylim(0, 1)
arr[3].bar(range(10), ((out1*out2)).detach().cpu().numpy()[0])
plt.show()
