In [9]:
import torch
import torch.nn as nn
from torchdiffeq import odeint
from tqdm import tqdm

In [11]:
device = "cuda"

In [31]:
class Latent_ODE(nn.Module):
    def __init__(self, latent_dim=4, obs_dim=2, nhidden=20, rhidden = 20, aug = False, aug_dim = 2):
        super(Latent_ODE, self).__init__()
        self.aug = aug
        self.aug_dim = aug_dim
        if self.aug:
            self.rec = RecognitionRNN(latent_dim, obs_dim+aug_dim, rhidden)
        else:
            self.rec = RecognitionRNN(latent_dim, obs_dim, rhidden)
    
        self.func = LatentODEfunc(latent_dim, nhidden)
        self.dec = LatentODEDecoder(latent_dim, obs_dim, nhidden)
        
    def forward(self, xx, output_length):
        time_steps = torch.arange(0, output_length, 0.01).float().to(device)[:output_length]#torch.linspace(0, 59, 60).float().to(device)[:output_length]
        if self.aug:
            aug_ten = torch.zeros(xx.shape[0], xx.shape[1], self.aug_dim).float().to(device)
            xx = torch.cat([xx, aug_ten], dim = -1)
#         print(xx.shape)
#         print(torch.flip(xx, [1]).shape)
        z0 = self.rec.forward(torch.flip(xx, [1]))
        pred_z = odeint(self.func, z0, time_steps).permute(1, 0, 2)
        out = self.dec(pred_z)

        return out  
    
class LatentODEfunc(nn.Module):
    def __init__(self, latent_dim=4, nhidden=20):
        super(LatentODEfunc, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, nhidden),
            nn.ELU(),
            nn.Linear(nhidden, nhidden),
            nn.ELU(),
            nn.Linear(nhidden, nhidden),
            nn.ELU(),
            nn.Linear(nhidden, nhidden),
            nn.ELU(),
            nn.Linear(nhidden, latent_dim)
        )
        self.nfe = 0

    def forward(self, t, x):
        self.nfe += 1
        out = self.model(x)
        return out
    
class RecognitionRNN(nn.Module):
    def __init__(self, latent_dim=4, obs_dim=2, nhidden=25):
        super(RecognitionRNN, self).__init__()
        self.nhidden = nhidden
        self.model = nn.GRU(obs_dim, nhidden, batch_first = True)
        self.linear = nn.Linear(nhidden, latent_dim)

    def forward(self, x):
        #h0 = torch.zeros(1, x.shape[0], self.nhidden).to(device)
        output, hn = self.model(x)#, h0
        return self.linear(hn[0])
    
class LatentODEDecoder(nn.Module):
    def __init__(self, latent_dim=4, obs_dim=2, nhidden=20):
        super(LatentODEDecoder, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, nhidden),
            nn.ReLU(),
            nn.Linear(nhidden, obs_dim)
        )
        
    def forward(self, z):
        out = self.model(z)
        return out


In [32]:
model = Latent_ODE(latent_dim = 64, obs_dim = 3, nhidden = 128, rhidden = 128, aug = False)

In [28]:
y_exact=torch.load("../../5S_191111_3cmp_torch.pt")
#first row (0) k, then (1) q, then (2) u
#y_exact=y_exact[0:5,:]
print(y_exact.shape)
# y_exact = y_exact.permute(2, 0, 1)
print(y_exact.shape)

torch.Size([3, 71, 31])
torch.Size([3, 71, 31])


In [23]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
#optimizer = torch.optim.SGD(tdnu.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size= 100, gamma=0.95) #0.95 #step_size=1

In [24]:
num_epochs = 100
loss_fun = torch.nn.MSELoss()
min_loss = 10

In [33]:
history=[]
tqdm_epochs = tqdm(range(num_epochs))
for e in tqdm_epochs:
    y_approx = model(y_exact[:, :65, :], 3)
    loss = loss_fun(y_approx, y_exact[:,0:65,:])
    loss_history.append(loss.item())
    if loss.item() < min_loss:
        best_model = model
    optimizer.zero_grad()
    loss.backward(retain_graph=True)
    optimizer.step()
    history.append(loss.item())
    tqdm_epochs.set_postfix({'loss': torch.sum(loss).item()}) 
    
    scheduler.step()

  0%|          | 0/100 [00:00<?, ?it/s]

torch.Size([3, 65, 31])
torch.Size([3, 65, 31])





RuntimeError: input.size(-1) must be equal to input_size. Expected 3, got 31