In [1]:
! pip install transformers
! pip install kobert-transformers
! pip install sentencepiece

Collecting transformers
[?25l  Downloading https://files.pythonhosted.org/packages/98/87/ef312eef26f5cecd8b17ae9654cdd8d1fae1eb6dbd87257d6d73c128a4d0/transformers-4.3.2-py3-none-any.whl (1.8MB)
[K     |████████████████████████████████| 1.8MB 5.6MB/s 
Collecting tokenizers<0.11,>=0.10.1
[?25l  Downloading https://files.pythonhosted.org/packages/fd/5b/44baae602e0a30bcc53fbdbc60bd940c15e143d252d658dfdefce736ece5/tokenizers-0.10.1-cp36-cp36m-manylinux2010_x86_64.whl (3.2MB)
[K     |████████████████████████████████| 3.2MB 20.0MB/s 
Collecting sacremoses
[?25l  Downloading https://files.pythonhosted.org/packages/7d/34/09d19aff26edcc8eb2a01bed8e98f13a1537005d31e95233fd48216eed10/sacremoses-0.0.43.tar.gz (883kB)
[K     |████████████████████████████████| 890kB 35.4MB/s 
Building wheels for collected packages: sacremoses
  Building wheel for sacremoses (setup.py) ... [?25l[?25hdone
  Created wheel for sacremoses: filename=sacremoses-0.0.43-cp36-none-any.whl size=893261 sha256=aea61eba5fd

In [2]:
from tokenizers import Tokenizer
import numpy as np
import re
import torchtext
import torch
from torch.utils.data import DataLoader, TensorDataset
import pandas as pd
from pandas import DataFrame as df
from matplotlib import pyplot as plt
import seaborn as sns
from tqdm import tqdm
import time
from google.colab import drive
import os
drive.mount('/content/gdrive')
os.chdir('./gdrive/My Drive/기상청')

Mounted at /content/gdrive


In [3]:
#  Data Load
train_data = pd.read_pickle('./data/train_data')
val_data = pd.read_pickle('./data/val_data')
test_data = pd.read_pickle('./data/test_data')

In [4]:
# 숫자는 [NUM]으로 치환
train_data['Total']=train_data['Total'].apply(lambda i : re.sub('[0-9]+','[NUM]',i))
val_data['Total']=val_data['Total'].apply(lambda i : re.sub('[0-9]+','[NUM]',i))
test_data['Total']=test_data['Total'].apply(lambda i : re.sub('[0-9]+','[NUM]',i))

In [5]:
# Tokenizer 학습 시킬 txt 만들기
train_tokenizer_text = open('./train_tokenizer.txt','w',encoding='utf-8')
for i in train_data['Total']:
    t=''
    for j in i.split('.'):
        train_tokenizer_text.write(j.strip()+'\n') 

# WordPieceTokenizer Train

In [6]:
from tokenizers import BertWordPieceTokenizer
# 앞서 제작한 텍스트 파일 활용해 토크나이저 훈련
tokenizer = BertWordPieceTokenizer()
corpus_file   = ['train_tokenizer.txt']  # data path
vocab_size    = 8000 # vocab size
limit_alphabet= 6000
output_path   = 'hugging_%d'%(vocab_size)
min_frequency = 3 # 3회 이상 등장한 pair만 등장함
special_tokens=['[CLS]','[SEP]','[BOS]', '[EOS]','[UNK]','[PAD]','[MASK]','[NUM]']  # 스페셜 토큰

In [7]:
# Then train it!
tokenizer.train(files=corpus_file,
               vocab_size=vocab_size,
               min_frequency=min_frequency,  # 단어의 최소 발생 빈도, 3
               limit_alphabet=limit_alphabet,
               show_progress=True,
               special_tokens=special_tokens
               )

# And finally save it somewhere
tokenizer.save("./%s_%s.json"%(str('bert_tokenizer'), vocab_size))

# 적정 수준의 seq len를 위한 EDA    
2048수준으로 상정  

In [8]:
train_data['X']=train_data['Total'].apply(lambda i : tokenizer.encode(i).ids)
train_data['len']=train_data['X'].apply(lambda i : len(i))

In [9]:
np.percentile(train_data['len'].values,[50,75,95,99]) # 628.  , 1053.  , 2053.  , 3586.12

array([ 628.  , 1053.  , 2053.  , 3586.12])

In [10]:
np.percentile(train_data.loc[train_data['damage']==1,'len'].values,[50,75,95,99]) # 681.  , 1019.  , 1651.1 , 2853.42
# 2048 수준으로 해서 진행

array([ 681.  , 1019.  , 1651.1 , 2853.42])

In [11]:
(train_data.loc[train_data['damage']==1,'len']<2046).sum()/len(train_data.loc[train_data['damage']==1]) # 97.3

0.9734693877551021

# VAE 용 data 만들기

In [12]:
class Config(dict): 
    __getattr__ = dict.__getitem__
    __setattr__ = dict.__setitem__
config = Config({'n_layers': 2, 'n_head': 4, 'bidirectional':True,'embedding_dim': 128, 'd_model': 128, 'latent_dim': 64, 'hidden_dim': 128*4, 'seq_len': 2048, 'batch_size': 8, 'dropout': 0.1, 'max_len': 9999,'n_vocab':8000})

In [13]:
# tokenizer
tokenizer = Tokenizer.from_file("./bert_tokenizer_%s.json"%(config.n_vocab))

In [14]:
config['padding_idx']=tokenizer.token_to_id('[PAD]')
config['eos_idx']=tokenizer.token_to_id('[EOS]')
config['bos_idx']=tokenizer.token_to_id('[BOS]')
config['unk_idx']=tokenizer.token_to_id('[UNK]')
config['mask_idx']=tokenizer.token_to_id('[MASK]')
config['cls_idx']=tokenizer.token_to_id('[CLS]')
config['num_idx']=tokenizer.token_to_id('[NUM]')

In [15]:
# 피해인 데이터만 선택
train_data = train_data.loc[train_data['damage']==1,:]
val_data = val_data.loc[val_data['damage']==1,:]
test_data = test_data.loc[test_data['damage']==1,:]
train_data['tokenized']=('[BOS]'+train_data['Total']+'[EOS]').apply(lambda i : tokenizer.encode(i).ids)
val_data['tokenized']=('[BOS]'+val_data['Total']+'[EOS]').apply(lambda i : tokenizer.encode(i).ids)
test_data['tokenized']=('[BOS]'+test_data['Total']+'[EOS]').apply(lambda i : tokenizer.encode(i).ids)

In [16]:
def padding(sentence):
    '''
    ids padded
    '''
    if len(sentence)>config.seq_len:
        result = sentence[:config.seq_len-1]
        result.append(config.eos_idx)
    elif len(sentence)<config.seq_len:
        result = sentence
        result.extend([config.padding_idx]*(config.seq_len-len(sentence))) 
    else:
        result = sentence       
    return result

In [17]:
# padding 씌우기
train_data['src']=train_data['tokenized'].apply(lambda i : padding(i))
val_data['src']=val_data['tokenized'].apply(lambda i : padding(i))
test_data['src']=test_data['tokenized'].apply(lambda i : padding(i))

In [18]:
print(len(train_data)) # 980
print(len(val_data)) # 136
print(len(test_data)) # 183

980
136
183


In [19]:
# TensorDataset으로 묶기
train_data_set = TensorDataset(torch.LongTensor(train_data.src.tolist()))
val_data_set = TensorDataset(torch.LongTensor(val_data.src.tolist()))
test_data_set = TensorDataset(torch.LongTensor(test_data.src.tolist()))
# DataLoader로 묶기
train_loader = DataLoader(train_data_set,batch_size = config.batch_size)
val_loader = DataLoader(val_data_set,batch_size = config.batch_size)
test_loader = DataLoader(test_data_set,batch_size = config.batch_size)

## models and tools

In [20]:
# -*- coding: utf-8 -*-
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
import math

def preprocessing(data, config, device, which):
    '''
    data : (bs, seq len)
    src : BOS Sentence EOS -> CLS Sentence
    tgt input : BOS Sentence
    tgt output : Sentence EOS
    '''
    batch_size = data[0].size(0)
    src = data[0][:,1:].to(device)
    src = src.masked_fill(src.eq(config.eos_idx),config.padding_idx)
    src = torch.cat([torch.full((batch_size,1),config.cls_idx,device=device),src],dim=-1).long()
    tgt_input = data[0].to(device)
    tgt_input = tgt_input.masked_fill(tgt_input.eq(config.eos_idx),config.padding_idx).long()
    tgt_output = data[0][:,1:].to(device)
    tgt_output = torch.cat([tgt_output,torch.full((batch_size,1),config.padding_idx,device=device)],dim=-1).long()
    return batch_size, src, tgt_input, tgt_output

def ce_kl(decoder_output, input, mu, log_var, kl_annealing, config): 
    '''
    decoder output : batch size, seq_len, n_vocab
    input : batch size, seq len
    mu : batch size, latent_dim
    log var : batch size, latent_dim
    '''
    # cross entropy : [bs, Class, d1,d2,.. dk] ~ [bs, Class]
    CE=F.cross_entropy(decoder_output.transpose(1,2),input,ignore_index=config.padding_idx,reduction='sum') # batch size*seq_len에 대해서 나누지 않는다.
    KL=-(0.5*(log_var+1.-log_var.exp()-mu.pow(2)).sum(-1)).sum()
    CE = CE
    KL = KL
    return CE,KL
    
def calculate_loss(model,data_loader,config,device,which):
    with torch.no_grad():
        model.eval()
        total_loss = 0.
        total_ce = 0.
        total_kl = 0.
        num = 0
        for data in data_loader:
            bs,src,tgt_input,tgt_output = preprocessing(data, config, device,which)
            output,mu,log_var=model(src,tgt_input)
            CE,KL=ce_kl(output,tgt_output,mu,log_var,1.,config)
            total_ce+=CE.item()
            total_kl+=KL.item()
            num+=bs
        avg_ce = total_ce/num
        avg_kl = total_kl/num
    return avg_ce, avg_kl

def linear_anneal_function(step, total_step):
    return min(1, (step-1)/total_step)

def cycle_anneal_function(step, total_step, n_cycle = 4, ratio = 0.5):
    t = (step-1)%(total_step//n_cycle)/(total_step/n_cycle)
    f_t = t/ratio if t<=ratio else 1
    return f_t

def total_loss(decoder_output, input, mu, log_var, kl_annealing, config): 
    '''
    decoder output : batch size,seq_len, n_vocab
    input : batch size, seq len
    mu : batch size, latent_dim
    log var : batch size, latent_dim
    '''
    batch_size = decoder_output.size(0)
    # cross entropy : [bs, Class, d1,d2,.. dk] ~ [bs, Class]
    CE=F.cross_entropy(decoder_output.transpose(1,2),input,ignore_index=config.padding_idx,reduction='sum') # batch size*seq_len에 대해서 나누지 않는다.
    KL=-(0.5*(log_var+1.-log_var.exp()-mu.pow(2)).sum(-1)).sum()
    CE = CE/batch_size
    KL = KL/batch_size
    total=(CE+kl_annealing*KL)
    return total,CE,KL

class PositionalEncoding(nn.Module):
    def __init__(self, config):
        super(PositionalEncoding, self).__init__()
        self.config=config
        self.dropout = nn.Dropout(p=self.config.dropout)
        self.pe = torch.zeros(config.max_len, config.d_model)
        position = torch.arange(0, config.max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, config.d_model, 2).float() * (-math.log(10000.0) / config.d_model))
        self.pe[:, 0::2] = torch.sin(position * div_term)
        self.pe[:, 1::2] = torch.cos(position * div_term)
        self.pe = self.pe.unsqueeze(0).transpose(0,1) # max len, d model -> 1, max len, d_model -> max len, 1, d_model
    def forward(self, x):
        '''
        x shape : seq len, batch size, d model
        '''
        self.pe=self.pe.to(x.device)
        x = x + self.pe[:x.size(0),:,:] # 후 항 shape : seq len, 1, d model
        return self.dropout(x)

class Transformer_Encoder_Base(nn.Module): # 현재 활용 중
    def __init__(self, config):
        super().__init__()
        self.config=config
        encoderlayer = nn.TransformerEncoderLayer(self.config.d_model,self.config.n_head,self.config.hidden_dim,self.config.dropout,activation='gelu')
        self.embedding = nn.Embedding(self.config.n_vocab,self.config.d_model,padding_idx=self.config.padding_idx)
        self.positional_embedding = PositionalEncoding(self.config)
        self.Encoder = nn.TransformerEncoder(encoderlayer,self.config.n_layers)
        self.dmodel2latent = nn.Linear(self.config.d_model, 2*self.config.latent_dim)

    def subsquent_mask(self,src):
        '''
        src shape : src_seq_len, batch size
        out put shape : src_seq_len, src_seq_len (mask씌울 부분을 -1e9, 아닌 부분을 0)
        additive mask
        '''
        s=src.size(0)
        mask=torch.triu(torch.ones((s,s)),1)==1
        mask=mask.float().masked_fill(mask,-1e9) # True 인 부분을 -1e9로 채운다.
        return mask
 
    def padding_mask(self,src):
        '''
        src에서 padding idx와 같은 부분을 masking 씌운다.
        src shape : seq_len, batch size
        out put shape : batch size, seq_len (mask씌울 부분을 True, 아닌 부분을 FALSE)
        # 나중에 TRUE인 부분에 -inf를 취한다(아마 softmax 이전! )
        '''
        out=src.eq(self.config.padding_idx)
        return out.T 
 
    def reparametrization(self,mu,log_var):
        '''
        mu, log_var shape : batch size, latent_dim
        e ~ N_k(0,I_k) shape : batch size, latent_dim
        in paper  
        cost(x_i) = -Eq(z|x_i)[log(p(x_i|z))]+KLD(q(z|x_i)|p(z))
        Eq(z|x_i)[log(p(x_i|z))]=1/L(sigma(log(p(x_i|z_i,l)))) l = 1~L "Monte Carlo Expectation estimate
        z_i,l = mu + e_i,l * sigma 
        '''
        sigma=torch.exp(0.5*log_var) # batch size, latent_dim
        e=torch.randn_like(sigma,device=sigma.device) # batch size, latent_dim # strictly speaking, e shape : seq_len, batch size, L, latent dim
        z=mu+sigma*e # z shape : batch size, latent_dim
        return z
 
    def forward(self, src):
        # src : sentence
        device=src.device
        src = src.T # batch size, seq len  -> seq len, batch size
        src_key_padding_mask = self.padding_mask(src).to(device)
        
        # encoder 부
        # src [CLS] sentence
        s1 = self.embedding(src)
        src_out = s1 * math.sqrt(self.config.d_model)
        src_out = self.positional_embedding(src_out) # seq len, batch size, d model
        out = self.Encoder.forward(src_out, src_key_padding_mask=src_key_padding_mask) # seq len, batch size, d_model
        # out shape : (seq len, batch size, d model)
        
        # z
        CLS = out[0,:,:] # batch_size, d_model
        z_para=self.dmodel2latent.forward(CLS).unsqueeze(-1).reshape(-1,self.config.latent_dim,2) # batch size, 2*latent_dim -> batch size, latent_dim, 2
        mu,log_var=z_para[:,:,0],z_para[:,:,1] # batch size, latent_dim
        z=self.reparametrization(mu,log_var) # batch size, latent_dim
        return z, mu, log_var

class Transformer_Decoder_Base(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config=config
        encoderlayer = nn.TransformerEncoderLayer(self.config.d_model,self.config.n_head,self.config.hidden_dim,self.config.dropout,activation='gelu')
        self.embedding = nn.Embedding(self.config.n_vocab,self.config.d_model,padding_idx=self.config.padding_idx)
        self.positional_embedding = PositionalEncoding(self.config)
        self.Decoder = nn.TransformerEncoder(encoderlayer,self.config.n_layers)
        self.latent2dmodel = nn.Linear(self.config.latent_dim,self.config.d_model)
        self.fc = nn.Linear(self.config.d_model,self.config.n_vocab)
        
    def subsquent_mask(self,src):
        '''
        src shape : src_seq_len, batch size
        out put shape : src_seq_len, src_seq_len (mask씌울 부분을 -1e9, 아닌 부분을 0)
        additive mask
        '''
        s=src.size(0)
        mask=torch.triu(torch.ones((s,s)),1)==1
        mask=mask.float().masked_fill(mask,-1e9) # True 인 부분을 -1e9로 채운다.
        return mask
 
    def padding_mask(self,src):
        '''
        src에서 padding idx와 같은 부분을 masking 씌운다.
        src shape : seq_len, batch size
        out put shape : batch size, seq_len (mask씌울 부분을 True, 아닌 부분을 FALSE)
        # 나중에 TRUE인 부분에 -inf를 취한다(아마 softmax 이전! )
        '''
        out=src.eq(self.config.padding_idx)
        return out.T 
 
    def forward(self, encoder_output, tgt): #, word_dropout=False):
        '''
        Parameters
        ----------
        encoder_output : Shape (bs, latent_dim)
            DESCRIPTION. z
        tgt : Shape (bs, seq len, d model)
            DESCRIPTION. [BOS] sentence
        label : Shape (bs, )
            DESCRIPTION. 1 - negative, 2 - neutral, 3 - positive
        word_dropout : TYPE, optional
            DESCRIPTION. The default is False.
        '''
        
        device=tgt.device
        tgt = tgt.T
        tgt_key_padding_mask = self.padding_mask(tgt).to(device)
        tgt_mask = self.subsquent_mask(tgt).to(device)
        # memory z
        memory_z=self.latent2dmodel(encoder_output) # batch size, d_model
        memory_z = memory_z.unsqueeze(1).repeat(1,self.config.seq_len, 1).transpose(0,1).contiguous() # batch size, seq_len, d_model -> seq_len, batch size, d_model
                
        # tgt
        tgt_input = tgt
        t1 = self.embedding(tgt_input) 
        tgt_out = t1 * math.sqrt(self.config.d_model) # seq_len, batch size, d_model
        tgt_out = self.positional_embedding(tgt_out) # seq_len, batch size, d model
        tgt_output = self.Decoder.forward(src=tgt_out,mask=tgt_mask,src_key_padding_mask=tgt_key_padding_mask)
        # seq len, batch size, d model
        
        tgt_output = memory_z + tgt_output
        output = self.fc(tgt_output) # seq_len, bs, d_model*2 -> seq_len, bs, n_vocab
        return output.transpose(0,1)  # batch size, seq_len, n_vocab
 
class VAE_Transformer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config=config
        self.encoder = Transformer_Encoder_Base(config)
        self.decoder = Transformer_Decoder_Base(config)
        
    def reparametrization(self,mu,log_var):
        '''
        mu, log_var shape : batch size, latent_dim
        e ~ N_k(0,I_k) shape : batch size, latent_dim
        in paper  
        cost(x_i) = -Eq(z|x_i)[log(p(x_i|z))]+KLD(q(z|x_i)|p(z))
        Eq(z|x_i)[log(p(x_i|z))]=1/L(sigma(log(p(x_i|z_i,l)))) l = 1~L "Monte Carlo Expectation estimate
        z_i,l = mu + e_i,l * sigma 
        '''
        sigma=torch.exp(0.5*log_var) # batch size, latent_dim
        e=torch.randn_like(sigma,device=sigma.device) # batch size, latent_dim # strictly speaking, e shape : seq_len, batch size, L, latent dim
        z=mu+sigma*e # z shape : batch size, latent_dim
        return z
 
    def forward(self, src, tgt):#, word_dropout=False):
        # src : sentence
        # tgt input : [BOS] sentence
        # tgt output  : sentence [SEP]
        
        memory_z, mu, log_var = self.encoder.forward(src)
        output = self.decoder.forward(memory_z,tgt)#,word_dropout)
        return output, mu, log_var

    def generation(self, method, device, beam_width = 5, alpha = 0.5, beta = 1):
        '''
        generation - greedy, beam search
        '''
        max_length = self.config.seq_len
        
        if method == 'greedy':
            with torch.no_grad():
                length = 0
                z = torch.randn((self.config.batch_size, self.config.latent_dim),device = device)
                bs = z.size(0)
                result=torch.full((bs,1),self.config.bos_idx,device = device) #  bs
                while length < max_length-1: # -1 : because of eos token
                    tgt = torch.full((bs,self.config.seq_len), self.config.padding_idx, device = device) # batch size, seq len
                    tgt[:,:length+1]=result
                    ### [[bos pad pad ... pad],[bos pad pad ... pad],...,[bos pad pad ... pad]] 로 ######
                    output = self.decoder.forward(z, tgt)# bs, seq_len, n_vocab
                    appending = output[:,length,:].argmax(-1).unsqueeze(-1) # bs -> bs,1
                    result = torch.hstack([result, appending])
                    length+=1
            return result
        
        if method == 'top_k':
            with torch.no_grad():
                length = 0
                z = torch.randn((self.config.batch_size, self.config.latent_dim),device=device)
                bs = z.size(0)
                result=torch.full((bs,1),self.config.bos_idx,device = device) #  bs
                while length < max_length-1:
                    tgt = torch.full((bs,self.config.seq_len),self.config.padding_idx,device = device) # batch size, seq len
                    tgt[:,:length+1]=result
                    ### eos pad pad ... pad 로 ######
                    output = self.decoder.forward(z, tgt) # batch size, seq_len, n_vocab
                    appending = output[:,length,:]# batch size, n_vocab
                    filtered_logits = top_k_filtering(appending,topk)
                    probs = F.softmax(filtered_logits/t, dim=-1) # t : temperature
                    next_token = torch.multinomial(probs,1) # batch size, 1
                    result = torch.hstack([result, next_token])
                    length+=1
            return result

        if method == 'beamsearch':
            with torch.no_grad():
                z = torch.randn((self.config.batch_size, self.config.latent_dim),device = device)
                # batch size, seq_len, d_model -> seq_len, batch size, d_model
                bs = z.size(0)
                device = z.device
                Batch_sequences=[]
                for b in range(bs): # batch 별로
                    length = 0
                    batch_sequences=[]
                    sequences = [[[self.config.bos_idx],0.0]] # beam width, sentence, log prob
                    while length < max_length-1: # because of eos token
                        all_candidates=[]
                        for i in range(len(sequences)): # beam_width
                            seq, score = sequences[i] # sequence and score
                            tgt = torch.full((1, self.config.seq_len), self.config.padding_idx, device = device) # 1, seq len <- batch 별로 하니깐
                            result = torch.LongTensor(seq).unsqueeze(0).to(device) # 1, seq len
                            tgt[:,:length+1]=result
                            ### sos pad pad ... pad 로 ######
                            output = self.decoder.forward(z[b].unsqueeze(0),tgt) # 1, seq_len, n_vocab
                            appending = output[:,length,:] # batch size, n_vocab
                            appending = F.softmax(appending,-1) # 확률로 변환 bs, n_vocab
                            penalty = ((1+length)/(1+beta))**alpha
                            for j in range(self.config.n_vocab):
                                candidate = [seq+[j], score+penalty*math.log(appending[0][j])] # 0 : batch 별로 하기에
                                all_candidates.append(candidate)
                        ordered = sorted(all_candidates, key = lambda t : t[1], reverse = True) # beam width * n_vocab 중에서 가장 작은 값 top beam width개
                        sequences = ordered[:beam_width]
                        length+=1
                        batch_sequences.append(sequences)
                    Batch_sequences.append(batch_sequences)
            return Batch_sequences 

In [21]:
def train(which, ds, epochs, annealing_strategy, device, k):
    '''
    which : model type
    ds : dataset
    epochs : epochs
    annealing_strategy : annealing_strategy
    device : device
    k : kl annealing step when it becomes 1. or freq
    '''
    if which == 'transformer':
        model = VAE_Transformer(config) 
    model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.1)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=k, T_mult=1, eta_min=1e-7, last_epoch=-1, verbose=False)
    epochs = epochs
    annealing_strategy = annealing_strategy
    description = 'dataset_%s__model_%s__anneal_%s(%s)'%(ds, which, annealing_strategy, k)
    #########################################################
                        # train
    
    #########################################################
    start_time = time.time()
    step = 0
    # train epoch
    TOTAL_loss = []
    CE_loss = []
    KL_loss = []
    kl_anneal = []
    # val epoch
    VAL_TOTAL_loss = []
    VAL_CE_loss = []
    VAL_KL_loss = []
    # min value를 위함
    check_count = 0
    check_model = None
    min_value_recon = None
    min_value_epoch = None
    for epoch in tqdm(range(1,1+epochs),desc='epoch',mininterval=60):
        model.train()
        Total_loss=0
        Total_cross_entropy=0
        Total_kl=0
        num=0
        for data in train_loader:
            step += 1
            optimizer.zero_grad()
            # src : batch size, seq len
            batch_size,src,tgt_input,tgt_output = preprocessing(data, config, device,which)
            if annealing_strategy == 'linear':
                anneal = linear_anneal_function(step, len(train_loader)*k)
            elif annealing_strategy == 'cycle':
                anneal = cycle_anneal_function(step, len(train_loader)*epochs, n_cycle = epochs//k, ratio = 0.5)    
            else:
                anneal = annealing_strategy
            output,mu,log_var = model.forward(src,tgt_input)
            CE,KL=ce_kl(output, tgt_output,mu,log_var,anneal,config)
            loss  = CE + KL * anneal
            loss.backward()
            optimizer.step()
            Total_cross_entropy+=CE.item()
            Total_kl+=KL.item()
            num+=batch_size
        CE_loss.append(Total_cross_entropy/num)
        KL_loss.append(Total_kl/num)
        kl_anneal.append(anneal) # epoch마다의 anneal임.
        scheduler.step()
    #########################################################
        
                            # validation
        
    #########################################################
        with torch.no_grad():
            model.eval()
            Val_CE=0
            Val_KL=0
            m=0
            
            for data in val_loader:
                batch_size,src,tgt_input,tgt_output = preprocessing(data, config, device,which)
                output,mu,log_var=model(src,tgt_input)
                CE,KL=ce_kl(output,tgt_output,mu,log_var,1.,config)
                Val_CE+=CE.item()
                Val_KL+=KL.item()
                m+=batch_size
            VAL_CE_loss.append(Val_CE/m)
            VAL_KL_loss.append(Val_KL/m)



            if min_value_recon is None:
                min_value_epoch = copy.deepcopy(epoch)
                min_value_recon = copy.deepcopy(Val_CE/m)
                min_value_KL = copy.deepcopy(Val_KL/m)
                check_model = copy.deepcopy(model.state_dict())
                
            current_recon = Val_CE/m
            current_kl = Val_KL/m
            if current_recon - min_value_recon < 0:
                min_value_epoch = copy.deepcopy(epoch)
                min_value_recon = copy.deepcopy(current_recon)
                min_value_KL = copy.deepcopy(current_kl)
                check_model = copy.deepcopy(model.state_dict())
                check_count = 0
            else:
                if check_count == 4:
                    print('early stop')
                    break
                else:
                    check_count+=1

        if epoch % 5 ==0:
            print('-------------------------TRAIN----------------------------')
            print("| Epoch : %d | |" % (epoch)) 
            print("| Reconstruction Error : %.3f | KL divergence : %.3f | KL annealing : %.3f |"%(Total_cross_entropy/num\
                                                                                                            ,Total_kl/num,anneal\
                                                                                                            ))
            
            print('----------------------------------------------------------')    
            print('--------------------------VAL-----------------------------')
            print("| Reconstruction Error : %.3f | KL divergence : %.3f |"%(Val_CE/m,Val_KL/m))
            print('----------------------------------------------------------')    
    result = df()
    result['train_recon']=CE_loss
    result['train_kl']=KL_loss
    result['train_kl_anneal']=kl_anneal
    result['val_recon']=VAL_CE_loss
    result['val_kl']=VAL_KL_loss
    result.to_csv('./epoch_%s.csv'%(description))
    torch.save(check_model, './min_recon_model_epoch%d'%min_value_epoch)
    #return result

# VAE Train

In [22]:
os.chdir('./data_augmentation')

In [23]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
k = 10
experiment = [('transformer','ptb',100,'cycle',k)]
torch.cuda.empty_cache()
for ex in experiment:
    try:
        os.mkdir('./lr_1e-3_%s'%str(ex))
    except:
        print()
    os.chdir('./lr_1e-3_%s'%str(ex))
    which, ds, epochs, annealing_strategy, k = ex
    now = time.time()
    train(which, ds, epochs, annealing_strategy, device, k)
    torch.cuda.empty_cache()    
    os.chdir('..')
    print(time.time()-now)

epoch:   3%|▎         | 3/100 [01:40<43:50, 27.12s/it]

-------------------------TRAIN----------------------------
| Epoch : 5 | |
| Reconstruction Error : 4361.824 | KL divergence : 28.997 | KL annealing : 0.998 |
----------------------------------------------------------
--------------------------VAL-----------------------------
| Reconstruction Error : 4811.388 | KL divergence : 29.311 |
----------------------------------------------------------


epoch:   9%|▉         | 9/100 [04:20<41:05, 27.09s/it]

-------------------------TRAIN----------------------------
| Epoch : 10 | |
| Reconstruction Error : 3950.135 | KL divergence : 26.501 | KL annealing : 1.000 |
----------------------------------------------------------
--------------------------VAL-----------------------------
| Reconstruction Error : 4623.431 | KL divergence : 28.666 |
----------------------------------------------------------


epoch:  15%|█▌        | 15/100 [06:46<38:20, 27.07s/it]

-------------------------TRAIN----------------------------
| Epoch : 15 | |
| Reconstruction Error : 3576.857 | KL divergence : 27.931 | KL annealing : 0.998 |
----------------------------------------------------------
--------------------------VAL-----------------------------
| Reconstruction Error : 4425.353 | KL divergence : 29.028 |
----------------------------------------------------------


epoch:  18%|█▊        | 18/100 [08:20<36:59, 27.06s/it]

-------------------------TRAIN----------------------------
| Epoch : 20 | |
| Reconstruction Error : 3348.786 | KL divergence : 25.570 | KL annealing : 1.000 |
----------------------------------------------------------
--------------------------VAL-----------------------------
| Reconstruction Error : 4380.991 | KL divergence : 28.043 |
----------------------------------------------------------


epoch:  21%|██        | 21/100 [09:40<35:37, 27.06s/it]

early stop
656.0571076869965
