In [83]:
import torch
from torchvision import transforms
from torch.utils.data import Dataset
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
import torch.utils.model_zoo as model_zoo

from easydict import EasyDict

import os
from glob import glob
import pickle
import random
import math

import cv2
import numpy as np
import pandas as pd

In [84]:
idx2large = EasyDict({
    '0':'구이',
    '1':'국',
    '2':'기타',
    '3':'김치',
    '4':'나물',
    '5':'떡',
    '6':'만두',
    '7':'면',
    '8':'무침',
    '9':'밥',
    '10':'볶음',
    '11':'쌈',
    '12':'음청류',
    '13':'장',
    '14':'장아찌',
    '15':'적',
    '16':'전',
    '17':'전골',
    '18':'조림',
    '19':'죽',
    '20':'찌개',
    '21':'찜',
    '22':'탕',
    '23':'튀김',
    '24':'한과',
    '25':'해물',
    '26':'회'
})

In [85]:
large2idx = {large[-1]:idx for idx, large in enumerate(idx2large.items())}

In [86]:
class2idx = {'구이/갈비구이': 0, '구이/갈치구이': 1, '구이/고등어구이': 2, '구이/곱창구이': 3, '구이/닭갈비': 4, '구이/더덕구이': 5, '구이/떡갈비': 6, '구이/불고기': 7, '구이/삼겹살': 8, '구이/장어구이': 9, '구이/조개구이': 10, '구이/조기구이': 11, '구이/황태구이': 12, '구이/훈제오리': 13, '국/계란국': 14, '국/떡국_만두국': 15, '국/무국': 16, '국/미역국': 17, '국/북엇국': 18, '국/시래기국': 19, '국/육개장': 20, '국/콩나물국': 21, '기타/과메기': 22, '기타/양념치킨': 23, '기타/젓갈': 24, '기타/콩자반': 25, '기타/편육': 26, '기타/피자': 27, '기타/후라이드치킨': 28, '김치/갓김치': 29, '김치/깍두기': 30, '김치/나박김치': 31, '김치/무생채': 32, '김치/배추김치': 33, '김치/백김치': 34, '김치/부추김치': 35, '김치/열무김치': 36, '김치/오이소박이': 37, '김치/총각김치': 38, '김치/파김치': 39, '나물/가지볶음': 40, '나물/고사리나물': 41, '나물/미역줄기볶음': 42, '나물/숙주나물': 43, '나물/시금치나물': 44, '나물/애호박볶음': 45, '떡/경단': 46, '떡/꿀떡': 47, '떡/송편': 48, '만두/만두': 49, '면/라면': 50, '면/막국수': 51, '면/물냉면': 52, '면/비빔냉면': 53, '면/수제비': 54, '면/열무국수': 55, '면/잔치국수': 56, '면/짜장면': 57, '면/짬뽕': 58, '면/쫄면': 59, '면/칼국수': 60, '면/콩국수': 61, '무침/꽈리고추무침': 62, '무침/도라지무침': 63, '무침/도토리묵': 64, '무침/잡채': 65, '무침/콩나물무침': 66, '무침/홍어무침': 67, '무침/회무침': 68, '밥/김밥': 69, '밥/김치볶음밥': 70, '밥/누룽지': 71, '밥/비빔밥': 72, '밥/새우볶음밥': 73, '밥/알밥': 74, '밥/유부초밥': 75, '밥/잡곡밥': 76, '밥/주먹밥': 77, '볶음/감자채볶음': 78, '볶음/건새우볶음': 79, '볶음/고추장진미채볶음': 80, '볶음/두부김치': 81, '볶음/떡볶이': 82, '볶음/라볶이': 83, '볶음/멸치볶음': 84, '볶음/소세지볶음': 85, '볶음/어묵볶음': 86, '볶음/오징어채볶음': 87, '볶음/제육볶음': 88, '볶음/주꾸미볶음': 89, '쌈/보쌈': 90, '음청류/수정과': 91, '음청류/식혜': 92, '장/간장게장': 93, '장/양념게장': 94, '장아찌/깻잎장아찌': 95, '적/떡꼬치': 96, '전/감자전': 97, '전/계란말이': 98, '전/계란후라이': 99, '전/김치전': 100, '전/동그랑땡': 101, '전/생선전': 102, '전/파전': 103, '전/호박전': 104, '전골/곱창전골': 105, '조림/갈치조림': 106, '조림/감자조림': 107, '조림/고등어조림': 108, '조림/꽁치조림': 109, '조림/두부조림': 110, '조림/땅콩조림': 111, '조림/메추리알장조림': 112, '조림/연근조림': 113, '조림/우엉조림': 114, '조림/장조림': 115, '조림/코다리조림': 116, '죽/전복죽': 117, '죽/호박죽': 118, '찌개/김치찌개': 119, '찌개/닭계장': 120, '찌개/동태찌개': 121, '찌개/된장찌개': 122, '찌개/순두부찌개': 123, '찜/갈비찜': 124, '찜/계란찜': 125, '찜/김치찜': 126, '찜/꼬막찜': 127, '찜/닭볶음탕': 128, '찜/수육': 129, '찜/순대': 130, '찜/족발': 131, '찜/찜닭': 132, '찜/해물찜': 133, '탕/갈비탕': 134, '탕/감자탕': 135, '탕/곰탕_설렁탕': 136, '탕/매운탕': 137, '탕/삼계탕': 138, '탕/추어탕': 139, '튀김/고추튀김': 140, '튀김/새우튀김': 141, '튀김/오징어튀김': 142, '한과/약과': 143, '한과/약식': 144, '한과/한과': 145, '해물/멍게': 146, '해물/산낙지': 147, '회/물회': 148, '회/육회': 149}
idx2class = {v:k.split('/')[-1] for k,v in class2idx.items()}

In [87]:
def convert_to_large_id(s_id) :
  '''
  input : small_id
  output : large_id
  '''
  if 0<=s_id and s_id<=13 :
    return 0
  elif 14<=s_id and s_id<=21 : 
    return 1
  elif 22<=s_id and s_id<=28 :
    return 2
  elif 29<=s_id and s_id<=39 :
    return 3
  elif 40<=s_id and s_id<=45 :
    return 4
  elif 46<=s_id and s_id<=48 :
    return 5
  elif s_id == 49 : 
    return 6
  elif 50<=s_id and s_id<=61 :
    return 7
  elif 62<=s_id and s_id<=68 :
    return 8
  elif 69<=s_id and s_id<=77 :
    return 9
  elif 78<=s_id and s_id<=89 :
    return 10
  elif s_id==90 :
    return 11
  elif 91<=s_id and s_id<=92 : 
    return 12
  elif 93<=s_id and s_id<=94 :
    return 13
  elif s_id==95 :
    return 14
  elif s_id==96 :
    return 15
  elif 97<=s_id and s_id<=104 :
    return 16
  elif s_id==105 :
    return 17
  elif 106<=s_id and s_id<=116 :
    return 18
  elif 117<=s_id and s_id<=118 :
    return 19
  elif 119<=s_id and s_id<=123 :
    return 20
  elif 124<=s_id and s_id<=133 :
    return 21
  elif 134<=s_id and s_id<=139 :
    return 22
  elif 140<=s_id and s_id<=142 :
    return 23
  elif 143<=s_id and s_id<=145 :
    return 24
  elif 146<=s_id and s_id<=147 :
    return 25
  else :
    return 26
  return o

In [88]:
correct_id = dict()
l = [0,14,22,29,40,46,49,50,62,69,78,90,91,93,95,96,97,105,106,117,119,124,134,140,143,146,148]
for idx, i in enumerate(large2idx) :
  correct_id[i] = l[idx]

In [89]:
# 데이터가 저장된 pkl path를 로드하고
# target_large_category 에 맞는 데이터만 뽑는다
class Food_DataSet(torch.utils.data.Dataset) : 
    def __init__(self, args, is_validated) :
        self.args = args
        if is_validated == True:
          self.args.mode = 'val'
          data_path = args.val_path
        else :
          data_path = args.data_path
        
        large_category = large2idx[args.category]
        self.data = []

        with open(data_path, 'rb') as f:
          total_data = pickle.load(f)
          for d, i in total_data : 
            idx = convert_to_large_id(i)
            if idx != large_category :
              continue
            sub = l[idx]
            self.data.append((d, i - sub))
        
    def __len__(self) : 
        return len(self.data)
        
    def __getitem__(self, i) :
        return self.data[i]


In [90]:
# 데이터가 저장된 pkl path를 로드하고
# target_large_category 에 맞는 데이터만 뽑는다
class Food_DataSet(torch.utils.data.Dataset) : 
    def __init__(self, args, is_validated) :
        self.args = args
        if is_validated == True:
          self.args.mode = 'val'
          data_path = args.val_path
        else :
          data_path = args.data_path
        
        large_category = large2idx[args.category]
        self.data = []

        with open(data_path, 'rb') as f:
          total_data = pickle.load(f)
          for d, i in total_data : 
            idx = convert_to_large_id(i)
            if idx != large_category :
              continue
            sub = l[idx]
            self.data.append((d, i - sub))
        
    def __len__(self) : 
        return len(self.data)
        
    def __getitem__(self, i) :
        return self.data[i]


In [91]:
class Food_DataLoader(object) :
    def __init__(self, args, is_validated=False, df=None):
        super().__init__()

        if df is None:
            dataset = Food_DataSet(args, is_validated)
        else:
            dataset = INFERENCE_Dataset(args, df)

        if args.mode == 'train':
            self.data_loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
        else:
            self.data_loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size)
        
        self.data_iter = self.data_loader.__iter__()
    
    def next_batch(self):
        try:
            batch = self.data_iter.__next__()
        except StopIteration:
            self.data_iter = self.data_loader.__iter__()
            batch = self.data_iter.__next__()
        
        return batch

In [92]:
# model
class CNN(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=6, stride=2, padding=2)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=6, stride=2, padding=2)
        self.fc1 = nn.Linear(512, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, args.n_class)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.reshape(x.size(0), -1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [93]:
print(os.getcwd())

C:\Users\82108\Downloads\Data


In [95]:
# trainer
class Trainer():
  def __init__(self, args):
    self.args = args
    
    if self.args.mode != 'inference':
        self.loader = Food_DataLoader(args, is_validated=False)
    self.model = CNN(args)
    if self.args.device == 'cuda':
        self.model.cuda()

    self.criterion = nn.CrossEntropyLoss()
    self.optim = torch.optim.Adam(self.model.parameters())
      
  def train(self):
    # train
    total_step = len(self.loader.data_loader)
    val_loader = Food_DataLoader(self.args, is_validated=True)
    opt_epoch = 1
    min_val_loss = 1e9

    for epoch in range(self.args.epoch):
      self.model.train()
      total_loss = 0

      for i in range(len(self.loader.data_loader)):
          feature, label = self.loader.next_batch()
          feature = torch.tensor(feature).to(device=self.args.device, dtype=torch.float)
          label = torch.tensor(label).to(device=self.args.device)

          pred = self.model(feature)
          loss = self.criterion(pred, label)
          
          self.optim.zero_grad()
          loss.backward()
          self.optim.step()

          total_loss += loss.item()
          if (i + 1) % 1000 == 0:
              print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch + 1, self.args.epoch, i + 1, total_step, loss.item()))

      #if (epoch+1)%50 == 0:
      val_loss = self.validate(val_loader)
      print ('Validation Loss: {:.20f}'.format(val_loss))
      if val_loss < min_val_loss:
        val_loss = min_val_loss
        opt_epoch = epoch+1
        
      torch.save(self.model.state_dict(), 'bbokum\\checkpoint_'+str(epoch+1)+'.pt')

    print("Best Epoch : {}".format(opt_epoch))

  def validate(self, val_loader):
      total_loss = 0
      total_step = len(val_loader.data_loader)

      self.model.eval()
      with torch.no_grad():
          for i in range(len(val_loader.data_loader)):
              feature, label = val_loader.next_batch()
              feature = torch.tensor(feature).to(device=self.args.device, dtype=torch.float)
              label = torch.tensor(label).to(device=self.args.device)

              pred = self.model(feature)
              loss = self.criterion(pred, label)
              total_loss += loss.item()

              if (i + 1) % 100 == 0:
                  print ('Validation Step [{}/{}], Loss: {:.4f}'.format(i + 1, total_step, loss.item()))

      return total_loss/total_step


  def test(self):
      # test
      if self.args.device == 'cuda':
          self.model.load_state_dict(torch.load(self.args.ckpt))
      else:
          self.model.load_state_dict(torch.load(self.args.ckpt, map_location=torch.device('cpu')))
      
      self.model.eval()

      pred_list = list()
      label_list = list()

      with torch.no_grad():
          for i in range(len(self.loader.data_loader)):
              feature, label = self.loader.next_batch()
              feature = torch.tensor(feature).to(device=self.args.device, dtype=torch.float)
              label = torch.tensor(label).to(device=self.args.device)

              pred = self.model(feature)
              pred = F.softmax(pred, dim=1)
              pred = torch.argmax(pred, dim=1)
              
              pred_list.extend(pred.tolist())
              label_list.extend(label.tolist())

      acc = accuracy_score(label_list, pred_list)
      print(acc)
  
  def inference(self, df):
      # inference
      if self.args.device == 'cuda':
          self.model.load_state_dict(torch.load(self.args.ckpt))
      else:
          self.model.load_state_dict(torch.load(self.args.ckpt, map_location=torch.device('cpu')))
      self.model.eval()
      
      loader = Food_DataLoader(self.args, df=df)

      pred_list = list()
      label_list = list()

      with torch.no_grad():
          for i in range(len(loader.data_loader)):
              feature = loader.next_batch()
              feature = torch.tensor(feature).to(device=self.args.device, dtype=torch.float)

              pred = self.model(feature)
              pred = F.softmax(pred, dim=1)
              pred = torch.argmax(pred, dim=1)
              
              pred_list.extend(pred.tolist())

      return pred_list

In [99]:
def get_args():
    args = EasyDict({
        "epoch":100,
        "batch_size":10,
        "mode":'train',
        "ckpt":1,
        "device":'cuda',
        "category":'볶음',
        "n_class":12,
        "data_path":'train.pkl',
        "val_path":'val.pkl'
    })
    return args
def get_args2() :
  args = EasyDict({
      "mode":'test',
      "batch_size":1,
      "ckpt":'bbokum\\checkpoint_final.pt',
      "device":'cuda',
      "category":'볶음',
      "n_class":12,
  })
  return args

In [100]:
args = get_args()
trainer = Trainer(args)

if args.mode == 'train':
    trainer.train()
else:
    trainer.test()

  feature = torch.tensor(feature).to(device=self.args.device, dtype=torch.float)
  label = torch.tensor(label).to(device=self.args.device)
  feature = torch.tensor(feature).to(device=self.args.device, dtype=torch.float)
  label = torch.tensor(label).to(device=self.args.device)


Validation Step [100/164], Loss: 2.4973
Validation Loss: 2.48468226630513244757
Validation Step [100/164], Loss: 2.4985
Validation Loss: 2.46616266558809993370
Validation Step [100/164], Loss: 2.4142
Validation Loss: 2.37509409101997936986
Validation Step [100/164], Loss: 2.3114
Validation Loss: 2.28048945490906884714
Validation Step [100/164], Loss: 2.5310
Validation Loss: 2.26188282486869063348
Validation Step [100/164], Loss: 2.4446
Validation Loss: 2.25216786163609183902
Validation Step [100/164], Loss: 2.4818
Validation Loss: 2.23921628696162544614
Validation Step [100/164], Loss: 2.2206
Validation Loss: 2.24123116818869982936
Validation Step [100/164], Loss: 2.2648
Validation Loss: 2.24069168654883776881
Validation Step [100/164], Loss: 2.2193
Validation Loss: 2.34920832078631347528
Validation Step [100/164], Loss: 2.2440
Validation Loss: 2.36519997803176318385
Validation Step [100/164], Loss: 2.3351
Validation Loss: 2.35265379053790413266
Validation Step [100/164], Loss: 2.0737


In [98]:
args = get_args2()
trainer = Trainer(args)

if args.mode == 'train':
    trainer.train()
else:
    trainer.test()

AttributeError: 'EasyDict' object has no attribute 'data_path'