In [19]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random
from matplotlib import pyplot as plt
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from scinet import SciNet
import pandas as pd
from scinet_utils import target_loss 
from loader import build_dataloader
import torch.optim.lr_scheduler as lr_scheduler
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [20]:
import torch
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
import pandas as pd
import json
import os
import numpy as np
from scinet import SciNet
from scinet_utils import target_loss
from loader import build_dataloader

def generate_data(size, size_, t_max):
    t = np.linspace(0, t_max, size)
    min_fr, max_fr = 0.01, 100
    fr = np.linspace(min_fr, max_fr, size_)
    start_st, end_st = 0.01, 100
    st = np.logspace(np.log10(start_st), np.log10(end_st), size_, endpoint=True)

    def f(t, st, fr):
        return st**2 * fr * (1 - t/st - np.exp(-t/st))

    data = []
    for st_ in st:
        for fr_ in fr:
            example = list(f(t, st_, fr_))
            t_pred = np.random.uniform(0, t_max)
            pred = f(t_pred, st_, fr_)
            example.extend([fr_, st_, t_pred, pred])
            data.append(example)

    columns = [str(i) for i in range(size)]
    columns.extend(["fr", "st", "t_pred", "pred"])
    df = pd.DataFrame(data, columns=columns)
    return df

def train_sci_net(scinet, dataloader, optimizer, scheduler, beta, N_EPOCHS, device):
    hist_error = []
    hist_kl = []
    hist_loss = []

    for epoch in range(N_EPOCHS):
        epoch_error = []
        epoch_kl = []
        epoch_loss = []
        for minibatch in dataloader:
            time_series, fr, st, question, answer = (
                minibatch['time_series'].to(device) / 5,
                minibatch['fr'].to(device) / 5,
                minibatch['st'].to(device) / 5,
                minibatch['question'].to(device) / 5,
                minibatch['answer'].to(device) / 5
            )
            inputs = torch.cat((time_series, question.view(-1, 1)), 1)
            outputs = answer

            optimizer.zero_grad()
            pred = scinet.forward(inputs)
            loss_ = target_loss(pred, outputs)
            kl = beta * scinet.kl_loss
            loss = loss_ + kl
            loss.backward()
            optimizer.step()
            error = torch.mean(torch.sqrt((pred[:, 0] - outputs)**2)).detach().cpu().numpy()
            epoch_error.append(float(error))
            epoch_kl.append(float(kl.data.detach().cpu().numpy()))
            epoch_loss.append(float(loss_.data.detach().cpu().numpy()))

        hist_error.append(np.mean(epoch_error))
        hist_loss.append(np.mean(epoch_loss))
        hist_kl.append(np.mean(epoch_kl))

        before_lr = optimizer.param_groups[0]["lr"]
        scheduler.step()
        after_lr = optimizer.param_groups[0]["lr"]
        print("Epoch %d: SGD lr %.6f -> %.6f" % (epoch+1, before_lr, after_lr))
        print("Epoch %d -- loss %f, RMS error %f, KL %f" % (epoch+1, hist_loss[-1], hist_error[-1], hist_kl[-1]))

    return hist_error, hist_kl, hist_loss


In [21]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
sizes = [25]
N_EPOCHS = 50
size_ = 200
t_max = 5
data_file = "data.csv"
for size in sizes: 
    df = generate_data(size, size_, t_max)
    df.to_csv(data_file)

    scinet = SciNet(size, 1, 3, 100).to(device)  # Move the model to the GPU
    dataloader = build_dataloader(size=size, batch_size=128)

    SAVE_PATH = f"saved_models/scinet1-{size}epoch{N_EPOCHS}.dat"
    optimizer = optim.Adam(scinet.parameters(), lr=0.001)
    beta = 0.5
    scheduler = lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.009, total_iters=N_EPOCHS)

    hist_error = []
    hist_kl = []
    hist_loss = []

    for epoch in range(N_EPOCHS):
        epoch_error = []
        epoch_kl = []
        epoch_loss = []
        for minibatch in dataloader:
            time_series, fr, st, question, answer = (
                minibatch['time_series'].to(device) / 5,
                minibatch['fr'].to(device) / 5,
                minibatch['st'].to(device) / 5,
                minibatch['question'].to(device) / 5,
                minibatch['answer'].to(device) / 5
            )
            inputs = torch.cat((time_series, question.view(-1, 1)), 1)
            outputs = answer

            optimizer.zero_grad()
            input, pred = scinet.forward(inputs)
            loss_ = target_loss(pred, outputs)
            kl = beta * scinet.kl_loss
            loss = loss_ + kl
            loss.backward()
            optimizer.step()
            error = torch.mean(torch.sqrt((pred[:, 0] - outputs)**2)).detach().cpu().numpy()
            epoch_error.append(float(error))
            epoch_kl.append(float(kl.data.detach().cpu().numpy()))
            epoch_loss.append(float(loss_.data.detach().cpu().numpy()))

        hist_error.append(np.mean(epoch_error))
        hist_loss.append(np.mean(epoch_loss))
        hist_kl.append(np.mean(epoch_kl))

        before_lr = optimizer.param_groups[0]["lr"]
        scheduler.step()
        after_lr = optimizer.param_groups[0]["lr"]
        print("Epoch %d: SGD lr %.6f -> %.6f" % (epoch+1, before_lr, after_lr))
        print("Epoch %d -- loss %f, RMS error %f, KL %f" % (epoch+1, hist_loss[-1], hist_error[-1], hist_kl[-1]))


    torch.save(scinet.state_dict(), SAVE_PATH)
    print(f"Model saved to {SAVE_PATH}")

Epoch 1: SGD lr 0.001000 -> 0.000980
Epoch 1 -- loss 87192.357004, RMS error 14.192720, KL 30315.707145
Epoch 2: SGD lr 0.000980 -> 0.000960
Epoch 2 -- loss 71414.306790, RMS error 12.934764, KL 899.540144
Epoch 3: SGD lr 0.000960 -> 0.000941
Epoch 3 -- loss 29297.325136, RMS error 7.641603, KL 411.668688
Epoch 4: SGD lr 0.000941 -> 0.000921
Epoch 4 -- loss 6621.528093, RMS error 3.508350, KL 331.298824
Epoch 5: SGD lr 0.000921 -> 0.000901
Epoch 5 -- loss 3076.450972, RMS error 2.311101, KL 271.402239
Epoch 6: SGD lr 0.000901 -> 0.000881
Epoch 6 -- loss 1771.622371, RMS error 1.761615, KL 221.288146
Epoch 7: SGD lr 0.000881 -> 0.000861
Epoch 7 -- loss 1119.256538, RMS error 1.412735, KL 191.921584
Epoch 8: SGD lr 0.000861 -> 0.000841
Epoch 8 -- loss 763.089544, RMS error 1.201446, KL 168.596913
Epoch 9: SGD lr 0.000841 -> 0.000822
Epoch 9 -- loss 545.230520, RMS error 1.038492, KL 152.727588
Epoch 10: SGD lr 0.000822 -> 0.000802
Epoch 10 -- loss 425.998118, RMS error 0.932242, KL 139.3

In [22]:
dataset = torch.Tensor().to(device)
for minibatch in dataloader:
    time_series, fr, st, question, answer = (
        minibatch['time_series'].to(device) / 5,
        minibatch['fr'].to(device) / 5,
        minibatch['st'].to(device) / 5,
        minibatch['question'].to(device) / 5,
        minibatch['answer'].to(device) / 5
    )
    inputs = torch.cat((time_series, question.view(-1, 1)), 1)
    outputs = answer

    optimizer.zero_grad()
    input, pred = scinet.forward(inputs)
    data = torch.cat((input, pred), 1)
    error = torch.mean(torch.sqrt((pred[:, 0] - outputs)**2)).detach().cpu().numpy()
    dataset = torch.cat((dataset, data), 0)

In [23]:
df = pd.DataFrame(dataset.detach().cpu().numpy())
df.to_csv("scinet_output.csv")