In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import random

import torchvision
from torchvision import transforms

from torch.utils.data import DataLoader

import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
print('cuda 사용 가능 : {}'.format(torch.cuda.is_available()))
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'현재 device : {device}')

cuda 사용 가능 : False
현재 device : cpu


In [3]:
torch.manual_seed(777)

if device == 'cuda':
    torch.cuda.manual_seed_all(777)

In [16]:
trans = transforms.Compose([
    transforms.Resize([64, 64])
])

train_data = torchvision.datasets.ImageFolder(root='animal/origin/', transform=trans) 

In [17]:
for num, value in enumerate(train_data):
    data, label = value
    print(num, data, label)
    
    if label == 0:
        data.save('animal/train/bear/{}_{}.jpg'.format(num, label))
    elif label == 1:
        data.save('animal/train/cat/{}_{}.jpg'.format(num, label))
    elif label == 2:
        data.save('animal/train/dog/{}_{}.jpg'.format(num, label))
    elif label == 3:
        data.save('animal/train/fox/{}_{}.jpg'.format(num, label))
    elif label == 4:
        data.save('animal/train/rabbit/{}_{}.jpg'.format(num, label))

0 <PIL.Image.Image image mode=RGB size=64x64 at 0x7FFC399063D0> 0
1 <PIL.Image.Image image mode=RGB size=64x64 at 0x7FFC39906730> 0
2 <PIL.Image.Image image mode=RGB size=64x64 at 0x7FFC39906250> 0
3 <PIL.Image.Image image mode=RGB size=64x64 at 0x7FFC39906280> 0
4 <PIL.Image.Image image mode=RGB size=64x64 at 0x7FFC39906250> 0
5 <PIL.Image.Image image mode=RGB size=64x64 at 0x7FFC399063D0> 0
6 <PIL.Image.Image image mode=RGB size=64x64 at 0x7FFC39906820> 0
7 <PIL.Image.Image image mode=RGB size=64x64 at 0x7FFC39906250> 0
8 <PIL.Image.Image image mode=RGB size=64x64 at 0x7FFC39906640> 0
9 <PIL.Image.Image image mode=RGB size=64x64 at 0x7FFC39906730> 0
10 <PIL.Image.Image image mode=RGB size=64x64 at 0x7FFC39906280> 0
11 <PIL.Image.Image image mode=RGB size=64x64 at 0x7FFC39906640> 0
12 <PIL.Image.Image image mode=RGB size=64x64 at 0x7FFC59557100> 0
13 <PIL.Image.Image image mode=RGB size=64x64 at 0x7FFC59557880> 0
14 <PIL.Image.Image image mode=RGB size=64x64 at 0x7FFC59557280> 0
15 <P

130 <PIL.Image.Image image mode=RGB size=64x64 at 0x7FFC399067F0> 1
131 <PIL.Image.Image image mode=RGB size=64x64 at 0x7FFC39906280> 1
132 <PIL.Image.Image image mode=RGB size=64x64 at 0x7FFC39906250> 1
133 <PIL.Image.Image image mode=RGB size=64x64 at 0x7FFC399067F0> 1
134 <PIL.Image.Image image mode=RGB size=64x64 at 0x7FFC39906820> 1
135 <PIL.Image.Image image mode=RGB size=64x64 at 0x7FFC39906250> 1
136 <PIL.Image.Image image mode=RGB size=64x64 at 0x7FFC399063D0> 1
137 <PIL.Image.Image image mode=RGB size=64x64 at 0x7FFC39906640> 1
138 <PIL.Image.Image image mode=RGB size=64x64 at 0x7FFC39906730> 1
139 <PIL.Image.Image image mode=RGB size=64x64 at 0x7FFC39906280> 1
140 <PIL.Image.Image image mode=RGB size=64x64 at 0x7FFC39906250> 1
141 <PIL.Image.Image image mode=RGB size=64x64 at 0x7FFC39906820> 1
142 <PIL.Image.Image image mode=RGB size=64x64 at 0x7FFC399063D0> 1
143 <PIL.Image.Image image mode=RGB size=64x64 at 0x7FFC39906280> 1
144 <PIL.Image.Image image mode=RGB size=64x64 a

293 <PIL.Image.Image image mode=RGB size=64x64 at 0x7FFC39906730> 3
294 <PIL.Image.Image image mode=RGB size=64x64 at 0x7FFC59557880> 3
295 <PIL.Image.Image image mode=RGB size=64x64 at 0x7FFC59557160> 3
296 <PIL.Image.Image image mode=RGB size=64x64 at 0x7FFC59557850> 3
297 <PIL.Image.Image image mode=RGB size=64x64 at 0x7FFC595573A0> 4
298 <PIL.Image.Image image mode=RGB size=64x64 at 0x7FFC59557880> 4
299 <PIL.Image.Image image mode=RGB size=64x64 at 0x7FFC59557280> 4
300 <PIL.Image.Image image mode=RGB size=64x64 at 0x7FFC59557100> 4
301 <PIL.Image.Image image mode=RGB size=64x64 at 0x7FFC59557850> 4
302 <PIL.Image.Image image mode=RGB size=64x64 at 0x7FFC595577C0> 4
303 <PIL.Image.Image image mode=RGB size=64x64 at 0x7FFC595573A0> 4
304 <PIL.Image.Image image mode=RGB size=64x64 at 0x7FFC59557100> 4
305 <PIL.Image.Image image mode=RGB size=64x64 at 0x7FFC595573A0> 4
306 <PIL.Image.Image image mode=RGB size=64x64 at 0x7FFC59557880> 4
307 <PIL.Image.Image image mode=RGB size=64x64 a

In [29]:
trans = transforms.Compose([
    transforms.Resize([64, 64]),
    transforms.ToTensor()
])

train_data = torchvision.datasets.ImageFolder(root='animal/train/', transform=trans) 

In [51]:
data_loader = DataLoader(dataset=train_data, batch_size=32, shuffle=True)

In [71]:
class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_layer1 = nn.Sequential(
            nn.Conv2d(3, 6, 5),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        
        self.conv_layer2 = nn.Sequential(
            nn.Conv2d(6, 16, 5),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
    
        
        self.linear_layer3 = nn.Sequential(
            nn.Linear(2704, 120), 
            nn.ReLU(),
            nn.Linear(120, 5)
        )
        
    def forward(self, x):
        out = self.conv_layer1(x)
        out = self.conv_layer2(out)
        out = out.view(out.shape[0], -1)
        print(out.shape)
        out = self.linear_layer3(out)
        return out

In [72]:
net = CNN().to(device)
test_input = torch.Tensor(3, 3, 64, 64).to(device)
output = net(test_input)

torch.Size([3, 2704])


In [74]:
net

CNN(
  (conv_layer1): Sequential(
    (0): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (conv_layer2): Sequential(
    (0): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (linear_layer3): Sequential(
    (0): Linear(in_features=2704, out_features=120, bias=True)
    (1): ReLU()
    (2): Linear(in_features=120, out_features=5, bias=True)
  )
)

In [73]:
optimizer = optim.Adam(net.parameters(), lr=1e-6)
loss_func = nn.CrossEntropyLoss().to(device)