<a href="https://colab.research.google.com/github/ErinZhang1998/sketch_collection/blob/master/primitive_selector.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
! pip install ftfy regex tqdm
! pip install git+https://github.com/openai/CLIP.git

In [None]:
! pip install svgwrite 
! pip install CairoSVG
! pip install wandb

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting svgwrite
  Downloading svgwrite-1.4.2-py3-none-any.whl (67 kB)
[K     |████████████████████████████████| 67 kB 4.9 MB/s 
[?25hInstalling collected packages: svgwrite
Successfully installed svgwrite-1.4.2
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting CairoSVG
  Downloading CairoSVG-2.5.2-py3-none-any.whl (45 kB)
[K     |████████████████████████████████| 45 kB 2.4 MB/s 
Collecting cssselect2
  Downloading cssselect2-0.6.0-py3-none-any.whl (15 kB)
Collecting cairocffi
  Downloading cairocffi-1.3.0.tar.gz (88 kB)
[K     |████████████████████████████████| 88 kB 3.4 MB/s 
Building wheels for collected packages: cairocffi
  Building wheel for cairocffi (setup.py) ... [?25l[?25hdone
  Created wheel for cairocffi: filename=cairocffi-1.3.0-py3-none-any.whl size=89668 sha256=76113ad9ae2a615679d7986ae42777e1837d5d07be1bc

In [1]:
from google.colab import drive

drive.mount('/content/gdrive')


Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [2]:
%cd gdrive/MyDrive

/content/gdrive/MyDrive


In [None]:
# !git clone https://github.com/ErinZhang1998/sketch_collection.git

In [None]:
!ls

'Colab Notebooks'	      pleasant-tree-10.pt
'Copy of 简历修改建议.gdoc'   primitive_selector_training_data
 doodler_model_checkpoint     sketch_collection
'Getting started.pdf'	      wandb


In [3]:
import sys    
path_to_module = '/content/gdrive/MyDrive/sketch_collection'
sys.path.append(path_to_module)

In [None]:
# %mkdir doodler_model_checkpoint

In [4]:
import pandas as pd
import read_datasets as rd
import numpy as np 
import pickle
from collections import defaultdict
import wandb 
wandb.login()
import argparse

import os

import torch
import torch.nn as nn
import torch.nn.utils.rnn as rnn
from torch import optim
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torch.distributions import Normal, OneHotCategorical

[34m[1mwandb[0m: Currently logged in as: [33merinz[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [5]:
all_data = pickle.load(open("/content/gdrive/MyDrive/primitive_selector_training_data/july_13_all.pkl", "rb"))
# df = pd.DataFrame.from_dict(all_data, orient='index')
df = pd.DataFrame(all_data)

# Training with LSTM

In [6]:
def preprocess_dataset_language(path):
    f = open(path, "rb")
    data_raw = pickle.load(f)
    q2i = defaultdict(lambda: len(q2i))
    pad = q2i["<pad>"]
    UNK = q2i["<unk>"]
    
    for info in data_raw:
        description = info['processed']
        [q2i[x] for x in description.lower().strip().split(" ")]
    return q2i

In [None]:
q2i = preprocess_dataset_language(
    "/content/gdrive/MyDrive/primitive_selector_training_data/july_13_all.pkl"
)

In [7]:
def collate_primitivedataset(seq_list):
    description_ts, primitive_types, affine_paramss = zip(*seq_list)
    lens = [len(x) for x in description_ts]
    seq_order = sorted(range(len(lens)), key=lens.__getitem__, reverse=True)
    description_ts = [description_ts[i] for i in seq_order]
    
    primitive_types = torch.stack([primitive_types[i] for i in seq_order])
    affine_paramss = torch.stack([affine_paramss[i] for i in seq_order])
    
    # (N, 1) (N, num_transformation_params)
    return description_ts, primitive_types, affine_paramss


### Dataset

In [8]:
class PrimitiveDataset(Dataset):
    def __init__(self, path, vocab, num_transformation_params, image_size=256.):
        super().__init__()
        self.path = path
        self.image_size = image_size
        
        f = open(self.path, "rb")
        self.data_raw = pickle.load(f)
        
        self.vocab = vocab
        self.vocab_keys = vocab.keys()
        self.original_image_size = 256.
        self.num_transformation_params = num_transformation_params

    def __len__(self):
        return len(self.data_raw)
    
    def __getitem__(self, index):
        # Process language input 
        info = self.data_raw[index]
        description = info['processed']
        description_t = [self.vocab[x.lower()] for x in description.split(" ") if x.lower() in self.vocab_keys]
        description_t = torch.from_numpy(np.array(description_t)).long()
        
        # Process M (num_transformation_params,) and type
        primitive_type = torch.from_numpy(np.array([info['primitive_type']])).long()
        if 'M' in info:
            # affine_params = info['M'].reshape(-1,)[:self.num_transformation_params]
            affine_params = info['M'].reshape(-1,)[[2,5]]
            affine_params = torch.FloatTensor(affine_params)
        else:
            raise 
        return description_t, primitive_type, affine_params


### Model

In [56]:
class PrimitiveSelector(nn.Module):
    def __init__(self, hp):
        super().__init__()
        '''
        num_embeddings: vocab size 
        '''
        self.hp = hp
        self.embed = nn.Embedding(hp.vocab_size, hp.word_embed_dim)
        self.lstm = nn.LSTM(
            input_size = hp.word_embed_dim, 
            hidden_size = hp.lstm_output_dim, 
            num_layers = hp.lstm_layers, 
            dropout = hp.lstm_drop_prob,
        )

        self.primitive_fc = nn.Linear(hp.lstm_output_dim, hp.num_primitives)
        self.num_normal_param = 3
        self.gmm_network = nn.Linear(hp.lstm_output_dim, hp.num_transformation_params * 2 * 1 * hp.M)
        self.pi_network = nn.Linear(hp.lstm_output_dim, hp.num_transformation_params * hp.M)

        # self.gmm_network = nn.Sequential(
        #     nn.Linear(hp.lstm_output_dim, hp.lstm_output_dim),
        #     nn.ELU(),
        #     nn.Linear(hp.lstm_output_dim, hp.num_transformation_params * 2 * 1 * hp.M)
        # )
        # self.pi_network = nn.Sequential(
        #     nn.Linear(hp.lstm_output_dim, hp.lstm_output_dim),
        #     nn.ELU(),
        #     nn.Linear(hp.lstm_output_dim, hp.num_transformation_params * hp.M)
        # )
    
    def forward(self, question):
        seq_tensor, seq_lengths = rnn.pad_packed_sequence(question, batch_first=True)               
        embedded_seq_tensor = self.embed(seq_tensor)
        seq_packed = rnn.pack_padded_sequence(
            torch.transpose(embedded_seq_tensor,0,1), 
            seq_lengths)
        _, (hidden,_) = self.lstm(seq_packed, None)
        seq_last_layer = hidden[-1] # N x lstm_output_dim
        prim_pred = self.primitive_fc(seq_last_layer) 

        params = self.gmm_network(seq_last_layer)
        pis = self.pi_network(seq_last_layer)
        # mean, sd = torch.split(params, params.shape[1] // 2, dim=1)
        # mean = torch.stack(mean.split(mean.shape[1] // self.hp.M, 1))
        # sd = torch.stack(sd.split(sd.shape[1] // self.hp.M, 1))
        # normal_dist = Normal(mean.transpose(0, 1), (F.elu(sd)+1+1e-7).transpose(0, 1))
        # pi_dist = OneHotCategorical(logits=pis)

        params_list = torch.split(params, 2 * 1 * self.hp.M, dim=1)
        pis_list = torch.split(pis, self.hp.M, dim=1)
        
        normal_dists, pi_dists = [],[]
        for i in range(self.hp.num_transformation_params):
            param = params_list[i]
            pi = pis_list[i]
            mean, sd = torch.split(param, param.shape[1] // 2, dim=1)
            mean = torch.stack(mean.split(mean.shape[1] // self.hp.M, 1))
            sd = torch.stack(sd.split(sd.shape[1] // self.hp.M, 1))
            
            # print(mean,sd)
            
            normal_dist = Normal(mean.transpose(0, 1), (F.elu(sd)+1+1e-7).transpose(0, 1))
            pi_dist = OneHotCategorical(logits=pi)

            normal_dists.append(normal_dist)
            pi_dists.append(pi_dist)

        return prim_pred, normal_dists, pi_dists

    # def forward(self, question): # question: PackedSequence 
    #     seq_tensor, seq_lengths = rnn.pad_packed_sequence(question, batch_first=True)               
    #     embedded_seq_tensor = self.embed(seq_tensor)
    #     seq_packed = rnn.pack_padded_sequence(
    #         torch.transpose(embedded_seq_tensor,0,1), 
    #         seq_lengths)
    #     _, (hidden,_) = self.lstm(seq_packed, None)
    #     seq_last_layer = hidden[-1] # N x hidden_embed_dim
    #     # print(hidden.shape, seq_last_layer.shape) #torch.Size([32, 512])
    #     prim_pred = self.primitive_fc(seq_last_layer) 
    #     # N x (self.num_normal_param * M * num_transformation_params)
    #     prim_param_pred = self.affine_fc(seq_last_layer) 
    #     # print(prim_pred.shape, prim_param_pred.shape) torch.Size([32, 5]) torch.Size([32, 36])
    #     # [N x (num_normal_param * M)]
    #     each_prim_param = torch.split(prim_param_pred, self.num_normal_param * self.hp.M, 1) 
    #     # print([x.shape for x in each_prim_param]) [torch.Size([32, 6])]
    #     pi_list = [] # length num_transformation_params
    #     mu_list = []
    #     sigma_list = []
    #     for y in each_prim_param: # N x (num_normal_param * M)
    #         params = torch.split(y, self.num_normal_param, 1) # N x self.num_normal_param
    #         # print([x.shape for x in params]) # [torch.Size([32, 3])]
    #         params_mixture = torch.stack(params) # M x N x self.num_normal_param
    #         # print(params_mixture.shape) # torch.Size([2, 32, 3])

    #         pi, mu, sigma = torch.split(params_mixture, 1, 2) # M x N x 1
    #         pi = F.softmax(pi.transpose(0,1).squeeze(), dim=-1) # N x M
    #         mu = mu.transpose(0,1).squeeze().contiguous()
    #         sigma = torch.exp(sigma.transpose(0,1).squeeze())
            
    #         pi_list.append(pi)
    #         mu_list.append(mu)
    #         sigma_list.append(sigma)
        
    #     return prim_pred, pi_list, mu_list, sigma_list

### Trainer

In [53]:
class Trainer():
    def __init__(self, train_dataset, val_dataset, hp, args):
        
        self.hp = hp
        self.args = args 
        
        if args.enable_wandb:
            wandb.init(project=args.wandb_project_name, entity=args.wandb_project_entity, config=hp.__dict__)
        
        self.enable_wandb = args.enable_wandb and not wandb.run is None
        if self.enable_wandb:
            self.run_name = wandb.run.name 
        else:
            import datetime
            import time 
            ts = time.time()                                                                                            
            self.run_name = datetime.datetime.fromtimestamp(ts).strftime('%Y_%m_%d_%H_%M_%S') 
        
        self.save_folder = os.path.join(args.save_root_folder, self.run_name)
        if not os.path.exists(self.save_folder):
            os.mkdir(self.save_folder)
        
        self.train_dataset_loader = DataLoader(
            train_dataset, 
            batch_size=hp.batch_size, 
            shuffle=True, 
            num_workers=args.num_workers, 
            collate_fn=collate_primitivedataset)
        self.val_dataset_loader = DataLoader(
                val_dataset, 
                batch_size=hp.batch_size, 
                shuffle=False, 
                num_workers=args.num_workers, 
                collate_fn=collate_primitivedataset)

        self.device = "cuda" # if torch.cuda.is_available() else "cpu"
        self.model = PrimitiveSelector(hp).cuda()
        self.optimizer = optim.Adam(self.model.parameters(), lr=hp.lr, weight_decay=hp.weight_decay)
        self.ce_loss = nn.CrossEntropyLoss()
    
    def make_target(self, affine_paramss):
        """Create ground truth for training transformation parameters by stacking M copies of each parameter

        Parameters
        ----------
        affine_paramss : torch.Tensor
            (N, num_transformation_params)

        Returns
        -------
        list of torch.Tensor
            GT for calculating log likelihood loss, each has shape (N, M)
            list of size num_transformation_params
        """
        return [
            torch.stack([affine_paramss[:,i]] * self.hp.M, 1) for i in range(affine_paramss.shape[1])
        ]
    
    def normal_pdf(self, x, mu, sigma):
        """Calculate univariate normal pdf for GMM

        Parameters
        ----------
        x : torch.Tensor
            (N, M)
        mu : torch.Tensor
            (N, M)
            predicted GMM means
        sigma : torch.Tensor
            (N, M)
            predicted GMM standard deviation

        Returns
        -------
        pdf : torch.Tensor
            (N, M)
            predicted probability
        """
        z = ( (x - mu) / sigma ) ** 2
        exp = torch.exp(-z / 2.0)
        norm = np.sqrt(2.0 * np.pi) * sigma
        pdf = exp / norm
        return pdf
    
    def log_losses(self, params_gt_list, pi_list, mu_list, sigma_list, epoch):
        """

        Args:
            params_gt_list : list of torch.Tensor (N, M)
                a list of GT transformation parameters
            pi_list : list of torch.Tensor (N, M)
                weights for combining the normal pdf in GMM
            mu_list : list of torch.Tensor (N, M)
                mean of GMM
            sigma_list : list of torch.Tensor (N, M)
                standard deviation of GMM
        Returns:
            losses : list of scalars
                each scalar is the log loss across the entire batch for one transformation parameter
        """
        losses = []
        for param_idx,(param,pi,mu,sigma) in enumerate(zip(params_gt_list, pi_list, mu_list, sigma_list)):
            pdf = self.normal_pdf(param, mu, sigma)
            gmm_pdf = torch.sum(pi * pdf, 1)
            log_prob = torch.log(1e-5 + gmm_pdf) # (N,)
            loss = -torch.sum(log_prob)
            losses.append(loss)

            if epoch == 1:
                #print(f"{param_idx} pi: ",pi)
                print(f"{param_idx} param: ",param)
                print(f"{param_idx} mu: ",mu)
                print(f"{param_idx} sigma: ",sigma)
                print(f"{param_idx} pdf: ", pdf)
                print(f"{param_idx} gmm: ", gmm_pdf)
            
        return losses
    
    def loss(self, y_list, normal_dists, pi_dists, epoch):
        losses = []
        for param_idx,(y,normal_dist,pi_dist) in enumerate(zip(y_list, normal_dists, pi_dists)):
            ys = y.unsqueeze(1).expand_as(normal_dist.loc)
            loglik = normal_dist.log_prob(ys)
            loglik = torch.sum(loglik, dim=2)
            loss = -torch.logsumexp(pi_dist.logits + loglik, dim=1)
            losses.append(loss.mean())
        return losses

    def train(self):
        self.model.train()
        step = 0
        for epoch in range(self.args.start_epoch, self.args.start_epoch + self.args.num_epochs):
        
            for batch_idx, (description_ts, primitive_types, affine_paramss) in enumerate(self.train_dataset_loader):
                description_ts = rnn.pack_sequence(description_ts)
                primitive_types = primitive_types.squeeze()
                description_ts, primitive_types, affine_paramss = description_ts.to(self.device), primitive_types.to(self.device), affine_paramss.to(self.device)
                # params_gt_list = self.make_target(affine_paramss)
                params_gt_list = [
                    affine_paramss[:,i].view(-1,1) for i in range(affine_paramss.shape[1])
                ]

                # prim_pred, pi_list, mu_list, sigma_list = self.model(description_ts)
                prim_pred, normal_dists, pi_dists = self.model(description_ts)
                
                self.optimizer.zero_grad()
                
                cel = self.ce_loss(prim_pred, primitive_types)
                # lls = self.log_losses(params_gt_list, pi_list, mu_list, sigma_list, epoch)
                lls = self.loss(params_gt_list, normal_dists, pi_dists, epoch)
                # print(lls)
                total_ll = torch.stack(lls).sum()
                total_lls = total_ll + cel
                
                wandb_dict = {'prim_type_loss' : cel.item(), 'total_param_loss' : total_ll.item()}
                for idx, ll in enumerate(lls):
                    wandb_dict[f'param_{idx}_loss'] = ll.item()
                wandb_dict['total_loss'] = total_lls.item()
                
                total_lls.backward()
                self.optimizer.step()
                
                if self.enable_wandb:
                    wandb.log(wandb_dict, step=step)
                else:
                    if step % self.hp.print_every == 0:
                        print_s = [f"Epoch {epoch} Iter {step}: "]
                        for k,v in wandb_dict.items():
                            print_s.append(f"{k} : {v}")
                        print(" | ".join(print_s))
                
                if step % self.hp.save_every == 0:
                    self.save_model(step)
                
                step += 1
            #     if epoch == 1:
            #         break
            # if epoch == 1:
            #     break

    # def evaluate(self):
    #     model.eval()
    #     with torch.no_grad(): 
    #         for batch_idx, (description_ts, primitive_types, affine_paramss) in enumerate(self.val_dataset_loader): 
    #             description_ts, primitive_types, affine_paramss = description_ts.to(self.device), \
    #                     primitive_types.to(self.device), affine_paramss.to(self.device)                
    #             prim_pred, pi_list, mu_list, sigma_list = self.model(description_ts)         
                
    def save_model(self, step):
        
        torch_path_name = os.path.join(self.save_folder, f"{step}.pt")

        torch.save({
            'iteration' : step,
            'model_state_dict': self.model.state_dict(),
        }, torch_path_name)


### main

In [23]:
class HParams():
    def __init__(self):
        self.word_embed_dim = 128
        self.lstm_output_dim = 512
        self.lstm_layers = 2 
        self.lstm_drop_prob = 0.4
        self.num_primitives = 5
        self.num_transformation_params = 2
        self.vocab_size = None
        self.M = 2
        self.weight_decay = 0.0
        self.batch_size = 32
        self.lr = 0.001
        
        self.save_every = 500
        self.print_every = 100

class CommandParams():
    def __init__(self):
        self.enable_wandb = False 
        self.start_epoch = 0
        self.num_epochs = 50
        self.num_workers = 0
        self.wandb_project_name = "doodler-draw"
        self.wandb_project_entity = "erinz"
        self.save_root_folder = "/content/gdrive/MyDrive/doodler_model_checkpoint"
        self.train_file = "/content/gdrive/MyDrive/primitive_selector_training_data/july_13_train.pkl"
        self.val_file = "/content/gdrive/MyDrive/primitive_selector_training_data/july_13_val.pkl"
        self.test_file = "/content/gdrive/MyDrive/primitive_selector_training_data/july_13_test.pkl"
        # self.word_file = "/content/gdrive/MyDrive/primitive_selector_training_data"

def get_args():
    args = CommandParams()
    return args

In [24]:
args = get_args()
hp = HParams()

In [26]:
vocab = preprocess_dataset_language(args.train_file)

In [27]:
hp.vocab_size = len(vocab)

In [16]:
train_dataset = PrimitiveDataset(args.train_file, vocab, hp.num_transformation_params)
val_dataset = PrimitiveDataset(args.test_file, vocab, hp.num_transformation_params)

In [57]:
trainer = Trainer(train_dataset, val_dataset, hp, args)

In [58]:
trainer.train()

Epoch 0 Iter 0:  | prim_type_loss : 1.6124851703643799 | total_param_loss : 12546.390625 | param_0_loss : 6198.1083984375 | param_1_loss : 6348.28173828125 | total_loss : 12548.0029296875
Epoch 0 Iter 100:  | prim_type_loss : 1.4718098640441895 | total_param_loss : 47.326602935791016 | param_0_loss : 22.864288330078125 | param_1_loss : 24.46231460571289 | total_loss : 48.79841232299805
Epoch 0 Iter 200:  | prim_type_loss : 1.3185399770736694 | total_param_loss : 28.102401733398438 | param_0_loss : 13.026363372802734 | param_1_loss : 15.076038360595703 | total_loss : 29.420942306518555
Epoch 1 Iter 300:  | prim_type_loss : 1.4346470832824707 | total_param_loss : 17.4107608795166 | param_0_loss : 9.459742546081543 | param_1_loss : 7.951018810272217 | total_loss : 18.845407485961914
Epoch 1 Iter 400:  | prim_type_loss : 1.1037076711654663 | total_param_loss : 13.532354354858398 | param_0_loss : 6.784598350524902 | param_1_loss : 6.747755527496338 | total_loss : 14.636061668395996
Epoch 1 

# CLIP

In [None]:
import clip
import os

device = "cuda"
model, preprocess = clip.load("ViT-B/32", device=device, jit=False) #Must set jit=False for training

root_folder = '/content/gdrive/MyDrive'
torch_path_name = os.path.join(root_folder, "pleasant-tree-10.pt")
checkpoint = torch.load(torch_path_name)
print(checkpoint.keys())

model.load_state_dict(checkpoint['model_state_dict'])

In [None]:
class TextDataset(Dataset):
    def __init__(self, df):
        self.text = df['raw'].to_list()
    
    def __len__(self):
        return len(self.text)
    
    def __getitem__(self, index):
        return self.text[index]
text_dataset = TextDataset(df)
text_loader = DataLoader(text_dataset, batch_size = 16, shuffle=False) 

In [None]:
model.eval()
word_features = []

with torch.no_grad():
    for list_txt in text_loader:
        texts = clip.tokenize(list_txt).to(device)
        feat = model.encode_text(texts)
        word_features.append(feat)

# Pretrained text embeddings

In [None]:
class PrimitiveDataset(Dataset):
    def __init__(self, path, word_path, vocab, num_transformation_params, image_size=256.):
        super().__init__()
        self.path = path
        
        f = open(self.path, "rb")
        self.data_raw = pickle.load(f)

        self.word_data = pickle.load(open(word_path), "rb")
        
        # self.vocab = vocab
        # self.vocab_keys = vocab.keys()
        # self.original_image_size = 256.
        self.num_transformation_params = num_transformation_params

    def __len__(self):
        return len(self.data_raw)
    
    def __getitem__(self, index):
        # Process language input 
        info = self.data_raw[index]
        word_arr = self.word_data[index] # (len, word_dim)
        description_t = torch.FloatTensor(word_arr)
        # description = info['processed']
        # description_t = [self.vocab[x.lower()] for x in description.split(" ") if x.lower() in self.vocab_keys]
        # description_t = torch.from_numpy(np.array(description_t)).long()
        
        # Process M (num_transformation_params,) and type
        primitive_type = torch.from_numpy(np.array([info['primitive_type']])).long()
        if 'M' in info:
            affine_params = info['M'].reshape(-1,)[:self.num_transformation_params]
            affine_params = torch.FloatTensor(affine_params)
        else:
            raise 
        return description_t, primitive_type, affine_params

In [None]:
def collate_primitivedataset(seq_list):
    description_ts, primitive_types, affine_paramss = zip(*seq_list)
    lens = [len(x) for x in description_ts]
    seq_order = sorted(range(len(lens)), key=lens.__getitem__, reverse=True)
    description_ts = [description_ts[i] for i in seq_order]
    
    primitive_types = torch.stack([primitive_types[i] for i in seq_order])
    affine_paramss = torch.stack([affine_paramss[i] for i in seq_order])
    
    # (N, 1) (N, num_transformation_params)
    return description_ts, primitive_types, affine_paramss


In [None]:
class PrimitiveSelector(nn.Module):
    def __init__(self, hp):
        super().__init__()
        '''
        num_embeddings: vocab size 
        '''
        self.hp = hp
        # self.embed = nn.Embedding(hp.vocab_size, hp.word_embed_dim)
        self.lstm = nn.LSTM(
            input_size = hp.word_embed_dim, 
            hidden_size = hp.lstm_output_dim, 
            num_layers = hp.lstm_layers, 
            dropout = hp.lstm_drop_prob,
        )

        self.primitive_fc = nn.Linear(hp.lstm_output_dim, hp.num_primitives)
        self.num_normal_param = 3
        self.affine_fc = nn.Linear(hp.lstm_output_dim, self.num_normal_param * hp.M * hp.num_transformation_params)
        
    def forward(self, question): # question: PackedSequence 
        # seq_tensor, seq_lengths = rnn.pad_packed_sequence(question, batch_first=True)               
        # embedded_seq_tensor = self.embed(seq_tensor)
        # seq_packed = rnn.pack_padded_sequence(np.transpose(embedded_seq_tensor,0,1), seq_lengths)
        _, hidden = self.lstm(question, None)
        seq_last_layer = hidden[-1] # N x hidden_embed_dim
        
        prim_pred = self.primitive_fc(seq_last_layer) 
        prim_param_pred = self.affine_fc(seq_last_layer) # N x (self.num_normal_param * M * num_transformation_params)
        each_prim_param = torch.split(prim_param_pred, self.num_normal_param * self.hp.M, 1) # [N x (num_normal_param * M)]
        pi_list = [] # length num_transformation_params
        mu_list = []
        sigma_list = []
        for y in each_prim_param: # N x (num_normal_param * M)
            params = torch.split(y, self.num_normal_param, 1) # N x self.num_normal_param
            params_mixture = torch.stack(params) # M x N x self.num_normal_param
            pi, mu, sigma = torch.split(params_mixture, 1, 2) # M x N x 1
            pi = F.softmax(pi.transpose(0,1).squeeze(), dim=-1) # N x M
            mu = mu.transpose(0,1).squeeze().contiguous()
            sigma = torch.exp(sigma.transpose(0,1).squeeze())
            
            pi_list.append(pi)
            mu_list.append(mu)
            sigma_list.append(sigma)
        
        return prim_pred, pi_list, mu_list, sigma_list
