<a href="https://colab.research.google.com/github/LeeJZh/A_Network_That_Trains_At_Zero_Init/blob/master/network_trains_at_zero_init.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
from torch import nn
from torchvision import datasets, transforms

In [2]:
mnist_set = datasets.MNIST(root="./mnist_data", train=True, transform=transforms.ToTensor(), download=True)
mnist_test_set = datasets.MNIST(root="./mnist_data", train=False, transform=transforms.ToTensor(), download=True)

In [3]:
mnist_loader = torch.utils.data.DataLoader(mnist_set, 36)
mnist_test_loader = torch.utils.data.DataLoader(mnist_test_set, 36)

In [4]:
class Selection(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.c_in = in_channels
        self.counter = 0
    
    def forward(self, x):

        if self.counter + 10 <= self.c_in:
            out = x[:, self.counter: self.counter+10]
            self.counter += 10
        
        else:
            out = x[:, -10:]
            self.counter = 0
        
        return out

class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        c_in = 28 * 28
        c_out= 512
        linears = []
        selectors = []
        for _ in range(5):
            linears += [nn.Linear(c_in, c_out)]
            selectors += [Selection(c_out)]
            c_in, cout = c_out, c_out // 2
        
        self.clf = nn.Linear(c_out, 10)
        self.linears = nn.ModuleList(linears)
        self.selectors = nn.ModuleList(selectors)
        # self.clfs = nn.ModuleList(clfs)
        self.act = nn.LeakyReLU(0.2)
    
    def forward(self, x):
        features = [x]
        outputs = []
        for l, s in zip(self.linears, self.selectors):
            feature = l(features[-1])
            outputs += [s(feature)]
            features += [self.act(feature)]
        outputs += [self.clf(feature)]
        return outputs

net = MLP().cuda()
for name, p in net.named_parameters():
    print(name, p.size())
    torch.nn.init.constant_(p, 0.0)
    print(p.data.sum())
criterion = torch.nn.CrossEntropyLoss()
adam = torch.optim.Adam(net.parameters(), lr=3e-4)

clf.weight torch.Size([10, 512])
tensor(0., device='cuda:0')
clf.bias torch.Size([10])
tensor(0., device='cuda:0')
linears.0.weight torch.Size([512, 784])
tensor(0., device='cuda:0')
linears.0.bias torch.Size([512])
tensor(0., device='cuda:0')
linears.1.weight torch.Size([512, 512])
tensor(0., device='cuda:0')
linears.1.bias torch.Size([512])
tensor(0., device='cuda:0')
linears.2.weight torch.Size([512, 512])
tensor(0., device='cuda:0')
linears.2.bias torch.Size([512])
tensor(0., device='cuda:0')
linears.3.weight torch.Size([512, 512])
tensor(0., device='cuda:0')
linears.3.bias torch.Size([512])
tensor(0., device='cuda:0')
linears.4.weight torch.Size([512, 512])
tensor(0., device='cuda:0')
linears.4.bias torch.Size([512])
tensor(0., device='cuda:0')


In [5]:
def train_one_epoch():
    for x, y in mnist_loader:
        x, y = x.cuda(), y.cuda()
        adam.zero_grad()
        x = torch.flatten(x, start_dim=1)
        y_hat_list = net(x)
        loss = 0.0
        alpha = 0.1
        for y_hat in y_hat_list[:-1]:
            loss += criterion(y_hat, y) * alpha
        
        loss += criterion(y_hat_list[-1], y)
    
        loss.backward()
        adam.step()
    print("loss: {:.4f}".format(loss.item()))

for i in range(30):
    # print("training epoch {}".format(i))
    train_one_epoch()



loss: 0.3279
loss: 0.2293
loss: 0.1672
loss: 0.1580
loss: 0.1275
loss: 0.1695
loss: 0.1167
loss: 0.0977
loss: 0.0810
loss: 0.0973
loss: 0.0779
loss: 0.0915
loss: 0.1180
loss: 0.0866
loss: 0.1281
loss: 0.0878
loss: 0.4143
loss: 0.0718
loss: 0.0800
loss: 0.0707
loss: 0.1012
loss: 0.0523
loss: 0.0526
loss: 0.0519
loss: 0.0503
loss: 0.0589
loss: 0.0481
loss: 0.0481
loss: 0.0661
loss: 0.0580


In [6]:
def acc(logit, target):
    predication = torch.argmax(logit, dim=1)
    return (predication == target).to(torch.float)

In [7]:
res = None
# net.cpu()
for x, y in mnist_test_loader:
    x, y = x.cuda(), y.cuda()
    net.eval()
    with torch.no_grad():
        x = torch.flatten(x, start_dim=1)
        if res is None:
            res = acc(net(x)[-1], y)
        else:
            res = torch.cat([res, acc(net(x)[-1], y)], dim=0)
    
print(res.mean().item())

    

0.9728999733924866


In [8]:
for name, p in net.named_parameters():
    print(name, p.data.sum(), p.data.mean(), (p.data < 1e-4).to(torch.float).mean())

clf.weight tensor(-10.9191, device='cuda:0') tensor(-0.0021, device='cuda:0') tensor(0.7193, device='cuda:0')
clf.bias tensor(-0.2071, device='cuda:0') tensor(-0.0207, device='cuda:0') tensor(0.5000, device='cuda:0')
linears.0.weight tensor(-7966.9590, device='cuda:0') tensor(-0.0198, device='cuda:0') tensor(0.6098, device='cuda:0')
linears.0.bias tensor(-2.9725, device='cuda:0') tensor(-0.0058, device='cuda:0') tensor(0.5352, device='cuda:0')
linears.1.weight tensor(-920.2904, device='cuda:0') tensor(-0.0035, device='cuda:0') tensor(0.4960, device='cuda:0')
linears.1.bias tensor(-23.1721, device='cuda:0') tensor(-0.0453, device='cuda:0') tensor(0.7402, device='cuda:0')
linears.2.weight tensor(-3057.7368, device='cuda:0') tensor(-0.0117, device='cuda:0') tensor(0.5878, device='cuda:0')
linears.2.bias tensor(-37.7117, device='cuda:0') tensor(-0.0737, device='cuda:0') tensor(0.8711, device='cuda:0')
linears.3.weight tensor(-658.0060, device='cuda:0') tensor(-0.0025, device='cuda:0') tens