In [1]:
import numpy as np
import torch 
from torch import nn
from torch.autograd import Variable
from torchvision import transforms
from torchvision.datasets import CIFAR10
from datetime import datetime

In [2]:
#建造一个卷积层+batch normalization层+relu层
def conv2_bn_relu(in_channel,out_channel,kernel,stride=1,padding=0):
    layers=nn.Sequential(
        nn.Conv2d(in_channel,out_channel,kernel,stride,padding),
        nn.BatchNorm2d(out_channel,eps=1e-3),
        nn.ReLU()
    )
    return layers

In [3]:
class inception(nn.Module):
    def __init__(self,in_channel,out_channel1_1,out_channel2_1,out_channel2_3,out_channel3_1,out_channel3_5,out_channel4_1):
        super(inception,self).__init__()
        #第一条线路
        self.branch1x1=conv2_bn_relu(in_channel,out_channel1_1,1)
        #第二条线路
        self.branch3x3=nn.Sequential(
            conv2_bn_relu(in_channel,out_channel2_1,1),
            conv2_bn_relu(out_channel2_1,out_channel2_3,3,padding=1)
        )        
        #第三条线路
        self.branch5x5=nn.Sequential(
            conv2_bn_relu(in_channel,out_channel3_1,1),
            conv2_bn_relu(out_channel3_1,out_channel3_5,5,padding=2)
        )       
        #第四条线路
        self.branchpool=nn.Sequential(
            nn.MaxPool2d(3,stride=1,padding=1),
            conv2_bn_relu(in_channel,out_channel4_1,1)
        )
    
    def forward(self,x):
        f1=self.branch1x1(x)
        f2=self.branch3x3(x)
        f3=self.branch5x5(x)
        f4=self.branchpool(x)
        output=torch.cat((f1,f2,f3,f4),dim=1)
        return output

In [4]:
class googlenet(nn.Module):
    def __init__(self,in_channel,num_class,verbose=False):
        super(googlenet,self).__init__()
        self.verbose=verbose
        self.block1=nn.Sequential(
            conv2_bn_relu(in_channel,out_channel=64,kernel=7,stride=2,padding=3),
            nn.MaxPool2d(3,2)
        )    
        self.block2=nn.Sequential(
            conv2_bn_relu(64,64,kernel=1),
            conv2_bn_relu(64,192,kernel=3,padding=1),
            nn.MaxPool2d(3,2)
        )        
        self.block3=nn.Sequential(
            inception(192,64,96,128,16,32,32),
            inception(256,128,128,192,32,96,64),
            nn.MaxPool2d(3,2)
        )        
        self.block4=nn.Sequential(
            inception(480,192,96,208,16,48,64),
            inception(512,160,112,224,24,64,64),
            inception(512,128,128,256,24,64,64),
            inception(512,112,144,288,32,64,64),
            inception(528,256,160,320,32,128,128),
            nn.MaxPool2d(3,2)
        )            
        self.block5=nn.Sequential(
            inception(832,256,160,320,32,128,128),
            inception(832,384,182,384,48,128,128),
            nn.AvgPool2d(2)
        )        
        self.classify=nn.Linear(1024,num_class)
        
    def forward(self,x):
        x=self.block1(x)
        if self.verbose:
            print('block 1 output:{}'.format(x.shape))
        x=self.block2(x)
        if self.verbose:
            print('block 2 output:{}'.format(x.shape))
        x=self.block3(x)
        if self.verbose:
            print('block 3 output:{}'.format(x.shape))
        x=self.block4(x)
        if self.verbose:
            print('block 4 output:{}'.format(x.shape))
        x=self.block5(x)
        if self.verbose:
            print('block 5 output:{}'.format(x.shape))
        x=x.view(x.size(0),-1)
        x=self.classify(x)
        return x

net=googlenet(3,10)

In [5]:
testnet=googlenet(3,10,True)
test_x=Variable(torch.rand(1,3,96,96))
test_y=testnet(test_x)
print('output:{}'.format(test_y.size()))

block 1 output:torch.Size([1, 64, 23, 23])
block 2 output:torch.Size([1, 192, 11, 11])
block 3 output:torch.Size([1, 480, 5, 5])
block 4 output:torch.Size([1, 832, 2, 2])
block 5 output:torch.Size([1, 1024, 1, 1])
output:torch.Size([1, 10])


In [6]:
def data_tf(x):
    x=x.resize((96,96),2)
    x=np.array(x,dtype='float32')/255
    x=(x-0.5)/0.5
    x=x.transpose((2,0,1))
    x=torch.from_numpy(x)
    return x

In [7]:
def data_tf2(x):
    x=x.resize((96,96),2)
    x=transforms.ToTensor()(x)
    x=(x-0.5)/0.5
    return x

In [8]:
train_set=CIFAR10('./data',train=True,transform=data_tf2)
test_set=CIFAR10('./data',train=False,transform=data_tf2)
train_set[0][0]

tensor([[[-0.5373, -0.5373, -0.5765,  ...,  0.1686,  0.1608,  0.1608],
         [-0.5373, -0.5373, -0.5765,  ...,  0.1686,  0.1608,  0.1608],
         [-0.6471, -0.6471, -0.6863,  ...,  0.0980,  0.0902,  0.0902],
         ...,
         [ 0.3961,  0.3961,  0.3725,  ..., -0.0824, -0.1373, -0.1373],
         [ 0.3882,  0.3882,  0.3647,  ...,  0.0353, -0.0353, -0.0353],
         [ 0.3882,  0.3882,  0.3647,  ...,  0.0353, -0.0353, -0.0353]],

        [[-0.5137, -0.5137, -0.5529,  ..., -0.0275, -0.0275, -0.0275],
         [-0.5137, -0.5137, -0.5529,  ..., -0.0275, -0.0275, -0.0275],
         [-0.6235, -0.6235, -0.6706,  ..., -0.1294, -0.1216, -0.1216],
         ...,
         [ 0.1137,  0.1137,  0.0745,  ..., -0.3255, -0.3804, -0.3804],
         [ 0.1294,  0.1294,  0.0902,  ..., -0.2078, -0.2784, -0.2784],
         [ 0.1294,  0.1294,  0.0902,  ..., -0.2078, -0.2784, -0.2784]],

        [[-0.5059, -0.5059, -0.5529,  ..., -0.1922, -0.1922, -0.1922],
         [-0.5059, -0.5059, -0.5529,  ..., -0

In [9]:
from torch.utils.data import DataLoader
train_data = DataLoader(train_set, batch_size=64, shuffle=True)
test_data = DataLoader(test_set, batch_size=128, shuffle=False)

In [10]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=1e-1) # 使用随机梯度下降，学习率 0.1

In [11]:
def get_acc(output, label):
    total = output.shape[0]
    _, pred_label = output.max(1)
    num_correct = (pred_label == label).sum().item()
    return num_correct / total

def set_learning_rate(optimizer, lr):
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

In [13]:
def train(net, train_data, valid_data, num_epochs, optimizer, criterion):
    prev_time = datetime.now()
    net=net.cuda()
    for epoch in range(num_epochs):
        if epoch==15:
            set_learning_rate(optimizer,1e-2)
        train_loss = 0
        train_acc = 0
        net = net.train()
        for im, label in train_data:
            im = Variable(im)
            label = Variable(label)
            im=im.cuda()
            label=label.cuda()
            # forward
            output = net(im)
            loss = criterion(output, label)
            # backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            train_acc += get_acc(output, label)
            
        cur_time = datetime.now()
        h, remainder = divmod((cur_time - prev_time).seconds, 3600)
        m, s = divmod(remainder, 60)
        time_str = "Time %02d:%02d:%02d" % (h, m, s)
        
        if valid_data is not None:
            valid_loss = 0
            valid_acc = 0
            net = net.eval()
            for im, label in valid_data:
                im = Variable(im, volatile=True)
                label = Variable(label, volatile=True)
                im=im.cuda()
                label=label.cuda()
                
                output = net(im)
                loss = criterion(output, label)
                
                valid_loss += loss.item()
                valid_acc += get_acc(output, label)
            epoch_str = (
                "Epoch %d. Train Loss: %f, Train Acc: %f, Valid Loss: %f, Valid Acc: %f, "
                % (epoch, train_loss / len(train_data),
                   train_acc / len(train_data), valid_loss / len(valid_data),
                   valid_acc / len(valid_data)))
        else:
            epoch_str = ("Epoch %d. Train Loss: %f, Train Acc: %f, " %
                         (epoch, train_loss / len(train_data),
                          train_acc / len(train_data)))
        prev_time = cur_time
        print(epoch_str + time_str)

In [14]:
train(net, train_data, test_data,5, optimizer, criterion)



RuntimeError: CUDA out of memory. Tried to allocate 72.00 MiB (GPU 0; 2.00 GiB total capacity; 1.16 GiB already allocated; 33.39 MiB free; 27.06 MiB cached)