In [3]:
from torchdiffeq import odeint_adjoint as odeint
import numpy as np
from matplotlib import pyplot as plt
from numpy import linalg as la
from scipy import stats
from scipy import spatial as sp
from scipy import integrate as int
import pandas as pd
import torch
import torch.nn as nn
from tqdm import tqdm


#import Networkx as net
import Plot3D as plot3d
import dataframe as dataframe

In [17]:
class Simple_FeedforwardNN(nn.Module):
    def __init__(self, input_dim, hidden_layers, depth, output_dim, activation_func=nn.Tanh()):
        super(Simple_FeedforwardNN, self).__init__()

        layers = []
        previous_depth = input_dim
        for _ in range(hidden_layers):
            layers.append(nn.Linear(previous_depth, depth))
            layers.append(activation_func)
            previous_depth = depth

        layers.append(nn.Linear(depth, output_dim))
        self.network = nn.Sequential(*layers)

        #Setting initial weights
        for m in self.network.modules():
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, mean=0, std=0.2)
                nn.init.constant_(m.bias, val=0)

    def forward(self, t, x):
            out = self.network(x)
            return out

In [28]:
1*(0+1)

1

In [29]:
class Piecewise_Auto_Network(Simple_FeedforwardNN):
    def __init__(self, input_dim, hidden_layers, depth, output_dim, domain, num_breakpoints, activation_func=nn.Tanh()):
        super().__init__(input_dim, hidden_layers, depth, output_dim*(num_breakpoints+1), activation_func)

        self.depth              = depth #Depth is number of neurons per hidden layer
        self.hiddenlayers       = hidden_layers #Number of internal hidden layers
        self.variables          = input_dim
        self.num_breakpoint     = num_breakpoints
        self.domain             = domain
        self.num_A              = (num_breakpoints+1)*output_dim
        self.breakpoints        = torch.linspace(domain[0], domain[1], num_breakpoints+2) 
        self.break_params       = nn.Parameter(self.breakpoints)

    def forward(self, t, x):
        out = self.network(x)

        #Getting indicator function
        #index = torch.where(self.break_params < t)[0].shape[0]
        #indicator = torch.zeros(self.num_A); 
        #indicator[index-1] = 1


       # final_output = out*indicator
        return out

In [30]:
model = Simple_FeedforwardNN(input_dim=1, hidden_layers=2, depth=5, output_dim=1, activation_func=nn.Tanh())

In [38]:
model = Piecewise_Auto_Network(input_dim=1, hidden_layers=2, depth=5, output_dim=1, domain = [0, 10], num_breakpoints=0, activation_func=nn.Tanh())

In [40]:
model(t=torch.tensor([2.0]), x=torch.tensor([2.0]))

tensor([-0.1228], grad_fn=<ViewBackward0>)

In [41]:
odeint(func=model, y0 = torch.tensor([2.0]), t = torch.arange(0, 10, 0.01), method='rk4', options={'step_size': 0.01})

tensor([[2.0000],
        [1.9988],
        [1.9975],
        [1.9963],
        [1.9951],
        [1.9939],
        [1.9926],
        [1.9914],
        [1.9902],
        [1.9890],
        [1.9878],
        [1.9865],
        [1.9853],
        [1.9841],
        [1.9829],
        [1.9817],
        [1.9804],
        [1.9792],
        [1.9780],
        [1.9768],
        [1.9756],
        [1.9744],
        [1.9731],
        [1.9719],
        [1.9707],
        [1.9695],
        [1.9683],
        [1.9671],
        [1.9659],
        [1.9646],
        [1.9634],
        [1.9622],
        [1.9610],
        [1.9598],
        [1.9586],
        [1.9574],
        [1.9562],
        [1.9550],
        [1.9538],
        [1.9526],
        [1.9514],
        [1.9502],
        [1.9490],
        [1.9478],
        [1.9466],
        [1.9454],
        [1.9442],
        [1.9430],
        [1.9418],
        [1.9406],
        [1.9394],
        [1.9382],
        [1.9370],
        [1.9358],
        [1.9346],
        [1