In [None]:
import common

import os
import time
import random

import timm
import torch
import albumentations as A
import pandas as pd
import numpy as np
import torch.nn as nn
from albumentations.pytorch import ToTensorV2
from torch.optim import Adam
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from tqdm import tqdm
from sklearn.metrics import accuracy_score, f1_score
from sklearn.model_selection import train_test_split

from dotenv import load_dotenv
from datetime import datetime
from zoneinfo import ZoneInfo
import wandb

In [None]:
common.wandb_login_init('SKF_tiny_vit_21m_384.dist_in22k_ft_in1k')

In [None]:
# device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# data config
data_path = 'datasets_fin/'

# model config
model_name = 'tiny_vit_21m_384.dist_in22k_ft_in1k' # 'resnet50' 'efficientnet-b0', ...

# training config
trn_img_size = 384
tst_img_size = 384
LR = 1e-3
EPOCHS = 5
FOLDS = 2
BATCH_SIZE = 32
num_workers = 12
augment_ratio = 200

wandb.config.update({
    "learning_rate": LR,
    "architecture": model_name,
    "dataset": "custom-dataset",
    "epochs": EPOCHS,
    "folds": FOLDS,
    "batch_size": BATCH_SIZE,
    "train_image_size": trn_img_size,
    "test_image_size": tst_img_size,
    "num_workers": num_workers,
    'augment_ratio' : augment_ratio,
})

In [None]:
# 시드를 고정합니다.
SEED = 42

common.set_seed(SEED)

In [None]:
device

### Load Data
* 학습, 테스트 데이터셋과 로더를 정의합니다.

In [None]:
# 학습과 검증에 사용할 폴드별 csv 파일들을 먼저 만들기.
# 나중에 이어서 학습하고 싶을때 사용하기 위해서.
common.generate_fold_train_valid_csv_files(SEED, FOLDS)

In [None]:
for fold in range(FOLDS):
    fold += 1
    print(f"Fold {fold}/{FOLDS}")
    
    supplies = common.get_supplies_for_train_and_valid_with_fold(seed =SEED, 
                                                                 model_name = model_name, 
                                                                 lr = LR,
                                                                 batch_size = BATCH_SIZE, 
                                                                 num_workers = num_workers, 
                                                                 fold = fold, 
                                                                 folds = FOLDS, 
                                                                 augment_ratio = augment_ratio, 
                                                                 trn_img_size = trn_img_size, 
                                                                 tst_img_size = tst_img_size, 
                                                                 device = device)
    
    common.train_with_start_end_epoch(seed = SEED, 
                                      tst_img_size = tst_img_size,
                                      batch_size = BATCH_SIZE,
                                      start_epoch_inclusive = 1, 
                                      end_epoch_exclusive = EPOCHS + 1, 
                                      augment_ratio = augment_ratio,
                                      trn_loader = supplies['trn_loader'], 
                                      val_loader = supplies['val_loader'], 
                                      model = supplies['model'], 
                                      model_name = model_name, 
                                      optimizer = supplies['optimizer'], 
                                      loss_fn = supplies['loss_fn'], 
                                      device = device, 
                                      is_save_model_checkpoint = True, 
                                      is_evaluate_train_valid = True,
                                      fold = fold,
                                      folds = FOLDS)
    

In [None]:
# wandb 실행 종료
wandb.finish()

In [None]:
import sys
sys.exit('아래 셀은 수동으로 실행하기 위해서 여기서 실행 멈춤.')

# 수동으로 특정 체크포인트부터 이어서 학습하기 위한 부분


In [None]:
# checkpoint = common.load_model_checkpoint("checkpoint-resnet34_seed_42_epoch_0_isFull_False.pt", model, optimizer, device)

In [None]:
# next_epoch = checkpoint['epoch'] + 1

# common.train_with_start_end_epoch(seed = checkpoint['seed'],
#                            tst_img_size = checkpoint['tst_img_size'],
#                            batch_size = checkpoint['batch_size'],
#                            start_epoch_inclusive = next_epoch, 
#                            end_epoch_exclusive = next_epoch + 2, 
#                            augment_ratio = augment_ratio,
#                            trn_loader = trn_loader,
#                            val_loader = val_loader,
#                            model = model,
#                            model_name = model_name,
#                            optimizer = optimizer,
#                            loss_fn = loss_fn,
#                            device = device,
#                            is_save_model_checkpoint = True,
#                            is_evaluate_train_valid = True)