-
Notifications
You must be signed in to change notification settings - Fork 3
/
train_with_sampling.py
100 lines (79 loc) · 4.39 KB
/
train_with_sampling.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
from model import Transformer
import torch
import logging
from plot import *
from helpers import *
from joblib import load
import math, random
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s %(message)s",
datefmt="[%Y-%m-%d %H:%M:%S]")
logger = logging.getLogger(__name__)
def flip_from_probability(p):
return True if random.random() < p else False
def transformer(dataloader, EPOCH, k, frequency, path_to_save_model, path_to_save_loss, path_to_save_predictions,
device):
device = torch.device(device)
model = Transformer().double().to(device)
optimizer = torch.optim.Adam(model.parameters())
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=200)
criterion = torch.nn.MSELoss()
best_model = ""
min_train_loss = float('inf')
for epoch in range(EPOCH + 1):
train_loss = 0
val_loss = 0
## TRAIN -- TEACHER FORCING
model.train()
for index_in, index_tar, _input, target, sensor_number in dataloader:
# Shape of _input : [batch, input_length, feature]
# Desired input for model: [input_length, batch, feature]
optimizer.zero_grad()
src = _input.permute(1,0,2).double().to(device)[:-1,:,:] # torch.Size([24, 1, 7])
target = _input.permute(1,0,2).double().to(device)[1:,:,:] # src shifted by 1.
sampled_src = src[:1, :, :] #t0 torch.Size([1, 1, 7])
for i in range(len(target)-1):
prediction = model(sampled_src, device) # torch.Size([1xw, 1, 1])
"""
# to update model at every step
# loss = criterion(prediction, target[:i+1,:,:1])
# loss.backward()
# optimizer.step()
"""
if i < 24: # One day, enough data to make inferences about cycles
prob_true_val = True
else:
## coin flip
v = k/(k+math.exp(epoch/k)) # probability of heads/tails depends on the epoch, evolves with time.
# starts with over 95 % probability of true val for each flip in epoch 0.
prob_true_val = flip_from_probability(v)
## if using true value as new value
if prob_true_val: # Using true value as next value
sampled_src = torch.cat((sampled_src.detach(), src[i+1, :, :].unsqueeze(0).detach()))
else: ## using prediction as new value
positional_encodings_new_val = src[i+1,:,1:].unsqueeze(0)
predicted_humidity = torch.cat((prediction[-1,:,:].unsqueeze(0), positional_encodings_new_val),
dim=2)
sampled_src = torch.cat((sampled_src.detach(), predicted_humidity.detach()))
"""To update model after each sequence"""
loss = criterion(target[:-1,:,0].unsqueeze(-1), prediction)
loss.backward()
optimizer.step()
train_loss += loss.detach().item()
if train_loss < min_train_loss:
torch.save(model.state_dict(), path_to_save_model + f"best_train_{epoch}.pth")
torch.save(optimizer.state_dict(), path_to_save_model + f"optimizer_{epoch}.pth")
min_train_loss = train_loss
best_model = f"best_train_{epoch}.pth"
if epoch % 10 == 0: # Plot 1-Step Predictions
logger.info(f"Epoch: {epoch}, Training loss: {train_loss}")
scaler = load('scalar_item.joblib')
sampled_src_humidity = scaler.inverse_transform(sampled_src[:,:,0].cpu()) #torch.Size([35, 1, 7])
src_humidity = scaler.inverse_transform(src[:,:,0].cpu()) #torch.Size([35, 1, 7])
target_humidity = scaler.inverse_transform(target[:,:,0].cpu()) #torch.Size([35, 1, 7])
prediction_humidity = scaler.inverse_transform(prediction[:,:,0].detach().cpu().numpy()) #torch.Size([35, 1, 7])
plot_training_3(epoch, path_to_save_predictions, src_humidity, sampled_src_humidity, prediction_humidity,
sensor_number, index_in, index_tar)
train_loss /= len(dataloader)
log_loss(train_loss, path_to_save_loss, train=True)
plot_loss(path_to_save_loss, train=True)
return best_model