   # Library


In [None]:
! pip install sentencepiece einops wandb torch-summary icecream -qq

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import pandas as pd
from pprint import pprint
from nltk.tokenize import word_tokenize as en_tokenizer
import sentencepiece as spm
import urllib.request
import csv
import numpy as np
from einops import rearrange, reduce, repeat
from torch.cuda import amp
from tqdm import tqdm
import wandb
import time
import copy
from collections import defaultdict
from sklearn.metrics import mean_squared_error
import joblib
import gc
import os
from icecream import ic
from sklearn.model_selection import train_test_split
import os




In [None]:
VOCAB_SIZE = 10000
SEQ_LEN = 60


PAD_IDX = 0
BOS_IDX = 2
SOS_IDX = 3



# ENV = 'COLAB'
ENV = 'KAGGLE'
# ENV = 'SYSTEM'

# Option for Mixed Precision
FP16 = True
# FP16 = False

N = 2
HIDDEN_DIM = 256
NUM_HEAD = 8 
INNER_DIM = 512
BATCH_SIZE = 64
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 0


CONFIG = {
    'VOCAB_SIZE': VOCAB_SIZE,
    'SEQ_LEN': SEQ_LEN,
    'N': N,
    'HIDDEN_DIM': HIDDEN_DIM,
    'NUM_HEAD': NUM_HEAD,
    'INNER_DIM': INNER_DIM,
    'BATCH_SIZE': BATCH_SIZE,
    'WEIGHT_DECAY' : WEIGHT_DECAY,
    'LEARNING_RATE' : LEARNING_RATE,
}


if 'device' not in globals():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using {device}')


In [None]:
import wandb
import os
# if want to run in offline mode

# os.environ["WANDB_MODE"] = "dryrun"
# wandb.init(project="Transformer_bible", entity="jiwon7258")


os.environ["WANDB_MODE"] = "online"
wandb.init(project="Transformer_bible", entity="jiwon7258", config = CONFIG, job_type = 'train')
wandb.run.name = f"train_{VOCAB_SIZE}_{SEQ_LEN}_{N}_{HIDDEN_DIM}_{INNER_DIM}"


dataset = wandb.Artifact(f'bible-dataset_{VOCAB_SIZE}_{SEQ_LEN}', type='dataset')



  # Load

In [None]:
dataset = wandb.run.use_artifact(f'bible-dataset_{VOCAB_SIZE}_{SEQ_LEN}:latest')

# Download the artifact's contents
artifact_dir = dataset.download()
    


In [None]:
# or Train / Valid Data
src_train_path = os.path.join(artifact_dir,'src_train.pkl')
src_valid_path = os.path.join(artifact_dir,'src_valid.pkl')
trg_train_path = os.path.join(artifact_dir,'trg_train.pkl')
trg_valid_path = os.path.join(artifact_dir,'trg_valid.pkl')

src_train = joblib.load(src_train_path)
src_valid = joblib.load(src_valid_path)
trg_train = joblib.load(trg_train_path)
trg_valid = joblib.load(trg_valid_path)



   # Create Dataset

In [None]:
class TrainDataset(Dataset):
    def __init__(self, src_data, trg_data):
        super().__init__()

        assert len(src_data) == len(trg_data)

        self.src_data = src_data
        self.trg_data = trg_data

    def __len__(self):
        return len(self.src_data)
        
    def __getitem__ (self, idx):
        src = self.src_data[idx]
        trg_input = self.trg_data[idx]
        trg_output = trg_input[1:SEQ_LEN]
        trg_output = np.pad(trg_output, (0,1), 'constant', constant_values =0)
        # (seq_len,)
        return torch.Tensor(src).long(), torch.Tensor(trg_input).long(), torch.Tensor(trg_output).long()

train_dataset = TrainDataset(src_train, trg_train)
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle= True, pin_memory=True)




In [None]:
class ValidDataset(Dataset):
    def __init__(self, src_data, trg_data):
        super().__init__()

        assert len(src_data) == len(trg_data)

        self.src_data = src_data
        self.trg_data = trg_data

    def __len__(self):
        return len(self.src_data)
        
    def __getitem__ (self, idx):
        src = self.src_data[idx]
        trg_input = self.trg_data[idx]
        trg_output = trg_input[1:SEQ_LEN]
        trg_output = np.pad(trg_output, (0,1), 'constant',constant_values= 0)

        return torch.Tensor(src).long(), torch.Tensor(trg_input).long(), torch.Tensor(trg_output).long()

valid_dataset = ValidDataset(src_valid, trg_valid)
valid_dataloader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle= False, pin_memory=True)




   # Transformer

   ## Mask Function

In [None]:
'''
Mask 행렬을 반환하는 Mask Function
Masking은 QK_T 중 srcK 의 seq_len을 중심으로 한다는 점을 알아두자!!

Input
- Tensor
    shape (bs, srcK seq_len)

Args
- Option
    If option is 'padding', function returns padding mask
    If option is 'lookahead', function returns lookahead mask

Output
- Tensor (option = 'padding' )
    shape (bs, 1, 1, srcK seq_len)


* shape 중 (1, 1) 부분은 broad casting을 위한 것이다.
'''


def makeMask(tensor, option: str) -> torch.Tensor:
    '''
    tensor (bs, seq_len)
    '''
    if option == 'padding':
        tmp = torch.full_like(tensor, fill_value=PAD_IDX).to(device)
        # tmp : (bs,seq_len)
        mask = (tensor != tmp).float()
        # mask : (bs, seq_len)
        mask = rearrange(mask, 'bs seq_len -> bs 1 1 seq_len ')

        # mask(bs, 1, seq_len,seq_len)

        '''
        Example of mask
        tensor([[
         [1., 1., 1., 1., 0., 0., 0., 0.]]])
        '''

    elif option == 'lookahead':
        # srcQ의 seq_len과 srcK의 seq_len이 동일하다고 가정한다
        # tensor : (bs, seq_len)

        padding_mask = makeMask(tensor, 'padding')
        padding_mask = repeat(
            padding_mask, 'bs 1 1 k_len -> bs 1 new k_len', new=padding_mask.shape[3])
        # padding_mask : (bs, 1, seq_len, seq_len)

        '''
        Example of padding_mask
        tensor([[
         [1., 1., 1., 1., 0., 0., 0., 0.]
         [1., 1., 1., 1., 0., 0., 0., 0.]
         [1., 1., 1., 1., 0., 0., 0., 0.]
         [1., 1., 1., 1., 0., 0., 0., 0.]
         [1., 1., 1., 1., 0., 0., 0., 0.]
         [1., 1., 1., 1., 0., 0., 0., 0.]
         [1., 1., 1., 1., 0., 0., 0., 0.]
         [1., 1., 1., 1., 0., 0., 0., 0.]]])
        '''
        mask = torch.ones_like(padding_mask)
        mask = torch.tril(mask)

        '''
        Example of 'mask'
        tensor([[
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1.]]])
        '''

        mask = mask * padding_mask
        # ic(mask.shape)

        '''
        Example
        tensor([[
         [1., 0., 0., 0., 0., 0., 0., 0.],
         [1., 1., 0., 0., 0., 0., 0., 0.],
         [1., 1., 1., 0., 0., 0., 0., 0.],
         [1., 1., 1., 1., 0., 0., 0., 0.],
         [1., 1., 1., 1., 0., 0., 0., 0.],
         [1., 1., 1., 1., 0., 0., 0., 0.],
         [1., 1., 1., 1., 0., 0., 0., 0.],
         [1., 1., 1., 1., 0., 0., 0., 0.]]])
        '''

    return mask



In [None]:
# test = torch.Tensor([[1,2,3,4,5,6,0,0,0,0]])
# ic(test.shape)
# test1 = makeMask(test, option = 'padding')
# test2 = makeMask(test, option = 'lookahead')
# ic(test1.shape)
# ic(test2.shape)



   ## Multihead Attention

In [None]:
class Multiheadattention(nn.Module):
    def __init__(self, hidden_dim: int, num_head: int):
        super().__init__()

        # embedding_dim, d_model, 512 in paper
        self.hidden_dim = hidden_dim
        # 8 in paper
        self.num_head = num_head
        # head_dim, d_key, d_query, d_value, 64 in paper (= 512 / 8)
        self.head_dim = hidden_dim // num_head
        self.scale = torch.sqrt(torch.FloatTensor()).to(device)

        self.fcQ = nn.Linear(hidden_dim, hidden_dim)
        self.fcK = nn.Linear(hidden_dim, hidden_dim)
        self.fcV = nn.Linear(hidden_dim, hidden_dim)
        self.fcOut = nn.Linear(hidden_dim, hidden_dim)

        self.dropout = nn.Dropout(0.1)


    def forward(self, srcQ, srcK, srcV, mask=None):

        ##### SCALED DOT PRODUCT ATTENTION ######

        # input : (bs, seq_len, hidden_dim)
        Q = self.fcQ(srcQ)
        K = self.fcK(srcK)
        V = self.fcV(srcV)

        Q = rearrange(
            Q, 'bs seq_len (num_head head_dim) -> bs num_head seq_len head_dim', num_head=self.num_head)
        K_T = rearrange(
            K, 'bs seq_len (num_head head_dim) -> bs num_head head_dim seq_len', num_head=self.num_head)
        V = rearrange(
            V, 'bs seq_len (num_head head_dim) -> bs num_head seq_len head_dim', num_head=self.num_head)
        
        attention_energy = torch.matmul(Q, K_T)
        # attention_energy : (bs, num_head, q_len, k_len)

        if mask is not None :
            '''
            mask.shape
            if padding : (bs, 1, 1, k_len)
            if lookahead : (bs, 1, q_len, k_len)
            '''
            attention_energy = torch.masked_fill(attention_energy, (mask == 0), -1e+4)
            
        attention_energy = torch.softmax(attention_energy, dim = -1)

        result = torch.matmul(self.dropout(attention_energy),V)
        # result (bs, num_head, seq_len, head_dim)

        ##### END OF SCALED DOT PRODUCT ATTENTION ######

        # CONCAT
        result = rearrange(result, 'bs num_head seq_len head_dim -> bs seq_len (num_head head_dim)')
        # result : (bs, seq_len, hidden_dim)

        # LINEAR

        result = self.fcOut(result)

        return result





In [None]:
# # TEST CODE #
# bs = 1
# seq_len = 10
# hidden_dim = 10
# # src = torch.Tensor([[   2, 1568,  955,  612,  221,   64,   20, 8905,  928, 8768,  167, 8841,
# #          3834,    9, 1687,   41, 7661,  562, 9073, 5204, 8794, 8931,   14, 8823,
# #          5616, 1289, 8793, 2477,  438,   27, 8783,   14, 8905,  534,  235,  204,
# #          9037, 8745, 9040, 6942,   47, 8738,    3,    0,    0,    0,    0,    0,
# #             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0]])
# src = torch.Tensor([[1,2,3,4,5,6,0,0,0,0]])

# padding_mask = makeMask(src, option = 'padding')
# ic(padding_mask.shape)
# lookahead_mask = makeMask(src, option = 'lookahead')
# ic(lookahead_mask.shape)


# test_Q = torch.randn((bs,2,10))
# test_K = rearrange(src, 'seq_len hidden_dim -> 1 seq_len hidden_dim')
# test_K = repeat(src, 'seq_len hidden_dim -> 1 (10 seq_len) hidden_dim')
# ic(test_Q.shape)
# ic(test_K.shape)
# test_layer = Multiheadattention(hidden_dim=hidden_dim, num_head =2)
# ic(test_layer(srcQ = test_Q, srcK = test_K, srcV = test_K, mask = padding_mask).shape)




   ## Poistionwise Feedforward Network

In [None]:
class FFN(nn.Module):
    def __init__ (self, hidden_dim, inner_dim):
        super().__init__()

        # 512 in paper 
        self.hidden_dim = hidden_dim
        # 2048 in paper
        self.inner_dim = inner_dim 

        self.fc1 = nn.Linear(hidden_dim, inner_dim)
        self.fc2 = nn.Linear(inner_dim, hidden_dim)
        self.relu = nn.ReLU(inplace=False)
        self.dropout = nn.Dropout(0.1)


        
    def forward(self, input):
        output = input
        output = self.fc1(output)
        output2 = self.relu(output)
        output2 = self.dropout(output)
        output3 = self.fc2(output2)

        return output3





   ## Encoder Layer

In [None]:
class EncoderLayer(nn.Module):
    def __init__(self, hidden_dim, num_head, inner_dim):
        super().__init__()

        self.hidden_dim = hidden_dim
        self.num_head = num_head
        self.inner_dim = inner_dim
        
        self.multiheadattention = Multiheadattention(hidden_dim, num_head)
        self.ffn = FFN(hidden_dim, inner_dim)
        self.layerNorm1 = nn.LayerNorm(hidden_dim)
        self.layerNorm2 = nn.LayerNorm(hidden_dim)


        self.dropout1 = nn.Dropout(p=0.1)
        self.dropout2 = nn.Dropout(p=0.1)


    def forward(self, input, mask = None):

        # input : (bs, seq_len, hidden_dim)
        
        # encoder attention
        # uses only padding mask
        output = self.multiheadattention(srcQ= input, srcK = input, srcV = input, mask = mask)
        output = self.dropout1(output)
        output = input + output
        output = self.layerNorm1(output)

        output_ = self.ffn(output)
        output_ = self.dropout2(output_)
        output = output + output_
        output = self.layerNorm2(output)

        # output : (bs, seq_len, hidden_dim)
        return output






   ## Encoder Architecture

In [None]:
class Encoder(nn.Module):
    def __init__ (self, N, hidden_dim, num_head, inner_dim,max_length=100):
        super().__init__()

        # N : number of encoder layer repeated 
        self.N = N
        self.hidden_dim = hidden_dim
        self.num_head = num_head
        self.inner_dim = inner_dim

        self.embedding = nn.Embedding(num_embeddings=VOCAB_SIZE, embedding_dim=hidden_dim, padding_idx=0)
        self.pos_embedding = nn.Embedding(max_length, hidden_dim)
        self.enc_layers = nn.ModuleList([EncoderLayer(hidden_dim, num_head, inner_dim) for _ in range(N)])

        self.dropout = nn.Dropout(p=0.1)



    def forward(self, input):
        
        batch_size = input.shape[0]
        seq_len = input.shape[1]
        # input : (bs, seq_len)

        mask = makeMask(input, option='padding')

        pos = torch.arange(0, seq_len).unsqueeze(0).repeat(batch_size, 1).to(device)
        # pos: [batch_size, src_len]

        # embedding layer
        output = self.dropout(self.embedding(input) + self.pos_embedding(pos))
        # output : (bs, seq_len, hidden_dim)


        # Positional Embedding
        # output = pos_embed(output)

        # Dropout
        output = self.dropout(output)

        # N encoder layer
        for layer in self.enc_layers:
            output = layer(output, mask)

        # output : (bs, seq_len, hidden_dim)

        return output





   ## Decoder Layer

In [None]:
class DecoderLayer(nn.Module):
    def __init__(self, hidden_dim, num_head, inner_dim):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_head = num_head
        self.inner_dim = inner_dim

        self.multiheadattention1 = Multiheadattention(hidden_dim, num_head)
        self.layerNorm1 = nn.LayerNorm(hidden_dim)
        self.multiheadattention2 = Multiheadattention(hidden_dim, num_head)
        self.layerNorm2 = nn.LayerNorm(hidden_dim)
        self.ffn = FFN(hidden_dim, inner_dim)
        self.layerNorm3 = nn.LayerNorm(hidden_dim)

        self.dropout1 = nn.Dropout(p=0.1)
        self.dropout2 = nn.Dropout(p=0.1)
        self.dropout3 = nn.Dropout(p=0.1)

    
    def forward(self, input, enc_output, paddingMask, lookaheadMask):
        # input : (bs, seq_len, hidden_dim)
        # enc_output : (bs, seq_len, hidden_dim)

        # first multiheadattention
        output = self.multiheadattention1(input, input, input, lookaheadMask)
        output = self.dropout1(output)
        output = output + input
        output = self.layerNorm1(output)


        # second multiheadattention
        output_ = self.multiheadattention2(output, enc_output, enc_output, paddingMask)
        output_ = self.dropout2(output_)
        output = output_ + output
        output = self.layerNorm2(output)



        # Feedforward Network
        output_ = self.ffn(output)
        output_ = self.dropout3(output_)
        output = output + output_
        output = self.layerNorm3(output)



        return output





   ## Decoder Architecture

In [None]:
class Decoder(nn.Module):
    def __init__ (self, N, hidden_dim, num_head, inner_dim, max_length=100):
        super().__init__()

        # N : number of encoder layer repeated 
        self.N = N
        self.hidden_dim = hidden_dim
        self.num_head = num_head
        self.inner_dim = inner_dim

        self.embedding = nn.Embedding(num_embeddings=VOCAB_SIZE, embedding_dim=hidden_dim, padding_idx=0)
        self.pos_embedding = nn.Embedding(max_length, hidden_dim)

        self.dec_layers = nn.ModuleList([DecoderLayer(hidden_dim, num_head, inner_dim) for _ in range(N)])

        self.dropout = nn.Dropout(p=0.1)
        
        self.finalFc = nn.Linear(hidden_dim, VOCAB_SIZE)


    def forward(self, input, enc_src, enc_output):

        # input = dec_src : (bs, seq_len)
        # enc_src : (bs, seq_len)
        # enc_output : (bs, seq_len,hidden_dim)
        
        lookaheadMask = makeMask(input, option= 'lookahead')
        paddingMask = makeMask(enc_src, option = 'padding')

        # embedding layer
        output = self.embedding(input)
        # output = (bs, seq_len, hidden_dim)


        # Positional Embedding
        # output = pos_embed(output)

        # Dropout
        output = self.dropout(output)

        # N decoder layer
        for layer in self.dec_layers:
            output = layer(output, enc_output, paddingMask, lookaheadMask)
        # output : (bs, seq_len, hidden_dim)

        logits = self.finalFc(output)
        # logits : (bs, seq_len, VOCAB_SIZE)
        output = torch.softmax(logits, dim = -1)

        output = torch.argmax(output, dim = -1)
        # output : (bs, seq_len), dtype=int64



        return logits, output





   ## Transformer Model

In [None]:
class Transformer(nn.Module):
    def __init__(self, N = 2, hidden_dim = 256, num_head = 8, inner_dim = 512):
        super().__init__()
        self.encoder = Encoder(N, hidden_dim, num_head, inner_dim)
        self.decoder = Decoder(N, hidden_dim, num_head, inner_dim)

    def forward(self, enc_src, dec_src):
        # enc_src : (bs, seq_len)
        # dec_src : (bs, seq_len)

        # print(f'enc_src : {enc_src.shape}')
        # print(f'dec_src : {dec_src.shape}')

        enc_output = self.encoder(enc_src)
        # enc_output : (bs, seq_len, hidden_dim)
        logits, output = self.decoder(dec_src, enc_src, enc_output)
        # logits = (bs, seq_len, VOCAB_SIZE) 

        return logits, output



   # Model Train

   # Create Model

In [None]:
model = Transformer(N, HIDDEN_DIM, NUM_HEAD, INNER_DIM).to(device)
ic.disable()



   # Check Model Structure

In [None]:
from torchsummary import summary
test1 = torch.randint(low = 0, high = 1000, size = (SEQ_LEN,))
test2 = torch.randint(low = 0, high = 1000, size = (SEQ_LEN,))
summary(model, [(SEQ_LEN,), (SEQ_LEN,)], dtypes = [torch.int, torch.int])




   # Weight Init

In [None]:
for param in model.named_parameters():
    if 'weight' in param[0] and 'layerNorm' not in param[0] :
        torch.nn.init.xavier_uniform_(param[1])

   # Optimizer

In [None]:
optimizer = torch.optim.Adam(params = model.parameters(), lr = LEARNING_RATE, weight_decay = WEIGHT_DECAY)




   # Loss Function

In [None]:
def criterion(logits: torch.tensor, targets: torch.tensor):
    return nn.CrossEntropyLoss(ignore_index=PAD_IDX)(logits.view(-1,VOCAB_SIZE), targets.view(-1))




   # Training Function

In [None]:
def train_one_epoch(model, optimizer, scheduler, dataloader, device, epoch):
    # train 모드로 변경
    model.train()

    # for the Mixed Precision
    # Pytorch 예제 : https://pytorch.org/docs/stable/notes/amp_examples.html#amp-examples
    if(FP16):
        scaler = amp.GradScaler()

    dataset_size = 0
    running_loss = 0
    running_accuracy = 0
    accuracy = 0

    bar = tqdm(enumerate(dataloader), total=len(dataloader))

    for step, (src, trg_input, trg_output) in bar:
        src = src.to(device)
        trg_input = trg_input.to(device)
        trg_output = trg_output.to(device)

        batch_size = src.shape[0]

        if(FP16):
            with amp.autocast(enabled=True):
                logits, output = model(enc_src=src, dec_src=trg_input)
                loss = criterion(logits, trg_output)

                # loss를 Scale
                # Scaled Grdients를 계산(call)하기 위해 scaled loss를 backward()
                scaler.scale(loss).backward()
                # scaler.step() first unscales the gradients of the optimizer's assigned params.
                # If these gradients do not contain infs or NaNs, optimizer.step() is then called,
                # otherwise, optimizer.step() is skipped.
                scaler.step(optimizer)

                # Updates the scale for next iteration.
                scaler.update()

        else:
            logits, output = model(enc_src=src, dec_src=trg_input)
            loss = criterion(logits, trg_output)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
            optimizer.step()

        # logits (bs, seq_len, VOCAB_SIZE)
        # trg_output (bs, seq_len)

        # zero the parameter gradients
        optimizer.zero_grad()

        # change learning rate by Scheduler
        if scheduler is not None:
            scheduler.step()

        # loss.item()은 loss를 Python Float으로 반환
        # loss.item()은 batch data의 average loss이므로, sum of loss를 구하기 위해 batch_size를 곱해준다
        running_loss += loss.item() * batch_size
        running_accuracy = np.mean(
            output.view(-1).detach().cpu().numpy() == trg_output.view(-1).detach().cpu().numpy())

        accuracy += running_accuracy

        dataset_size += batch_size
        epoch_loss = running_loss / dataset_size

        bar.set_postfix(
            Epoch=epoch, Train_Loss=epoch_loss, LR=optimizer.param_groups[0]["lr"], accuracy=accuracy / np.float(
                step+1)
        )

        # break

    accuracy /= len(dataloader)
    # Garbage Collector
    gc.collect()

    return epoch_loss, accuracy


   # Validation Function

In [None]:
@torch.no_grad()
def valid_one_epoch(model, dataloader, device, epoch):
    model.eval()

    dataset_size = 0
    running_loss = 0
    accuracy = 0

    bar = tqdm(enumerate(dataloader), total=len(dataloader))

    for step, (src, trg_input, trg_output) in bar:
        src = src.to(device)
        trg_input = trg_input.to(device)
        trg_output = trg_output.to(device)

        batch_size = src.shape[0]

        logits, output = model(enc_src = src, dec_src = trg_input)
        loss = criterion(logits, trg_output)

        running_loss += loss.item() * batch_size
        dataset_size += batch_size

        # 실시간으로 정보를 표시하기 위한 epoch loss
        val_loss = running_loss / dataset_size
        running_accuracy = np.mean(output.view(-1).detach().cpu().numpy() == trg_output.view(-1).detach().cpu().numpy())
        
        accuracy += running_accuracy

        bar.set_postfix(
            Epoch=epoch, Valid_Loss=val_loss, LR=optimizer.param_groups[0]["lr"], accuracy = accuracy / np.float(step + 1)
        )

        # break

    accuracy /= len(dataloader)

    gc.collect()

    return val_loss, accuracy




In [None]:
def run_training(
    model,
    optimizer,
    scheduler,
    device,
    num_epochs,
    metric_prefix="",
    file_prefix="",
    early_stopping=True,
    early_stopping_step=10,
):
    # To automatically log graidents
    wandb.watch(model, log_freq=100)

    if torch.cuda.is_available():
        print("[INFO] Using GPU:{}\n".format(torch.cuda.get_device_name()))

    start = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_loss = np.inf
    history = defaultdict(list)
    early_stop_counter = 0

    # num_epochs만큼, train과 val을 실행한다
    for epoch in range(1, num_epochs + 1):
        gc.collect()

        train_epoch_loss, train_accuracy = train_one_epoch(
            model,
            optimizer,
            scheduler,
            dataloader= train_dataloader,
            device=device,
            epoch=epoch,
        )

        val_loss, val_accuracy = valid_one_epoch(
            model, valid_dataloader, device=device, epoch=epoch
        )

        history[f"{metric_prefix}Train Loss"].append(train_epoch_loss)
        history[f"{metric_prefix}Train Accuracy"].append(train_accuracy)
        history[f"{metric_prefix}Valid Loss"].append(val_loss)
        history[f"{metric_prefix}Valid Accuracy"].append(val_accuracy)


        # Log the metrics
        wandb.log(
            {
                f"{metric_prefix}Train Loss": train_epoch_loss,
                f"{metric_prefix}Valid Loss": val_loss,
                f"{metric_prefix}Train Accuracy" : train_accuracy,
                f"{metric_prefix}Valid Accuracy" : val_accuracy,
            }
        )

        print(f"Valid Loss : {val_loss}")

        # deep copy the model
        if val_loss <= best_loss:
            early_stop_counter = 0

            print(
                f"Validation Loss improved( {best_loss} ---> {val_loss}  )"
            )

            # Update Best Loss
            best_loss = val_loss
            
            # Update Best Model Weight
            # run.summary['Best RMSE'] = best_loss
            best_model_wts = copy.deepcopy(model.state_dict())

            PATH = "{}epoch{:.0f}_Loss{:.4f}.bin".format(file_prefix, epoch, best_loss)
            torch.save(model.state_dict(), PATH)
            torch.save(model.state_dict(), f"{file_prefix}best_{epoch}epoch.bin")
            # Save a model file from the current directory
            wandb.save(PATH)

            print(f"Model Saved")

        elif early_stopping:
            early_stop_counter += 1
            if early_stop_counter > early_stopping_step:
                break
        
        # break


    end = time.time()
    time_elapsed = end - start
    print(
        "Training complete in {:.0f}h {:.0f}m {:.0f}s".format(
            time_elapsed // 3600,
            (time_elapsed % 3600) // 60,
            (time_elapsed % 3600) % 60,
        )
    )
    print("Best Loss: {:.4f}".format(best_loss))

    # load best model weights
    model.load_state_dict(best_model_wts)

    return model, history




In [None]:
# wandb.restore('epoch2_Loss7.8590.bin',
#               run_path='jiwon7258/Transformer_enko/12t3zwp8',
#               root='./')
# model.load_state_dict(torch.load('epoch2_Loss7.8590.bin'))


 # Run training

In [None]:
run_training(
    model = model,
    optimizer = optimizer,
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=100, eta_min=1e-5),
    device = device,
    num_epochs = 2000,
    metric_prefix="",
    file_prefix="",
    early_stopping=True,
    early_stopping_step=10,
)

In [None]:
torch.save(model.state_dict(), 'final.bin')
wandb.save('final.bin')

In [None]:
wandb.finish()