In [16]:
import torch
from torch import nn
from torch.nn import functional as F
import torchvision
from torchvision import transforms
from torch.utils import data
from tqdm import tqdm
from IPython import display
import matplotlib.pyplot as plt
import utils

In [None]:
class Inception (nn.Module):
    def __init__(self, in_channels, c1, c2, c3, c4, **kwargs):
        super(Inception, self).__init__(**kwargs)
        #第一条线路
        self.conv1 = nn.Conv2d(in_channels, c1, kernel_size=1)
        #第二条线路
        self.conv2 = nn.Sequential(
            nn.Conv2d(in_channels, c2[0], kernel_size=1), nn.ReLU(),
            nn.Conv2d(c2[0], c2[1], kernel_size=3,padding=1), nn.ReLU()
        )
        #第三条线路
        self.conv3 = nn.Sequential(
            nn.Conv2d(in_channels, c3[0], kernel_size=1), nn.ReLU(),
            nn.Conv2d(c3[0], c3[1], kernel_size=5, padding=2), nn.ReLU()
        )
        #第四条线路
        self.conv4 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3,stride=1, padding=1), nn.ReLU(),
            nn.Conv2d(in_channels, c4, kernel_size=1), nn.ReLU()
        )

    def forward(self, x):
        return torch.cat((
            self.conv1(x),
            self.conv2(x),
            self.conv3(x),
            self.conv4(x)
            ), dim = 1)


In [None]:
def train_ch6(net , train_iter, test_iter, num_epochs, lr, device):
    def init_weight(m):
        if type(m) == nn.Linear or type(m) == nn.Conv2d:
            nn.init.xavier_uniform_(m.weight)
    net.apply(init_weight)
    net.to(device)
    print('training on', device)
    loss = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(net.parameters(), lr=lr)
    for epoch in tqdm(range(num_epochs)):
        net.train()
        metric = utils.Accumulator(2)
        for i, (X, y) in enumerate(train_iter):
            optimizer.zero_grad()
            X, y = X.to(device), y.to(device)
            l = loss(net(X), y)
            l.backward()
            optimizer.step()
        print(l)