In [None]:
import torch
import torch.nn as nn

nRoi = 264
nLat = 50
nMem = 3
nHid = 100

lu = nn.ReLU

def KLDiv(mu, logvar):
    return 0.5*torch.sum(logvar-1+torch.exp(logvar)+mu)

def getLatent(mu, logvar):
    return torch.normal(mu, torch.exp(logvar))

def makeWindows(ts, nRoi):
    nB = ts.shape[0]
    nt = ts.shape[2]
    ts = torch.cat([torch.zeros(nB, nMem, nRoi), ts], dim=2)
    wins = torch.zeros(nB, (nMem+1)*nt, nRoi)
    for i in range(nMem,nt+nMem):
        wins[:,(nMem+1)*i:(nMem+1)*(i+1),:] = ts[:,i-nMem:i+1,:]
    return wins.reshape(-1,nRoi)

class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
        self.mu = torch.Parameter(torch.randn(nLat).cuda())
        self.logvar = torch.Parameter(torch.randn(nLat).cuda())
        self.dec1 = nn.Linear((nMem+1)*nLat,nHid).cuda()
        self.dec2 = nn.Linear(nHid,nHid).cuda()
        self.dec3 = nn.Linear(nHid,nRoi).cuda()
        self.enc1 = nn.Linear((nMem+1)*nRoi,nHid).cuda()
        self.enc2 = nn.Linear(nHid,nHid).cuda()
        self.enc3 = nn.Linear(nHid,nLat*2).cuda()
        
    def forward(self, x):
        nB = x.shape[0]
        nt = x.shape[1]
        x = makeWindows(x, nRoi)
        x = lu(self.enc1(x))
        x = lu(self.enc2(x))
        x = self.enc3(x)
        mu, logvar = x[:,:nLat], x[:,nLat:]
        x = getLatent(mu, logvar)
        x = self.decoder(x)
        x = x.reshape(nB,nt,nRoi)
        return x, KLDiv(mu, logvar)
    
    def decoder(self, x):
        x = lu(self.dec1(x))
        x = lu(self.dec2(x))
        x = self.dec3(x)
        x = x[::4,:]
        return x
    
    def sample(self, mu, logvar):
        x = getLatent(mu, logvar)
        x = makeWindows(x, nLat)
        x = self.decoder(x)
        return x