# Concolution Neural Network

## DataLoad

In [7]:
import torch
from torch import nn, optim
from torch.utils.data import (Dataset,DataLoader,TensorDataset)
import tqdm

from torchvision.datasets import FashionMNIST
from torchvision import transforms

In [8]:
# 훈련 데이터 가져오기
# PIL -> Tensor
fashion_mnist_train=FashionMNIST("Documents/Pytorch_FristStep/Data", train=True, download=True, transform=transforms.ToTensor())

#검증 데이터 가져오기
fashion_mnist_test=FashionMNIST("Documents/Pytorch_FristStep/Data", train=False, download=True, transform=transforms.ToTensor())

#배치 크기 128인 DataLoader 작성
batch_size=128
train_loader=DataLoader(fashion_mnist_train, batch_size=batch_size, shuffle=True)
test_loader=DataLoader(fashion_mnist_test, batch_size=batch_size, shuffle=True)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to Documents/Pytorch_FristStep/Data\FashionMNIST\raw\train-images-idx3-ubyte.gz


  0%|          | 0/26421880 [00:00<?, ?it/s]

Extracting Documents/Pytorch_FristStep/Data\FashionMNIST\raw\train-images-idx3-ubyte.gz to Documents/Pytorch_FristStep/Data\FashionMNIST\raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to Documents/Pytorch_FristStep/Data\FashionMNIST\raw\train-labels-idx1-ubyte.gz


  0%|          | 0/29515 [00:00<?, ?it/s]

Extracting Documents/Pytorch_FristStep/Data\FashionMNIST\raw\train-labels-idx1-ubyte.gz to Documents/Pytorch_FristStep/Data\FashionMNIST\raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to Documents/Pytorch_FristStep/Data\FashionMNIST\raw\t10k-images-idx3-ubyte.gz


  0%|          | 0/4422102 [00:00<?, ?it/s]

Extracting Documents/Pytorch_FristStep/Data\FashionMNIST\raw\t10k-images-idx3-ubyte.gz to Documents/Pytorch_FristStep/Data\FashionMNIST\raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to Documents/Pytorch_FristStep/Data\FashionMNIST\raw\t10k-labels-idx1-ubyte.gz


  0%|          | 0/5148 [00:00<?, ?it/s]

Extracting Documents/Pytorch_FristStep/Data\FashionMNIST\raw\t10k-labels-idx1-ubyte.gz to Documents/Pytorch_FristStep/Data\FashionMNIST\raw



## CNN

In [12]:
#(N,C,H,W) -> (N, C*H*W) : 합성곱 출력을 MLP에 전달할 때 필요
class FlattenLayer(nn.Module):
    def forward(self,x):
        sizes=x.size()
        return x. view(sizes[0],-1)
    
# 5*5 커널로 32 -> 64개 채널 작성
# BatchNorm2d : 이미지용 배치 정규화
# Droptout2d : 이미지용 Dropout
conv_net= nn.Sequential(
    nn.Conv2d(1,32,5),    # in_channel, out_channel, kernel_size
    nn.MaxPool2d(2),
    nn.ReLU(),
    nn.BatchNorm2d(32),
    nn.Dropout2d(0.25),
    
    nn.Conv2d(32,64,5),
    nn.MaxPool2d(2),
    nn.ReLU(),
    nn.BatchNorm2d(64),
    nn.Dropout2d(0.25),
    
    FlattenLayer()
)
    
# 합성곱에 의한 이미지 크기가 어떤지 더미 데이터를 넣어 확인
test_input=torch.ones(1,1,28,28)
conv_output_size=conv_net(test_input).size()[-1]

#2층 MLP
mlp=nn.Sequential(
    nn.Linear(conv_output_size, 200),
    nn.ReLU(),
    nn.BatchNorm1d(200),
    nn.Dropout(0.25),
    nn.Linear(200,10)
)

#최종 cnn
net=nn.Sequential(
    conv_net,
    mlp
)


In [19]:
# 평가용 함수
def eval_net(net,data_loader,device="cpu"):
        #Dropout, barhnorm 유효화
        net.eval()
        ys=[]
        ypreds=[]
        for x, y in data_loader:
            x=x.to(device)
            y=y.to(device)
            
            #확률이 가장 큰 분류 예측, 자동미분에 필요한 처리는 off로 설정해 불필요한 계산 제한
        with torch.no_grad():
            _,y_pred=net(x).max(1)
            ys.append(y)
            ypreds.append(y_pred)
        
        #미니 배치 단위의 예측 결과를 하나로 묶음
        ys=torch.cat(ys)
        ypreds=torch.cat(ypreds)
             
        #accuracy
        acc=(ys==ypreds).float().sum()/len(ys)
        return acc.item()
        
# 훈련용 함수
def train_net(net,train_loader,test_loader,optimizer_cls=optim.Adam,loss_fn=nn.CrossEntropyLoss(), n_iter=10, device="cpu"):
    train_losses=[]
    train_acc=[]
    val_acc=[]
    
    optim=optimizer_cls(net.parameters())
    for e in range(n_iter):
        running_loss=0.0
        
        #신경망을 훈련 모드로 설정
        net.train()
        n=0
        n_acc=0
        
        for i,(xx,yy) in tqdm.tqdm(enumerate(train_loader),total=len(train_loader)):
            xx=xx.to(device)
            yy=yy.to(device)
            
            h=net(xx)
            loss=loss_fn(h,yy)
            optim.zero_grad()
            
            loss.backward()
            optim.step()
            
            running_loss+=loss.item()
            n += len(xx)
            
            _,y_pred=h.max(1)
            n_acc += (yy==y_pred).float().sum().item()
        train_losses.append(running_loss /i)
        
        #훈련 데이터 예측 정확도
        train_acc.append(n_acc/n)
        
        #검증 데이터 예측 정화곧
        val_acc.append(eval_net(net,test_loader,device))
        #epoch 결과 표시
        print(e,train_losses[-1],train_acc[-1],val_acc[-1],flush=True)

            

In [20]:
net.to("cuda:0")

train_net(net,train_loader,test_loader,n_iter=20,device="cuda:0")

100%|████████████████████████████████████████████████████████████████████████████████| 469/469 [00:12<00:00, 36.21it/s]


0 0.32850918161053944 0.88015 0.6875


100%|████████████████████████████████████████████████████████████████████████████████| 469/469 [00:12<00:00, 36.29it/s]


1 0.28915331644825953 0.8941666666666667 0.8125


100%|████████████████████████████████████████████████████████████████████████████████| 469/469 [00:13<00:00, 35.82it/s]


2 0.26432550825879103 0.9022333333333333 0.625


100%|████████████████████████████████████████████████████████████████████████████████| 469/469 [00:13<00:00, 35.95it/s]


3 0.24579757291218665 0.9093166666666667 1.0


100%|████████████████████████████████████████████████████████████████████████████████| 469/469 [00:13<00:00, 35.75it/s]


4 0.23483772123726004 0.9131 0.9375


100%|████████████████████████████████████████████████████████████████████████████████| 469/469 [00:13<00:00, 35.41it/s]


5 0.22270567823424298 0.9176833333333333 0.875


100%|████████████████████████████████████████████████████████████████████████████████| 469/469 [00:13<00:00, 35.75it/s]


6 0.2149942409661081 0.9186666666666666 0.9375


100%|████████████████████████████████████████████████████████████████████████████████| 469/469 [00:13<00:00, 35.96it/s]


7 0.20874602218660024 0.92345 0.75


100%|████████████████████████████████████████████████████████████████████████████████| 469/469 [00:12<00:00, 36.20it/s]


8 0.20108485209126759 0.9252333333333334 1.0


100%|████████████████████████████████████████████████████████████████████████████████| 469/469 [00:13<00:00, 35.03it/s]


9 0.19256968415764153 0.9273166666666667 0.9375


100%|████████████████████████████████████████████████████████████████████████████████| 469/469 [00:13<00:00, 35.12it/s]


10 0.1908406559537109 0.9295166666666667 0.9375


100%|████████████████████████████████████████████████████████████████████████████████| 469/469 [00:13<00:00, 35.54it/s]


11 0.18042628715435663 0.9318833333333333 0.9375


100%|████████████████████████████████████████████████████████████████████████████████| 469/469 [00:13<00:00, 35.19it/s]


12 0.18018527215935737 0.9318666666666666 0.875


100%|████████████████████████████████████████████████████████████████████████████████| 469/469 [00:13<00:00, 35.76it/s]


13 0.17298098939319706 0.9340833333333334 0.8125


100%|████████████████████████████████████████████████████████████████████████████████| 469/469 [00:13<00:00, 34.50it/s]


14 0.16857698664833337 0.9358 0.875


100%|████████████████████████████████████████████████████████████████████████████████| 469/469 [00:13<00:00, 35.40it/s]


15 0.1637427164798873 0.9383666666666667 0.875


100%|████████████████████████████████████████████████████████████████████████████████| 469/469 [00:12<00:00, 36.12it/s]


16 0.16067435391820395 0.9401833333333334 0.9375


100%|████████████████████████████████████████████████████████████████████████████████| 469/469 [00:13<00:00, 35.71it/s]


17 0.15704900685411233 0.9403 0.9375


100%|████████████████████████████████████████████████████████████████████████████████| 469/469 [00:13<00:00, 35.23it/s]


18 0.15109130816581923 0.9438 0.9375


100%|████████████████████████████████████████████████████████████████████████████████| 469/469 [00:12<00:00, 36.35it/s]


19 0.15224995823083526 0.9420166666666666 0.9375
