In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim

import torchvision
import torchvision.transforms as transforms

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

### nn.Sequential

In [10]:
class my_network(nn.Module):
    def __init__(self):
        super(my_network, self).__init__()
        
        self.layer1 = nn.Sequential(
            nn.Conv2d(3, 32, 5),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 64,3),
            nn.ReLU(inplace=True),
            nn.Conv2d(64,128,3),
            nn.Conv2d(128, 256, 3),
            nn.MaxPool2d(2)
        )
        
        self.layer2_1 = nn.Sequential(
            nn.Conv2d(256, 512, 7, 1, 2),
            nn.Conv2d(512, 64, 1),
            nn.MaxPool2d(2)
        )
        
        self.layer2_2 = nn.Sequential(
            nn.Conv2d(256, 512, 5, 1, 1),
            nn.Conv2d(512, 64, 1),
            nn.MaxPool2d(2)
        )
        
        self.layer2_3 = nn.Sequential(
            nn.Conv2d(256, 512, 3),
            nn.Conv2d(512, 64, 1),
            nn.MaxPool2d(2)
        )
        
        self.fc = nn.Sequential(
            nn.Linear(3*64*4*4, 1024),
            nn.ReLU(True),
            nn.Linear(1024,10)
        )
        
    def forward(self, x):
        print(x.data.shape)
        x= self.layer1(x)
        x1 = self.layer2_1(x)
        print("x1", x1.shape)
        x2 = self.layer2_2(x)
        print("x2", x2.shape)
        x3 = self.layer2_3(x)
        print("x3", x3.shape)
        x= torch.cat((x1,x2,x3), dim=1)
        x = x.view(x.shape[0], -1)
        x = self.fc(x)
        
        return x

In [11]:
a = torch.rand(1, 3, 32, 32)
a.shape

torch.Size([1, 3, 32, 32])

In [12]:
net = my_network()
output = net(a)

print(output.shape)

torch.Size([1, 3, 32, 32])
x1 torch.Size([1, 64, 4, 4])
x2 torch.Size([1, 64, 4, 4])
x3 torch.Size([1, 64, 4, 4])
torch.Size([1, 10])


In [13]:
print(output)

tensor([[ 0.0086, -0.0039,  0.0216, -0.0073,  0.0301, -0.0197,  0.0177, -0.0261,
          0.0055, -0.0140]], grad_fn=<AddmmBackward>)


### 모델 저장하기

In [14]:
torch.save(net.state_dict(), './my_model.pth')

### 모델 불러오기

In [15]:
model = my_network()  #모델 구조 정의
model.load_state_dict(torch.load('./my_model.pth'))

In [16]:
output2 = model(a)
print(output2)

torch.Size([1, 3, 32, 32])
x1 torch.Size([1, 64, 4, 4])
x2 torch.Size([1, 64, 4, 4])
x3 torch.Size([1, 64, 4, 4])
tensor([[ 0.0086, -0.0039,  0.0216, -0.0073,  0.0301, -0.0197,  0.0177, -0.0261,
          0.0055, -0.0140]], grad_fn=<AddmmBackward>)
