# 1 参数设定

In [4]:
import torch
import numpy as np


class Config():
    def __init__(self):
        self.mode = 'train' # train or test


        # 数据的部分特征
        self.data_lable = 'pig'
        self.data_path = './sketch_datas/'
        self.data_labels_list = ['ambulance','apple','bear','bicycle','bird','bus','cat','foot','owl','pig']
        # self.full_length = 256 # 一条数据的完整长度
        self.N_max = 0 # 一条数据的最大长度，可修改


        # 编码器的参数设置
        self.encoder_dim = 256 # encoder的输出维度
        self.z_dim = 128 # 中间向量Z的维度

        # 解码器的参数设置
        self.decoder_dim = 512 # decoder的输出维度
        self.M = 20 # GMM的分布数量

        # 训练参数
        if torch.cuda.is_available():
            self.device = 'cuda'
        elif torch.backends.mps.is_available():
            self.device = 'mps'
        else:
            self.device = 'cpu'

        self.batch_size = 16
        self.lr = 0.001
        self.epoch = 40
        self.W_kl = 0.5 # kl损失的权重
        self.temperature = 0.4 # gumbel softmax的温度参数
        





# 2 数据
读入数据，并对数据进行预处理
数据集由配置文件决定

In [5]:
import random
from re import split

def get_max(arr):
    return arr.shape[1]
    

# 按照7:1:2划分数据集
def split_data(arr, ratios):
    # 计算分割点
    total_length = arr.shape[0]
    splits = [int(total_length * ratio / sum(ratios)) for ratio in ratios]
    
    splits[-1] = total_length - sum(splits[:-1])

    # 按照分割点分割数组
    split_arrays = []
    start = 0
    for split in splits:
        split_arrays.append(arr[start:start + split])
        start += split

    return split_arrays

# 按照batch_size随机划分数据集
def split_in_batch(arr, config):
    # 划分batch，不足batchsize的直接丢弃
    batch_num = arr.shape[0] // config.batch_size 
    max_length = batch_num * config.batch_size # 最长有效数据长度

    # 打乱数据
    shuffled_arr = arr[:max_length].copy()
    np.random.shuffle(shuffled_arr)

    split_arrays = np.array(np.split(shuffled_arr, batch_num))

    return split_arrays


In [11]:
config = Config() # 初始化相关配置


f_path = config.data_path + config.data_lable + '.npy'
f = np.load(f_path)
config.N_max = get_max(f) # 获取最大长度，数据已补全
train_data, val_data, test_data = split_data(f,[7,1,2])

train_data = split_in_batch(train_data, config)
val_data = split_in_batch(val_data, config)
test_data = split_in_batch(test_data, config)


print(config.data_lable, 'dataset is loaded')
print('N_max:',config.N_max,', batch_size:',config.batch_size)
print('train_data:',train_data.shape)
print('val_data:',val_data.shape)
print('test_data:',test_data.shape)


pig dataset is loaded
N_max: 151 , batch_size: 16
train_data: (3062, 16, 151, 5)
val_data: (437, 16, 151, 5)
test_data: (875, 16, 151, 5)


# 3 模型搭建

In [7]:
import torch.nn as nn
import torch

# 2 

class encoderRNN(nn.Module):
    def __init__(self, config):
        super(encoderRNN, self).__init__()
        self.lstm = nn.LSTM(5, config.encoder_dim, bidirectional=True) # 双向lstm
        self.fc_mu = nn.Linear(config.encoder_dim*2, config.z_dim)
        self.fc_sigma = nn.Linear(config.encoder_dim*2, config.z_dim)


    def forward(self, input, batchsize, config):
        h0 = torch.zeros(2, batchsize, config.encoder_dim).to(config.device)
        c0 = torch.zeros(2, batchsize, config.encoder_dim).to(config.device)
        hidden_cell = (h0, c0)

        outputs, (hidden, cell) = self.lstm(input.float(), hidden_cell)

        # 完成编码，调整数据形状
        # (2, batch_size, hidden_size) -> (batch_size, 2*hidden_size)
        hidden_forward, hidden_backward = torch.split(hidden, 1, 0)
        hidden_cat = torch.cat([hidden_forward.squeeze(0), hidden_backward.squeeze(0)], 1)
        
        # mu部分
        mu = self.fc_mu(hidden_cat)

        # sigma部分
        sigma_hat = self.fc_sigma(hidden_cat)
        sigma = torch.exp(sigma_hat / 2.0)

        # gaussian noise
        N = torch.randn_like(mu)

        Z = mu + sigma * N


        return Z, mu, sigma_hat

# 2
class decoderRNN(nn.Module):
    def __init__(self, config):
        super(decoderRNN, self).__init__()
        self.N_max = config.N_max
        
        self.lstm = nn.LSTMCell(config.z_dim + 5, config.decoder_dim)
        self.fc_hc = nn.Linear(config.z_dim, 2 * config.decoder_dim)

        # 6M+3
        self.fc_output = nn.Linear(config.decoder_dim, 6 * config.M)
        self.fc_q = nn.Linear(config.decoder_dim, 3)
        

    def forward(self, config, input, z):
        h0, c0 = torch.split(torch.tanh(self.fc_hc(z)), config.decoder_dim, 1)
        hidden_cell = (h0.unsqueeze(0).contiguous, c0.unsqueeze(0).contiguous) # [batch_size, lstm_size] -> [1, batch_size, lstm_size]

        output, (hidden, cell) = self.lstm(input, hidden_cell)

        # y = self.fc_output(output.view(-1, config.decoder_dim))

        pi, mu_x, mu_y, sigma_x, sigma_y, rho_xy = torch.split(self.fc_xy(output), self.M, 2) # 6M
        q = self.fc_q(output) # 3

        # 后处理 
        # 这些变量的维度？
        pi = nn.functional.softmax(pi.transpose(0, 1).squeeze(), dim=-1).view(self.N_max+1, -1, config.M)
        sigma_x = torch.exp(sigma_x.transpose(0, 1).squeeze()).view(self.N_max+1, -1, config.M)
        sigma_y = torch.exp(sigma_y.transpose(0, 1).squeeze()).view(self.N_max+1, -1, config.M)
        rho_xy = torch.tanh(rho_xy.transpose(0, 1).squeeze()).view(self.N_max+1, -1, config.M)
        mu_x = mu_x.transpose(0, 1).squeeze().contiguous().view(self.N_max+1, -1, config.M)
        mu_y = mu_y.transpose(0, 1).squeeze().contiguous().view(self.N_max+1, -1, config.M)

        return pi, mu_x, mu_y, sigma_x, sigma_y, rho_xy, q, hidden, cell



# 4 训练设定

In [8]:
from tqdm import tqdm
from torch import optim

# 1
class model(nn.Module):
    def __init__(self, config):
        super(model, self).__init__()
        self.encoder = encoderRNN(config).to(config.device)
        self.decoder = decoderRNN(config).to(config.device)

        self.encoder_optimizer = optim.Adam(self.encoder.parameters(), lr=config.learning_rate)
        self.decoder_optimizer = optim.Adam(self.decoder.parameters(), lr=config.learning_rate)

    
    # KL Loss
    def KL_loss(self, mu, sigma, config):
        return -0.5 * torch.sum(1 + sigma - mu.pow(2) - torch.exp(sigma)) / float(config.z_dim * config.batch_size)


    # Reconstruction Loss
    def RC_loss(self):
        pass


    # 储存模型（解码编码分开保存）
    def save(self, config, epoch):
        pass

    # 读取模型（解码编码分开读取）
    def load(self, config, epoch):
        pass


    def train(self, config, train_data, val_data):
        train_loss_history = []
        train_LKL_history = []
        train_LR_history = []
        val_loss_history = []
        val_LKL_history = []
        val_LR_history = []
        
        train_data = torch.from_numpy(train_data).to(config.device)
        val_data = torch.from_numpy(val_data).to(config.device)


        for epoch in tqdm(range(config.num_epochs)):
            # 训练阶段
            self.encoder.train()
            self.decoder.train()

            for train_batch in train_data: 
                # encoder
                train_z, train_mu, train_sigma = self.encoder(train_batch, config.batch_size)

                # decoder
                # S0数据准备
                s0 = torch.stack([torch.Tensor([0, 0, 1, 0, 0])] * config.batch_size).to(config.device).unsqueeze(0) # [1, batch_size, 5]

                train_batch_init = torch.cat([s0, train_batch], dim = 0) # 引入s0的数据,[N_max+1, batch_size, 5]
                train_z_stack = torch.stack([train_z] * (config.N_max + 1)) # [N_max+1, batch_size, z_dim]
                train_inputs = torch.cat([train_batch_init, train_z_stack], dim=2) # [N_max+1, batch_size, 5+z_dim]

                train_pi, train_mu_x, train_mu_y, train_sigma_x, train_sigma_y, train_rho_xy, train_q, _, _ = self.decoder(config, train_inputs, train_z)

                # 多种损失计算，最后相加
                kl_loss = config.W_kl * self.KL_loss(train_mu, train_sigma, config)

                make_target() # 处理数据标签（补全，记录有效数据的长度mask机制，分割“笔画、笔状态数据”）


                train_loss = kl_loss + r_loss


                train_loss_history.append(train_loss.item())


                # 反向传播
                self.encoder_optimizer.zero_grad()
                self.decoder_optimizer.zero_grad()
                train_loss.backward()
                nn.utils.clip_grad_norm_(self.encoder.parameters(), config.grad_clip)
                nn.utils.clip_grad_norm_(self.decoder.parameters(), config.grad_clip)
                self.encoder_optimizer.step()
                self.decoder_optimizer.step()


                # 验证val阶段
                if epoch > config.num_epochs/10:

                    self.encoder.eval()
                    self.decoder.eval()

                    make_batch()

                    # encode


                    # decode
                    # S0数据准备


                    # 计算损失

                    val_loss = kl_loss + r_loss


                    val_loss_history.append(val_loss.item()) # 列表，保存val_loss

                    # 保存最优的模型
                    if val_loss < val_loss_history.max():
                        self.save(epoch)


    # 使用测试集数据测试模型（展示生成结果）
    def test(self, config, test_data):
        


        # 最终生成样例(有条件)
        self.generate(config, test_data)





# 5 训练

# 6 展示