In [1]:
import torch
import pandas as pd
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import DataLoader, TensorDataset 
BASE_DIR = "P:/pet ML"
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [4]:
y_tr = torch.load(f'{BASE_DIR}/data/tensor/y_tr.pt')
y_stats = torch.load(f'{BASE_DIR}/data/tensor/y_stats.pt')
X_drop = torch.load(f'{BASE_DIR}/data/tensor/X_drop.pt')

In [5]:
input_feat = len(X_drop[0, :])

In [6]:
class model_nn(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Linear(input_feat, 256)
        self.layer2 = nn.Linear(256, 128)
        self.layer3 = nn.Linear(128, 64)
        self.layer4 = nn.Linear(64, 1)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.2)
        
    def forward(self, x):
        x = self.layer1(x)
        x = self.relu(x)

        x = self.layer2(x)
        x = self.relu(x)

        x = self.layer3(x)
        x = self.relu(x)

        x = self.layer4(x)
        return x

In [9]:
model = model_nn().to(device)
opt = Adam(model.parameters(), lr=0.001)
EPOCH = 50
loss_fn = nn.MSELoss()
train_ds = TensorDataset(X_drop, y_tr)
train_loader = DataLoader(dataset=train_ds, batch_size=256, shuffle=True)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, patience=10, factor=0.5)

In [10]:
model.train()
loss_list = []
for epoch in range(EPOCH):
    print(f'EPOCH {epoch+1} start')
    epoch_loss = 0
    for batch_X, batch_y in train_loader:
        opt.zero_grad()
        pred = model(batch_X)
        loss = loss_fn(pred.squeeze(), batch_y)
        loss.backward()
        opt.step()
        epoch_loss += loss.item()
                
    epoch_loss /= len(train_loader)
    loss_list.append(epoch_loss)
    scheduler.step(epoch_loss)
    print(f'loss {epoch+1}: {epoch_loss:.6f}, lr: {opt.param_groups[0]["lr"]:.5f}')
    if (len(loss_list)) > 2:
        if loss_list[-2] - loss_list[-1] < 0.00001:
            print(f'EARLY STOPPING HERE!!!')
            break

EPOCH 1 start
loss 1: 0.158606, lr: 0.00100
EPOCH 2 start
loss 2: 0.154691, lr: 0.00100
EPOCH 3 start
loss 3: 0.154189, lr: 0.00100
EPOCH 4 start
loss 4: 0.153854, lr: 0.00100
EPOCH 5 start
loss 5: 0.153649, lr: 0.00100
EPOCH 6 start
loss 6: 0.153466, lr: 0.00100
EPOCH 7 start
loss 7: 0.153305, lr: 0.00100
EPOCH 8 start
loss 8: 0.153224, lr: 0.00100
EPOCH 9 start
loss 9: 0.153079, lr: 0.00100
EPOCH 10 start
loss 10: 0.153059, lr: 0.00100
EPOCH 11 start
loss 11: 0.152886, lr: 0.00100
EPOCH 12 start
loss 12: 0.152862, lr: 0.00100
EPOCH 13 start
loss 13: 0.152720, lr: 0.00100
EPOCH 14 start
loss 14: 0.152624, lr: 0.00100
EPOCH 15 start
loss 15: 0.152531, lr: 0.00100
EPOCH 16 start
loss 16: 0.152462, lr: 0.00100
EPOCH 17 start
loss 17: 0.152385, lr: 0.00100
EPOCH 18 start
loss 18: 0.152368, lr: 0.00100
EPOCH 19 start
loss 19: 0.152274, lr: 0.00100
EPOCH 20 start
loss 20: 0.152214, lr: 0.00100
EPOCH 21 start
loss 21: 0.152167, lr: 0.00100
EPOCH 22 start
loss 22: 0.152094, lr: 0.00100
EPOCH 

In [12]:
torch.save(model.state_dict(), f'{BASE_DIR}/data/model/model_weight.pt')

In [None]:
model.eval()
pred_norm = model()
pred_real = pred_norm * y_stats['std'] + y_stats['mean']