In [None]:
# 多任务学习：在不同任务之间学习共性以及差异性，能够提高建模的质量以及效率。
# Hard Parameter Sharing：底层是共享的隐藏层，学习各个任务的共同模式，上层用一些特定的全连接层学习特定任务模式
# Soft Parameter Sharing：底层不使用共享的shared bottom，而是有多个tower，给不同的tower分配不同的权重。

# MMOE、ESSM、CGC、PLE、kuaishouEBR、AITM
# DSIN、AutoInt

In [1]:
import gc
# 加载数据集2
# 数据集：ml-100k

import os
import numpy as np
from sklearn.preprocessing import MinMaxScaler

# 加载数据
ratings = np.array([[int(x) for x in line.strip().split('\t')[:3]] for line in open('./data/ml-100k/ua.base','r').read().strip().split('\n')], dtype=np.int32)
ratings[:,-1] = (ratings[:,-1] - 0)/(max(ratings[:,-1]) - 0)
occupation_dict = {'administrator':0, 'artist':1, 'doctor':2, 'educator':3, 'engineer':4, 'entertainment':5, 'executive':6, 'healthcare':7, 'homemaker':8, 'lawyer':9, 'librarian':10, 'marketing':11, 'none':12, 'other':13, 'programmer':14, 'retired':15, 'salesman':16, 'scientist':17, 'student':18, 'technician':19, 'writer':20}
gender_dict={'M':1,'F':0}
user_info = {}
for line in open('./data/ml-100k/u.user','r', encoding='utf-8').read().strip().split('\n'):
    phs = line.strip().split('|')
    user_info[int(phs[0])] = [int(phs[1]), gender_dict[phs[2]], occupation_dict[phs[3]]]
item_info = {}
for line in open('./data/ml-100k/u.item','r', encoding='ISO-8859-1').read().strip().split('\n'):
    phs = line.strip().split('|')
    item_info[int(phs[0])] = phs[5:]
# print(len(user_info[list(user_info.keys())[0]]), len(item_info[list(item_info.keys())[0]]))
# print(data.shape)
num_users = len(user_info)
num_items = len(item_info)
num_features = 22

In [11]:
# ESSM:
# 一个输入到两个相对独立网络，网络输出的概率按照其任务级联关系相乘。
# soft sharing

import os, gc
import numpy as np
import torch
import torch.nn as nn
from torch.nn import Module, Parameter, MSELoss
from torch.utils.data import Dataset, DataLoader, TensorDataset 
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
from datetime import datetime
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')
from torch.nn import Module, Sequential, ReLU, Dropout, Sigmoid

# category: [1,2] + [all]
# number: [0] + []
number_feature_data = MinMaxScaler().fit_transform(np.array([[user_info[u][0]]  for u, i, r in ratings], dtype=np.float32))
category_feature_data = np.array([user_info[u][1:] + item_info[i] for u, i, r in ratings], dtype=np.int32)
data = np.concatenate([number_feature_data, category_feature_data, ratings[:,-1:]], axis=-1)
num_number_features = number_feature_data.shape[-1]
num_category_features = category_feature_data.shape[-1]
num_features = data.shape[-1] - 1
category_feature_vals = {}
for i in range(num_number_features, num_features):
    category_feature_vals[i] = sorted(list(set(list(data[:, i]))))
    for rid in range(data.shape[0]):
        data[rid, i] = category_feature_vals[i].index(data[rid, i])
# print(len(user_info[list(user_info.keys())[0]]), len(item_info[list(item_info.keys())[0]]))
# print(data.shape)

num_users = len(user_info)
num_items = len(item_info)
device = torch.device("cuda:0" if torch.cuda.is_available() else ('mps:0' if torch.backends.mps.is_available() else "cpu"))
batch_size = 100
num_epochs = 10
embedding_dim = 8 # sparse feature embedding dim
X_train, X_test, y_train, y_test = train_test_split(data[:,:-1], data[:,-1], test_size=0.4, random_state=0)
train_loader = DataLoader(dataset=TensorDataset(torch.from_numpy(X_train).float(), torch.from_numpy(y_train>1.8).float(), torch.from_numpy(y_train>3.8).float()), batch_size=batch_size, shuffle=True, pin_memory=True)
test_loader = DataLoader(dataset=TensorDataset(torch.from_numpy(X_test).float(), torch.from_numpy(y_test>1.8).float(), torch.from_numpy(y_test>3.8).float()), batch_size=batch_size, shuffle=False, pin_memory=True)

class ESSM(Module):
    def __init__(self, dense_feature_cols:[(int,int)], sparse_feature_cols:[(int,int)], sparse_feature_embedding_dim, 
                 num_task:int, dnn_layer_dims:list[int], dnn_dropout=0.):
        super(ESSM, self).__init__()
        self.dense_feature_cols, self.sparse_feature_cols, self.sparse_feature_embedding_dim = dense_feature_cols, sparse_feature_cols, sparse_feature_embedding_dim
        self.num_task, self.dnn_layer_dims = num_task, dnn_layer_dims
        # sparse feature embedding dict
        self.embed_layers = nn.ModuleDict({'embed_' + str(i): nn.Embedding(num_embeddings=valcount, embedding_dim=sparse_feature_embedding_dim) for i, valcount in sparse_feature_cols})
        self.input_dim = len(dense_feature_cols) + len(sparse_feature_cols) * sparse_feature_embedding_dim
        # dnn for each task
        self.dnn_nets = []
        for i in range(num_task):
            net = nn.Sequential()
            pre_layer_dim = self.input_dim
            for layer_dim in dnn_layer_dims:
                net.append(nn.Linear(pre_layer_dim, layer_dim))
                net.append(nn.BatchNorm1d(layer_dim))
                net.append(nn.Dropout(dnn_dropout))
                pre_layer_dim = layer_dim
            net.append(nn.Linear(dnn_layer_dims[-1], 1))
            net.append(nn.Sigmoid())
            self.dnn_nets.append(net)
    def forward(self, x):
        dense_input = x[:, :len(self.dense_feature_cols)]
        sparse_embeds = torch.cat([self.embed_layers['embed_' + str(i)](x[:, i].long()) for i in range(len(self.dense_feature_cols), x.shape[1])], axis=1)
        x = torch.cat([sparse_embeds, dense_input], axis=-1)
        outputs = [self.dnn_nets[i](x).squeeze() for i in range(self.num_task)]
        # ESSM原文计算概率，然后相乘，我这里任务是评分预测，改成相加
        output_sum = outputs[0]
        for i in range(1,self.num_task):
            outputs[i] = output_sum + outputs[i] # 这里注意不能用 +=，会inplace报错
            output_sum = outputs[i]
        return outputs
    def parameters(self, recurse: bool = True):
        paras = []
        for layer in self.embed_layers.values():
            paras += [p for p in layer.parameters()]
        for net in self.dnn_nets:
            paras += [p for p in net.parameters()]
        return paras
model = ESSM(dense_feature_cols=[i for i in range(num_number_features)], sparse_feature_cols=[(i,len(category_feature_vals[i])) for i in range(num_number_features, num_features)], sparse_feature_embedding_dim =embedding_dim, num_task=2, dnn_layer_dims=[128, 32], dnn_dropout=0.).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = MSELoss(reduction='sum').to(device)
alpha = 0.5

for epoch in range(num_epochs):
    # train:
    epoch_train_losses = []
    model.train()
    for i, inputs in enumerate(train_loader):
        optimizer.zero_grad()
        input = inputs[0].to(device)
        label_1, label_2 = inputs[1].to(device), inputs[2].to(device)
        output1, output2 = model(input)
        loss_1, loss_2 = criterion(output1, label_1), criterion(output2, label_2)
        loss = alpha * loss_1 + loss_2
        loss.backward()
        # torch.nn.utils.clip_grad_norm_(parameters=model.parameters(), max_norm=1, norm_type=2)
        optimizer.step()
        epoch_train_losses.append([input.shape[0], loss_1.item(), loss_2.item(), loss.item()])
    # validate:
    model.eval()
    epoch_test_losses = []
    for i, inputs in enumerate(test_loader):
        input = inputs[0].to(device)
        label_1, label_2 = inputs[1].to(device), inputs[2].to(device)
        output1, output2 = model(input)
        loss_1, loss_2 = criterion(output1, label_1), criterion(output2, label_2)
        loss = alpha * loss_1 + loss_2
        epoch_test_losses.append([input.shape[0], loss_1.item(), loss_2.item(), loss.item()])
    num_train, num_test = sum([x[0] for x in epoch_train_losses]), sum([x[0] for x in epoch_test_losses])
    train_overall_loss = sum([x[-1] for x in epoch_train_losses])/num_train
    train_task1_loss = sum([x[1] for x in epoch_train_losses])/num_train
    train_task2_loss = sum([x[2] for x in epoch_train_losses])/num_train
    test_overall_loss = sum([x[-1] for x in epoch_test_losses])/num_test
    test_task1_loss = sum([x[1] for x in epoch_test_losses])/num_test
    test_task2_loss = sum([x[2] for x in epoch_test_losses])/num_test
    # print
    print('['+datetime.now().strftime("%Y-%m-%d %H:%M:%S")+']', 'epoch=[{}/{}], train_mse_overall_loss: {:.6f}, train_mse_task1_loss: {:.6f}, train_mse_task2_loss: {:.6f}, validate_mse_overall_loss: {:.6f}, validate_mse_task1_loss: {:.6f}, validate_mse_task2_loss: {:.6f}'.format(epoch+1, num_epochs,  train_overall_loss, train_task1_loss, train_task2_loss, test_overall_loss, test_task1_loss, test_task2_loss))
    gc.collect()

[2023-09-05 16:37:53] epoch=[1/10], train_mse_overall_loss: 0.032886, train_mse_task1_loss: 0.007735, train_mse_task2_loss: 0.029018, validate_mse_overall_loss: 0.000093, validate_mse_task1_loss: 0.000021, validate_mse_task2_loss: 0.000083
[2023-09-05 16:38:07] epoch=[2/10], train_mse_overall_loss: 0.000051, train_mse_task1_loss: 0.000012, train_mse_task2_loss: 0.000045, validate_mse_overall_loss: 0.000028, validate_mse_task1_loss: 0.000006, validate_mse_task2_loss: 0.000024
[2023-09-05 16:38:20] epoch=[3/10], train_mse_overall_loss: 0.000019, train_mse_task1_loss: 0.000004, train_mse_task2_loss: 0.000017, validate_mse_overall_loss: 0.000013, validate_mse_task1_loss: 0.000003, validate_mse_task2_loss: 0.000011
[2023-09-05 16:38:33] epoch=[4/10], train_mse_overall_loss: 0.000009, train_mse_task1_loss: 0.000002, train_mse_task2_loss: 0.000008, validate_mse_overall_loss: 0.000007, validate_mse_task1_loss: 0.000002, validate_mse_task2_loss: 0.000006
[2023-09-05 16:38:47] epoch=[5/10], trai

In [12]:
# MMOE:
# 一个输入，多个输出。中间共享多个expert，并行每个任务在其上的gate，进入各自任务的tower网络，预测。模块概率共享。
# soft sharing

import os, gc
import numpy as np
import torch
import torch.nn as nn
from torch.nn import Module, Parameter, MSELoss
from torch.utils.data import Dataset, DataLoader, TensorDataset 
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
from datetime import datetime
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')
from torch.nn import Module, Sequential, ReLU, Dropout, Sigmoid

# category: [1,2] + [all]
# number: [0] + []
number_feature_data = MinMaxScaler().fit_transform(np.array([[user_info[u][0]]  for u, i, r in ratings], dtype=np.float32))
category_feature_data = np.array([user_info[u][1:] + item_info[i] for u, i, r in ratings], dtype=np.int32)
data = np.concatenate([number_feature_data, category_feature_data, ratings[:,-1:]], axis=-1)
num_number_features = number_feature_data.shape[-1]
num_category_features = category_feature_data.shape[-1]
num_features = data.shape[-1] - 1
category_feature_vals = {}
for i in range(num_number_features, num_features):
    category_feature_vals[i] = sorted(list(set(list(data[:, i]))))
    for rid in range(data.shape[0]):
        data[rid, i] = category_feature_vals[i].index(data[rid, i])
# print(len(user_info[list(user_info.keys())[0]]), len(item_info[list(item_info.keys())[0]]))
# print(data.shape)

num_users = len(user_info)
num_items = len(item_info)
device = torch.device("cuda:0" if torch.cuda.is_available() else ('mps:0' if torch.backends.mps.is_available() else "cpu"))
batch_size = 100
num_epochs = 10
embedding_dim = 8 # sparse feature embedding dim
X_train, X_test, y_train, y_test = train_test_split(data[:,:-1], data[:,-1], test_size=0.4, random_state=0)
train_loader = DataLoader(dataset=TensorDataset(torch.from_numpy(X_train).float(), torch.from_numpy(y_train>1.8).float(), torch.from_numpy(y_train>3.8).float()), batch_size=batch_size, shuffle=True, pin_memory=True)
test_loader = DataLoader(dataset=TensorDataset(torch.from_numpy(X_test).float(), torch.from_numpy(y_test>1.8).float(), torch.from_numpy(y_test>3.8).float()), batch_size=batch_size, shuffle=False, pin_memory=True)

class MMOE(Module):
    def __init__(self, dense_feature_cols:[(int,int)], sparse_feature_cols:[(int,int)], sparse_feature_embedding_dim, 
                 hidden_dim:int, num_task:int, n_expert:int, dnn_layer_dims:list[int], dnn_dropout=0.):
        super(MMOE, self).__init__()
        self.dense_feature_cols, self.sparse_feature_cols, self.sparse_feature_embedding_dim = dense_feature_cols, sparse_feature_cols, sparse_feature_embedding_dim
        self.num_task, self.n_expert, self.dnn_layer_dims, self.hidden_dim = num_task, n_expert, dnn_layer_dims, hidden_dim
        # sparse feature embedding dict
        self.embed_layers = nn.ModuleDict({'embed_' + str(i): nn.Embedding(num_embeddings=valcount, embedding_dim=sparse_feature_embedding_dim) for i, valcount in sparse_feature_cols})
        self.input_dim = len(dense_feature_cols) + len(sparse_feature_cols) * sparse_feature_embedding_dim
        # experts: 线性层
        self.experts = torch.nn.Parameter(torch.randn(self.input_dim, hidden_dim, n_expert), requires_grad=True)
        self.experts_bias = torch.nn.Parameter(torch.randn(hidden_dim, n_expert), requires_grad=True)
        # gates: 线性层
        self.gates = [torch.nn.Parameter(torch.randn(self.input_dim, n_expert), requires_grad=True) for i in range(num_task)]
        self.gates_bias = [torch.nn.Parameter(torch.randn(n_expert), requires_grad=True) for i in range(num_task)]
        # dnn for each task
        self.dnn_nets = []
        for i in range(num_task):
            net = nn.Sequential()
            pre_layer_dim = hidden_dim
            for layer_dim in dnn_layer_dims:
                net.append(nn.Linear(pre_layer_dim, layer_dim))
                net.append(nn.BatchNorm1d(layer_dim))
                net.append(nn.Dropout(dnn_dropout))
                pre_layer_dim = layer_dim
            net.append(nn.Linear(dnn_layer_dims[-1], 1))
            net.append(nn.Sigmoid())
            self.dnn_nets.append(net)
    def forward(self, x):
        dense_input = x[:, :len(self.dense_feature_cols)]
        sparse_embeds = torch.cat([self.embed_layers['embed_' + str(i)](x[:, i].long()) for i in range(len(self.dense_feature_cols), x.shape[1])], axis=1)
        x = torch.cat([sparse_embeds, dense_input], axis=-1)
        experts_out = torch.einsum('ij, jkl -> ikl', x, self.experts) + self.experts_bias.unsqueeze(0) # [batch, hidden_dim, n_expert]
        gates_out = [torch.softmax(torch.einsum('ab, bc -> ac', x, gate) + self.gates_bias[i].unsqueeze(0), dim=-1) for i, gate in enumerate(self.gates)] # [batch, n_expert]
        towers_input = [torch.bmm(experts_out, gate_out.unsqueeze(-1)).squeeze() for gate_out in gates_out]
        outputs = [self.dnn_nets[i](tower_input).squeeze() for i,tower_input in enumerate(towers_input)]
        return outputs
    def parameters(self, recurse: bool = True):
        paras = [self.experts, self.experts_bias] + self.gates + self.gates_bias
        for layer in self.embed_layers.values():
            paras += [p for p in layer.parameters()]
        for net in self.dnn_nets:
            paras += [p for p in net.parameters()]
        return paras
model = MMOE(dense_feature_cols=[i for i in range(num_number_features)], sparse_feature_cols=[(i,len(category_feature_vals[i])) for i in range(num_number_features, num_features)], sparse_feature_embedding_dim =embedding_dim, hidden_dim=embedding_dim, num_task=2, n_expert=3, dnn_layer_dims=[128, 32], dnn_dropout=0.).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = MSELoss(reduction='sum').to(device)
alpha = 0.5

for epoch in range(num_epochs):
    # train:
    epoch_train_losses = []
    model.train()
    for i, inputs in enumerate(train_loader):
        optimizer.zero_grad()
        input = inputs[0].to(device)
        label_1, label_2 = inputs[1].to(device), inputs[2].to(device)
        output1, output2 = model(input)
        loss_1, loss_2 = criterion(output1, label_1), criterion(output2, label_2)
        loss = alpha * loss_1 + loss_2
        loss.backward()
        # torch.nn.utils.clip_grad_norm_(parameters=model.parameters(), max_norm=1, norm_type=2)
        optimizer.step()
        epoch_train_losses.append([input.shape[0], loss_1.item(), loss_2.item(), loss.item()])
    # validate:
    model.eval()
    epoch_test_losses = []
    for i, inputs in enumerate(test_loader):
        input = inputs[0].to(device)
        label_1, label_2 = inputs[1].to(device), inputs[2].to(device)
        output1, output2 = model(input)
        loss_1, loss_2 = criterion(output1, label_1), criterion(output2, label_2)
        loss = alpha * loss_1 + loss_2
        epoch_test_losses.append([input.shape[0], loss_1.item(), loss_2.item(), loss.item()])
    num_train, num_test = sum([x[0] for x in epoch_train_losses]), sum([x[0] for x in epoch_test_losses])
    train_overall_loss = sum([x[-1] for x in epoch_train_losses])/num_train
    train_task1_loss = sum([x[1] for x in epoch_train_losses])/num_train
    train_task2_loss = sum([x[2] for x in epoch_train_losses])/num_train
    test_overall_loss = sum([x[-1] for x in epoch_test_losses])/num_test
    test_task1_loss = sum([x[1] for x in epoch_test_losses])/num_test
    test_task2_loss = sum([x[2] for x in epoch_test_losses])/num_test
    # print
    print('['+datetime.now().strftime("%Y-%m-%d %H:%M:%S")+']', 'epoch=[{}/{}], train_mse_overall_loss: {:.6f}, train_mse_task1_loss: {:.6f}, train_mse_task2_loss: {:.6f}, validate_mse_overall_loss: {:.6f}, validate_mse_task1_loss: {:.6f}, validate_mse_task2_loss: {:.6f}'.format(epoch+1, num_epochs,  train_overall_loss, train_task1_loss, train_task2_loss, test_overall_loss, test_task1_loss, test_task2_loss))
    gc.collect()

# 这种方法共享很稀疏，收敛很快

[2023-09-05 16:40:36] epoch=[1/10], train_mse_overall_loss: 0.012091, train_mse_task1_loss: 0.007442, train_mse_task2_loss: 0.008370, validate_mse_overall_loss: 0.000036, validate_mse_task1_loss: 0.000023, validate_mse_task2_loss: 0.000025
[2023-09-05 16:40:45] epoch=[2/10], train_mse_overall_loss: 0.000020, train_mse_task1_loss: 0.000013, train_mse_task2_loss: 0.000014, validate_mse_overall_loss: 0.000011, validate_mse_task1_loss: 0.000007, validate_mse_task2_loss: 0.000007
[2023-09-05 16:40:54] epoch=[3/10], train_mse_overall_loss: 0.000007, train_mse_task1_loss: 0.000005, train_mse_task2_loss: 0.000005, validate_mse_overall_loss: 0.000005, validate_mse_task1_loss: 0.000003, validate_mse_task2_loss: 0.000003
[2023-09-05 16:41:04] epoch=[4/10], train_mse_overall_loss: 0.000004, train_mse_task1_loss: 0.000002, train_mse_task2_loss: 0.000003, validate_mse_overall_loss: 0.000003, validate_mse_task1_loss: 0.000002, validate_mse_task2_loss: 0.000002
[2023-09-05 16:41:12] epoch=[5/10], trai

In [28]:
# CGC:
# 一个输入，多个输出。各自任务有自己的expert外，中间共享多个expert，并行每个任务在其上的gate，进入各自任务的tower网络，预测。模块概率共享。
# 相较于MMOE，增加了任务私有的expert
# soft sharing

import os, gc
import numpy as np
import torch
import torch.nn as nn
from torch.nn import Module, Parameter, MSELoss
from torch.utils.data import Dataset, DataLoader, TensorDataset 
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
from datetime import datetime
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')
from torch.nn import Module, Sequential, ReLU, Dropout, Sigmoid

# category: [1,2] + [all]
# number: [0] + []
number_feature_data = MinMaxScaler().fit_transform(np.array([[user_info[u][0]]  for u, i, r in ratings], dtype=np.float32))
category_feature_data = np.array([user_info[u][1:] + item_info[i] for u, i, r in ratings], dtype=np.int32)
data = np.concatenate([number_feature_data, category_feature_data, ratings[:,-1:]], axis=-1)
num_number_features = number_feature_data.shape[-1]
num_category_features = category_feature_data.shape[-1]
num_features = data.shape[-1] - 1
category_feature_vals = {}
for i in range(num_number_features, num_features):
    category_feature_vals[i] = sorted(list(set(list(data[:, i]))))
    for rid in range(data.shape[0]):
        data[rid, i] = category_feature_vals[i].index(data[rid, i])
# print(len(user_info[list(user_info.keys())[0]]), len(item_info[list(item_info.keys())[0]]))
# print(data.shape)

num_users = len(user_info)
num_items = len(item_info)
device = torch.device("cuda:0" if torch.cuda.is_available() else ('mps:0' if torch.backends.mps.is_available() else "cpu"))
batch_size = 100
num_epochs = 10
embedding_dim = 8 # sparse feature embedding dim
X_train, X_test, y_train, y_test = train_test_split(data[:,:-1], data[:,-1], test_size=0.4, random_state=0)
train_loader = DataLoader(dataset=TensorDataset(torch.from_numpy(X_train).float(), torch.from_numpy(y_train>1.8).float(), torch.from_numpy(y_train>3.8).float()), batch_size=batch_size, shuffle=True, pin_memory=True)
test_loader = DataLoader(dataset=TensorDataset(torch.from_numpy(X_test).float(), torch.from_numpy(y_test>1.8).float(), torch.from_numpy(y_test>3.8).float()), batch_size=batch_size, shuffle=False, pin_memory=True)

class Expert(Module):
    def __init__(self, input_dim, hidden_dim:int, n_expert:int):
        super(Expert, self).__init__()
        self.input_dim, self.hidden_dim, self.n_expert = input_dim, hidden_dim, n_expert
        # experts: 线性层
        self.experts = torch.nn.Parameter(torch.randn(self.input_dim, hidden_dim, n_expert), requires_grad=True)
        self.experts_bias = torch.nn.Parameter(torch.randn(hidden_dim, n_expert), requires_grad=True)
    def forward(self, x):
        expert_out = torch.einsum('ij, jkl -> ikl', x, self.experts) + self.experts_bias.unsqueeze(0) # [batch, hidden_dim, n_expert]
        return expert_out.permute((0,2,1)) # [batch, n_expert, hidden_dim]
    def parameters(self, recurse: bool = True):
        return [self.experts, self.experts_bias]
class CGC(Module):
    def __init__(self, dense_feature_cols:[(int,int)], sparse_feature_cols:[(int,int)], sparse_feature_embedding_dim, 
                 hidden_dim:int, num_task:int, shared_n_expert:int, self_n_expert:int, dnn_layer_dims:list[int], dnn_dropout=0.):
        super(CGC, self).__init__()
        self.dense_feature_cols, self.sparse_feature_cols, self.sparse_feature_embedding_dim = dense_feature_cols, sparse_feature_cols, sparse_feature_embedding_dim
        self.num_task, self.shared_n_expert, self.self_n_expert, self.dnn_layer_dims, self.hidden_dim = num_task, shared_n_expert, self_n_expert, dnn_layer_dims, hidden_dim
        # sparse feature embedding dict
        self.embed_layers = nn.ModuleDict({'embed_' + str(i): nn.Embedding(num_embeddings=valcount, embedding_dim=sparse_feature_embedding_dim) for i, valcount in sparse_feature_cols})
        self.input_dim = len(dense_feature_cols) + len(sparse_feature_cols) * sparse_feature_embedding_dim
        # experts
        self.shared_expert = Expert(self.input_dim, hidden_dim, shared_n_expert)
        self.self_experts = [Expert(self.input_dim, hidden_dim, self_n_expert) for i in range(num_task)]
        # gates:
        self.gates = [nn.Sequential(nn.Linear((shared_n_expert + self_n_expert) * hidden_dim, shared_n_expert + self_n_expert), nn.Softmax()) for i in range(num_task)]
        # dnn for each task
        self.dnn_nets = []
        for i in range(num_task):
            net = nn.Sequential()
            pre_layer_dim = hidden_dim
            for layer_dim in dnn_layer_dims:
                net.append(nn.Linear(pre_layer_dim, layer_dim))
                net.append(nn.BatchNorm1d(layer_dim))
                net.append(nn.Dropout(dnn_dropout))
                pre_layer_dim = layer_dim
            net.append(nn.Linear(dnn_layer_dims[-1], 1))
            net.append(nn.Sigmoid())
            self.dnn_nets.append(net)
    def forward(self, x):
        batch_len = x.shape[0]
        dense_input = x[:, :len(self.dense_feature_cols)]
        sparse_embeds = torch.cat([self.embed_layers['embed_' + str(i)](x[:, i].long()) for i in range(len(self.dense_feature_cols), x.shape[1])], axis=1)
        x = torch.cat([sparse_embeds, dense_input], axis=-1)
        # torch.Size([100, 169]) torch.Size([100, 8, 3]) [torch.Size([100, 8, 2]), torch.Size([100, 8, 2])]
        shared_expert_out = self.shared_expert(x)
        experts_out = [torch.cat([shared_expert_out, self.self_experts[i](x)], dim=1) for i in range(self.num_task)]
        gates_out = [gate(experts_out[i].reshape((batch_len,-1))) for i, gate in enumerate(self.gates)]
        towers_input = [torch.bmm(experts_out[i].permute((0,2,1)), gate_out.unsqueeze(-1)).squeeze() for i, gate_out in enumerate(gates_out)]
        outputs = [self.dnn_nets[i](tower_input).squeeze() for i,tower_input in enumerate(towers_input)]
        return outputs
    def parameters(self, recurse: bool = True):
        paras = [p for p in self.shared_expert.parameters()]
        for expert in self.self_experts:
            paras += [p for p in expert.parameters()]
        for gate in self.gates:
            paras += [p for p in gate.parameters()]
        for layer in self.embed_layers.values():
            paras += [p for p in layer.parameters()]
        for net in self.dnn_nets:
            paras += [p for p in net.parameters()]
        return paras
model = CGC(dense_feature_cols=[i for i in range(num_number_features)], sparse_feature_cols=[(i,len(category_feature_vals[i])) for i in range(num_number_features, num_features)], sparse_feature_embedding_dim =embedding_dim, hidden_dim=embedding_dim, num_task=2, shared_n_expert=3, self_n_expert=2, dnn_layer_dims=[128, 32], dnn_dropout=0.).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = MSELoss(reduction='sum').to(device)
alpha = 0.5

for epoch in range(num_epochs):
    # train:
    epoch_train_losses = []
    model.train()
    for i, inputs in enumerate(train_loader):
        optimizer.zero_grad()
        input = inputs[0].to(device)
        label_1, label_2 = inputs[1].to(device), inputs[2].to(device)
        output1, output2 = model(input)
        loss_1, loss_2 = criterion(output1, label_1), criterion(output2, label_2)
        loss = alpha * loss_1 + loss_2
        loss.backward()
        # torch.nn.utils.clip_grad_norm_(parameters=model.parameters(), max_norm=1, norm_type=2)
        optimizer.step()
        epoch_train_losses.append([input.shape[0], loss_1.item(), loss_2.item(), loss.item()])
    # validate:
    model.eval()
    epoch_test_losses = []
    for i, inputs in enumerate(test_loader):
        input = inputs[0].to(device)
        label_1, label_2 = inputs[1].to(device), inputs[2].to(device)
        output1, output2 = model(input)
        loss_1, loss_2 = criterion(output1, label_1), criterion(output2, label_2)
        loss = alpha * loss_1 + loss_2
        epoch_test_losses.append([input.shape[0], loss_1.item(), loss_2.item(), loss.item()])
    num_train, num_test = sum([x[0] for x in epoch_train_losses]), sum([x[0] for x in epoch_test_losses])
    train_overall_loss = sum([x[-1] for x in epoch_train_losses])/num_train
    train_task1_loss = sum([x[1] for x in epoch_train_losses])/num_train
    train_task2_loss = sum([x[2] for x in epoch_train_losses])/num_train
    test_overall_loss = sum([x[-1] for x in epoch_test_losses])/num_test
    test_task1_loss = sum([x[1] for x in epoch_test_losses])/num_test
    test_task2_loss = sum([x[2] for x in epoch_test_losses])/num_test
    # print
    print('['+datetime.now().strftime("%Y-%m-%d %H:%M:%S")+']', 'epoch=[{}/{}], train_mse_overall_loss: {:.6f}, train_mse_task1_loss: {:.6f}, train_mse_task2_loss: {:.6f}, validate_mse_overall_loss: {:.6f}, validate_mse_task1_loss: {:.6f}, validate_mse_task2_loss: {:.6f}'.format(epoch+1, num_epochs,  train_overall_loss, train_task1_loss, train_task2_loss, test_overall_loss, test_task1_loss, test_task2_loss))
    gc.collect()


[2023-09-05 17:28:46] epoch=[1/10], train_mse_overall_loss: 0.011408, train_mse_task1_loss: 0.007582, train_mse_task2_loss: 0.007617, validate_mse_overall_loss: 0.000035, validate_mse_task1_loss: 0.000027, validate_mse_task2_loss: 0.000022
[2023-09-05 17:28:57] epoch=[2/10], train_mse_overall_loss: 0.000019, train_mse_task1_loss: 0.000014, train_mse_task2_loss: 0.000012, validate_mse_overall_loss: 0.000010, validate_mse_task1_loss: 0.000008, validate_mse_task2_loss: 0.000006
[2023-09-05 17:29:11] epoch=[3/10], train_mse_overall_loss: 0.000007, train_mse_task1_loss: 0.000005, train_mse_task2_loss: 0.000004, validate_mse_overall_loss: 0.000005, validate_mse_task1_loss: 0.000004, validate_mse_task2_loss: 0.000003
[2023-09-05 17:29:25] epoch=[4/10], train_mse_overall_loss: 0.000004, train_mse_task1_loss: 0.000003, train_mse_task2_loss: 0.000002, validate_mse_overall_loss: 0.000003, validate_mse_task1_loss: 0.000002, validate_mse_task2_loss: 0.000002
[2023-09-05 17:29:36] epoch=[5/10], trai