# Auto-encoder

In [None]:
import configparser
import numpy as np
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
import sys
import torch
import torch.nn as nn

from datetime import datetime
from pathlib import Path
sys.path.append(str(Path.cwd().parent))
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm

from src.dataset import BirdsongDataset
from src.network import AutoEncoder

In [None]:
config = configparser.ConfigParser()
config.read(str(Path.cwd().parent.parent.joinpath('setting', 'config.ini')))

EPOCHS = config['Model'].getint('Epochs')
BATCH_SIZE = config['Model'].getint('BatchSize')
LEARNING_RATE = config['Model'].getfloat('LearningRate')
EARLY_STOP = config['Model'].getint('EarlyStop')

torch.manual_seed(42)
if torch.cuda.is_available():
  DEVICE = torch.device(f'cuda:{config["Model"]["Autoencoder_Device"]}')
  torch.backends.cudnn.benchmark = True
else:
  DEVICE = torch.device('cpu')

In [None]:
earlyStatusPath = Path.cwd().parent.parent.joinpath('model', 'AE_CheckPoint.tar')
encoderFilePath = Path.cwd().parent.parent.joinpath('model', f'AE{datetime.now().strftime("%Y%m%d")}_encoder.pth')

model = AutoEncoder().to(DEVICE)
optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=0.9)
criterion = nn.MSELoss().to(DEVICE)

In [None]:
if earlyStatusPath.exists():
  ck = torch.load(earlyStatusPath, map_location=torch.device(DEVICE))
  model.load_state_dict(ck['model_state_dict'])
  optimizer.load_state_dict(ck['optimizer_state_dict'])
  curEpoch = ck['current_epoch']
  bestLoss = ck['best_loss']
else:
  curEpoch = 0
  bestLoss = np.Inf

In [None]:
aeTrainDataloader = DataLoader(
  BirdsongDataset(Path.cwd().parent.parent.joinpath('data', 'tmp', 'ae-train.csv'), needAugment=False, needLabel=False),
  batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True
)
aeValidateDataloader = DataLoader(
  BirdsongDataset(Path.cwd().parent.parent.joinpath('data', 'tmp', 'ae-validate.csv'), needAugment=False, needLabel=False),
  batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True
)

## Training

In [None]:
for epoch in tqdm(range(curEpoch, EPOCHS)):
  # Train
  model.train()
  trainingLoss = 0.0
  for _, (inputs, _) in tqdm(enumerate(aeTrainDataloader), total=len(aeTrainDataloader)):
    inputs = inputs.to(DEVICE)
    optimizer.zero_grad()
    _, outputs = model(inputs)
    loss = criterion(outputs, inputs)
    loss.backward()
    optimizer.step()
    trainingLoss += loss.item()
  trainingLoss /= len(aeTrainDataloader)

  # Validate 
  model.eval()
  validationLoss = 0.0
  with torch.no_grad():
    for _, (inputs, _) in tqdm(enumerate(aeValidateDataloader), total=len(aeValidateDataloader)):
      inputs = inputs.to(DEVICE)
      _, outputs = model(inputs)
      loss = criterion(outputs, inputs)
      validationLoss += loss.item()
  validationLoss /= len(aeValidateDataloader)

  # Check loss
  if validationLoss < bestLoss:
    bestLoss = validationLoss
    torch.save(model.encoder.stat_dict(), encoderFilePath)

  # Save early Status
  torch.save({
    'current_epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'best_loss': bestLoss,
  }, earlyStatusPath)

  # Print results
  print(f"""
    >> [{epoch + 1} / {EPOCHS}]
    >> {"Best Loss :":>16} {bestLoss}
    >> {"Current Train Loss :":>16} {trainingLoss:6f}
    >> {"Current Validate Loss :":>16} {validationLoss:6f}
  """)