In [2]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn
import torchvision
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
from functiona import *

In [2]:
# デバイスの割り当て

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [3]:

# 1階層
transform1 = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(0.5, 0.5),
    transforms.Lambda(lambda x: x.view(-1)),
])

# 3階層
transform2 = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(0.5, 0.5),
])

# データセットの取得
data_root = './hidden_data/'

# 一階層セット
#　訓練データ
train_set1 = datasets.CIFAR10(root = data_root,
                              train = True,
                              download = True,
                              transform = transform1)

test_set1 = datasets.CIFAR10(root = data_root,
                              train = False,
                              download = True,
                              transform = transform1)

# ３階層
train_set2 = datasets.CIFAR10(root = data_root,
                              train = True,
                              download = True,
                              transform = transform2)

test_set2 = datasets.CIFAR10(root = data_root,
                              train = False,
                              download = True,
                              transform = transform2)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [4]:
# データローダの定義
batch_size = 100

# 1階層
train_loader1 = DataLoader(dataset=train_set1,
                           batch_size=batch_size,
                           shuffle=True)

test_loader1 = DataLoader(dataset=test_set1,
                           batch_size=batch_size,
                           shuffle=False)

#　2階層
train_loader2 = DataLoader(dataset=train_set2,
                           batch_size=batch_size,
                           shuffle=True)

test_loader2 = DataLoader(dataset=test_set2,
                           batch_size=batch_size,
                           shuffle=False)

In [14]:
# データローダの確認
for images1, labels1 in train_loader1:
    break

for images2, labels2 in train_loader2:
    break

print(images1[0].shape)
print(images2.shape)

torch.Size([3072])
torch.Size([100, 3, 32, 32])


In [6]:
# 正解ラベル定義
classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')


In [15]:
n_input = images1[0].view(-1).shape[0]

n_output = len(classes)

n_hidden = 128

# 結果確認
print(f'n_input: {n_input}  n_hidden: {n_hidden} n_output: {n_output}')

n_input: 3072  n_hidden: 128 n_output: 10


In [16]:
# モデル

class Net(nn.Module):
    def __init__(self, n_input, n_output, n_hidden):
        super().__init__()

        self.l1 = nn.Linear(n_input, n_hidden)
        self.l2 = nn.Linear(n_hidden, n_output)
        self.relu = nn.ReLU(inplace=True)


    def forward(self, x):
        x1 = self.l1(x)
        x2 = self.relu(x1)
        x3 = self.l2(x2)
        return x3

In [18]:

class CNN(nn.Module):
    def __init__(self, n_output, n_hidden):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, 3)
        self.conv2 = nn.Conv2d(32, 32, 3)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d((2, 2))
        self.flatten = nn.Flatten()
        self.l1 = nn.Linear(6272, n_hidden)
        self.l2 = nn.Linear(n_hidden, n_output)

        self.features = nn.Sequential(
            self.conv1,
            self.relu,
            self.conv2,
            self.relu,
            self.maxpool
        )

        self.classifier = nn.Sequential(
            self.l1, 
            self.relu,
            self.l2
        )

    def forward(self, x):
        x1 = self.features(x)
        x2 = self.flatten(x1)
        x3 = self.classifier(x2)
        return x3       

In [19]:
net = CNN(n_output, n_hidden).to(device)
criterion = nn.CrossEntropyLoss()
lr = 0.01
optimizer = optim.SGD(net.parameters(), lr)

In [21]:
# cnnの実装
start = time.time()

torch.seed()

net = CNN(n_output, n_hidden).to(device)

criterion = nn.CrossEntropyLoss()

lr = 0.01
optimizer = optim.SGD(net.parameters(), lr)

num_epochs = 50

history2 = np.zeros((0, 5))

history2 = fit(net, optimizer, criterion, num_epochs, train_loader2, test_loader2, device, history2)

end = time.time()

print(f'掛かった時間:{end - start}')

  0%|          | 0/500 [00:00<?, ?it/s]

Epoch [1/50], loss: 2.03958 acc: 0.26982 val_loss: 1.82522, val_acc: 0.36500


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch [2/50], loss: 1.75475 acc: 0.38134 val_loss: 1.68791, val_acc: 0.40470


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch [3/50], loss: 1.60647 acc: 0.43212 val_loss: 1.56175, val_acc: 0.44300


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch [4/50], loss: 1.49994 acc: 0.46664 val_loss: 1.46021, val_acc: 0.47320


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch [5/50], loss: 1.41445 acc: 0.49482 val_loss: 1.37923, val_acc: 0.51150


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch [6/50], loss: 1.34517 acc: 0.52300 val_loss: 1.31376, val_acc: 0.52960


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch [7/50], loss: 1.28099 acc: 0.54540 val_loss: 1.26717, val_acc: 0.55110


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch [8/50], loss: 1.22714 acc: 0.56604 val_loss: 1.22557, val_acc: 0.56590


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch [9/50], loss: 1.17503 acc: 0.58452 val_loss: 1.20100, val_acc: 0.57820


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch [10/50], loss: 1.12753 acc: 0.60576 val_loss: 1.16908, val_acc: 0.58340


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch [11/50], loss: 1.08206 acc: 0.61818 val_loss: 1.13982, val_acc: 0.59630


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch [12/50], loss: 1.04133 acc: 0.63462 val_loss: 1.10672, val_acc: 0.60630


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch [13/50], loss: 0.99911 acc: 0.64712 val_loss: 1.10522, val_acc: 0.60940


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch [14/50], loss: 0.95977 acc: 0.66390 val_loss: 1.07992, val_acc: 0.61500


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch [15/50], loss: 0.91871 acc: 0.67962 val_loss: 1.07355, val_acc: 0.62620


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch [16/50], loss: 0.87950 acc: 0.69240 val_loss: 1.03760, val_acc: 0.63690


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch [17/50], loss: 0.84005 acc: 0.70842 val_loss: 1.02338, val_acc: 0.64190


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch [18/50], loss: 0.80278 acc: 0.71886 val_loss: 1.00128, val_acc: 0.65090


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch [19/50], loss: 0.76700 acc: 0.73428 val_loss: 1.04534, val_acc: 0.64170


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch [20/50], loss: 0.73172 acc: 0.74628 val_loss: 1.02709, val_acc: 0.65170


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch [21/50], loss: 0.69541 acc: 0.75874 val_loss: 1.00096, val_acc: 0.66080


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch [22/50], loss: 0.66148 acc: 0.77092 val_loss: 1.00562, val_acc: 0.65730


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch [23/50], loss: 0.62362 acc: 0.78598 val_loss: 1.02304, val_acc: 0.65590


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch [24/50], loss: 0.59434 acc: 0.79608 val_loss: 1.02965, val_acc: 0.65900


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch [25/50], loss: 0.55965 acc: 0.80674 val_loss: 1.07054, val_acc: 0.65300


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch [26/50], loss: 0.52742 acc: 0.81938 val_loss: 1.04197, val_acc: 0.66430


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch [27/50], loss: 0.49215 acc: 0.83162 val_loss: 1.07669, val_acc: 0.66040


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch [28/50], loss: 0.46258 acc: 0.84344 val_loss: 1.08285, val_acc: 0.65840


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch [29/50], loss: 0.42769 acc: 0.85664 val_loss: 1.13788, val_acc: 0.65420


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch [30/50], loss: 0.39781 acc: 0.86632 val_loss: 1.12513, val_acc: 0.66230


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch [31/50], loss: 0.36058 acc: 0.88038 val_loss: 1.18185, val_acc: 0.65810


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch [32/50], loss: 0.33558 acc: 0.88886 val_loss: 1.18088, val_acc: 0.66270


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch [33/50], loss: 0.30174 acc: 0.90160 val_loss: 1.29966, val_acc: 0.64690


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch [34/50], loss: 0.27619 acc: 0.90980 val_loss: 1.29539, val_acc: 0.65720


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch [35/50], loss: 0.24463 acc: 0.92262 val_loss: 1.36510, val_acc: 0.64800


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch [36/50], loss: 0.22534 acc: 0.92954 val_loss: 1.36862, val_acc: 0.65250


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch [37/50], loss: 0.19781 acc: 0.94084 val_loss: 1.39879, val_acc: 0.64750


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch [38/50], loss: 0.17485 acc: 0.94862 val_loss: 1.44220, val_acc: 0.65250


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch [39/50], loss: 0.14686 acc: 0.95986 val_loss: 1.46963, val_acc: 0.65800


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch [40/50], loss: 0.12801 acc: 0.96534 val_loss: 1.62811, val_acc: 0.64710


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch [41/50], loss: 0.10834 acc: 0.97336 val_loss: 1.55153, val_acc: 0.66360


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch [42/50], loss: 0.09178 acc: 0.97908 val_loss: 1.60153, val_acc: 0.66000


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch [43/50], loss: 0.08136 acc: 0.98290 val_loss: 1.68660, val_acc: 0.65550


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch [44/50], loss: 0.05372 acc: 0.99272 val_loss: 1.75865, val_acc: 0.65470


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch [45/50], loss: 0.04431 acc: 0.99462 val_loss: 1.78581, val_acc: 0.65540


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch [46/50], loss: 0.03374 acc: 0.99700 val_loss: 1.82584, val_acc: 0.66030


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch [47/50], loss: 0.02574 acc: 0.99876 val_loss: 1.87905, val_acc: 0.66020


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch [48/50], loss: 0.02081 acc: 0.99926 val_loss: 1.92505, val_acc: 0.66170


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch [49/50], loss: 0.01807 acc: 0.99920 val_loss: 1.95827, val_acc: 0.65650


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch [50/50], loss: 0.01483 acc: 0.99960 val_loss: 1.98318, val_acc: 0.65890
掛かった時間:770.2849364280701


・CPU(macbook pro intel i5)では -> 掛かった時間:2763.5931901931763   
・GPU(colab:T4)では -> 掛かった時間:921.6939368247986  
・GPU(MyPC:RTX3080)では -> 掛かった時間:770.2849364280701 