In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np

In [7]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [4]:
from torchvision import datasets

dataset = datasets.MNIST('./data', train=True, download=True)

x0 = dataset.data.reshape(60000,1,28,28) / 255.0
y01 = dataset.targets
y02 = dataset.targets
for i in range(len(y02)):
    if (y02[i] in [6, 8, 9]):
        y02[i] = 1
    else:
        y02[i] = 0

# 層が分岐するネットワーク

In [5]:
class MyCNN2(nn.Module):
    def __init__(self):
        super().__init__()
        self.cn1 = nn.Conv2d(1, 20, 5)
        self.pool1 = nn.MaxPool2d(2)
        self.cn2 = nn.Conv2d(20, 50, 5)
        self.dropout = nn.Dropout(p=0.4)
        self.fc = nn.Linear(3200, 10)
        self.fc2 = nn.Linear(3200, 2)
        
    def forward(self, x):
        x = F.relu(self.cn1(x))
        x = self.pool1(x)
        x = F.relu(self.cn2(x))
        x = self.dropout(x)
        x = x.reshape(len(x), -1)
        return x

In [8]:
model = MyCNN2().to(device)
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
criterion = nn.CrossEntropyLoss()

In [10]:
n = len(y01)
bs = 200

model.train()
for j in range(10):
    idx = np.random.permutation(n)
    for i in range(0, n, bs):
        x = x0[idx[i:(i+bs) if (i+bs) < n else n]].to(device)
        y1 = y01[idx[i:(i+bs) if (i+bs) < n else n]].to(device) 
        y2 = y02[idx[i:(i+bs) if (i+bs) < n else n]].to(device) 
        cnnx = model(x)
        out1 = model.fc(cnnx)
        out2 = model.fc2(cnnx)
        loss1 = criterion(out1, y1)
        loss2 = criterion(out2, y2)
        loss = loss1 + loss2
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    outfile = 'cnn2-' + str(j) + '.model'
    torch.save(model.state_dict(), outfile)
    print(outfile, 'saved')

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


cnn2-0.model saved
cnn2-1.model saved
cnn2-2.model saved
cnn2-3.model saved
cnn2-4.model saved
cnn2-5.model saved
cnn2-6.model saved
cnn2-7.model saved
cnn2-8.model saved
cnn2-9.model saved


In [11]:
dataset = datasets.MNIST('./data', train=False, download=True)
xt = dataset.data.reshape(10000,1,28,28) / 255.0
yans1 = dataset.targets
yans2 = dataset.targets
for i in range(len(yans2)):
    if (yans2[i] in [6, 8, 9]):
        yans2[i] = 1
    else:
        yans2[i] = 0

In [13]:
model = MyCNN2()
model.load_state_dict(torch.load('cnn2-9.model'))
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
criterion = nn.CrossEntropyLoss()

In [14]:
model.eval()
with torch.no_grad():
    cnnx = model(xt)
    out1 = model.fc(cnnx)
    out2 = model.fc2(cnnx)
    ans1 = torch.argmax(out1, 1)
    ans2 = torch.argmax(out2, 1)    
    print(((yans1 == ans1).sum().float()/len(ans1)).item())
    print(((yans2 == ans2).sum().float()/len(ans2)).item())  

0.9916999936103821
0.9914000034332275


# 複数のモデルの混在

In [15]:
class MyCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.cn1 = nn.Conv2d(1, 20, 5)
        self.pool1 = nn.MaxPool2d(2)
        self.cn2 = nn.Conv2d(20, 50, 5)
        self.dropout = nn.Dropout(p=0.4)
        self.fc = nn.Linear(3200, 10)
        
    def forward(self, x):
        x = F.relu(self.cn1(x))
        x = self.pool1(x)
        x = F.relu(self.cn2(x))
        x = self.dropout(x)
        x = x.view(len(x), -1)
        return self.fc(x)