In [None]:
import torch
import torch.nn as nn
import numpy as np
import csv
from datetime import datetime as dt
import matplotlib.pyplot as plt
import os

def load_csv(file):
    datetime = []
    temp = []
    humd = []
    power = []
    with open(file, 'r') as f:
        reader = csv.reader(f)
        for row in reader:
            d, t, h, p = row
            d = dt.fromisoformat(d)
            d = d.hour + d.minute / 60
            datetime.append(d)
            temp.append(float(t))
            humd.append(float(h))
            power.append(float(p))
    return (datetime, temp, humd, power)

DEVICE = torch.device('cuda:0')

In [None]:
class Model(nn.Module):

    def __init__(self) -> None:
        super().__init__()
        self.rnn = nn.LSTM(input_size=3, hidden_size=128, num_layers=2, bias=False, batch_first=True)
        self.linear1 = nn.Linear(128, 64, bias=True)
        self.linear2 = nn.Linear(64, 1, bias=True)
        self.norm = nn.Sigmoid()

    def forward(self, x):
        x = self.rnn(x)[0][:,-1,:]
        x = self.linear1(x)
        x = self.linear2(x)
        out = self.norm(x)
        return out

model = Model().to(DEVICE)

In [None]:
from torch.utils.data import DataLoader, Dataset

class MyDataset(Dataset):

    def __init__(self, raw_data, device='cpu') -> None:
        super().__init__()
        self.raw_data = raw_data
        self.device = device

    def __len__(self):
        return len(self.raw_data)

    def __getitem__(self, index):
        x = torch.tensor(self.raw_data[index, :, :-1]).float().to(self.device)
        y = torch.tensor(self.raw_data[index, :, -1][0]).float().unsqueeze(-1).to(self.device)
        return x, y

raw_data = np.load('dataset_all.npy')
dataset = MyDataset(raw_data, device=DEVICE)
dataloader = DataLoader(dataset, batch_size=len(dataset))

In [None]:
from torchmetrics.functional.classification import binary_accuracy, binary_precision, binary_recall, binary_f1_score

@torch.no_grad()
def validate(model, dataloader):
    model.eval()
    x, y = next(iter(dataloader))
    preds = model(x)
    # loss = loss_fn(preds, y)
    accuracy = binary_accuracy(preds.cpu(), y.cpu())
    precision = binary_precision(preds.cpu(), y.cpu())
    recall = binary_recall(preds.cpu(), y.cpu())
    f1 = binary_f1_score(preds.cpu(), y.cpu())
    # return # loss.cpu().item()
    return  accuracy, precision, recall, f1

In [None]:
model_files = []
for r, ds, fs in os.walk('models'):
    for f in fs:
        model_files.append(os.path.join(r, f))

metric_acc = []
metric_prec = []
metrix_rcl = []
metric_f1 = []

for path in model_files:
    model.load_state_dict(torch.load(path))
    model.eval()
    acc, prec, rcl, f1 = validate(model, dataloader)
    metric_acc.append(acc)
    metric_prec.append(prec)
    metrix_rcl.append(rcl)
    metric_f1.append(f1)

In [None]:
idx = np.argmax(metric_f1)
metric_f1[idx], model_files[idx]

In [None]:
model = Model()
model.load_state_dict(torch.load('models/1300_model.pt'))
model.eval()

file = "zuraach_ail2/2023-07-09_raw.csv"
date, temp, humd, power = load_csv(file)

dn = np.array(date) / 24
tn = np.array(temp)
tn /= 34.46
hn = np.array(humd)
hn /= 54.88
pn = np.array(power) / 1.8

dn_list = dn.tolist()
tn_list = tn.tolist()
hn_list = hn.tolist()
pn_list = pn.tolist()

WINDOW = 7

preds = []
labels = []
pred_date = []

for i in range(len(dn_list) - (WINDOW - 1)):
    data = []
    data.append(dn_list[i:i+WINDOW])
    data.append(tn_list[i:i+WINDOW])
    data.append(hn_list[i:i+WINDOW])
    data_T = np.array(data).T

    if pn_list[i:i+WINDOW][0] >= 0.7:
        labels.append(1)
    else:
        labels.append(0)

    data_tensor = torch.tensor(data_T).float().unsqueeze(0)

    with torch.no_grad():
        out = model(data_tensor)[0]
    preds.append(out.item())
    pred_date.append(np.mean(date[i:i+WINDOW]))

In [None]:
plt.figure(figsize=[15, 7])
plt.axis([0, 24, 0, 1.3])
plt.plot(date, tn, alpha=0.5, label='temp')
plt.plot(date, hn, alpha=0.5, label='humd')
plt.plot(date, pn, alpha=0.7, label='power')
plt.plot(pred_date, labels, alpha=0.5, label='label', c='pink')
plt.plot(pred_date, preds, alpha=1, label='pred', c='r')
plt.legend()