In [1]:
import torch
from torch import nn
from torchvision import datasets, transforms

In [2]:
mnist_set = datasets.MNIST(root="./mnist_data", train=True, transform=transforms.ToTensor(), download=True)
mnist_test_set = datasets.MNIST(root="./mnist_data", train=False, transform=transforms.ToTensor(), download=True)

In [3]:
mnist_loader = torch.utils.data.DataLoader(mnist_set, 36)
mnist_test_loader = torch.utils.data.DataLoader(mnist_test_set, 36)

In [6]:
class Selection(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.c_in = in_channels
        self.counter = 0
        self.clf = nn.Linear(in_features, 10)
    
    def forward(self, x):
        if self.counter == -1:
            out = self.clf(x)

        elif self.counter + 10 <= self.c_in:
            out = x[:, self.counter: self.counter+10]
            self.counter += 10
        
        else:
            out = x[:, -10:]
            self.counter = -1
        
        return out

class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        c_in = 28 * 28
        c_out= 512
        linears = []
        # clfs = []
        selectors = []
        for _ in range(5):
            linears += [nn.Linear(c_in, c_out)]
            # clfs += [nn.Linear(c_out, 10)]
            selectors += [Selection(c_out)]
            c_in, cout = c_out, c_out // 2
        
        self.clf = nn.Linear(c_out, 10)
        
        self.linears = nn.ModuleList(linears)
        self.selectors = nn.ModuleList(selectors)
        # self.clfs = nn.ModuleList(clfs)
        self.act = nn.LeakyReLU(0.2)
    
    def forward(self, x):
        features = [x]
        outputs = []
        for l, s in zip(self.linears, self.selectors):
            feature = l(features[-1])
            # perm = torch.randperm(feature.size(1))
            outputs += [s(feature)]
            features += [self.act(feature)]
        outputs += [self.clf(feature)]
        return outputs

net = MLP().cuda()
for name, p in net.named_parameters():
    print(name, p.size())
    torch.nn.init.constant_(p, 0.0)
    print(p.data.sum())
criterion = torch.nn.CrossEntropyLoss()
adam = torch.optim.Adam(net.parameters(), lr=3e-4)

clf.weight torch.Size([10, 512])
tensor(0.)
clf.bias torch.Size([10])
tensor(0.)
linears.0.weight torch.Size([512, 784])
tensor(0.)
linears.0.bias torch.Size([512])
tensor(0.)
linears.1.weight torch.Size([512, 512])
tensor(0.)
linears.1.bias torch.Size([512])
tensor(0.)
linears.2.weight torch.Size([512, 512])
tensor(0.)
linears.2.bias torch.Size([512])
tensor(0.)
linears.3.weight torch.Size([512, 512])
tensor(0.)
linears.3.bias torch.Size([512])
tensor(0.)
linears.4.weight torch.Size([512, 512])
tensor(0.)
linears.4.bias torch.Size([512])
tensor(0.)


In [7]:
def train_one_epoch():
    for x, y in mnist_loader:
        x, y = x.cuda(), y.cuda()
        adam.zero_grad()
        x = torch.flatten(x, start_dim=1)
        y_hat_list = net(x)
        loss = 0.0

        for y_hat in y_hat_list:
            loss += criterion(y_hat, y)
    
        loss.backward()
        adam.step()
    print("loss: {:.4f}".format(loss.item()))

for i in range(3):
    # print("training epoch {}".format(i))
    train_one_epoch()



loss: 1.0552
loss: 0.6074
loss: 0.5476


In [8]:
def acc(logit, target):
    predication = torch.argmax(logit, dim=1)
    return (predication == target).to(torch.float)

In [9]:
res = None
for x, y in mnist_test_loader:
    net.eval()
    with torch.no_grad():
        x = torch.flatten(x, start_dim=1)
        if res is None:
            res = acc(net(x)[-1], y)
        else:
            res = torch.cat([res, acc(net(x)[-1], y)])
    
print(res.mean())

    

0.9293279611378265


In [13]:
for name, p in net.named_parameters():
    print(name, p.data.sum(), p.data.mean(), (p.data < 1e-4).to(torch.float).mean())

clf.weight tensor(-4.2497) tensor(-0.0008) tensor(0.6049)
clf.bias tensor(0.0217) tensor(0.0022) tensor(0.4000)
linears.0.weight tensor(-2496.5786) tensor(-0.0062) tensor(0.6172)
linears.0.bias tensor(-6.8054) tensor(-0.0133) tensor(0.5820)
linears.1.weight tensor(-269.0327) tensor(-0.0010) tensor(0.5506)
linears.1.bias tensor(1.9512) tensor(0.0038) tensor(0.5059)
linears.2.weight tensor(-376.4302) tensor(-0.0014) tensor(0.5540)
linears.2.bias tensor(7.0401) tensor(0.0138) tensor(0.3125)
linears.3.weight tensor(-1322.2778) tensor(-0.0050) tensor(0.7080)
linears.3.bias tensor(-3.1884) tensor(-0.0062) tensor(0.7402)
linears.4.weight tensor(-367.5648) tensor(-0.0014) tensor(0.6515)
linears.4.bias tensor(3.2873) tensor(0.0064) tensor(0.3398)
