In [1]:
from google.colab import drive
drive.mount("/gdrive", force_remount=True)

Mounted at /gdrive


# CIFAR10

In [2]:
import os
import numpy as np
from sklearn.metrics import accuracy_score
from torch.utils.data import (DataLoader, RandomSampler, TensorDataset)
import torch.optim as optim
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F

class CIFAR10(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1) # flatten
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [3]:
# 데이터 읽기 함수
def load_dataset(config):

  transform = transforms.Compose(
      [transforms.ToTensor(),
      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
  trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                          download=True, transform=transform)
  trainloader = torch.utils.data.DataLoader(trainset, batch_size=config['batch_size'],
                                            shuffle=True)

  testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                        download=True, transform=transform)
  testloader = torch.utils.data.DataLoader(testset, batch_size=config['batch_size'],
                                          shuffle=False)
  
  return (trainloader, testloader)

In [4]:
# 모델 평가 결과 계산을 위해 텐서를 리스트로 변환하는 함수
def tensor2list(input_tensor):
    return input_tensor.cpu().detach().numpy().tolist()

# 평가 수행 함수
def do_test(model, test_dataloader):

  # 평가 모드 셋팅
  model.eval()

  # Batch 별로 예측값과 정답을 저장할 리스트 초기화
  predicts, golds = [], []
  
  with torch.no_grad():

    for step, batch in enumerate(test_dataloader):
  
      # .cuda()를 통해 메모리에 업로드
      batch = tuple(t.cuda() for t in batch)

      input_features, labels = batch
      hypothesis = model(input_features)

      # ont-hot 표현으로 변경
      logits = torch.argmax(hypothesis,-1)

      x = tensor2list(logits)
      y = tensor2list(labels)

      # 예측값과 정답을 리스트에 추가
      predicts.extend(x)
      golds.extend(y)
    
    print("PRED=",predicts)
    print("GOLD=",golds)
    print("Accuracy= {0:f}\n".format(accuracy_score(golds, predicts)))

# 모델 평가 함수
def test(config):

  model =  CIFAR10(config).cuda()

  # 저장된 모델 가중치 로드
  model.load_state_dict(torch.load(config["output_dir"]))

  (_, test_dataloader) = load_dataset(config)
  
  do_test(model, test_dataloader)

In [11]:
# 모델 학습 함수
def train(config):
  model = CIFAR10(config).cuda()

  criterion = nn.CrossEntropyLoss()
  optimizer = torch.optim.Adam(model.parameters(), lr=config["learn_rate"])

  (trainloader, _) = load_dataset(config)

  for epoch in range(config['epoch']):  
      running_loss = 0.0
      for i, data in enumerate(trainloader, 0):
        
        data = tuple(t.cuda() for t in data)
        inputs, labels = data

        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if (i % 100 == 99):
          print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 2000))
          running_loss = 0.0
  print('Finished Training')
  PATH = '/gdrive/My Drive/Colab Notebooks/cnn/mnist/cifar_net.pth'
  torch.save(model.state_dict(), PATH)

# Train

In [12]:
if(__name__=="__main__"):
  
    config = {"mode": "train",
              "model_name":"epoch_{0:d}.pt".format(10),
              "output_dir":"/gdrive/My Drive/Colab Notebooks/cnn/mnist/cifar_net.pth",
              "learn_rate":0.001,
              "batch_size":32,
              "epoch":10,
              }

    if(config["mode"] == "train"):
        train(config)
    else:
        test(config)

Files already downloaded and verified
Files already downloaded and verified
[1,   100] loss: 0.106
[1,   200] loss: 0.094
[1,   300] loss: 0.088
[1,   400] loss: 0.085
[1,   500] loss: 0.082
[1,   600] loss: 0.079
[1,   700] loss: 0.076
[1,   800] loss: 0.076
[1,   900] loss: 0.076
[1,  1000] loss: 0.073
[1,  1100] loss: 0.073
[1,  1200] loss: 0.073
[1,  1300] loss: 0.073
[1,  1400] loss: 0.071
[1,  1500] loss: 0.073
[2,   100] loss: 0.068
[2,   200] loss: 0.069
[2,   300] loss: 0.068
[2,   400] loss: 0.067
[2,   500] loss: 0.065
[2,   600] loss: 0.066
[2,   700] loss: 0.065
[2,   800] loss: 0.065
[2,   900] loss: 0.067
[2,  1000] loss: 0.065
[2,  1100] loss: 0.062
[2,  1200] loss: 0.062
[2,  1300] loss: 0.061
[2,  1400] loss: 0.064
[2,  1500] loss: 0.063
[3,   100] loss: 0.061
[3,   200] loss: 0.060
[3,   300] loss: 0.060
[3,   400] loss: 0.057
[3,   500] loss: 0.060
[3,   600] loss: 0.059
[3,   700] loss: 0.059
[3,   800] loss: 0.060
[3,   900] loss: 0.059
[3,  1000] loss: 0.058
[3, 

# Test

In [14]:
test(config)

Files already downloaded and verified
Files already downloaded and verified
PRED= [5, 8, 0, 0, 4, 6, 1, 6, 5, 1, 0, 9, 6, 7, 9, 8, 5, 5, 8, 6, 7, 2, 2, 9, 4, 2, 4, 4, 9, 6, 6, 5, 4, 6, 9, 9, 4, 1, 9, 5, 4, 6, 7, 6, 0, 9, 5, 7, 7, 6, 2, 0, 6, 3, 8, 8, 5, 5, 5, 3, 7, 5, 6, 0, 6, 2, 1, 0, 5, 7, 4, 6, 8, 8, 7, 2, 0, 3, 4, 8, 8, 1, 1, 7, 2, 7, 2, 4, 8, 9, 0, 2, 8, 6, 4, 6, 6, 0, 0, 4, 7, 5, 6, 3, 1, 1, 3, 6, 8, 5, 4, 0, 6, 2, 9, 7, 0, 4, 3, 5, 0, 5, 9, 2, 8, 1, 8, 3, 3, 2, 4, 1, 1, 9, 1, 4, 9, 7, 6, 0, 6, 5, 6, 3, 8, 4, 6, 7, 5, 5, 8, 1, 6, 0, 2, 5, 3, 9, 5, 4, 2, 1, 9, 6, 0, 7, 8, 6, 7, 0, 9, 7, 8, 2, 9, 9, 6, 7, 5, 9, 0, 7, 6, 2, 4, 8, 6, 3, 7, 0, 6, 8, 9, 9, 7, 4, 8, 3, 7, 8, 3, 9, 8, 7, 1, 3, 0, 5, 7, 9, 6, 9, 3, 7, 1, 3, 7, 9, 3, 4, 7, 6, 9, 4, 5, 9, 3, 2, 3, 6, 5, 1, 5, 8, 8, 0, 4, 7, 3, 5, 1, 1, 1, 9, 0, 2, 1, 8, 2, 4, 5, 5, 9, 9, 2, 8, 6, 1, 1, 1, 8, 9, 4, 3, 8, 8, 2, 4, 7, 2, 4, 3, 6, 5, 8, 2, 7, 6, 7, 5, 9, 1, 6, 1, 9, 9, 1, 8, 7, 9, 1, 2, 6, 9, 5, 2, 6, 0, 0, 6, 6, 6, 5, 0, 6, 1,