In [None]:
import torch
import copy
import os
import matplotlib.pyplot as plt
import time
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

"""
모델과 데이터를 입력받아 훈련하고, 평가하며, 훈련된 모델과 예측값을 저장하는 함수


train 데이터 loader와 val 데이터 로더를 입력값으로 한다. 
모델은 gpu에 싣지 않은 상태로 입력되도 gpu에 알아서 실린다. 
scheduler의 경우 사용하지 않아도 상관없다. 
기본적인 손실 함수는 cross entropy다.

early_stopping은 tolerance를 설정하여 accuracy 기준으로 early_stopping한다. 
make_pred는 예측값을 만들어서 working directory에 저장해준다. 
save_mode 역시 early stopping된 모델을 working directory에 저장해준다. 

향후 클래스 분리 예정. 현재 코드는 좋지 못한 형태이다. 
"""
class train_the_model():
  def __init__(self, train_data_loader, val_data_loader, device, epochs, model, optimizer, batch_size, scheduler = None, loss_function = F.cross_entropy):
    self.train_data_loader = train_data_loader
    self.val_data_loader = val_data_loader  
    self.device = device
    self.epochs = epochs
    self.model = model.to(self.device)
    self.optimizer = optimizer
    self.batch_size = batch_size
    self.model_es = copy.deepcopy(model)
    self.scheduler = scheduler
    self.loss_function = loss_function

  def _train(self):
    self.model.train()
    
    for batch_idx, (_feature, _label) in enumerate(self.train_data_loader):
      _feature, _label = _feature.to(self.device), _label.to(self.device)
      self.optimizer.zero_grad()
      output = self.model(_feature)
      loss = self.loss_function(output, _label)
      loss.backward()
      self.optimizer.step()

      if self.scheduler:
        self.scheduler.step()

  def evaluate(self):
    self.model.eval()
    train_loss = 0
    val_loss = 0
    train_correct = 0
    val_correct = 0

    with torch.no_grad():
      for val_feat, val_lab in self.val_data_loader:
        val_feat, val_lab = val_feat.to(self.device), val_lab.to(self.device)
        val_output = self.model(val_feat)

        val_loss += F.cross_entropy(val_output, val_lab, reduction = "sum").item()
        val_pred = val_output.max(1, keepdim = True)[1]
        val_correct += val_pred.eq(val_lab.view_as(val_pred)).sum().item()

      for train_feat, train_lab in self.train_data_loader:
        train_feat, train_lab = train_feat.to(self.device), train_lab.to(self.device)
        train_output = self.model(train_feat)

        train_loss += F.cross_entropy(train_output, train_lab, reduction = "sum").item()
        train_pred = train_output.max(1, keepdim = True)[1]
        train_correct += train_pred.eq(train_lab.view_as(train_pred)).sum().item()

    train_loss /= len(self.train_data_loader.dataset)
    val_loss /= len(self.val_data_loader.dataset)

    train_acc = 100 * train_correct / len(self.train_data_loader.dataset)
    val_acc = 100 * val_correct / len(self.val_data_loader.dataset)

    return train_loss, train_acc, val_loss, val_acc

  def early_stop(self, val_loss_es, val_acc_es, model, tolerance):
    if self.min_acc == 'start':
      self.count = 0
      self.min_loss = val_loss_es
      self.min_acc = val_acc_es
    if self.min_acc < val_acc_es:
      self.count = 0
      self.min_loss = val_loss_es
      self.min_acc = val_acc_es
      self.model_es = copy.deepcopy(self.model)
    else :
      self.count += 1
      if self.count == tolerance:
        print("================================================================================================================================")
        print("================================================================================================================================")
        print("====================================================Learing the Data is Over====================================================")
        print(f"============================final accuracy : {self.min_acc} =========== final loss : {self.min_loss}============================")
        return "break"        

  def train_val_epoch(self, tolerance):
    self.min_acc = 'start'
    self.val_loss_list = []
    self.val_acc_list = []
    self.train_loss_list = []
    self.train_acc_list = []

    self.epoch_num = 0
    for epoch in range(1, self.epochs):
      start = time.time()
      self.epoch_num += 1
      self._train()
      train_l, train_ac, val_l, val_ac = self.evaluate()

      self.val_loss_list.append(val_l)
      self.val_acc_list.append(val_ac)
      self.train_loss_list.append(train_l)
      self.train_acc_list.append(train_ac)

      state_of_early_stop = self.early_stop(val_l, val_ac, self.model, tolerance)
      end = time.time()
      if state_of_early_stop == "break": break
      print(f"===================================================={self.epoch_num} epoch 완료====================================================")
      print(f"TRAIN loss : {round(train_l, 4)} accuracy : {round(train_ac, 4)}")
      print(f"VALID loss : {round(val_l, 4)} accuracy : {round(val_ac, 4)}") 
      print(f"duration: {round(end - start, 1)}")

    self.epochs_list = range(1,(self.epoch_num + 1))
    
    plt.figure(figsize = (12, 3))
    plt.subplot(2, 1, 1)
    
    plt.plot(self.epochs_list, self.train_acc_list, 'b', label='train accuracy')
    plt.plot(self.epochs_list, self.val_acc_list, 'g', label='val accuracy')
    plt.title('Train, Validation accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()

    plt.subplot(2, 1, 2)
    plt.plot(self.epochs_list, self.train_loss_list, 'b', label='train loss')
    plt.plot(self.epochs_list, self.val_loss_list, 'g', label='val loss')
    plt.title('Train, Validation loss')
    plt.xlabel('Epochs')
    plt.ylabel('loss')
    plt.legend()

    plt.tight_layout()
    plt.show()

  def save_model(self, model_name):
    torch.save(self.model_es, f'./model_save/{model_name}.pt')


  def make_pred(self, test_data_x, test_data_label, file_name):
    self.model_es.eval()
    self.test_x = test_data_x.to(self.device)
    self.test_label = test_data_label
    self.pred = self.model_es(self.test_x)
    self.pred_y = self.pred.max(1, keepdim = True)[1].cpu().numpy()

    self.tmp = [a[0] for a in self.pred_y]
    self.tmp2 = list(self.test_label)

    self.trial_df = pd.DataFrame({'ID' : self.tmp2, "Category" :self.tmp})
    self.path = os.getcwd() + '/trial'
    self.trial_df.to_csv(self.path + file_name, index = False)
