In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torchvision.utils import save_image
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

In [None]:
torch.manual_seed(1)
IMAGE_SIZE = 128

device = 'cuda' if torch.cuda.is_available() else 'cpu'
if device == 'cuda': torch.cuda.manual_seed_all(1)

print(device)

In [None]:
test_dataset = ImageFolder(
    root='../data/test_image',
    transform=transforms.Compose(
        [transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), transforms.ToTensor()]    # transform할 내용이 2가지 이상일 때
    )
)

test_loader = DataLoader(
    test_dataset,
    batch_size=10,
    shuffle=False
)

print(test_dataset)
print(test_loader)

In [None]:
class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(3, 8, 3, stride=1, padding=1),   # 패딩이 1이면 컨볼루션 때는 해상도가 안 줄고 풀링할때만 줄어듬
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.layer2 = nn.Sequential(
            nn.Conv2d(8, 16, 3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.layer3 = nn.Sequential(
            nn.Conv2d(16, 32, 3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.layer4 = nn.Sequential(
            nn.Conv2d(32, 64, 3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.layer5 = nn.Sequential(
            nn.Conv2d(64, 128, 3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.fc1 = nn.Linear(128*4*4, 128*4)
        self.fc2 = nn.Linear(128*4, 128)
        self.fc3 = nn.Linear(128, 32)
        self.fc4 = nn.Linear(32, 5)

    def forward(self, x):
        x = self.layer5(self.layer4(self.layer3(self.layer2(self.layer1(x)))))
        x = self.fc4(self.fc3(self.fc2(self.fc1(x.view(x.size(0), -1)))))
        return x

In [None]:
def plot(x):
    plt.figure(figsize=(8, 4))
    for i in range(10):
        plt.subplot(2, 5, i+1)
        plt.imshow(x[i].permit(1, 2, 0))
        plt.axis('off')
    plt.show()

In [None]:
test_image, test_label = next(iter(test_loader))

In [None]:
plot(test_image)

In [None]:
model = CNN().to(device)
model.load_state_dict(torch.load('../data/model.pt'))
print(model)

In [None]:
pred = model(test_image)