In [11]:
import torch
import torch.nn as nn
import numpy as np
import time
import math
from torch.utils.data import DataLoader
from matplotlib import pyplot

In [20]:
torch.manual_seed(0)
np.random.seed(0)

# S is the source sequence length
# T is the target sequence length
# N is the batch size
# E is the feature number

#src = torch.rand((10, 32, 512)) # (S,N,E) 
#tgt = torch.rand((20, 32, 512)) # (T,N,E)
#out = transformer_model(src, tgt)

input_window = 100 # number of input steps
output_window = 1 # number of prediction steps, in this model its fixed to one
batch_size = 64 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Надо посмотреть размерности в forward

In [21]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()       
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        #pe.requires_grad = False
        self.register_buffer('pe', pe)

    def forward(self, x):
#         print(x.size())
#         print(x)
#         print((self.pe[:x.size(0), :]).size())
#         print((self.pe[:x.size(0), :]))
        
#         print(x + self.pe[:x.size(0), :])
        return x + self.pe[:x.size(0), :]
          

class TransAm(nn.Module):
    def __init__(self,feature_size=250,num_layers=1,dropout=0.1):
        super(TransAm, self).__init__()
        self.model_type = 'Transformer'
        
        self.src_mask = None
        self.pos_encoder = PositionalEncoding(feature_size)
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=feature_size, nhead=10, dropout=dropout)
        self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_layers)        
        self.decoder = nn.Linear(feature_size,1)
        self.init_weights()

    def init_weights(self):
        initrange = 0.1    
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)

    def forward(self,src):
        if self.src_mask is None or self.src_mask.size(0) != len(src):
            device = src.device
            mask = self._generate_square_subsequent_mask(len(src)).to(device)
            self.src_mask = mask
        
        #print(src.size())
        src = self.pos_encoder(src)
        #print(src.size())

        output = self.transformer_encoder(src,self.src_mask)#, self.src_mask)
        #print(output.size())
        
        output = self.decoder(output)
        #print(output.size())
        return output

    def _generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

In [22]:
# if window is 100 and prediction step is 1
# in -> [0..99]
# target -> [1..100]
def create_inout_sequences(input_data, tw):
    inout_seq = []
    L = len(input_data)
    for i in range(L-tw):
        train_seq = input_data[i:i+tw]
        train_label = input_data[i+output_window:i+tw+output_window]
        inout_seq.append((train_seq ,train_label))
    return torch.FloatTensor(inout_seq)

def get_data():
    # construct a littel toy dataset
    time        = np.arange(0, 400, 0.1)    
    amplitude   = np.sin(time) + np.sin(time*0.05) +np.sin(time*0.12) *np.random.normal(-0.2, 0.2, len(time))

    
    from sklearn.preprocessing import MinMaxScaler
    
    #loading weather data from a file
    #from pandas import read_csv
    #series = read_csv('daily-min-temperatures.csv', header=0, index_col=0, parse_dates=True, squeeze=True)
    
    # looks like normalizing input values curtial for the model
    scaler = MinMaxScaler(feature_range=(-1, 1)) 
    #amplitude = scaler.fit_transform(series.to_numpy().reshape(-1, 1)).reshape(-1)
    amplitude = scaler.fit_transform(amplitude.reshape(-1, 1)).reshape(-1)
    
    
    sampels = 2600
    train_data = amplitude[:sampels]
    test_data = amplitude[sampels:]

    # convert our train data into a pytorch train tensor
    #train_tensor = torch.FloatTensor(train_data).view(-1)
    # todo: add comment.. 
    train_sequence = create_inout_sequences(train_data,input_window)
    train_sequence = train_sequence[:-output_window] #todo: fix hack? -> din't think this through, looks like the last n sequences are to short, so I just remove them. Hackety Hack.. 

    #test_data = torch.FloatTensor(test_data).view(-1) 
    test_data = create_inout_sequences(test_data,input_window)
    test_data = test_data[:-output_window] #todo: fix hack?

    return train_sequence.to(device),test_data.to(device)

def get_batch(source, i,batch_size):
    seq_len = min(batch_size, len(source) - 1 - i)
    data = source[i:i+seq_len]    
    input = torch.stack(torch.stack([item[0] for item in data]).chunk(input_window,1)) # 1 is feature size
    target = torch.stack(torch.stack([item[1] for item in data]).chunk(input_window,1))
    return input, target

In [43]:
def train(train_data):
    model.train() # Turn on the train mode \o/
    total_loss = 0.
    start_time = time.time()

    for batch, i in enumerate(range(0, len(train_data) - 1, batch_size)):
        data, targets = get_batch(train_data, i,batch_size)
        #print(data.size())
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, targets)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.7)
        optimizer.step()

        total_loss += loss.item()
        log_interval = int(len(train_data) / batch_size / 5)
        if batch % log_interval == 0 and batch > 0:
            cur_loss = total_loss / log_interval
            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches | '
                  'lr {:02.6f} | {:5.2f} ms | '
                  'loss {:5.5f} | ppl {:8.2f}'.format(
                    epoch, batch, len(train_data) // batch_size, scheduler.get_lr()[0],
                    elapsed * 1000 / log_interval,
                    cur_loss, math.exp(cur_loss)))
            total_loss = 0
            start_time = time.time()

In [40]:
def plot_and_loss(eval_model, data_source,epoch):
    eval_model.eval() 
    total_loss = 0.
    test_result = torch.Tensor(0)    
    truth = torch.Tensor(0)
    with torch.no_grad():
        for i in range(0, len(data_source) - 1):
            data, target = get_batch(data_source, i,1)
            output = eval_model(data)            
            total_loss += criterion(output, target).item()
            test_result = torch.cat((test_result, output[-1].view(-1).cpu()), 0)
            truth = torch.cat((truth, target[-1].view(-1).cpu()), 0)
            
    #test_result = test_result.cpu().numpy() -> no need to detach stuff.. 
    len(test_result)

    pyplot.plot(test_result,color="red")
    pyplot.plot(truth[:500],color="blue")
    pyplot.plot(test_result-truth,color="green")
    pyplot.grid(True, which='both')
    pyplot.axhline(y=0, color='k')
    pyplot.savefig('graph/transformer-epoch%d.png'%epoch)
    pyplot.close()
    
    return total_loss / i

In [41]:
def predict_future(eval_model, data_source,steps):
    eval_model.eval() 
    total_loss = 0.
    test_result = torch.Tensor(0)    
    truth = torch.Tensor(0)
    data, _ = get_batch(data_source, 0,1)
    with torch.no_grad():
        for i in range(0, steps):            
            output = eval_model(data[-input_window:])                        
            data = torch.cat((data, output[-1:]))
            
    data = data.cpu().view(-1)
    
    # I used this plot to visualize if the model pics up any long therm struccture within the data. 
    pyplot.plot(data,color="red")       
    pyplot.plot(data[:input_window],color="blue")    
    pyplot.grid(True, which='both')
    pyplot.axhline(y=0, color='k')
    pyplot.savefig('graph/transformer-future%d.png'%steps)
    pyplot.close()
        

def evaluate(eval_model, data_source):
    eval_model.eval() # Turn on the evaluation mode
    total_loss = 0.
    eval_batch_size = 1000
    with torch.no_grad():
        for i in range(0, len(data_source) - 1, eval_batch_size):
            data, targets = get_batch(data_source, i,eval_batch_size)
            output = eval_model(data)            
            total_loss += len(data[0])* criterion(output, targets).cpu().item()
    return total_loss / len(data_source)

In [8]:
data = np.loadtxt("./lienard_intermittency.dat")

train = data[:, 1][:45000]
test =  data[:, 1][45000:]

mean = train.mean()
std = train.std()
#sigmastest = np.abs(test - mean)
train_norm = (train - mean)/std
test_norm = (test - mean)/std

#train_norm = torch.FloatTensor(train_norm).view(-1)
#test_norm = torch.FloatTensor(test_norm).view(-1)

In [9]:
train_window = 20

def create_inout_sequences(input_data, tw):
    inout_seq = []
    L = len(input_data)
    for i in range(L-tw):
        train_seq = input_data[i:i+tw]
        train_label = input_data[i+1:i+tw+1]
        inout_seq.append((train_seq ,train_label))
    return inout_seq

train_inout_seq = create_inout_sequences(train_norm, train_window)
test_inout_seq = create_inout_sequences(test_norm, train_window)

In [13]:
train_dataloader = DataLoader(train_inout_seq, batch_size=batch_size, shuffle=True, drop_last=True)
test_dataloader = DataLoader(test_inout_seq, batch_size=batch_size, shuffle=True, drop_last=True)

In [37]:
train_window = 20

def create_inout_sequences(input_data, tw):
    inout_seq = []
    L = len(input_data)
    for i in range(L-tw):
        train_seq = input_data[i:i+tw]
        train_label = input_data[i+1:i+tw+1]
        inout_seq.append((train_seq ,train_label))
    return torch.FloatTensor(inout_seq)

train_data = create_inout_sequences(train_norm, train_window).to(device)
val_data = create_inout_sequences(test_norm, train_window).to(device)

  return torch.FloatTensor(inout_seq)


In [22]:
a = torch.zeros((20, 10, 1))
b = torch.ones((20, 1, 250))
(a + b).size()

torch.Size([20, 10, 250])

In [44]:
#train_data, val_data = get_data()
model = TransAm().to(device)

criterion = nn.MSELoss()
lr = 0.005 
#optimizer = torch.optim.SGD(model.parameters(), lr=lr)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)

best_val_loss = float("inf")
epochs = 100 # The number of epochs
best_model = None

for epoch in range(1, epochs + 1):
    epoch_start_time = time.time()
    train(train_data)
    
    if(epoch % 10 is 0):
        val_loss = plot_and_loss(model, val_data,epoch)
        predict_future(model, val_data,200)
    else:
        val_loss = evaluate(model, val_data)
   
    print('-' * 89)
    print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.5f} | valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time),
                                     val_loss, math.exp(val_loss)))
    print('-' * 89)

    #if val_loss < best_val_loss:
    #    best_val_loss = val_loss
    #    best_model = model

    scheduler.step() 

  if(epoch % 10 is 0):


| epoch   1 |   140/  702 batches | lr 0.005000 |  7.80 ms | loss 3.52265 | ppl    33.87
| epoch   1 |   280/  702 batches | lr 0.005000 |  6.62 ms | loss 0.58351 | ppl     1.79
| epoch   1 |   420/  702 batches | lr 0.005000 |  6.15 ms | loss 0.58048 | ppl     1.79
| epoch   1 |   560/  702 batches | lr 0.005000 |  6.59 ms | loss 0.73506 | ppl     2.09
| epoch   1 |   700/  702 batches | lr 0.005000 |  6.35 ms | loss 0.56849 | ppl     1.77
-----------------------------------------------------------------------------------------
| end of epoch   1 | time:  4.96s | valid loss 0.70474 | valid ppl     2.02
-----------------------------------------------------------------------------------------
| epoch   2 |   140/  702 batches | lr 0.004513 |  6.18 ms | loss 0.71913 | ppl     2.05
| epoch   2 |   280/  702 batches | lr 0.004513 |  6.06 ms | loss 0.50016 | ppl     1.65
| epoch   2 |   420/  702 batches | lr 0.004513 |  6.06 ms | loss 0.53332 | ppl     1.70
| epoch   2 |   560/  702 batche



| epoch  11 |   140/  702 batches | lr 0.002844 |  6.45 ms | loss 0.48524 | ppl     1.62
| epoch  11 |   280/  702 batches | lr 0.002844 |  6.18 ms | loss 0.33598 | ppl     1.40
| epoch  11 |   420/  702 batches | lr 0.002844 |  6.14 ms | loss 0.33182 | ppl     1.39
| epoch  11 |   560/  702 batches | lr 0.002844 |  6.15 ms | loss 0.40270 | ppl     1.50
| epoch  11 |   700/  702 batches | lr 0.002844 |  6.10 ms | loss 0.39848 | ppl     1.49
-----------------------------------------------------------------------------------------
| end of epoch  11 | time:  4.49s | valid loss 0.43808 | valid ppl     1.55
-----------------------------------------------------------------------------------------
| epoch  12 |   140/  702 batches | lr 0.002702 |  6.16 ms | loss 0.44953 | ppl     1.57
| epoch  12 |   280/  702 batches | lr 0.002702 |  6.10 ms | loss 0.30206 | ppl     1.35
| epoch  12 |   420/  702 batches | lr 0.002702 |  6.10 ms | loss 0.29288 | ppl     1.34
| epoch  12 |   560/  702 batche



| epoch  21 |   140/  702 batches | lr 0.001703 |  6.64 ms | loss 0.41749 | ppl     1.52
| epoch  21 |   280/  702 batches | lr 0.001703 |  6.19 ms | loss 0.26264 | ppl     1.30
| epoch  21 |   420/  702 batches | lr 0.001703 |  6.15 ms | loss 0.26568 | ppl     1.30
| epoch  21 |   560/  702 batches | lr 0.001703 |  6.10 ms | loss 0.33965 | ppl     1.40
| epoch  21 |   700/  702 batches | lr 0.001703 |  6.10 ms | loss 0.30048 | ppl     1.35
-----------------------------------------------------------------------------------------
| end of epoch  21 | time:  4.51s | valid loss 0.37159 | valid ppl     1.45
-----------------------------------------------------------------------------------------
| epoch  22 |   140/  702 batches | lr 0.001618 |  6.27 ms | loss 0.38861 | ppl     1.47
| epoch  22 |   280/  702 batches | lr 0.001618 |  6.67 ms | loss 0.25617 | ppl     1.29
| epoch  22 |   420/  702 batches | lr 0.001618 |  6.10 ms | loss 0.24117 | ppl     1.27
| epoch  22 |   560/  702 batche



| epoch  31 |   140/  702 batches | lr 0.001020 |  6.44 ms | loss 0.34911 | ppl     1.42
| epoch  31 |   280/  702 batches | lr 0.001020 |  6.07 ms | loss 0.26047 | ppl     1.30
| epoch  31 |   420/  702 batches | lr 0.001020 |  6.11 ms | loss 0.25749 | ppl     1.29
| epoch  31 |   560/  702 batches | lr 0.001020 |  6.05 ms | loss 0.34484 | ppl     1.41
| epoch  31 |   700/  702 batches | lr 0.001020 |  6.11 ms | loss 0.32774 | ppl     1.39
-----------------------------------------------------------------------------------------
| end of epoch  31 | time:  4.46s | valid loss 0.34133 | valid ppl     1.41
-----------------------------------------------------------------------------------------
| epoch  32 |   140/  702 batches | lr 0.000969 |  7.08 ms | loss 0.39999 | ppl     1.49
| epoch  32 |   280/  702 batches | lr 0.000969 |  6.66 ms | loss 0.24822 | ppl     1.28
| epoch  32 |   420/  702 batches | lr 0.000969 |  6.31 ms | loss 0.24873 | ppl     1.28
| epoch  32 |   560/  702 batche



| epoch  41 |   140/  702 batches | lr 0.000610 |  6.48 ms | loss 0.35635 | ppl     1.43
| epoch  41 |   280/  702 batches | lr 0.000610 |  6.14 ms | loss 0.21639 | ppl     1.24
| epoch  41 |   420/  702 batches | lr 0.000610 |  6.27 ms | loss 0.20589 | ppl     1.23
| epoch  41 |   560/  702 batches | lr 0.000610 |  6.16 ms | loss 0.28478 | ppl     1.33
| epoch  41 |   700/  702 batches | lr 0.000610 |  6.09 ms | loss 0.24688 | ppl     1.28
-----------------------------------------------------------------------------------------
| end of epoch  41 | time:  4.51s | valid loss 0.26347 | valid ppl     1.30
-----------------------------------------------------------------------------------------
| epoch  42 |   140/  702 batches | lr 0.000580 |  6.11 ms | loss 0.30436 | ppl     1.36
| epoch  42 |   280/  702 batches | lr 0.000580 |  6.11 ms | loss 0.20290 | ppl     1.22
| epoch  42 |   420/  702 batches | lr 0.000580 |  6.05 ms | loss 0.20646 | ppl     1.23
| epoch  42 |   560/  702 batche



| epoch  51 |   140/  702 batches | lr 0.000365 |  6.72 ms | loss 0.31987 | ppl     1.38
| epoch  51 |   280/  702 batches | lr 0.000365 |  6.33 ms | loss 0.20104 | ppl     1.22
| epoch  51 |   420/  702 batches | lr 0.000365 |  6.14 ms | loss 0.18547 | ppl     1.20
| epoch  51 |   560/  702 batches | lr 0.000365 |  6.28 ms | loss 0.24405 | ppl     1.28
| epoch  51 |   700/  702 batches | lr 0.000365 |  6.15 ms | loss 0.23526 | ppl     1.27
-----------------------------------------------------------------------------------------
| end of epoch  51 | time:  4.58s | valid loss 0.28361 | valid ppl     1.33
-----------------------------------------------------------------------------------------
| epoch  52 |   140/  702 batches | lr 0.000347 |  6.20 ms | loss 0.29270 | ppl     1.34
| epoch  52 |   280/  702 batches | lr 0.000347 |  6.06 ms | loss 0.18187 | ppl     1.20
| epoch  52 |   420/  702 batches | lr 0.000347 |  6.10 ms | loss 0.17580 | ppl     1.19
| epoch  52 |   560/  702 batche



| epoch  61 |   140/  702 batches | lr 0.000219 |  6.50 ms | loss 0.27898 | ppl     1.32
| epoch  61 |   280/  702 batches | lr 0.000219 |  6.10 ms | loss 0.17307 | ppl     1.19
| epoch  61 |   420/  702 batches | lr 0.000219 |  6.06 ms | loss 0.17209 | ppl     1.19
| epoch  61 |   560/  702 batches | lr 0.000219 |  6.80 ms | loss 0.24353 | ppl     1.28
| epoch  61 |   700/  702 batches | lr 0.000219 |  6.23 ms | loss 0.22136 | ppl     1.25
-----------------------------------------------------------------------------------------
| end of epoch  61 | time:  4.59s | valid loss 0.22872 | valid ppl     1.26
-----------------------------------------------------------------------------------------
| epoch  62 |   140/  702 batches | lr 0.000208 |  6.61 ms | loss 0.27555 | ppl     1.32
| epoch  62 |   280/  702 batches | lr 0.000208 |  6.57 ms | loss 0.17730 | ppl     1.19
| epoch  62 |   420/  702 batches | lr 0.000208 |  6.49 ms | loss 0.17105 | ppl     1.19
| epoch  62 |   560/  702 batche



| epoch  71 |   140/  702 batches | lr 0.000131 | 17.75 ms | loss 0.27103 | ppl     1.31
| epoch  71 |   280/  702 batches | lr 0.000131 | 20.20 ms | loss 0.16643 | ppl     1.18
| epoch  71 |   420/  702 batches | lr 0.000131 |  9.93 ms | loss 0.16668 | ppl     1.18
| epoch  71 |   560/  702 batches | lr 0.000131 |  7.97 ms | loss 0.23945 | ppl     1.27
| epoch  71 |   700/  702 batches | lr 0.000131 |  8.24 ms | loss 0.22138 | ppl     1.25
-----------------------------------------------------------------------------------------
| end of epoch  71 | time:  9.16s | valid loss 0.22658 | valid ppl     1.25
-----------------------------------------------------------------------------------------
| epoch  72 |   140/  702 batches | lr 0.000124 |  8.04 ms | loss 0.27638 | ppl     1.32
| epoch  72 |   280/  702 batches | lr 0.000124 |  7.92 ms | loss 0.17061 | ppl     1.19
| epoch  72 |   420/  702 batches | lr 0.000124 |  8.75 ms | loss 0.16461 | ppl     1.18
| epoch  72 |   560/  702 batche



| epoch  81 |   140/  702 batches | lr 0.000078 | 15.97 ms | loss 0.27526 | ppl     1.32
| epoch  81 |   280/  702 batches | lr 0.000078 | 15.25 ms | loss 0.17191 | ppl     1.19
| epoch  81 |   420/  702 batches | lr 0.000078 | 15.80 ms | loss 0.16762 | ppl     1.18
| epoch  81 |   560/  702 batches | lr 0.000078 | 15.59 ms | loss 0.23939 | ppl     1.27
| epoch  81 |   700/  702 batches | lr 0.000078 | 15.51 ms | loss 0.22068 | ppl     1.25
-----------------------------------------------------------------------------------------
| end of epoch  81 | time: 11.26s | valid loss 0.22521 | valid ppl     1.25
-----------------------------------------------------------------------------------------
| epoch  82 |   140/  702 batches | lr 0.000075 | 15.69 ms | loss 0.27125 | ppl     1.31
| epoch  82 |   280/  702 batches | lr 0.000075 | 16.07 ms | loss 0.16500 | ppl     1.18
| epoch  82 |   420/  702 batches | lr 0.000075 | 15.96 ms | loss 0.16202 | ppl     1.18
| epoch  82 |   560/  702 batche



| epoch  91 |   140/  702 batches | lr 0.000047 | 15.58 ms | loss 0.26191 | ppl     1.30
| epoch  91 |   280/  702 batches | lr 0.000047 | 14.98 ms | loss 0.16369 | ppl     1.18
| epoch  91 |   420/  702 batches | lr 0.000047 | 15.86 ms | loss 0.16032 | ppl     1.17
| epoch  91 |   560/  702 batches | lr 0.000047 | 15.57 ms | loss 0.22567 | ppl     1.25
| epoch  91 |   700/  702 batches | lr 0.000047 | 15.12 ms | loss 0.21205 | ppl     1.24
-----------------------------------------------------------------------------------------
| end of epoch  91 | time: 11.12s | valid loss 0.21811 | valid ppl     1.24
-----------------------------------------------------------------------------------------
| epoch  92 |   140/  702 batches | lr 0.000045 | 15.48 ms | loss 0.25981 | ppl     1.30
| epoch  92 |   280/  702 batches | lr 0.000045 | 15.59 ms | loss 0.16307 | ppl     1.18
| epoch  92 |   420/  702 batches | lr 0.000045 | 15.23 ms | loss 0.15904 | ppl     1.17
| epoch  92 |   560/  702 batche

In [21]:
#train_data, val_data = get_data()
model = TransAm().to(device)

criterion = nn.MSELoss()
lr = 0.005 
#optimizer = torch.optim.SGD(model.parameters(), lr=lr)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)

best_val_loss = float("inf")
epochs = 100 # The number of epochs
best_model = None

for epoch in range(1, epochs + 1):
    epoch_start_time = time.time()
    train(train_data)
    
    if(epoch % 10 is 0):
        val_loss = plot_and_loss(model, val_data,epoch)
        predict_future(model, val_data,200)
    else:
        val_loss = evaluate(model, val_data)
   
    print('-' * 89)
    print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.5f} | valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time),
                                     val_loss, math.exp(val_loss)))
    print('-' * 89)

    #if val_loss < best_val_loss:
    #    best_val_loss = val_loss
    #    best_model = model

    scheduler.step() 

  if(epoch % 10 is 0):


| epoch   1 |   899/ 4498 batches | lr 0.005000 |  4.31 ms | loss 1.38168 | ppl     3.98
| epoch   1 |  1798/ 4498 batches | lr 0.005000 |  4.08 ms | loss 0.59456 | ppl     1.81
| epoch   1 |  2697/ 4498 batches | lr 0.005000 |  4.06 ms | loss 0.60184 | ppl     1.83
| epoch   1 |  3596/ 4498 batches | lr 0.005000 |  4.06 ms | loss 0.68608 | ppl     1.99
| epoch   1 |  4495/ 4498 batches | lr 0.005000 |  4.08 ms | loss 0.59812 | ppl     1.82
-----------------------------------------------------------------------------------------
| end of epoch   1 | time: 18.63s | valid loss 0.75296 | valid ppl     2.12
-----------------------------------------------------------------------------------------
| epoch   2 |   899/ 4498 batches | lr 0.004513 |  4.09 ms | loss 0.69621 | ppl     2.01
| epoch   2 |  1798/ 4498 batches | lr 0.004513 |  4.08 ms | loss 0.53782 | ppl     1.71
| epoch   2 |  2697/ 4498 batches | lr 0.004513 |  4.00 ms | loss 0.55057 | ppl     1.73
| epoch   2 |  3596/ 4498 batche



| epoch  11 |   899/ 4498 batches | lr 0.002844 |  4.17 ms | loss 0.61945 | ppl     1.86
| epoch  11 |  1798/ 4498 batches | lr 0.002844 |  4.15 ms | loss 0.39745 | ppl     1.49
| epoch  11 |  2697/ 4498 batches | lr 0.002844 |  4.09 ms | loss 0.42578 | ppl     1.53
| epoch  11 |  3596/ 4498 batches | lr 0.002844 |  4.10 ms | loss 0.53863 | ppl     1.71
| epoch  11 |  4495/ 4498 batches | lr 0.002844 |  4.12 ms | loss 0.47559 | ppl     1.61
-----------------------------------------------------------------------------------------
| end of epoch  11 | time: 18.69s | valid loss 0.54011 | valid ppl     1.72
-----------------------------------------------------------------------------------------
| epoch  12 |   899/ 4498 batches | lr 0.002702 |  4.17 ms | loss 0.58466 | ppl     1.79
| epoch  12 |  1798/ 4498 batches | lr 0.002702 |  4.10 ms | loss 0.39209 | ppl     1.48
| epoch  12 |  2697/ 4498 batches | lr 0.002702 |  4.12 ms | loss 0.36910 | ppl     1.45
| epoch  12 |  3596/ 4498 batche



| epoch  21 |   899/ 4498 batches | lr 0.001703 |  4.18 ms | loss 0.49665 | ppl     1.64
| epoch  21 |  1798/ 4498 batches | lr 0.001703 |  4.15 ms | loss 0.31371 | ppl     1.37
| epoch  21 |  2697/ 4498 batches | lr 0.001703 |  4.14 ms | loss 0.31329 | ppl     1.37
| epoch  21 |  3596/ 4498 batches | lr 0.001703 |  4.14 ms | loss 0.40982 | ppl     1.51
| epoch  21 |  4495/ 4498 batches | lr 0.001703 |  4.10 ms | loss 0.42362 | ppl     1.53
-----------------------------------------------------------------------------------------
| end of epoch  21 | time: 18.74s | valid loss 0.45894 | valid ppl     1.58
-----------------------------------------------------------------------------------------
| epoch  22 |   899/ 4498 batches | lr 0.001618 |  4.09 ms | loss 0.47633 | ppl     1.61
| epoch  22 |  1798/ 4498 batches | lr 0.001618 |  4.06 ms | loss 0.34483 | ppl     1.41
| epoch  22 |  2697/ 4498 batches | lr 0.001618 |  4.08 ms | loss 0.30845 | ppl     1.36
| epoch  22 |  3596/ 4498 batche



| epoch  31 |   899/ 4498 batches | lr 0.001020 |  4.16 ms | loss 0.38189 | ppl     1.47
| epoch  31 |  1798/ 4498 batches | lr 0.001020 |  4.23 ms | loss 0.27253 | ppl     1.31
| epoch  31 |  2697/ 4498 batches | lr 0.001020 |  4.29 ms | loss 0.28673 | ppl     1.33
| epoch  31 |  3596/ 4498 batches | lr 0.001020 |  4.16 ms | loss 0.38147 | ppl     1.46
| epoch  31 |  4495/ 4498 batches | lr 0.001020 |  4.15 ms | loss 0.34100 | ppl     1.41
-----------------------------------------------------------------------------------------
| end of epoch  31 | time: 18.98s | valid loss 0.48520 | valid ppl     1.62
-----------------------------------------------------------------------------------------
| epoch  32 |   899/ 4498 batches | lr 0.000969 |  4.16 ms | loss 0.39921 | ppl     1.49
| epoch  32 |  1798/ 4498 batches | lr 0.000969 |  4.16 ms | loss 0.27472 | ppl     1.32
| epoch  32 |  2697/ 4498 batches | lr 0.000969 |  4.26 ms | loss 0.28007 | ppl     1.32
| epoch  32 |  3596/ 4498 batche



| epoch  41 |   899/ 4498 batches | lr 0.000610 |  4.03 ms | loss 0.38893 | ppl     1.48
| epoch  41 |  1798/ 4498 batches | lr 0.000610 |  4.01 ms | loss 0.22889 | ppl     1.26
| epoch  41 |  2697/ 4498 batches | lr 0.000610 |  4.01 ms | loss 0.24754 | ppl     1.28
| epoch  41 |  3596/ 4498 batches | lr 0.000610 |  4.00 ms | loss 0.34759 | ppl     1.42
| epoch  41 |  4495/ 4498 batches | lr 0.000610 |  3.99 ms | loss 0.31793 | ppl     1.37
-----------------------------------------------------------------------------------------
| end of epoch  41 | time: 18.13s | valid loss 0.40940 | valid ppl     1.51
-----------------------------------------------------------------------------------------
| epoch  42 |   899/ 4498 batches | lr 0.000580 |  3.99 ms | loss 0.34645 | ppl     1.41
| epoch  42 |  1798/ 4498 batches | lr 0.000580 |  3.99 ms | loss 0.24980 | ppl     1.28
| epoch  42 |  2697/ 4498 batches | lr 0.000580 |  3.99 ms | loss 0.26208 | ppl     1.30
| epoch  42 |  3596/ 4498 batche



| epoch  51 |   899/ 4498 batches | lr 0.000365 |  4.07 ms | loss 0.37179 | ppl     1.45
| epoch  51 |  1798/ 4498 batches | lr 0.000365 |  3.98 ms | loss 0.23075 | ppl     1.26
| epoch  51 |  2697/ 4498 batches | lr 0.000365 |  3.99 ms | loss 0.24058 | ppl     1.27
| epoch  51 |  3596/ 4498 batches | lr 0.000365 |  3.99 ms | loss 0.32787 | ppl     1.39
| epoch  51 |  4495/ 4498 batches | lr 0.000365 |  3.98 ms | loss 0.28626 | ppl     1.33
-----------------------------------------------------------------------------------------
| end of epoch  51 | time: 18.11s | valid loss 0.39210 | valid ppl     1.48
-----------------------------------------------------------------------------------------
| epoch  52 |   899/ 4498 batches | lr 0.000347 |  4.00 ms | loss 0.36014 | ppl     1.43
| epoch  52 |  1798/ 4498 batches | lr 0.000347 |  4.20 ms | loss 0.20877 | ppl     1.23
| epoch  52 |  2697/ 4498 batches | lr 0.000347 |  4.13 ms | loss 0.20924 | ppl     1.23
| epoch  52 |  3596/ 4498 batche



| epoch  61 |   899/ 4498 batches | lr 0.000219 |  4.18 ms | loss 0.31576 | ppl     1.37
| epoch  61 |  1798/ 4498 batches | lr 0.000219 |  4.16 ms | loss 0.19089 | ppl     1.21
| epoch  61 |  2697/ 4498 batches | lr 0.000219 |  4.12 ms | loss 0.18596 | ppl     1.20
| epoch  61 |  3596/ 4498 batches | lr 0.000219 |  4.12 ms | loss 0.27302 | ppl     1.31
| epoch  61 |  4495/ 4498 batches | lr 0.000219 |  4.11 ms | loss 0.24271 | ppl     1.27
-----------------------------------------------------------------------------------------
| end of epoch  61 | time: 18.71s | valid loss 0.36514 | valid ppl     1.44
-----------------------------------------------------------------------------------------
| epoch  62 |   899/ 4498 batches | lr 0.000208 |  4.11 ms | loss 0.31383 | ppl     1.37
| epoch  62 |  1798/ 4498 batches | lr 0.000208 |  4.10 ms | loss 0.19473 | ppl     1.21
| epoch  62 |  2697/ 4498 batches | lr 0.000208 |  4.10 ms | loss 0.18523 | ppl     1.20
| epoch  62 |  3596/ 4498 batche



| epoch  71 |   899/ 4498 batches | lr 0.000131 |  4.13 ms | loss 0.28516 | ppl     1.33
| epoch  71 |  1798/ 4498 batches | lr 0.000131 |  4.11 ms | loss 0.17877 | ppl     1.20
| epoch  71 |  2697/ 4498 batches | lr 0.000131 |  4.11 ms | loss 0.16220 | ppl     1.18
| epoch  71 |  3596/ 4498 batches | lr 0.000131 |  4.11 ms | loss 0.24344 | ppl     1.28
| epoch  71 |  4495/ 4498 batches | lr 0.000131 |  4.11 ms | loss 0.21505 | ppl     1.24
-----------------------------------------------------------------------------------------
| end of epoch  71 | time: 18.63s | valid loss 0.30828 | valid ppl     1.36
-----------------------------------------------------------------------------------------
| epoch  72 |   899/ 4498 batches | lr 0.000124 |  4.11 ms | loss 0.27719 | ppl     1.32
| epoch  72 |  1798/ 4498 batches | lr 0.000124 |  4.08 ms | loss 0.16980 | ppl     1.19
| epoch  72 |  2697/ 4498 batches | lr 0.000124 |  4.07 ms | loss 0.15812 | ppl     1.17
| epoch  72 |  3596/ 4498 batche



| epoch  81 |   899/ 4498 batches | lr 0.000078 |  4.10 ms | loss 0.25158 | ppl     1.29
| epoch  81 |  1798/ 4498 batches | lr 0.000078 |  4.06 ms | loss 0.15445 | ppl     1.17
| epoch  81 |  2697/ 4498 batches | lr 0.000078 |  4.08 ms | loss 0.15021 | ppl     1.16
| epoch  81 |  3596/ 4498 batches | lr 0.000078 |  4.08 ms | loss 0.22638 | ppl     1.25
| epoch  81 |  4495/ 4498 batches | lr 0.000078 |  4.07 ms | loss 0.20324 | ppl     1.23
-----------------------------------------------------------------------------------------
| end of epoch  81 | time: 18.45s | valid loss 0.29583 | valid ppl     1.34
-----------------------------------------------------------------------------------------
| epoch  82 |   899/ 4498 batches | lr 0.000075 |  4.07 ms | loss 0.24997 | ppl     1.28
| epoch  82 |  1798/ 4498 batches | lr 0.000075 |  4.08 ms | loss 0.15459 | ppl     1.17
| epoch  82 |  2697/ 4498 batches | lr 0.000075 |  4.07 ms | loss 0.15035 | ppl     1.16
| epoch  82 |  3596/ 4498 batche



| epoch  91 |   899/ 4498 batches | lr 0.000047 |  4.12 ms | loss 0.25246 | ppl     1.29
| epoch  91 |  1798/ 4498 batches | lr 0.000047 |  4.12 ms | loss 0.14774 | ppl     1.16
| epoch  91 |  2697/ 4498 batches | lr 0.000047 |  4.12 ms | loss 0.14376 | ppl     1.15
| epoch  91 |  3596/ 4498 batches | lr 0.000047 |  4.11 ms | loss 0.21821 | ppl     1.24
| epoch  91 |  4495/ 4498 batches | lr 0.000047 |  4.11 ms | loss 0.19418 | ppl     1.21
-----------------------------------------------------------------------------------------
| end of epoch  91 | time: 18.62s | valid loss 0.25982 | valid ppl     1.30
-----------------------------------------------------------------------------------------
| epoch  92 |   899/ 4498 batches | lr 0.000045 |  4.13 ms | loss 0.24004 | ppl     1.27
| epoch  92 |  1798/ 4498 batches | lr 0.000045 |  4.09 ms | loss 0.14874 | ppl     1.16
| epoch  92 |  2697/ 4498 batches | lr 0.000045 |  4.07 ms | loss 0.13822 | ppl     1.15
| epoch  92 |  3596/ 4498 batche

In [22]:
torch.save(model.state_dict(), 'transformer')