In [None]:
import random
import numpy as np
import torch
import time
from torch.optim import Adam

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.enabled = False
    torch.cuda.manual_seed_all(seed)


def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    random.seed(worker_seed)
    np.random.seed(worker_seed)

def get_optimizer(cfg,model):
    if cfg.hparam.train.optimizer == 'Adam':
        optimizer = Adam(
            params=model.parameters(),
            lr=cfg.hparam.train.lr,
        )
        return optimizer

def train_one_epoch(model, train_loader, criterion, optimizer, device):
    print('start train')
    start = time.time()
    model.train()

    total_loss = 0.0
    
    for images, labels in train_loader:
        print('images, labels')
        images = images.to(device)
        labels = labels.to(device)

        print('outputs')
        outputs = model(images)
        loss = criterion(outputs, labels)

        print('loss')
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print('total_loss')
        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)

    end = time.time()
    elasped = end - start
    print('finish train')

    return avg_loss, elasped


# evaluate valid
def evaluate_one_epoch(model,loader, criterion, device):
    print('start evaluate')
    start = time.time()
    model.eval()
    correct = 0
    total = 0

    total_loss = 0.0
    with torch.no_grad():
        for images, labels in loader:
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)
            total_loss += loss.item()

            prediceted = outputs.argmax(dim=1) # 모델이 각 이미지에 대해 선택한 클래스 번호
            correct += (prediceted == labels).sum().item()
            total += labels.size(0)

    avg_loss = total_loss / len(loader)
    accuracy = correct / total
    end = time.time()
    elasped = end - start
    print('finish evaluate')

    return avg_loss, accuracy, elasped
