In [1]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [2]:
!pip install torcheval -q
!pip install wandb -q

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m179.2/179.2 kB[0m [31m4.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m34.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m190.0/190.0 kB[0m [31m23.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m241.0/241.0 kB[0m [31m28.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.7/62.7 kB[0m [31m8.0 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for pathtools (setup.py) ... [?25l[?25hdone


In [1]:
import os, sys, pickle
import math
import pandas as pd
import numpy as np
from collections import OrderedDict

from glob import glob
from zipfile import ZipFile
from tqdm import tqdm

# torch
import torch
from torch import nn
from torch.utils.data import DataLoader
from torcheval.metrics.functional import multiclass_accuracy, multiclass_f1_score

from dotenv import load_dotenv

load_dotenv()
container = os.environ.get('SSLRv2_container') # 경로 설정
wandb_login_key = os.environ.get('wandb_login_key')

modules_path = f'{container}/modules/'
sys.path.append(modules_path)

from Callbacks import ModelCheckpoint, EarlyStopping, validation_step_kp

import wandb
tqdm.pandas()

torch.set_default_dtype(torch.float64)

import warnings
warnings.filterwarnings("ignore")

import logging
logging.getLogger().setLevel(logging.ERROR)

---
# 압축파일 해제 및 불필요파일 삭제

In [5]:
# 영상에서 미리 추출한 keypoint 파일(pkl)의 압축파일
train_zip = f'{container}/data/train_keypoints(main).zip'
val_zip = f'{container}/data/Validation_keypoints.zip'

In [6]:
ZipFile(train_zip).extractall()

In [7]:
ZipFile(val_zip).extractall()

In [8]:
# filename2frame: {파일명 : 프레임 수}
with open(f'{container}/data/temp/filename2frame(main).pkl', 'rb') as f:
    train_name2frame, val_name2frame = pickle.load(f)

# 잘못 저장된 파일 확인(아무 프레임도 담겨있지 않은 깡통 파일)
train_zeros = [name for name, frame in train_name2frame.items() if frame == 0]
val_zeros = [name for name, frame in val_name2frame.items() if frame == 0]

# 디렉토리 내 pkl 개수
train_pkls = glob(f'{container}/train_keypoints(main)/**/*.pkl', recursive=True)
val_pkls = glob(f'{container}/Validation_keypoints(modified)/**/*.pkl', recursive=True)

# {이름:경로}
train_name2path = {path.split('/')[-1].rstrip('.pkl'):path for path in train_pkls}
val_name2path = {path.split('/')[-1].rstrip('.pkl'):path for path in val_pkls}

len(train_pkls), len(val_pkls)

(15837, 32000)

In [9]:
# 잘못 저장된 깡통 파일이 있다면 삭제
for path in train_zeros:
    if path in train_name2frame and path in train_name2path:
        os.remove(train_name2path[path])

for path in val_zeros:
    if path in val_name2frame and path in val_name2path:
        os.remove(val_name2path[path])

In [10]:
# 디렉토리 내 pkl 개수(깡통 파일 제거 후)
train_pkls = glob(f'{container}/train_keypoints(main)/**/*.pkl', recursive=True)
val_pkls = glob(f'{container}/Validation_keypoints(modified)/**/*.pkl', recursive=True)

len(train_pkls), len(val_pkls)

(15837, 31971)

---

# Augmentation

In [32]:
class CustomAugmentation:
    """ input data(프레임별 키포인트 좌표값)에 augmentation을 적용하는 코드입니다. """
    
    def __init__(self, mode: str = 'train', num_augmentation: int = 10, num_landmarks: int = 23, maxlen: int = 50):
        """
        Args:
            mode: 'train' | 'val'
            num_augmentation: 데이터 증강 배수
            num_landmarks: 사용할 랜드마크 수 (23: pose(상반신) | 33: pose | 65: holistic(하반신 제거) | 75: holistic)
            maxlen: 시퀀스 최대길이 지정
        """
        self.mode = mode
        self.num_augmentation = num_augmentation
        self.num_mask_range = range(0, 15)
        self.num_landmarks = num_landmarks
        self.maxlen = maxlen
        self.coor_dim = 2 # 2(x, y 좌표) | 3(x, y, z 좌표. 미구현)


    def act(self, x: np.ndarray, y: np.ndarray, start_frame: int) -> (torch.tensor, torch.tensor):
        """ self.num_augmentation 수만큼 데이터를 증강한 뒤 torch.tensor로 변환합니다.
        Args: 
            x (numpy.ndarray): X data 
            y (numpy.ndarray): y data 
            start_frame (int): 프레임 내 시작 동작 정보(label에서 파싱한 정보)
        
        Returns: (X_tensor, y_tensor)
        """
        self.start_frame = start_frame
        if self.mode == 'train':
            transformed_list = [] # self.num_augmentation만큼 증강한 x의 리스트
            for _ in range(self.num_augmentation):
                transformed = self.transform(x)
                transformed = torch.tensor(transformed)
                transformed_list.append(transformed)

            label_list = [y] * self.num_augmentation # self.num_augmentation만큼 증간한 y의 리스트

            X_data = torch.stack(transformed_list)
            y_data = torch.tensor(label_list).long()
            return X_data, y_data

        elif self.mode == 'val':
            X_data = self.transform(x)
            X_data = torch.tensor(X_data)
            return torch.tensor(X_data), torch.tensor(y).long()



    def transform(self, x: np.ndarray):
        """ x data에 적용하는 transformation을 정의합니다.
        Args:
            x: x data
        
        Returns: x (transformation 적용 후)
        """
        if self.mode == 'train':
            x = self.distance_normalization(x)
            x = self.center_crop(x)
            x = self.skip_sampling(x)
            return x

        elif self.mode == 'val':
            x = self.distance_normalization(x)
            x = self.center_crop(x)
            x = self.skip_sampling(x)
            return x


    def distance_normalization(self, x: np.ndarray) -> np.ndarray:
        """ 수어를 위한 키포인트 정규화기법입니다.
        Args:
            x: x data
        
        Functions:
            euclidean: 유클리디안 거리(L2 distance) 계산
            distance: 각 지역별로 별도의 정규화 계산
            mm: MinMaxScaler
        
        Returns: x data (정규화된 좌표)
            
        Reference: https://www.mdpi.com/1424-8220/23/6/3231
        
        """
        x = x.reshape((-1, self.num_landmarks, self.coor_dim))
        """ holistic 상반신 좌표 설명
            0~22 : 상반신
                [Section1] 0~10 : 얼굴
                    [Reference Point] 0: 코(ref)
                    
                [Section2] 11, 13, 15 : 오른쪽 팔
                    [Reference Point] 13 : 오른쪽 팔꿈치(ref)
                    
                [Section3] 12, 14, 16 : 왼쪽 팔
                    [Reference Point] 14 : 왼쪽 팔꿈치(ref)
                    
                [Section4] 15 ~ 22 : 양손 (버리기)
        """
        # Center
        center              = np.mean(x, axis=1)

        # Reference points
        nose_ref            = x[:, 0]
        right_elbow_ref     = x[:, 13]
        left_elbow_ref      = x[:, 14]
        left_wrist_ref      = x[:, 23]
        right_wrist_ref     = x[:, 44]

        # Partition of landmarks
        face_section        = x[:, :11] # 얼굴
        right_arm_section   = x[:, 11:16:2] # 오른팔
        left_arm_section    = x[:, 12:17:2] # 왼팔
        two_hands_section   = x[:, 15:23] # 버릴 좌표 (pose의 양 손)
        right_hand_section  = x[:, 44:64] # 오른손 (rhands의 손)
        left_hand_section   = x[:, 23:44] # 왼손 (lhands의 손)


        # 대표 거리 구하기: center - reference point
        euclidean = lambda ax, ay: np.sqrt(((center[:, 0] - ax)**2 + (center[:, 1] - ay)**2) + 1e-9)

        d_nose          = euclidean(nose_ref[:, 0], nose_ref[:, 1])
        d_right_elbow   = euclidean(right_elbow_ref[:, 0], right_elbow_ref[:, 1])
        d_left_elbow    = euclidean(left_elbow_ref[:, 0], left_elbow_ref[:, 1])
        d_right_wrist   = euclidean(right_wrist_ref[:, 0], right_wrist_ref[:, 1])
        d_left_wrist    = euclidean(left_wrist_ref[:, 0], left_wrist_ref[:, 1])


        # Normalized Distance
        distance = lambda section, d_ref: (section - center.reshape(-1, 1, 2)) / (d_ref.reshape(-1, 1, 1) + 1e-9)

        d_face       = distance(face_section, d_nose)
        d_right_arm  = distance(right_arm_section, d_right_elbow)
        d_left_arm   = distance(left_arm_section, d_left_elbow)
        d_right_hand = distance(right_hand_section, d_right_wrist)
        d_left_hand  = distance(left_hand_section, d_left_wrist)


        # Rescale: MinMaxScaler
        def mm(arr):
            if len(arr) == 0:
                print(arr)
            arr_x, arr_y = arr[..., 0], arr[..., 1]

            mm_x = (arr_x - arr_x.min()) / (arr_x.max() - arr_x.min() + 1e-9)
            mm_y = (arr_y - arr_y.min()) / (arr_y.max() - arr_y.min() + 1e-9)

            mm_x = np.expand_dims(mm_x, axis=-1)
            mm_y = np.expand_dims(mm_y, axis=-1)

            return np.concatenate([mm_x, mm_y], axis=-1)

        mm_face = mm(d_face)
        mm_right_arm = mm(d_right_arm)
        mm_left_arm = mm(d_left_arm)
        mm_right_hand = mm(d_right_hand)
        mm_left_hand = mm(d_left_hand)


        # 병합 및 reshape
        result = np.concatenate([mm_face, mm_right_arm, mm_left_arm, mm_right_hand, mm_left_hand], axis=1)
        return result.reshape(-1, (self.num_landmarks-7) * self.coor_dim)



    def radian_and_distance(self, x: np.ndarray) -> np.ndarray:
        """
        Args:
            x: x data
        
        Returns: x data (정규화된 좌표)
        """
        # reshape
        x = x.reshape((-1, self.num_landmarks, self.coor_dim))

        # 중심점 추출: 양 어깨의 중심점
        shoulders = x[:, 11:13, :].reshape((-1, 2, 2)) # 11: 왼쪽 어깨, 12: 오른쪽 어깨
        center = np.mean(shoulders, axis=1, keepdims=True)

        # 중심점 대비 x, y 변화량 추출
        xy_deltas = np.subtract(x, center)
        x_delta = xy_deltas[..., 0]
        y_delta = xy_deltas[..., 1]

        # 라디안 각도 구하기 (arctan2)
        theta_radian = np.arctan2(y_delta, x_delta)

        # 중심점 대비 거리 구하기 (맨해튼 거리(L1))
        manhattan_distance = np.abs(x_delta) + np.abs(y_delta)

        # 거리 MinMaxScaling
        min_value = manhattan_distance.min()
        max_value = manhattan_distance.max()

        mm = (manhattan_distance - min_value) / (max_value - min_value)

        result = np.concatenate([theta_radian.reshape((-1, self.num_landmarks, 1)), mm.reshape((-1, self.num_landmarks, 1))], axis=-1)
        return result.reshape(-1, self.num_landmarks * self.coor_dim)
    
    
    
    def delta_embedding_from_frame0(self, x: np.ndarray) -> np.ndarray:
        """ frame 0을 기준으로 각 프레임값을 빼는 delta embedding 기법입니다.
        Args:
             x: x data
             
        Returns: x data (정규화된 좌표)
        """
        fr_0 = x[0] # 0번째 프레임
        delta = x - fr_0
        eps = np.zeros(shape=(1, (self.num_landmarks-7) * self.coor_dim)) + 1e-5

        return np.concatenate([eps, delta], axis=0)



    def center_crop(self, x):
        """ 전체 프레임 중 self.maxlen에 맞춰 중심 프레임만 슬라이싱
        Args:
            x: x data
        
        Returns: x data (정규화된 좌표) 
        """
        # setting
        self.num_crop = self.maxlen

        # set index
        start_idx = self.start_frame - 5 # min값이 7이어서 5를 뺌(약간의 여유)
        end_idx = start_idx + self.num_crop

        zero_pad = np.zeros((100, (self.num_landmarks-7) * self.coor_dim))
        x = np.concatenate([x, zero_pad], axis=0)

        # crop
        return x[start_idx:end_idx, :]



    def skip_sampling(self, x: np.ndarray) -> np.ndarray:
        """ 특정 프레임을 임의로 배제합니다.
        Args: 
            x: x data
        
        Returns: x (skip sampling 적용된 프레임별 좌표값)
        """
        length = len(x)
        random_num_mask = np.random.choice(self.num_mask_range)
        mask_idx = np.random.choice(length, random_num_mask, replace=False)
        x = [x[i] if not i in mask_idx else np.zeros((self.num_landmarks-7) * self.coor_dim) for i in range(length)]

        return x


    def eps_projection(self, x: np.ndarray) -> np.ndarray:
        """ 각 값에 임의의 epsilon을 더해 jittering을 적용합니다.
        Args:
            x: x data
        
        Returns: x (epsiolon이 더해진 좌표값)
        """
        eps = 5e-5

        array_shape = x.shape
        eps_array = np.random.uniform(-eps, eps, array_shape)

        return x + eps_array

# Dataset

In [33]:
class KeypointDS(torch.utils.data.Dataset):
    """ Custom Dataset """
    
    def __init__(self, mode='train', num_augmentation=5, device=None):
        """
        Args:
            mode: 'train' | 'val'
            num_augmentation: 데이터 증강 배수
            device: torch.device 객체
        """
        super().__init__()
        self.medical = False # True: 의료용어 | False: 의료용어 + 일상어

        self.mode = mode
        self.num_augmentation = num_augmentation
        self.device = device
        self.maxlen = 70 # 최대 시퀀스 길이
        self.word2idx = self.load_word2idx() # {단어 : 라벨}
        self.num_classes = len(self.word2idx) # 클래스(단어) 수
        self.name2frame = self.load_name2frame() # {파일명 : 프레임 수}
        self.df, self.filename2label_df, self.filename2start  = self.load_df()
        self.filename2path, self.filename2label, self.filename_list = self.get_kp_info()

        # self.num_landmarks = 23 # pose(상반신)
        # self.num_landmarks = 33 # pose
        self.num_landmarks = 65 # holistic(하반신 제거)
        # self.num_landmarks = 75 # holistic
        self.CA = CustomAugmentation(self.mode, self.num_augmentation, self.num_landmarks, self.maxlen) # 데이터 증강


    def __len__(self):
        return len(self.filename_list)


    def __getitem__(self, idx):
        filename = self.filename_list[idx] # 파일명
        filepath = self.filename2path[filename] # 파일경로
        label = self.filename2label[filename] # 라벨 (y data)
        start_frame = self.filename2start[filename] # 동작이 시작하는 프레임 정보

        
        with open(filepath, 'rb') as f:
            # 하반신의 x, y 좌표에 대한 컬럼명
            legs = ['x23', 'y23', 'z23', 'x24', 'y24', 'z24', 'x25', 'y25', 'z25', 'x26',
                    'y26', 'z26', 'x27', 'y27', 'z27', 'x28', 'y28', 'z28', 'x29', 'y29',
                    'z29', 'x30', 'y30', 'z30', 'x31', 'y31', 'z31', 'x32', 'y32', 'z32']
            
            # x, y, z 좌표값으로 구성된 데이터프레임 - 마지막 3개 컬럼은 좌표와 무관한 정보(파일명, 프레임 수 등)
            keypoints = pickle.load(f).iloc[:, :-3]

            if self.num_landmarks == 23: # pose - 상반신만 불러오는 경우 -> 23 * 3 = 69
                keypoints = keypoints.iloc[:, :69]
            elif self.num_landmarks == 33: # pose 모두 불러오는 경우
                keypoints = keypoints.iloc[:, :99]
            elif self.num_landmarks == 65: # holistic에서 하반신만 버릴 때 (위의 legs를 제거할 때)
                keypoints = keypoints.drop(legs, axis=1, inplace=False)
            else:
                keypoints = keypoints

            keypoints = keypoints.replace('', 0.0).dropna(axis=0) # 빈 값은 0.0으로 대체하고, nan값은 drop
            keypoints = keypoints.to_numpy().astype(np.float64) 
            keypoints = keypoints.reshape(-1, self.num_landmarks, 3)[..., :2] # x, y, z -> x, y

            try:
                X_tensor, y_tensor = self.CA.act(keypoints, label, start_frame)
                return X_tensor.to(self.device), y_tensor.to(self.device)
            except Exception as e:
                print(filepath)
                raise e



    def load_df(self):
        """ 파일명, 단어(클래스) 등의 정보가 들어있는 데이터프레임 불러오기 """
        if self.mode == 'train':
            csv_path = f'{container}/data/train_df(checked_main).csv'
        elif self.mode == 'val':
            csv_path = f'{container}/data/validation_data(clean_main).csv'

        # 불러오기
        raw_df = pd.read_csv(csv_path)


        # 라벨 정보 심기
        raw_df['label'] = raw_df['word'].apply(lambda word: self.word2idx[word] if word in self.word2idx else '<UNK>') # 라벨을 붙이되, 없으면 <UNK> 부여
        raw_df = raw_df[raw_df['label'] != '<UNK>'] # <UNK>가 아닌 데이터만 골라내기
        raw_df = raw_df[(raw_df['data_type'] == 'WORD') & (raw_df['category'] <= 1)] # 0: 의료용어, 1: 일상어, 2+: 무관한 단어

        # 프레임 정보 심기
        raw_df['frame'] = raw_df['filename'].apply(lambda filename: self.name2frame[filename] if filename in self.name2frame else -1)
        raw_df['start_frame'] = raw_df['start'].apply(lambda t: int(t*30))

        self.raw_df = raw_df
        df = raw_df

        filename2label_df = df.set_index(keys='filename')['label'].to_dict()
        name2start = df.set_index('filename')['start_frame'].to_dict()

        return df, filename2label_df, name2start


    def load_word2idx(self):
        """ 사전에 구성한 word2idx 불러오기 """
        if self.medical:
            with open(f'{container}/data/medical_word2idx.pkl', 'rb') as f: #
                word2idx = pickle.load(f)
        else:
            with open(f'{container}/data/integrated_word2idx.pkl', 'rb') as f: #
                word2idx = pickle.load(f)

        return word2idx


    def get_kp_info(self):
        if self.mode == 'train':
            if self.medical:
                keypoints_path = f'{container}/data/train_keypoints(medical)'
            else:
                keypoints_path = f'{container}/train_keypoints(main)'
        elif self.mode == 'val':
            keypoints_path = f'{container}/Validation_keypoints(modified)'


        # 파일명 추출
        keypoints_list = glob(f"{keypoints_path}/**/*.pkl", recursive=True)
        filename2path = {path.split('/')[-1].rstrip('.mp4.pkl').rstrip('.pkl'):path for path in keypoints_list}

        filename2label = {}
        self.err_list = []

        # 파일명에서 확장자명(.mp4)를 제거하고 filename2label에 추가
        for filename in filename2path:
            filename = filename.rstrip('.mp4')
            try:
                filename2label[filename] = self.filename2label_df[filename]
            except:
                self.err_list.append(filename)

        filename_list = list(filename2label.keys())
        # filename_list = [name for name in filename_list if name.endswith('F')] # 정면 영상만 쓰고 싶을 떄
        
        return filename2path, filename2label, filename_list

    
    def load_name2frame(self) -> dict:
        """ 딕셔너리 {파일명 : 프레임 수} 불러오기 """
        with open(f'{container}/data/temp/filename2frame(main).pkl', 'rb') as f:
            train_name2frame, val_name2frame = pickle.load(f)

        if self.mode == 'train': return train_name2frame
        elif self.mode == 'val': return val_name2frame



def my_collate_fn(samples, is_graph=True):
    """ augmentation하면서 흐트러진 batch의 차원을 정렬한다. """

    X_collate = torch.stack([sample[0] for sample in samples])
    y_collate = torch.stack([sample[1] for sample in samples])

    # 차원 재정렬
    X_collate = torch.reshape(X_collate, (-1, train_ds.maxlen, (train_ds.num_landmarks-7) * 2))
    y_collate = torch.reshape(y_collate, (-1,))
    return (X_collate, y_collate)

In [34]:
device = torch.device('cuda')

train_ds = KeypointDS('train', 1, device)
val_ds = KeypointDS('val', 1, device)
display(len(train_ds), len(val_ds))

15837

1946

# Model

In [35]:
class PositionalEncoding(nn.Module):
    """ 정석 PositionalEncoding 코드
    Reference: https://ysg2997.tistory.com/11 
    """

    def __init__(self, dim_model, max_len, device):
        """
        Args:
            dim_model: 
            max_len: 시퀀스 최대길이
            device: torch.device 객체
        """
        super().__init__()
        self.device =device
        # Encoding - From formula
        pos_encoding = torch.zeros(max_len, dim_model, device=device)
        positions_list = torch.arange(0, max_len, dtype=torch.float).view(-1, 1) # 0, 1, 2, 3, 4, 5
        division_term = torch.exp(torch.arange(0, dim_model, 2).float() * (-math.log(10000.0)) / dim_model) # 1000^(2i/dim_model)

        pos_encoding[:, 0::2] = torch.sin(positions_list * division_term)
        pos_encoding[:, 1::2] = torch.cos(positions_list * division_term)

        # Saving buffer (same as parameter without gradients needed)
        pos_encoding = pos_encoding.unsqueeze(0)
        self.register_buffer("pos_encoding", pos_encoding)


    def forward(self, token_embedding: torch.tensor) -> torch.tensor:

        #  positional encoding
        return token_embedding + self.pos_encoding[:token_embedding.size(0), :]



class Model(torch.nn.Module):
    """ SSLRv2
    처음에 의료용어(72개 클래스)만 가지고 사전학습 진행.
    이후 class TLModel에 의료용어로 prerained weight를 적용하여 전이학습 잔행  
    """
    def __init__(self, num_classes, transformer_dropout, maxlen, device):
        """
        Args:
            num_classes: 클래스 수
            transformer_dropout: dropout 적용비율
            maxlen: 시퀀스 최대길이
            device: torch.device 객체
        """
        super().__init__()
        self.device = device
        self.num_classes = num_classes
        self.num_landmarks = train_ds.num_landmarks
        if self.num_landmarks == 65:
            self.num_landmarks = 58

        self.embedding = nn.Sequential(OrderedDict([
            ('Linear_Embedding', nn.Linear(in_features=self.num_landmarks * 2, out_features=512, device=device)),
            ('Positional_Encoding', PositionalEncoding(dim_model=512, max_len=maxlen, device=device))
        ]))
        # self.embedding = nn.Linear(in_features=66, out_features=512, device=device)
        self.backbone = nn.Sequential(OrderedDict([
            ('BN_1', nn.BatchNorm1d(num_features=maxlen + 1, eps=1e-5, momentum=0.1, affine=True, device=device)),
            ('Transformer_1', nn.TransformerEncoderLayer(d_model=512, nhead=4, dropout=transformer_dropout,
                                                         batch_first=True, device=device)),
            ('BN_2', nn.BatchNorm1d(num_features=maxlen + 1, eps=1e-5, momentum=0.1, affine=True, device=device)),
            ('Transformer_2', nn.TransformerEncoderLayer(d_model=512, nhead=4, dropout=transformer_dropout,
                                                         batch_first=True, device=device)),
            ('BN_3', nn.BatchNorm1d(num_features=maxlen + 1, eps=1e-5, momentum=0.1, affine=True, device=device)),
            ('Transformer_3', nn.TransformerEncoderLayer(d_model=512, nhead=4, dropout=transformer_dropout,
                                                         batch_first=True, device=device)),
            ('BN_4', nn.BatchNorm1d(num_features=maxlen + 1, eps=1e-5, momentum=0.1, affine=True, device=device)),
            ('Transformer_4', nn.TransformerEncoderLayer(d_model=512, nhead=4, dropout=transformer_dropout,
                                                         batch_first=True, device=device)),
            ('BN_5', nn.BatchNorm1d(num_features=maxlen + 1, eps=1e-5, momentum=0.1, affine=True, device=device)),
            ('Transformer_5', nn.TransformerEncoderLayer(d_model=512, nhead=4, dropout=transformer_dropout,
                                                         batch_first=True, device=device)),


        ]))
        self.classifier = nn.Sequential(OrderedDict([
            ('linear', nn.Linear(in_features=512, out_features=self.num_classes, device=device)),
            ('Softmax', nn.Softmax(dim=-1))
        ]))

    def forward(self, x):
        x = self.embedding(x)

        ##### cls token #####
        batch_size = x.size(0)
        emb_dim = x.size(-1)
        cls_token = torch.ones((batch_size, 1, emb_dim), device=self.device)
        x = torch.cat([cls_token, x], dim=1)
        ##### cls token #####

        output = self.backbone(x)
        return self.classifier(output[:, 0])



class TLModel(torch.nn.Module):
    """ 
        의료용어로 사전학습한 class Model의 weight를 가져와 전이학습 진행. 
        클래스 수 변경에 따른 분류층 변경.
        별도의 freeze는 없음.
     """
    def __init__(self, num_classes, transformer_dropout, maxlen, device):
        super().__init__()
        self.device = device
        self.num_classes = num_classes

        self.pretrained = Model(76, transformer_dropout, maxlen, device)
        weight_pt = f'{container}/main/ckpts/medical(ref)/Epoch60(f1_0.901).pt' # 사전학습된 의료용어 weight
        self.pretrained.load_state_dict(torch.load(weight_pt, map_location=device), strict=True)
        self.pretrained.classifier = nn.Identity() # nn.Identity()를 통해 classifier를 빈 레이어로 변경

        self.classifier = nn.Sequential(OrderedDict([
            ('linear', nn.Linear(in_features=512, out_features=num_classes, device=device)),
            ('Softmax', nn.Softmax(dim=-1))
        ]))

    def forward(self, x):
        x = self.pretrained(x)
        return self.classifier(x)


# Configuration

In [36]:
# Path
"""
save_path: 체크포인트 저장경로 설정
"""

save_path = os.path.join(container, 'main/ckpts/main(ref)')

device = torch.device('cuda')
nb_epochs = 500 # Epoch
batch_size = 128 
learning_rate = 1e-4
patience = 5 # Patience for EarlyStopping
start_from_epoch = 1 
transformer_dropout = 0.25
max_norm = 3 # Gradient Clipping

# Dataset, DataLoader
train_ds = KeypointDS('train', 5, device)
val_ds = KeypointDS('val', 5, device)
train_loader = DataLoader(train_ds, batch_size, shuffle=False, collate_fn=my_collate_fn)
val_loader = DataLoader(val_ds, batch_size, shuffle=False, collate_fn=my_collate_fn)

num_classes = train_ds.num_classes
maxlen = train_ds.maxlen

# Model
# model = Model(num_classes, transformer_dropout, maxlen, device) # class Model 의료용어 pretrain
model = TLModel(num_classes, transformer_dropout, maxlen, device) # class TLModel 의료용어 + 일상어 전이학습

optimizer = torch.optim.AdamW(model.parameters(), learning_rate) 
loss_fn = torch.nn.CrossEntropyLoss()


# Callbacks
es = EarlyStopping('val_loss', patience, start_from_epoch)
chk = ModelCheckpoint('val_loss', save_path, save_best_only=False)


scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', # 낮은 값이 좋은 값
                                                       factor=0.5, # lr * factor
                                                       patience=3,
                                                       # cooldown=3, # 변경 후 쉬어갈 스텝(epoch)
                                                       threshold=1e-3, # 이 이상으로 개선 안되면 patience 카운트
                                                       verbose=True)

# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5, verbose=True)

''' os.environ
현재 ipynb 파일의 경로로 잡아주세요.
'''
os.environ['WANDB_NOTEBOOK_NAME'] = os.path.join(container, 'main/main(ref).ipynb')

wandb.login(key=wandb_login_key)

run = wandb.init(
          project='delta_keypoint',
           name='main(ref, colab)',
           tags=['holistic', 'noHip', 'aug5', 'dropout0.25', 'dim512', 'replace0', '5stack', 'skip-sampling', 'selective_crop', 'no_eps', 'reference'],
           config={
                'epochs' : nb_epochs,
                'learning_rate' : learning_rate,
                'batch_size' : batch_size,
                'transformer_dropout' : transformer_dropout,
                'optimizer' : optimizer,
                'patience' : patience,
                'start_from_epoch' : start_from_epoch,
                'num_classes' : num_classes,
                'scheduler' : scheduler,
                'max_norm' : max_norm,
                })

run.save()



VBox(children=(Label(value='0.002 MB of 0.012 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.210389…

True

# Run

In [None]:
regression = False
step = 0

for epoch in range(nb_epochs):
    model.train()
    train_loss_step, train_accuracy_step, train_f1_score_step = [], [], []

    train_loop = tqdm(train_loader, desc=f"Epoch [{epoch+1}/{nb_epochs}]",
                      position=0, leave=True,
                      bar_format='{l_bar}{bar}|{n_fmt}/{total_fmt} [{elapsed}<{remaining}] {postfix}'
                      )

    for X_batch, y_batch in train_loop:
        # y_batch = y_batch.to(device)
        optimizer.zero_grad()

        # H(x) 계산
        y_pred = model(X_batch)

        loss = loss_fn(y_pred, y_batch)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) # clipping by 5
        optimizer.step()

        # step 결과
        accuracy = multiclass_accuracy(y_pred, y_batch)
        f1_score = multiclass_f1_score(y_pred, y_batch, num_classes=num_classes, average='macro')
        # accuracy = sequential_accuracy(y_pred, y_batch)

        train_loss_step.append(loss)
        train_accuracy_step.append(accuracy)
        train_f1_score_step.append(f1_score)

        msg = f'Train Loss: {torch.tensor(train_loss_step).mean():.6f}  Train Accuracy: {torch.tensor(train_accuracy_step).mean():.6f}  Train f1_score: {torch.tensor(train_f1_score_step).mean():.6f}'

        train_loop.set_postfix_str(msg)

        step += 1


    # Epoch의 Train 결과
    train_loss_epoch = torch.tensor(train_loss_step).mean()
    train_accuracy_epoch = torch.tensor(train_accuracy_step).mean()
    train_f1_score_epoch = torch.tensor(train_f1_score_step).mean()


    # Epoch에서의 Validation 결과
    val_scores = validation_step_kp(model, val_loader, device, loss_fn, epoch, num_classes)


    # Epoch마다 체크포인트 저장
    if (epoch+1) % 5 == 0:
        if regression:
            val_MSE_log = round(float(val_scores.val_loss))
            torch.save(model.state_dict(), f"{save_path}/TL_Epoch{epoch+1}(f1_{val_MSE_log}).pt")
        else:
            val_f1_log = round(float(val_scores.val_f1), 3)
            torch.save(model.state_dict(), f"{save_path}/Epoch{epoch+1}(f1_{val_f1_log}).pt")
    # 로그 남기기
    wandb.log({
        "Loss/Train" : train_loss_epoch,
        "Loss/Validation" : val_scores.val_loss,
        "Accuracy/Train" : train_accuracy_epoch,
        "Accuracy/Validation" : val_scores.val_accuracy,
        "f1_score/Train" : train_f1_score_epoch,
        "f1_score/Validation" : val_scores.val_f1
    }, step=epoch)


    # Scheduler
    scheduler.step(val_scores.val_loss)

    # Callbacks
    monitor_metric= val_scores.val_loss
    ## ModelCheckpoint
    chk.monitoring(model, monitor_metric)
    ## EarlyStopping
    FLAG = es.monitoring(epoch, monitor_metric)
    if not FLAG: break

Epoch [1/500]: 100%|██████████|124/124 [09:39<00:00] , Train Loss: 5.148005  Train Accuracy: 0.171610  Train f1_score: 0.095235
Validation: 100%|██████████|16/16 [00:44<00:00] 


              val_loss: 5.002808  val_accuracy: 0.314153  val_f1_score: 0.206368


Epoch [2/500]: 100%|██████████|124/124 [09:38<00:00] , Train Loss: 4.915575  Train Accuracy: 0.406249  Train f1_score: 0.268989
Validation: 100%|██████████|16/16 [00:45<00:00] 


              val_loss: 4.853809  val_accuracy: 0.457069  val_f1_score: 0.337634


Epoch [3/500]: 100%|██████████|124/124 [09:40<00:00] , Train Loss: 4.808020  Train Accuracy: 0.506535  Train f1_score: 0.358101
Validation: 100%|██████████|16/16 [00:44<00:00] 


              val_loss: 4.769494  val_accuracy: 0.540527  val_f1_score: 0.421534


Epoch [4/500]: 100%|██████████|124/124 [09:37<00:00] , Train Loss: 4.725611  Train Accuracy: 0.589988  Train f1_score: 0.435882
Validation: 100%|██████████|16/16 [00:44<00:00] 


              val_loss: 4.707607  val_accuracy: 0.599008  val_f1_score: 0.485376


Epoch [5/500]: 100%|██████████|124/124 [09:39<00:00] , Train Loss: 4.660120  Train Accuracy: 0.655889  Train f1_score: 0.505064
Validation: 100%|██████████|16/16 [00:45<00:00] 


              val_loss: 4.651136  val_accuracy: 0.656588  val_f1_score: 0.548225


Epoch [6/500]: 100%|██████████|124/124 [09:42<00:00] , Train Loss: 4.617174  Train Accuracy: 0.694849  Train f1_score: 0.547367
Validation: 100%|██████████|16/16 [00:45<00:00] 


              val_loss: 4.626579  val_accuracy: 0.677471  val_f1_score: 0.567090


Epoch [7/500]: 100%|██████████|124/124 [09:44<00:00] , Train Loss: 4.593936  Train Accuracy: 0.712624  Train f1_score: 0.568148
Validation: 100%|██████████|16/16 [00:45<00:00] 


              val_loss: 4.614963  val_accuracy: 0.693096  val_f1_score: 0.586639


Epoch [8/500]: 100%|██████████|124/124 [09:40<00:00] , Train Loss: 4.575114  Train Accuracy: 0.729103  Train f1_score: 0.586211
Validation: 100%|██████████|16/16 [00:44<00:00] 


              val_loss: 4.603106  val_accuracy: 0.703350  val_f1_score: 0.597207


Epoch [9/500]: 100%|██████████|124/124 [09:41<00:00] , Train Loss: 4.558556  Train Accuracy: 0.745975  Train f1_score: 0.605576
Validation: 100%|██████████|16/16 [00:45<00:00] 


              val_loss: 4.581294  val_accuracy: 0.724346  val_f1_score: 0.612911


Epoch [10/500]: 100%|██████████|124/124 [09:42<00:00] , Train Loss: 4.542777  Train Accuracy: 0.762312  Train f1_score: 0.625692
Validation: 100%|██████████|16/16 [00:45<00:00] 


              val_loss: 4.564841  val_accuracy: 0.742375  val_f1_score: 0.641894


Epoch [11/500]: 100%|██████████|124/124 [09:42<00:00] , Train Loss: 4.523594  Train Accuracy: 0.781734  Train f1_score: 0.650115
Validation: 100%|██████████|16/16 [00:45<00:00] 


              val_loss: 4.547700  val_accuracy: 0.753606  val_f1_score: 0.656649


Epoch [12/500]: 100%|██████████|124/124 [09:42<00:00] , Train Loss: 4.496206  Train Accuracy: 0.807173  Train f1_score: 0.681983
Validation: 100%|██████████|16/16 [00:45<00:00] 


              val_loss: 4.518812  val_accuracy: 0.786771  val_f1_score: 0.698579


Epoch [13/500]: 100%|██████████|124/124 [09:42<00:00] , Train Loss: 4.479166  Train Accuracy: 0.824188  Train f1_score: 0.703949
Validation: 100%|██████████|16/16 [00:45<00:00] 


              val_loss: 4.518634  val_accuracy: 0.785344  val_f1_score: 0.697633


Epoch [14/500]: 100%|██████████|124/124 [09:41<00:00] , Train Loss: 4.464382  Train Accuracy: 0.837403  Train f1_score: 0.722441
Validation: 100%|██████████|16/16 [00:45<00:00] 


              val_loss: 4.507961  val_accuracy: 0.796988  val_f1_score: 0.715767


Epoch [15/500]: 100%|██████████|124/124 [09:41<00:00] , Train Loss: 4.458582  Train Accuracy: 0.841548  Train f1_score: 0.729306
Validation: 100%|██████████|16/16 [00:45<00:00] 


              val_loss: 4.540730  val_accuracy: 0.761869  val_f1_score: 0.671498


Epoch [16/500]: 100%|██████████|124/124 [09:42<00:00] , Train Loss: 4.459431  Train Accuracy: 0.841364  Train f1_score: 0.727898
Validation: 100%|██████████|16/16 [00:44<00:00] 


              val_loss: 4.486832  val_accuracy: 0.811185  val_f1_score: 0.722885


Epoch [17/500]: 100%|██████████|124/124 [09:43<00:00] , Train Loss: 4.443061  Train Accuracy: 0.855214  Train f1_score: 0.749797
Validation: 100%|██████████|16/16 [00:45<00:00] 


              val_loss: 4.476290  val_accuracy: 0.822416  val_f1_score: 0.737811


Epoch [18/500]: 100%|██████████|124/124 [09:42<00:00] , Train Loss: 4.435757  Train Accuracy: 0.862542  Train f1_score: 0.760938
Validation: 100%|██████████|16/16 [00:45<00:00] 


              val_loss: 4.474566  val_accuracy: 0.827299  val_f1_score: 0.750219


Epoch [19/500]: 100%|██████████|124/124 [09:41<00:00] , Train Loss: 4.427047  Train Accuracy: 0.871195  Train f1_score: 0.774638
Validation: 100%|██████████|16/16 [00:45<00:00] 


              val_loss: 4.464128  val_accuracy: 0.839956  val_f1_score: 0.762126


Epoch [20/500]: 100%|██████████|124/124 [09:42<00:00] , Train Loss: 4.420712  Train Accuracy: 0.878238  Train f1_score: 0.784820
Validation: 100%|██████████|16/16 [00:45<00:00] 


              val_loss: 4.474615  val_accuracy: 0.825346  val_f1_score: 0.741037


Epoch [21/500]: 100%|██████████|124/124 [09:43<00:00] , Train Loss: 4.413493  Train Accuracy: 0.883996  Train f1_score: 0.790988
Validation: 100%|██████████|16/16 [00:45<00:00] 


              val_loss: 4.454913  val_accuracy: 0.843374  val_f1_score: 0.765145


Epoch [22/500]: 100%|██████████|124/124 [09:41<00:00] , Train Loss: 4.408271  Train Accuracy: 0.888395  Train f1_score: 0.799072
Validation: 100%|██████████|16/16 [00:45<00:00] 


              val_loss: 4.453596  val_accuracy: 0.849271  val_f1_score: 0.775139


Epoch [23/500]: 100%|██████████|124/124 [09:45<00:00] , Train Loss: 4.404613  Train Accuracy: 0.892648  Train f1_score: 0.805609
Validation: 100%|██████████|16/16 [00:44<00:00] 


              val_loss: 4.478075  val_accuracy: 0.824820  val_f1_score: 0.746851


Epoch [24/500]: 100%|██████████|124/124 [09:39<00:00] , Train Loss: 4.404525  Train Accuracy: 0.892406  Train f1_score: 0.805877
Validation: 100%|██████████|16/16 [00:44<00:00] 


              val_loss: 4.458698  val_accuracy: 0.838529  val_f1_score: 0.760689


Epoch [25/500]: 100%|██████████|124/124 [09:41<00:00] , Train Loss: 4.399902  Train Accuracy: 0.898054  Train f1_score: 0.815927
Validation: 100%|██████████|16/16 [00:45<00:00] 


              val_loss: 4.455627  val_accuracy: 0.845891  val_f1_score: 0.761357
Epoch 00025: reducing learning rate of group 0 to 5.0000e-05.


Epoch [26/500]: 100%|██████████|124/124 [09:40<00:00] , Train Loss: 4.385907  Train Accuracy: 0.909330  Train f1_score: 0.836227
Validation: 100%|██████████|16/16 [00:45<00:00] 


              val_loss: 4.431540  val_accuracy: 0.864896  val_f1_score: 0.791450


Epoch [27/500]: 100%|██████████|124/124 [09:39<00:00] , Train Loss: 4.379708  Train Accuracy: 0.915045  Train f1_score: 0.847254
Validation: 100%|██████████|16/16 [00:44<00:00] 


              val_loss: 4.429104  val_accuracy: 0.867788  val_f1_score: 0.797400


Epoch [28/500]: 100%|██████████|124/124 [09:37<00:00] , Train Loss: 4.377519  Train Accuracy: 0.916699  Train f1_score: 0.849727
Validation: 100%|██████████|16/16 [00:45<00:00] 


              val_loss: 4.426683  val_accuracy: 0.870718  val_f1_score: 0.804722


Epoch [29/500]: 100%|██████████|124/124 [09:40<00:00] , Train Loss: 4.375611  Train Accuracy: 0.918132  Train f1_score: 0.851081
Validation: 100%|██████████|16/16 [00:44<00:00] 


              val_loss: 4.424035  val_accuracy: 0.872183  val_f1_score: 0.804348


Epoch [30/500]: 100%|██████████|124/124 [09:37<00:00] , Train Loss: 4.375300  Train Accuracy: 0.918523  Train f1_score: 0.852279
Validation: 100%|██████████|16/16 [00:44<00:00] 


              val_loss: 4.425721  val_accuracy: 0.872671  val_f1_score: 0.801853


Epoch [31/500]: 100%|██████████|124/124 [09:38<00:00] , Train Loss: 4.374900  Train Accuracy: 0.918812  Train f1_score: 0.852769
Validation: 100%|██████████|16/16 [00:45<00:00] 


              val_loss: 4.424519  val_accuracy: 0.874624  val_f1_score: 0.804847


Epoch [32/500]: 100%|██████████|124/124 [09:42<00:00] , Train Loss: 4.374763  Train Accuracy: 0.919312  Train f1_score: 0.852073
Validation: 100%|██████████|16/16 [00:45<00:00] 


              val_loss: 4.429320  val_accuracy: 0.869742  val_f1_score: 0.799330
Epoch 00032: reducing learning rate of group 0 to 2.5000e-05.


Epoch [33/500]: 100%|██████████|124/124 [09:41<00:00] , Train Loss: 4.368952  Train Accuracy: 0.924725  Train f1_score: 0.862729
Validation: 100%|██████████|16/16 [00:44<00:00] 


              val_loss: 4.410876  val_accuracy: 0.887320  val_f1_score: 0.824240


Epoch [34/500]: 100%|██████████|124/124 [09:40<00:00] , Train Loss: 4.365438  Train Accuracy: 0.928079  Train f1_score: 0.866370
Validation: 100%|██████████|16/16 [00:44<00:00] 


              val_loss: 4.414316  val_accuracy: 0.885367  val_f1_score: 0.818912


Epoch [35/500]: 100%|██████████|124/124 [09:38<00:00] , Train Loss: 4.364310  Train Accuracy: 0.929016  Train f1_score: 0.868998
Validation: 100%|██████████|16/16 [00:44<00:00] 


              val_loss: 4.414059  val_accuracy: 0.883902  val_f1_score: 0.817072


Epoch [36/500]: 100%|██████████|124/124 [09:37<00:00] , Train Loss: 4.363203  Train Accuracy: 0.930114  Train f1_score: 0.870099
Validation: 100%|██████████|16/16 [00:44<00:00] 


              val_loss: 4.417411  val_accuracy: 0.879995  val_f1_score: 0.812008


Epoch [37/500]: 100%|██████████|124/124 [09:40<00:00] , Train Loss: 4.362938  Train Accuracy: 0.930215  Train f1_score: 0.872166
Validation: 100%|██████████|16/16 [00:44<00:00] 


              val_loss: 4.416650  val_accuracy: 0.879507  val_f1_score: 0.813640
Epoch 00037: reducing learning rate of group 0 to 1.2500e-05.


Epoch [38/500]: 100%|██████████|124/124 [09:40<00:00] , Train Loss: 4.362018  Train Accuracy: 0.930915  Train f1_score: 0.873212
Validation: 100%|██████████|16/16 [00:45<00:00] 


              val_loss: 4.412660  val_accuracy: 0.885367  val_f1_score: 0.819915


Epoch [39/500]: 100%|██████████|124/124 [09:46<00:00] , Train Loss: 4.361810  Train Accuracy: 0.930996  Train f1_score: 0.871260
Validation: 100%|██████████|16/16 [00:45<00:00] 


              val_loss: 4.410675  val_accuracy: 0.887320  val_f1_score: 0.821297


Epoch [40/500]: 100%|██████████|124/124 [09:46<00:00] , Train Loss: 4.361338  Train Accuracy: 0.931374  Train f1_score: 0.872518
Validation: 100%|██████████|16/16 [00:45<00:00] 


              val_loss: 4.409857  val_accuracy: 0.886343  val_f1_score: 0.821689


Epoch [41/500]: 100%|██████████|124/124 [09:46<00:00] , Train Loss: 4.361188  Train Accuracy: 0.931575  Train f1_score: 0.871411
Validation: 100%|██████████|16/16 [00:45<00:00] 


              val_loss: 4.410610  val_accuracy: 0.886831  val_f1_score: 0.822446
Epoch 00041: reducing learning rate of group 0 to 6.2500e-06.


Epoch [42/500]: 100%|██████████|124/124 [09:43<00:00] , Train Loss: 4.360539  Train Accuracy: 0.932105  Train f1_score: 0.873277
Validation: 100%|██████████|16/16 [00:45<00:00] 


              val_loss: 4.405820  val_accuracy: 0.889761  val_f1_score: 0.827518


Epoch [43/500]: 100%|██████████|124/124 [09:46<00:00] , Train Loss: 4.357706  Train Accuracy: 0.935646  Train f1_score: 0.879893
Validation: 100%|██████████|16/16 [00:45<00:00] 


              val_loss: 4.406380  val_accuracy: 0.890249  val_f1_score: 0.830260


Epoch [44/500]: 100%|██████████|124/124 [09:46<00:00] , Train Loss: 4.356804  Train Accuracy: 0.936452  Train f1_score: 0.881469
Validation: 100%|██████████|16/16 [00:45<00:00] 


              val_loss: 4.404229  val_accuracy: 0.897085  val_f1_score: 0.837469


Epoch [45/500]: 100%|██████████|124/124 [09:48<00:00] , Train Loss: 4.356504  Train Accuracy: 0.936641  Train f1_score: 0.882652
Validation: 100%|██████████|16/16 [00:45<00:00] 


              val_loss: 4.404179  val_accuracy: 0.891714  val_f1_score: 0.830135


Epoch [46/500]: 100%|██████████|124/124 [09:45<00:00] , Train Loss: 4.356296  Train Accuracy: 0.936674  Train f1_score: 0.882033
Validation: 100%|██████████|16/16 [00:45<00:00] 


              val_loss: 4.404626  val_accuracy: 0.892691  val_f1_score: 0.831847
Epoch 00046: reducing learning rate of group 0 to 3.1250e-06.


Epoch [47/500]: 100%|██████████|124/124 [09:46<00:00] , Train Loss: 4.356083  Train Accuracy: 0.936830  Train f1_score: 0.883630
Validation: 100%|██████████|16/16 [00:45<00:00] 


              val_loss: 4.404199  val_accuracy: 0.893179  val_f1_score: 0.832985


Epoch [48/500]: 100%|██████████|124/124 [09:45<00:00] , Train Loss: 4.356003  Train Accuracy: 0.936780  Train f1_score: 0.882945
Validation: 100%|██████████|16/16 [00:45<00:00] 


              val_loss: 4.405287  val_accuracy: 0.894644  val_f1_score: 0.833098


Epoch [49/500]: 100%|██████████|124/124 [09:45<00:00] , Train Loss: 4.355933  Train Accuracy: 0.936817  Train f1_score: 0.883753
Validation: 100%|██████████|16/16 [00:45<00:00] 


              val_loss: 4.402907  val_accuracy: 0.895620  val_f1_score: 0.834753


Epoch [50/500]: 100%|██████████|124/124 [09:48<00:00] , Train Loss: 4.355847  Train Accuracy: 0.936956  Train f1_score: 0.883155
Validation: 100%|██████████|16/16 [00:45<00:00] 


              val_loss: 4.404411  val_accuracy: 0.892691  val_f1_score: 0.832928
Epoch 00050: reducing learning rate of group 0 to 1.5625e-06.


Epoch [51/500]: 100%|██████████|124/124 [09:46<00:00] , Train Loss: 4.355773  Train Accuracy: 0.936956  Train f1_score: 0.883923
Validation: 100%|██████████|16/16 [00:45<00:00] 


              val_loss: 4.403815  val_accuracy: 0.893667  val_f1_score: 0.831434


Epoch [52/500]:  30%|██▉       |37/124 [02:57<06:54] , Train Loss: 4.356466  Train Accuracy: 0.936233  Train f1_score: 0.885361