# NiN
## 主要贡献：
LeNet、AlexNet和VGG都有一个共同的设计模式：通过一系列的卷积层与汇聚层来提取空间结构特征；然后通过全连接层对特征的表征进行处理。 AlexNet和VGG对LeNet的改进主要在于如何扩大和加深这两个模块。

但是卷积层简单的展平再后接全连接层可能会完全放弃表征的空间结构，损失空间信息。NiN提供了一个非常简单的解决方案：在每个像素的通道上分别使用多层感知机 ($ 1 \times 1$ 卷积)，相当于对前一层的所有feature map（输入通道）的线性组合（所以可以看作是全连接），再加上ReLU提供非线性变换。


NiN和AlexNet之间的一个显著区别是**NiN完全取消了全连接层**。 相反，NiN使用一个NiN块，其**输出通道数等于标签类别的数量。最后放一个全局平均汇聚层（global average pooling layer）**，每个通道经池化后成为一个特征。
NiN设计的一个优点是，它显著减少了模型所需参数的数量。然而，在实践中，这种设计有时会增加训练模型的时间。

## $1 \times 1$ 卷积的作用
1. 取代全连接层：不损失空间信息，同时显著减少参数量
2. 在卷积层中起到调整通道数的作用。

![](./img/nin.svg#pic_center)

In [4]:
import torch
from torch import nn
from torchsummary import summary

def nin_blocks(in_channels,out_channels,kernel_size,strides,padding):
    return nn.Sequential(
        nn.Conv2d(in_channels,out_channels,kernel_size,strides,padding),
        nn.ReLU(),
        nn.Conv2d(out_channels,out_channels,kernel_size=1),
        nn.ReLU(),
        nn.Conv2d(out_channels,out_channels,kernel_size=1),
        nn.ReLU(),    
    )

class NiN(nn.Module):
    def __init__(self, in_channels = 1, num_classes = 10):
        super(NiN,self).__init__()
        self.backbone = nn.Sequential(
            nin_blocks(in_channels,96,kernel_size = 11,strides = 4,padding = 0),
            nn.MaxPool2d(3,stride = 2),
            nin_blocks(96, 256, kernel_size = 5,strides = 1,padding = 2),
            nn.MaxPool2d(3,stride = 2),
            nin_blocks(256, 384, kernel_size = 3,strides = 1,padding = 1),
            nn.MaxPool2d(3,stride = 2),
            nn.Dropout(p = 0.5),
            nin_blocks(384,num_classes,kernel_size = 3, strides = 1, padding = 1),
            # 实现全局平均池化，参数为output_size,tuple类型
            nn.AdaptiveAvgPool2d((1,1)),
            nn.Flatten()
        )
    
    def forward(self,x):
        x = self.backbone(x)
        return x



# net = NiN(3,10)
net = NiN()
summary(net,(1,224,224),device="cpu")


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 96, 54, 54]          11,712
              ReLU-2           [-1, 96, 54, 54]               0
            Conv2d-3           [-1, 96, 54, 54]           9,312
              ReLU-4           [-1, 96, 54, 54]               0
            Conv2d-5           [-1, 96, 54, 54]           9,312
              ReLU-6           [-1, 96, 54, 54]               0
         MaxPool2d-7           [-1, 96, 26, 26]               0
            Conv2d-8          [-1, 256, 26, 26]         614,656
              ReLU-9          [-1, 256, 26, 26]               0
           Conv2d-10          [-1, 256, 26, 26]          65,792
             ReLU-11          [-1, 256, 26, 26]               0
           Conv2d-12          [-1, 256, 26, 26]          65,792
             ReLU-13          [-1, 256, 26, 26]               0
        MaxPool2d-14          [-1, 256,

In [5]:

X = torch.rand(size=(1, 1, 224, 224))
for layer in list(net.children())[0]:
    X = layer(X)
    print(layer.__class__.__name__,'output shape:\t', X.shape)

Sequential output shape:	 torch.Size([1, 96, 54, 54])
MaxPool2d output shape:	 torch.Size([1, 96, 26, 26])
Sequential output shape:	 torch.Size([1, 256, 26, 26])
MaxPool2d output shape:	 torch.Size([1, 256, 12, 12])
Sequential output shape:	 torch.Size([1, 384, 12, 12])
MaxPool2d output shape:	 torch.Size([1, 384, 5, 5])
Dropout output shape:	 torch.Size([1, 384, 5, 5])
Sequential output shape:	 torch.Size([1, 10, 5, 5])
AdaptiveAvgPool2d output shape:	 torch.Size([1, 10, 1, 1])
Flatten output shape:	 torch.Size([1, 10])


In [8]:
from torch.utils import data
import torchvision
from torchvision import transforms

# 超参数
batch_size = 128
num_epochs = 10
lr =0.001

trans = transforms.Compose({
    transforms.Resize(224),
    transforms.ToTensor()
})

mnist_train = torchvision.datasets.FashionMNIST('../../../DataSets/',train = True ,transform= trans,download=False)
mnist_test = torchvision.datasets.FashionMNIST('../../../DataSets/',train= False, transform= trans, download=False)

train_loader = data.DataLoader(mnist_train,batch_size,shuffle= True)
test_loader = data.DataLoader(mnist_test,shuffle= False,num_workers=4)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net.to(device)
optimzer = torch.optim.SGD(net.parameters(),lr = lr)
loss = nn.CrossEntropyLoss()

# 权重初始化
def init_weight(m):
    if type(m) == nn.Linear or type(m) == nn.Conv2d:
        nn.init.xavier_uniform_(m.weight) 
net.apply(init_weight)



def test(net,test_loader,loss,device):
    net.eval()
    num_correct = 0
    test_loss = 0
    with torch.no_grad():
        for X,y in test_loader:
            X,y = X.to(device),y.to(device)
            y_hat = net(X)
            l = loss(y_hat,y)
            test_loss += l.item()
            pred = torch.argmax(y_hat,dim=1)
            num_correct += torch.eq(pred,y).sum().item()
    return test_loss/len(test_loader.dataset), num_correct/len(test_loader.dataset)

def train(num_epochs,net,train_loader,loss,optimzer,device):

    print('Training start:')
    for epoch in range(num_epochs):
        net.train()
        num_correct = 0
        train_loss = 0
        for batch_idx, (X,y) in enumerate(train_loader):
            optimzer.zero_grad()
            X,y = X.to(device),y.to(device)
            y_hat = net(X)
            l = loss(y_hat,y)
            l.backward()
            optimzer.step()
            pred = torch.argmax(y_hat,dim = 1)
            num_correct += torch.eq(pred,y).sum().item()
            train_loss += l.item()
            # print('Batch {}: train loss {:.4f},lr {}'.format(batch_idx,l.item(),optimzer.param_groups[0]['lr']))
        train_loss = train_loss / (len(train_loader.dataset)/batch_size)
        train_acc = num_correct / len(train_loader.dataset)
        test_loss,test_acc = test(net,test_loader,loss,device)
        print('Epoch {}: train accuracy = {:.4f},train loss = {:.4f}; test accuracy = {:.4f}, test loss = {:.4f}'.format(epoch,train_acc,train_loss,test_acc,test_loss))

train(num_epochs, net,train_loader,loss,optimzer,device)

Training start:
Epoch 0: train accuracy = 0.1000,train loss = 2.3057; test accuracy = 0.1000, test loss = 2.3043
Epoch 1: train accuracy = 0.1000,train loss = 2.3054; test accuracy = 0.1000, test loss = 2.3041
Epoch 2: train accuracy = 0.1000,train loss = 2.3052; test accuracy = 0.1000, test loss = 2.3039
Epoch 3: train accuracy = 0.1000,train loss = 2.3050; test accuracy = 0.1000, test loss = 2.3037
Epoch 4: train accuracy = 0.1000,train loss = 2.3049; test accuracy = 0.1000, test loss = 2.3036
Epoch 5: train accuracy = 0.1000,train loss = 2.3047; test accuracy = 0.1000, test loss = 2.3034
Epoch 6: train accuracy = 0.1000,train loss = 2.3046; test accuracy = 0.1000, test loss = 2.3033
Epoch 7: train accuracy = 0.1000,train loss = 2.3045; test accuracy = 0.1000, test loss = 2.3032
Epoch 8: train accuracy = 0.1000,train loss = 2.3044; test accuracy = 0.1000, test loss = 2.3031
Epoch 9: train accuracy = 0.1000,train loss = 2.3044; test accuracy = 0.1000, test loss = 2.3031
