In [1]:
import torch
from torch import nn
from utils import MLP

In [88]:
dec = nn.TransformerDecoderLayer(d_model=4, nhead=2, dim_feedforward=50, 
                                 batch_first=True, dropout=0.0)
obs_enc = MLP(2,10,output_size=4, unsqueeze=False)
dec_mu = MLP(1,12,input_size=10, 
             last_activation=lambda y: nn.functional.tanh(y)*10,
             unsqueeze=False)
dec_sigma = MLP(1,12,input_size=10, 
                last_activation=torch.exp, 
                unsqueeze=False)
dec_mixture = MLP(1,12,input_size=10, 
                  last_activation=torch.exp, 
                  unsqueeze=False)
obs = torch.rand(1,1,1)
start = torch.rand(1,1,4)

obs_embed = obs_enc(obs)
start.expand(obs_embed.shape).repeat(1,2,1)
print(start)
output_1 = dec(start.expand(obs_embed.shape), obs_embed)
print(output_1)
output_2 = dec(torch.cat([start, output_1[:,-1:,:].detach()], dim=-2), obs_embed)
print(output_2)
output_1[:,-1:,:].detach()

tensor([[[0.1716, 0.7469, 0.4589, 0.9850]]])
tensor([[[-0.9002,  0.5805, -1.0237,  1.3434]]],
       grad_fn=<NativeLayerNormBackward0>)
tensor([[[-1.3024,  0.0461, -0.2415,  1.4977],
         [-1.0205,  0.5959, -0.9087,  1.3333]]],
       grad_fn=<NativeLayerNormBackward0>)


tensor([[[-0.9002,  0.5805, -1.0237,  1.3434]]])

In [9]:
obs = torch.rand(11,1,10)
test = torch.rand(11,1,10)
torch.cat([obs,test], dim=-2).shape
# enc_obs = obs_enc(obs).unsqueeze(1)
# # enc_obs.shape = torch.Size([11, 1, 10])

# x = torch.rand(11,1,10) # batch, seq=1 '<start>', features
# dec_emb = dec(x,enc_obs) #shape = 11, 1 , 10
# dec_mu(dec_emb).shape

torch.Size([11, 2, 10])

In [114]:
class SeqGaussMixPosterior(nn.Module):
    def __init__(self):
        super().__init__()
        self.embed_obs = MLP(2,10,output_size=10, unsqueeze=False)
        self.decode_mu = MLP(1,12,input_size=10, 
                             last_activation=lambda y: torch.tanh(y)*10,
                             unsqueeze=False)
        self.decode_sigma = MLP(1,12,input_size=10, 
                                last_activation=torch.exp, 
                                unsqueeze=False)
        self.decode_mixture_prob = MLP(1,12,input_size=10, 
                                       last_activation=torch.exp, 
                                       unsqueeze=False)
        self.trans_dec = nn.TransformerDecoderLayer(d_model=10, nhead=2, dim_feedforward=20, 
                                                    batch_first=True, dropout=0.0)
        
        self.start = torch.nn.Parameter(torch.rand(1,1,10))
    def forward(self,obs):
        obs_embed = self.embed_obs(obs)
        so_far_decoded = self.start.expand(obs_embed.shape)
        all_mu, all_sigma, all_mix_p =  [], [], []
        for _ in range(2):
            obs_dec = self.trans_dec(so_far_decoded, obs_embed)
            so_far_decoded = torch.cat([so_far_decoded, obs_dec[:,-1:,:]], dim=-2).detach()
            all_mu.append(self.decode_mu(obs_dec[:,-1:,:]))
            all_sigma.append(self.decode_sigma(obs_dec[:,-1:,:]))
            all_mix_p.append(self.decode_mixture_prob(obs_dec[:,-1:,:]))
        return {'mixture_probs': torch.cat(all_mix_p, dim=1),
                'mus': torch.cat(all_mu, dim=1),
                'sigmas': torch.cat(all_sigma, dim=1)}
        

In [117]:
q_z = SeqGaussMixPosterior()
obs = torch.rand(2,1,1)
out = q_z(obs)['mixture_probs']
out

q_z

# all_mu, all_sigma, all_mix_p = [], [], []

SeqGaussMixPosterior(
  (embed_obs): MLP(
    (network): Sequential(
      (0): Sequential(
        (0): Linear(in_features=1, out_features=10, bias=True)
        (1): ReLU()
      )
      (1): Sequential(
        (0): Linear(in_features=10, out_features=10, bias=True)
        (1): ReLU()
      )
      (2): Sequential(
        (0): Linear(in_features=10, out_features=10, bias=True)
      )
    )
  )
  (decode_mu): MLP(
    (network): Sequential(
      (0): Sequential(
        (0): Linear(in_features=10, out_features=12, bias=True)
        (1): ReLU()
      )
      (1): Sequential(
        (0): Linear(in_features=12, out_features=1, bias=True)
      )
    )
  )
  (decode_sigma): MLP(
    (network): Sequential(
      (0): Sequential(
        (0): Linear(in_features=10, out_features=12, bias=True)
        (1): ReLU()
      )
      (1): Sequential(
        (0): Linear(in_features=12, out_features=1, bias=True)
      )
    )
  )
  (decode_mixture_prob): MLP(
    (network): Sequential(
     

In [132]:
def gen_batch():
    obs = torch.randn(100,1,1)
    less_than_0 = obs<=0
    mu_1 = torch.ones(100,1,1) * -1
    mu_1[less_than_0] = 1.0
    mu_2 = torch.ones(100,1,1) * -2
    mu_2[less_than_0] = 2.0
    sigma = torch.ones(100,2,1)
    mix_p_1 = torch.ones(100,1,1) * 0.6
    mix_p_1[less_than_0] = 0.4
    mix_p_2 = torch.ones(100,1,1) * 0.7
    mix_p_2[less_than_0] = 0.3
    return obs, torch.cat([mu_1,mu_2],dim=1), sigma, torch.cat([mix_p_1,mix_p_2],dim=1)

In [140]:
q_z_given_obs = SeqGaussMixPosterior()
num_iterations = 1000
optim = torch.optim.Adam(q_z_given_obs.parameters())
losses = []
for _ in range(num_iterations):
    obs, mu, sigma, mix_p = gen_batch()
    out = q_z_given_obs(obs)
    loss_mu = nn.functional.mse_loss(out['mus'], mu)
    loss_p = nn.functional.mse_loss(out['mixture_probs'], mix_p)
    loss_sig = nn.functional.mse_loss(out['sigmas'], sigma)
    loss = loss_mu + loss_p + loss_sig
    losses.append(loss.detach())
    loss.backward()
    
    optim.step()
    optim.zero_grad()

In [141]:
#training was successful on this toy example
print(q_z_given_obs(torch.tensor([[[1.0]]])))
print(q_z_given_obs(torch.tensor([[[-1.0]]])))

{'mixture_probs': tensor([[[0.6144],
         [0.6933]]], grad_fn=<CatBackward0>), 'mus': tensor([[[-1.0640],
         [-2.0548]]], grad_fn=<CatBackward0>), 'sigmas': tensor([[[0.9997],
         [0.9975]]], grad_fn=<CatBackward0>)}
{'mixture_probs': tensor([[[0.3534],
         [0.3280]]], grad_fn=<CatBackward0>), 'mus': tensor([[[0.9773],
         [2.0022]]], grad_fn=<CatBackward0>), 'sigmas': tensor([[[1.0018],
         [1.0028]]], grad_fn=<CatBackward0>)}
