# 多目标学习

## 数据
### 人口普查数据简介


数据集信息：
```
Barry Becker 从 1994 年人口普查数据库中提取。使用以下条件提取了一组合理干净的记录： ((AAGE>16) && (AGI>100) && (AFNLWGT>1)&& (HRSWK>0))

预测任务是确定一个人是否赚了超过 50K年。

>50K，<=50K。
属性列表：
年龄：连续。
工作班级：私人、Self-emp-not-inc、Self-emp-inc、联邦政府、地方政府、州政府、无薪、从未工作。
fnlwgt：连续。
教育：学士，一些大学，11th，HS-grad，教授学校，Assoc-acdm，Assoc-voc，9th，7th-8th，12th，硕士，1st-4th，10th，博士，5th-6th，学前班。
教育编号：连续。
婚姻状况：已婚公民配偶、离婚、未婚、分居、丧偶、已婚配偶缺席、已婚 AF 配偶。
职业：技术支持、工艺维修、其他服务、销售、执行管理、专业教授、处理清洁工、机器操作检查、行政文员、农业-捕鱼、运输-搬家、私人住宅- serv，保护性服务，武装部队。
关系：妻子、自己的孩子、丈夫、非家庭成员、其他亲属、未婚。
种族：白人、亚太岛民、美洲印第安人-爱斯基摩人、其他、黑人。
性别：女，男。
资本收益：持续。
资本损失：持续。
每周小时数：连续。
祖国：美国、柬埔寨、英国、波多黎各、加拿大、德国、美国边远地区（关岛-USVI-etc）、印度、日本、希腊、南部、中国、古巴、伊朗、洪都拉斯、菲律宾、意大利、波兰、牙买加、越南、墨西哥、葡萄牙、爱尔兰、法国、多米尼加共和国、老挝、厄瓜多尔、台湾、海地、哥伦比亚、匈牙利、危地马拉、尼加拉瓜、苏格兰、泰国、南斯拉夫、萨尔瓦多、特立纳达和多巴哥、秘鲁、香港，荷兰-荷兰。
```


### EDA

In [1]:
import pandas as pd
from sklearn.preprocessing import MinMaxScaler
from sklearn.preprocessing import LabelEncoder

In [2]:
!ls data/

adult.data  adult.names  adult.test


In [3]:
column_names = ['age', 'workclass', 'fnlwgt', 'education', 'education_num', 'marital_status', 'occupation',
                'relationship', 'race', 'sex', 'capital_gain', 'capital_loss', 'hours_per_week', 'native_country',
                'income_50k']
train_df = pd.read_csv("./data/adult.data", header=None, names=column_names)
train_df.head()

Unnamed: 0,age,workclass,fnlwgt,education,education_num,marital_status,occupation,relationship,race,sex,capital_gain,capital_loss,hours_per_week,native_country,income_50k
0,39,State-gov,77516,Bachelors,13,Never-married,Adm-clerical,Not-in-family,White,Male,2174,0,40,United-States,<=50K
1,50,Self-emp-not-inc,83311,Bachelors,13,Married-civ-spouse,Exec-managerial,Husband,White,Male,0,0,13,United-States,<=50K
2,38,Private,215646,HS-grad,9,Divorced,Handlers-cleaners,Not-in-family,White,Male,0,0,40,United-States,<=50K
3,53,Private,234721,11th,7,Married-civ-spouse,Handlers-cleaners,Husband,Black,Male,0,0,40,United-States,<=50K
4,28,Private,338409,Bachelors,13,Married-civ-spouse,Prof-specialty,Wife,Black,Female,0,0,40,Cuba,<=50K


In [4]:
train_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 32561 entries, 0 to 32560
Data columns (total 15 columns):
 #   Column          Non-Null Count  Dtype 
---  ------          --------------  ----- 
 0   age             32561 non-null  int64 
 1   workclass       32561 non-null  object
 2   fnlwgt          32561 non-null  int64 
 3   education       32561 non-null  object
 4   education_num   32561 non-null  int64 
 5   marital_status  32561 non-null  object
 6   occupation      32561 non-null  object
 7   relationship    32561 non-null  object
 8   race            32561 non-null  object
 9   sex             32561 non-null  object
 10  capital_gain    32561 non-null  int64 
 11  capital_loss    32561 non-null  int64 
 12  hours_per_week  32561 non-null  int64 
 13  native_country  32561 non-null  object
 14  income_50k      32561 non-null  object
dtypes: int64(6), object(9)
memory usage: 3.7+ MB


In [5]:
test_df = pd.read_csv("./data/adult.test", delimiter=","
                      , names=column_names
                      , header=None
                     )
test_df.head()

Unnamed: 0,age,workclass,fnlwgt,education,education_num,marital_status,occupation,relationship,race,sex,capital_gain,capital_loss,hours_per_week,native_country,income_50k
0,|1x3 Cross validator,,,,,,,,,,,,,,
1,25,Private,226802.0,11th,7.0,Never-married,Machine-op-inspct,Own-child,Black,Male,0.0,0.0,40.0,United-States,<=50K.
2,38,Private,89814.0,HS-grad,9.0,Married-civ-spouse,Farming-fishing,Husband,White,Male,0.0,0.0,50.0,United-States,<=50K.
3,28,Local-gov,336951.0,Assoc-acdm,12.0,Married-civ-spouse,Protective-serv,Husband,White,Male,0.0,0.0,40.0,United-States,>50K.
4,44,Private,160323.0,Some-college,10.0,Married-civ-spouse,Machine-op-inspct,Husband,Black,Male,7688.0,0.0,40.0,United-States,>50K.


In [6]:
# 设置标记
train_df["tag"] = 1
test_df["tag"] = 0
# 划分数据
test_df.dropna(inplace=True)
# 规范化数据
test_df["income_50k"] = test_df["income_50k"].apply(lambda x: x[:-1])

In [7]:
# 合并数据
data = pd.concat([train_df, test_df])
data.reset_index(inplace=True, drop=True)

### 数据预处理

In [8]:
label_columns = ['income_50k', 'marital_status']

# categorical columns
categorical_columns = ['workclass', 'education', 'occupation', 'relationship', 'race', 'sex', 'native_country']
for col in label_columns:
    if col == 'income_50k':
        data[col] = data[col].apply(lambda x: 0 if x == ' <=50K' else 1)
    else:
        data[col] = data[col].apply(lambda x: 0 if x == ' Never-married' else 1)

In [9]:
# feature engine
for col in column_names:
    if col not in label_columns + ['tag']:
        if col in categorical_columns:
            le = LabelEncoder()
            data[col] = le.fit_transform(data[col])
        else:
            mm = MinMaxScaler()
            data[col] = mm.fit_transform(data[[col]]).reshape(-1)
            
data = data[['age', 'workclass', 'fnlwgt', 'education', 'education_num', 'occupation',
             'relationship', 'race', 'sex', 'capital_gain', 'capital_loss', 'hours_per_week', 'native_country',
             'income_50k', 'marital_status', 'tag']]

In [10]:
# 编码用户特征和物品特征
user_feat_dict, item_feat_dict = dict(), dict()

for idx, col in enumerate(data.columns):
    if col not in label_columns + ["tag"]:
        #　用户特征
        if idx < 7:
            if col in categorical_columns:
                user_feat_dict[col] = (data[col].nunique() + 1, idx)
            else:
                user_feat_dict[col] = (1, idx)
        # 物品特征
        else:
            if col in categorical_columns:
                item_feat_dict[col] = (data[col].nunique() + 1, idx)
            else:
                item_feat_dict[col] = (1, idx)

                user_feat_dict, item_feat_dict

In [11]:
# 重新划分数据集
train_data, test_data = data[data["tag"] == 1], data[data["tag"] == 0]

In [12]:
train_data.drop(columns="tag", inplace=True)
test_data.drop(columns="tag", inplace=True)

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  return super().drop(


## 算法

In [13]:
import torch
import torch.nn as nn
from torchsummary import summary
from torch.utils.tensorboard import SummaryWriter
from torch.optim import Adam
import numpy as np
import torchsnooper
from tqdm import tqdm
from sklearn.metrics import roc_auc_score
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

### 自定义数据格式

In [14]:
class TrainDateSet(Dataset):
    
    def __init__(self, data):
        
        self.features = data[0]
        self.label1 = data[1]
        self.label2 = data[2]
    
    def __getitem__(self, index):
        
        return self.features[index], self.label1[index], self.label2[index]
    
    def __len__(self):
        
        return len(self.features)

In [15]:
# 加载Dataloader
train_datasets = (train_data.iloc[:, :-2].values, train_data.iloc[:, -2].values, train_data.iloc[:, -1].values)
test_datasets = (test_data.iloc[:, :-2].values, test_data.iloc[:, -2].values, test_data.iloc[:, -1].values)
train_datasets = TrainDateSet(train_datasets)
test_datasets = TrainDateSet(test_datasets)

### ESMM模型
![](./imgs/ESMM.png)

In [16]:
class ESMM(nn.Module):

    def __init__(self, user_feature_dict, item_feature_dict, emb_dim=128, hidden_dim=[128, 64], dropouts=[0.5, 0.5],
                 output_size=1, task_name=["ctr", "cvr"]):
        """

        :param user_feature_dict: 用户特征
        :param item_feature_dict:　物品特征
        :param emb_dim: 128
        :param hidden_dim: [128, 64]
        :param dropout: 0.5
        :param output_size: 1
        :param num_tasks:2
        """
        super(ESMM, self).__init__()

        if user_feature_dict is None or item_feature_dict is None:
            Exception("用户特征和物品特征不能为空！")
        if isinstance(user_feature_dict, dict) is False or isinstance(item_feature_dict, dict):
            Exception("输入数据类型必须为字典类型！")

        self.user_feature_dict = user_feature_dict
        self.item_feature_dict = item_feature_dict
        self.num_tasks = len(task_name)
        self.task_name = task_name

        # 共享Embedding(Share bottom)
        user_cate_feature_nums, item_cate_feature_nums = 0, 0
        # 用户特征Embedding编码
        for user_cate, num in self.user_feature_dict.items():
            # 必须为Spase Feature
            if num[0] > 1:
                user_cate_feature_nums += 1
                setattr(self, user_cate, nn.Embedding(num[0], emb_dim))
        # 物品特征
        for item_cate, num in self.item_feature_dict.items():
            if num[0] > 1:
                item_cate_feature_nums += 1
                setattr(self, item_cate, nn.Embedding(num[0], emb_dim))

        # 构建独立任务（tower）
        # Spase feat + Dense feat
        hidden_size = emb_dim * (user_cate_feature_nums + item_cate_feature_nums) \
                      + (len(self.user_feature_dict) - user_cate_feature_nums) \
                      + (len(self.item_feature_dict) - item_cate_feature_nums)

        for i in range(self.num_tasks):
            setattr(self, 'task_{}_dnn'.format(self.task_name[i]), nn.ModuleList())
            hid_dim = [hidden_size] + hidden_dim
            for j in range(len(hid_dim) - 1):
                getattr(self, 'task_{}_dnn'.format(self.task_name[i])).add_module('hidden_{}'.format(j),
                                                                      nn.Linear(hid_dim[j], hid_dim[j + 1]))
                getattr(self, 'task_{}_dnn'.format(self.task_name[i])).add_module('batchnorm_{}'.format(j),
                                                                      nn.BatchNorm1d(hid_dim[j + 1]))
                getattr(self, "task_{}_dnn".format(self.task_name[i])).add_module("{}_activation".format(task_name[i])
                                                                             , nn.ReLU())
                getattr(self, 'task_{}_dnn'.format(self.task_name[i])).add_module('dropout_{}'.format(j),
                                                                      nn.Dropout(dropouts[j]))
            getattr(self, 'task_{}_dnn'.format(self.task_name[i])).add_module('task_{}_last_layer'.format(j),
                                                                  nn.Linear(hid_dim[-1], output_size))

    def forward(self, x):
#         assert x.size()[1] != len(self.item_feature_dict) + len(self.user_feature_dict)
        # 编码Embedding向量
        user_embed_list, item_embed_list = list(), list()
        for user_feature, num in self.user_feature_dict.items():
            if num[0] > 1:
                user_embed_list.append(getattr(self, user_feature)(x[:, num[1]].long()))
            else:
                user_embed_list.append(x[:, num[1]].unsqueeze(1))
        for item_feature, num in self.item_feature_dict.items():
            if num[0] > 1:
                item_embed_list.append(getattr(self, item_feature)(x[:, num[1]].long()))
            else:
                item_embed_list.append(x[:, num[1]].unsqueeze(1))
        # 拼接向量
        user_embed = torch.cat(user_embed_list, dim=1)
        item_embed = torch.cat(item_embed_list, dim=1)
        # hidden_input
        hidden = torch.cat([user_embed, item_embed], axis=1).float()

        # 子网络
        task_outputs = list()
        for i in range(self.num_tasks):
            x = hidden
            #　Module list
            for mod in getattr(self,'task_{}_dnn'.format(self.task_name[i])):
                x = mod(x)
            task_outputs.append(x)
        
        if self.num_tasks == 2:
            
            pCTCVR = torch.mul(task_outputs[0], task_outputs[1])
            pCVR = task_outputs[0]
            
            return pCTCVR, pCVR
        elif len(self.num_tasks) == 1:
            return task_outputs
        else:
            Exception("目标数目为：1或２!")

In [17]:
a = torch.from_numpy(np.array([[1, 2, 4, 2, 0.5, 0.1],
                               [4, 5, 3, 8, 0.6, 0.43],
                               [6, 3, 2, 9, 0.12, 0.32],
                               [9, 1, 1, 1, 0.12, 0.45],
                               [8, 3, 1, 4, 0.21, 0.67]]))

user_cate_dict = {'user_id': (11, 0), 'user_list': (12, 3), 'user_num': (1, 4)}
item_cate_dict = {'item_id': (8, 1), 'item_cate': (6, 2), 'item_num': (1, 5)}
esmm = ESMM(user_cate_dict, item_cate_dict)
esmm

ESMM(
  (user_id): Embedding(11, 128)
  (user_list): Embedding(12, 128)
  (item_id): Embedding(8, 128)
  (item_cate): Embedding(6, 128)
  (task_ctr_dnn): ModuleList(
    (hidden_0): Linear(in_features=514, out_features=128, bias=True)
    (batchnorm_0): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (ctr_activation): ReLU()
    (dropout_0): Dropout(p=0.5, inplace=False)
    (hidden_1): Linear(in_features=128, out_features=64, bias=True)
    (batchnorm_1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (dropout_1): Dropout(p=0.5, inplace=False)
    (task_1_last_layer): Linear(in_features=64, out_features=1, bias=True)
  )
  (task_cvr_dnn): ModuleList(
    (hidden_0): Linear(in_features=514, out_features=128, bias=True)
    (batchnorm_0): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (cvr_activation): ReLU()
    (dropout_0): Dropout(p=0.5, inplace=False)
    (hidden_1): Linea

In [18]:
tasks = esmm(a)
print(tasks)

(tensor([[-0.0026],
        [ 1.4870],
        [ 0.9397],
        [-0.6603],
        [-0.0971]], grad_fn=<MulBackward0>), tensor([[ 0.0644],
        [-0.7170],
        [ 1.0652],
        [-0.3548],
        [ 0.4381]], grad_fn=<AddmmBackward>))


In [19]:
w = SummaryWriter(log_dir="./log", comment="model info")
w.add_graph(esmm, a)
w.close()

### MMoE
![](./imgs/mmoe.png)

In [1]:
# @torchsnooper.snoop()
class MMoE(nn.Module):

    def __init__(self, user_feature_dict, item_feature_dict, emb_dim=128, n_expert=3, mmoe_hidden_dim=128,
                 hidden_dim=[128, 64], output_size=1, num_tasks=2, expert_activation=None):
        """

        :param user_feature_dict:
        :param item_feature_dict:
        :param emb_dim:
        :param n_expert:
        :param mmoe_hidden_dim:
        :param hidden_dim:
        :param output_size:
        :param num_tasks:
        """
        super(MMoE, self).__init__()

        if user_feature_dict is None or item_feature_dict is None:
            Exception("用户特征和物品特征不能为空！")
        if isinstance(user_feature_dict, dict) is False or isinstance(item_feature_dict, dict):
            Exception("输入数据类型必须为字典类型！")

        self.user_feature_dict = user_feature_dict
        self.item_feature_dict = item_feature_dict
        self.num_tasks = num_tasks

        # 共享Embedding(Share bottom)
        user_cate_feature_nums, item_cate_feature_nums = 0, 0
        # 用户特征Embedding编码
        for user_cate, num in self.user_feature_dict.items():
            # 必须为Spase Feature
            if num[0] > 1:
                user_cate_feature_nums += 1
                setattr(self, user_cate, nn.Embedding(num[0], emb_dim))
        # 物品特征
        for item_cate, num in self.item_feature_dict.items():
            if num[0] > 1:
                item_cate_feature_nums += 1
                setattr(self, item_cate, nn.Embedding(num[0], emb_dim))

        # 构建独立任务（tower）
        # Spase feat + Dense feat
        hidden_size = emb_dim * (user_cate_feature_nums + item_cate_feature_nums) \
                      + (len(self.user_feature_dict) - user_cate_feature_nums) \
                      + (len(self.item_feature_dict) - item_cate_feature_nums)

        # 专家网络
        self.erperts = torch.nn.Parameter(torch.rand(hidden_size, mmoe_hidden_dim, n_expert), requires_grad=True)
        self.erperts.data.normal_(0, 1)
        self.erperts_bias = torch.nn.Parameter(torch.rand(mmoe_hidden_dim, n_expert), requires_grad=True)

        # 门控网络
        self.gates = torch.nn.ParameterList([torch.nn.Parameter(torch.rand(hidden_size, n_expert), requires_grad=True)
                      for _ in range(num_tasks)])
        for gate in self.gates:
            gate.data.normal_(0, 1,)

        self.gate_bias = torch.nn.ParameterList([torch.nn.Parameter(torch.rand(n_expert), requires_grad=True) for _ in range(num_tasks)])

        for i in range(self.num_tasks):
            setattr(self, 'task_{}_dnn'.format(i + 1), nn.ModuleList())
            # input: mmoe_hidden_dim + hidden_dim
            hid_dim = [mmoe_hidden_dim] + hidden_dim
            for j in range(len(hid_dim) - 1):
                getattr(self, 'task_{}_dnn'.format(i + 1)).add_module('hidden_{}'.format(j),
                                                                      nn.Linear(hid_dim[j], hid_dim[j + 1]))
                getattr(self, 'task_{}_dnn'.format(i + 1)).add_module('batchnorm_{}'.format(j),
                                                                      nn.BatchNorm1d(hid_dim[j + 1]))
            getattr(self, 'task_{}_dnn'.format(i + 1)).add_module('task_last_layer',
                                                                  nn.Linear(hid_dim[-1], output_size))
            
        self.Softmax = nn.Softmax(dim=-1)

    def forward(self, x):

        assert x.size()[1] == len(self.item_feature_dict) + len(self.user_feature_dict)
        # 编码Embedding向量
        user_embed_list, item_embed_list = list(), list()
        for user_feature, num in self.user_feature_dict.items():
            if num[0] > 1:
                user_embed_list.append(getattr(self, user_feature)(x[:, num[1]].long()))
            else:
                user_embed_list.append(x[:, num[1]].unsqueeze(1))
        for item_feature, num in self.item_feature_dict.items():
            if num[0] > 1:
                item_embed_list.append(getattr(self, item_feature)(x[:, num[1]].long()))
            else:
                item_embed_list.append(x[:, num[1]].unsqueeze(1))
        # 拼接向量
        user_embed = torch.cat(user_embed_list, dim=1)
        item_embed = torch.cat(item_embed_list, dim=1)
        # hidden_input
        # B*hidden
        hidden = torch.cat([user_embed, item_embed], dim=1).float()
        # MMoE
        expert_outs = torch.matmul(hidden, self.erperts.permute(1, 0, 2)).permute(1, 0, 2)  # B*mmoe_hidden_dim*experts
        expert_outs += self.erperts_bias
        # 门控单元
        gates_out = list()
        for idx, gate in enumerate(self.gates):
            gate_out = torch.mm(hidden, gate)  # B * num_experts
            if self.gate_bias:
                gate_out += self.gate_bias[idx]
            # 归一化
            gate_out = self.Softmax(gate_out)
            gates_out.append(gate_out)
        # 各个模块
        outs = list()
        for gate_out in gates_out:
            expand_gate_out = torch.unsqueeze(gate_out, dim=1)  # B * 1 * experts
            weighted_expert_output = expert_outs * expand_gate_out.expand_as(expert_outs)  # B * mmoe_hidden * expert
            outs.append(torch.sum(weighted_expert_output, 2))  # B * mmoe_hidden

        # task_tower
        task_outputs = list()
        for i in range(self.num_tasks):
            x = outs[i]
            for mod in getattr(self, 'task_{}_dnn'.format(i + 1)):
                x = mod(x)
            task_outputs.append(x)

        return task_outputs

NameError: name 'nn' is not defined

In [24]:
mmoe = MMoE(user_cate_dict, item_cate_dict)
mmoe

MMoE(
  (user_id): Embedding(11, 128)
  (user_list): Embedding(12, 128)
  (item_id): Embedding(8, 128)
  (item_cate): Embedding(6, 128)
  (gates): ParameterList(
      (0): Parameter containing: [torch.FloatTensor of size 514x3]
      (1): Parameter containing: [torch.FloatTensor of size 514x3]
  )
  (gate_bias): ParameterList(
      (0): Parameter containing: [torch.FloatTensor of size 3]
      (1): Parameter containing: [torch.FloatTensor of size 3]
  )
  (task_1_dnn): ModuleList(
    (hidden_0): Linear(in_features=128, out_features=128, bias=True)
    (batchnorm_0): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (hidden_1): Linear(in_features=128, out_features=64, bias=True)
    (batchnorm_1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (task_last_layer): Linear(in_features=64, out_features=1, bias=True)
  )
  (task_2_dnn): ModuleList(
    (hidden_0): Linear(in_features=128, out_features=128, bias=True)
  

In [25]:
outs = mmoe(a)
outs

[tensor([[-0.1679],
         [ 0.1511],
         [-0.0939],
         [ 0.5660],
         [-0.1203]], grad_fn=<AddmmBackward>),
 tensor([[-0.5108],
         [ 0.0298],
         [ 1.0607],
         [-0.0723],
         [ 0.0566]], grad_fn=<AddmmBackward>)]

In [26]:
w = SummaryWriter(log_dir="./log", comment="model info")
w.add_graph(esmm, a)
w.close()

## 训练与评估

In [27]:
# 定义超参数
learning_rate = 0.01
epochs = 100
count = 0
writer = SummaryWriter("./log", comment="mertics")
device = torch.device("cuda")

model = MMoE(user_feat_dict, item_feat_dict)
model.to(device)

optimizer = Adam(model.parameters(), lr = learning_rate)
loss_fun = nn.BCEWithLogitsLoss()

train_dataload = DataLoader(train_datasets, batch_size = 64, shuffle = True)
test_dataload = DataLoader(test_datasets, batch_size = 64)

In [28]:
for epoch in tqdm(range(epochs)):
    y_train_income_true = []
    y_train_income_predict = []
    y_train_marry_true = []
    y_train_marry_predict = []
    total_loss, count = 0, 0
    for x, y1, y2 in train_dataload:
        x, y1, y2 = x.to(device), y1.to(device), y2.to(device)
        predict = model(x)
        y_train_income_true += list(y1.squeeze().cpu().numpy())
        y_train_marry_true += list(y2.squeeze().cpu().numpy())
        y_train_income_predict += list(predict[0].squeeze().cpu().detach().numpy())
        y_train_marry_predict += list(predict[1].squeeze().cpu().detach().numpy())
        loss1 = loss_fun(predict[0], y1.unsqueeze(1).float())
        loss2 = loss_fun(predict[1], y2.unsqueeze(1).float())
        loss = loss1 + loss2
        # 梯度更新
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += float(loss)
        count += 1
    
    y1_auc = roc_auc_score(y_train_income_true, y_train_income_predict)
    y2_auc = roc_auc_score(y_train_marry_true, y_train_marry_predict)
    loss_value = total_loss / count
    print(loss_value, y1_auc, y2_auc)
    writer.add_scalar("Train loss", loss_value, global_step = epoch)
    writer.add_scalar("Train_y1_auc", y1_auc, global_step = epoch)
    writer.add_scalar("Train_y2_auc", y2_auc, global_step = epoch)
    
    # 验证
    total_eval_loss = 0
    model.eval()
    count_eval = 0
    y_val_income_true = []
    y_val_marry_true = []
    y_val_income_predict = []
    y_val_marry_predict = []
    for x, y1, y2 in test_dataload:
        x, y1, y2 = x.to(device), y1.to(device), y2.to(device)
        predict = model(x)
        y_val_income_true += list(y1.squeeze().cpu().numpy())
        y_val_marry_true += list(y2.squeeze().cpu().numpy())
        y_val_income_predict += list(predict[0].squeeze().cpu().detach().numpy())
        y_val_marry_predict += list(predict[1].squeeze().cpu().detach().numpy())
        loss_1 = loss_fun(predict[0], y1.unsqueeze(1).float())
        loss_2 = loss_fun(predict[1], y2.unsqueeze(1).float())
        loss = loss_1 + loss_2
        total_eval_loss += float(loss)
        count_eval += 1
        
    y1_val_auc = roc_auc_score(y_val_income_true, y_val_income_predict)
    y2_val_auc = roc_auc_score(y_val_marry_true, y_val_marry_predict)
    print(total_eval_loss / count_eval, y1_val_auc, y2_val_auc)
    writer.add_scalar("Val loss",   total_eval_loss / count_eval, global_step = epoch + 1)
    writer.add_scalar("Val_y1_auc", y1_val_auc, global_step = epoch + 1)
    writer.add_scalar("Val_y2_auc", y2_val_auc, global_step = epoch + 1)
    
writer.close()        

  0%|          | 0/100 [00:00<?, ?it/s]

0.7151841826897941 0.8560233626952178 0.9191840090790677


  1%|          | 1/100 [00:05<09:32,  5.79s/it]

0.6690055573687834 0.8742806953934772 0.9298375768913343
0.6734854857673345 0.8714917031213821 0.9287788441099215


  2%|▏         | 2/100 [00:10<08:44,  5.35s/it]

0.6663197818924399 0.8747982802303649 0.932559750759544
0.661307070133494 0.8745611194827292 0.9324897185627783


  3%|▎         | 3/100 [00:14<08:11,  5.07s/it]

0.6562127282806471 0.8785249496027288 0.9387888620043522
0.650787933577723 0.876099512602621 0.9371075803282997


  4%|▍         | 4/100 [00:19<07:50,  4.90s/it]

0.6386312778089561 0.8790680022858333 0.9452028480319106
0.6413241743690139 0.8775659610569123 0.941307774015969


  5%|▌         | 5/100 [00:24<08:00,  5.06s/it]

0.6322365729247823 0.8790498109671071 0.9483800239005412
0.6227313517821561 0.8789520786101107 0.9493571856019413


  6%|▌         | 6/100 [00:29<07:56,  5.07s/it]

0.6127230543716281 0.8812453776800048 0.9545843143866852
0.6149067049644785 0.8801268712835899 0.9523688082569174


  7%|▋         | 7/100 [00:34<07:51,  5.07s/it]

0.6076856318642111 0.8829767312123928 0.9592293929765363
0.6045419970757835 0.8829729522107881 0.9553938142946286


  8%|▊         | 8/100 [00:40<08:19,  5.43s/it]

0.601495927104763 0.8850177239900212 0.9583522684222368
0.5974222891222986 0.8853016919197858 0.9567852496844187


  9%|▉         | 9/100 [00:45<07:48,  5.14s/it]

0.5907308163596134 0.888505334342847 0.9593556514085111
0.5945895760373189 0.8866334292114019 0.9576951186173747


 10%|█         | 10/100 [00:49<07:28,  4.98s/it]

0.5874656818661035 0.8892805249805489 0.9610585454682536
0.5892383661157724 0.889931649730134 0.957647091612515


 11%|█         | 11/100 [00:54<07:10,  4.83s/it]

0.586059527537402 0.8923660235512759 0.9608211959031735
0.5851585967484533 0.8914767884685469 0.9587086723986394


 12%|█▏        | 12/100 [00:58<06:53,  4.70s/it]

0.6014273774390128 0.8942511460008059 0.9606300777580249
0.5808543626709397 0.894869011696464 0.9586104298977857


 13%|█▎        | 13/100 [01:03<06:49,  4.70s/it]

0.5781103006764954 0.8951193423691913 0.9596699741670701
0.5782877752495187 0.8956505593162487 0.9591805500222884


 14%|█▍        | 14/100 [01:08<06:42,  4.68s/it]

0.5758254492984098 0.8930637965365821 0.9605548520273912
0.575984377231017 0.8965266823134062 0.9591442313380345


 15%|█▌        | 15/100 [01:12<06:31,  4.61s/it]

0.5759162718174504 0.8953240156144243 0.9617730711496633
0.5728376449208362 0.8985993103630449 0.9597058734660892


 16%|█▌        | 16/100 [01:17<06:27,  4.61s/it]

0.5707737123264985 0.8979823422932897 0.9618361664343332
0.5722804633948095 0.8988097839792412 0.9598412967840682


 17%|█▋        | 17/100 [01:21<06:19,  4.57s/it]

0.5689875243925581 0.8993666180101164 0.9613193755049615
0.5720579650523855 0.8990456407259326 0.9596502904977033


 18%|█▊        | 18/100 [01:26<06:08,  4.50s/it]

0.5701600489663142 0.9001996131312884 0.9612202536440624
0.5709908173102995 0.9000019604856887 0.9594742848954396


 19%|█▉        | 19/100 [01:30<06:12,  4.60s/it]

0.566136585263645 0.900588802804223 0.9612756380368574
0.5678239733507685 0.9012781334855495 0.9598751274769345


 20%|██        | 20/100 [01:35<06:02,  4.53s/it]

0.5622284294343463 0.900974751495086 0.9613798411125345
0.566530081694159 0.9018730480269465 0.9598735101755682


 21%|██        | 21/100 [01:39<05:54,  4.49s/it]

0.5622193501276128 0.9016400728405494 0.9616332486735655
0.5662385262768489 0.9018307969807695 0.9601484022042294


 22%|██▏       | 22/100 [01:44<05:50,  4.49s/it]

0.5838371166995927 0.9011533818811537 0.9612060109735916
0.5665862283912764 0.901961971530446 0.9597194472454137


 23%|██▎       | 23/100 [01:48<05:44,  4.47s/it]

0.56434172613948 0.9022882483453742 0.9613192058483747
0.5656464718179759 0.9020762936419592 0.9600926694857171


 24%|██▍       | 24/100 [01:52<05:37,  4.43s/it]

0.5600436425676533 0.9028630103788792 0.9603838551534494
0.5658546576682617 0.9019971828852489 0.9600713086998138


 25%|██▌       | 25/100 [01:57<05:33,  4.45s/it]

0.5639175688519197 0.9028812853358525 0.9619528732004653
0.5638481624468371 0.9030290277765739 0.960417635817396


 26%|██▌       | 26/100 [02:02<05:34,  4.52s/it]

0.5620702814822104 0.902793716091225 0.96140503511569
0.566363634552609 0.9022912763752395 0.9593965945298057


 27%|██▋       | 27/100 [02:06<05:35,  4.59s/it]

0.5610890238892798 0.9020599786597014 0.9617029690479539
0.563822153450698 0.9031374529535026 0.9602540616149207


 28%|██▊       | 28/100 [02:11<05:25,  4.52s/it]

0.5603108105706234 0.9021210659443666 0.9614563986473754
0.5629814355570113 0.9034863033246949 0.9602592686407481


 29%|██▉       | 29/100 [02:15<05:22,  4.54s/it]

0.562975388765335 0.9023714893107184 0.9615808841680173
0.5635009122855068 0.903373265847225 0.9602290469259307


 30%|███       | 30/100 [02:20<05:15,  4.50s/it]

0.5648402225737478 0.9016461888873625 0.9612296780674648
0.5620173139750138 0.9039224340028289 0.9604835964695492


 31%|███       | 31/100 [02:24<05:10,  4.50s/it]

0.5626391276425007 0.9008365392918893 0.9618578824774573
0.5623853265770761 0.9036505352745031 0.9602901727882849


 32%|███▏      | 32/100 [02:29<05:01,  4.44s/it]

0.5656685087026334 0.9021200100114981 0.9617052169977307
0.5623342149501231 0.903782661691573 0.9603441577088924


 33%|███▎      | 33/100 [02:33<04:58,  4.45s/it]

0.568230897655674 0.9016759641033008 0.9613584813482432
0.5618918830261718 0.9039224494803475 0.960365045712253


 34%|███▍      | 34/100 [02:38<04:54,  4.47s/it]

0.5588907655547647 0.9026160057258745 0.9616209570538441
0.5621331037847372 0.9036663919923034 0.960510771838936


 35%|███▌      | 35/100 [02:43<05:10,  4.77s/it]

0.561808386035994 0.9023222472927868 0.9608519885736968
0.5624034590945965 0.9038487584347318 0.9602306792023096


 36%|███▌      | 36/100 [02:48<05:02,  4.73s/it]

0.5671127211813833 0.9025127438551502 0.9615303265051194
0.5634772248258759 0.9036638691567723 0.9598208387774991


 37%|███▋      | 37/100 [02:52<04:51,  4.63s/it]

0.5630401095923255 0.9010468894831385 0.9604655787313616
0.5627385253170852 0.9040408241221461 0.9598666623162114


 38%|███▊      | 38/100 [02:57<05:00,  4.84s/it]

0.5607817087687698 0.9019555249439573 0.9613823180987034
0.5616637521033672 0.9039575112191373 0.9605487548888816


 39%|███▉      | 39/100 [03:02<04:50,  4.77s/it]

0.5634933355976554 0.90115071591203 0.9613126401384616
0.56119546346908 0.9042265749819738 0.9604446079544683


 40%|████      | 40/100 [03:06<04:38,  4.64s/it]

0.560697736342748 0.9024920747533561 0.9617760995197395
0.5617407215484702 0.9038788467308798 0.9605473857448679


 41%|████      | 41/100 [03:11<04:28,  4.55s/it]

0.5626456114591337 0.902091510278827 0.9611369098457454
0.560796673850601 0.9042425606790958 0.9604508803454815


 42%|████▏     | 42/100 [03:15<04:25,  4.57s/it]

0.5589387966137306 0.9035452266502402 0.9619299695612331
0.5607929883748001 0.9043373553213154 0.9605025484177029


 43%|████▎     | 43/100 [03:20<04:19,  4.56s/it]

0.5601848561389774 0.9026258541294608 0.9616756119233156
0.5605899047523678 0.90468349970634 0.9603832617454993


 44%|████▍     | 44/100 [03:24<04:13,  4.52s/it]

0.5603261281462276 0.9025881437348368 0.9617116300167156
0.5599334138081444 0.9046667607699799 0.9607740197256172


 45%|████▌     | 45/100 [03:29<04:07,  4.50s/it]

0.5596932420543596 0.9026226340569505 0.9616492218412225
0.5617316304934751 0.9037804716226919 0.960528416682414


 46%|████▌     | 46/100 [03:33<03:59,  4.44s/it]

0.5664870791575488 0.902269178825054 0.9611850244537916
0.5609081348876354 0.9045018091155569 0.9604208575844037


 47%|████▋     | 47/100 [03:37<03:53,  4.40s/it]

0.5582721187787898 0.9027711024001877 0.9611309718652036
0.5606438675071964 0.9043877114280632 0.9604539031587496


 48%|████▊     | 48/100 [03:42<03:57,  4.56s/it]

0.5607225685727363 0.9018376054704431 0.9616889554138757
0.5594081493398295 0.9045606056291116 0.9608549468332713


 49%|████▉     | 49/100 [03:47<03:48,  4.48s/it]

0.561563428476745 0.9034721268223467 0.9618037535434051
0.5596554621849172 0.9045853025896159 0.9606686598151791


 50%|█████     | 50/100 [03:51<03:41,  4.44s/it]

0.5610405662480522 0.9014022370303738 0.9616799296834524
0.561360719922952 0.9042911213936864 0.9603836767672784


 51%|█████     | 51/100 [03:55<03:36,  4.42s/it]

0.5614903050310471 0.903042247142238 0.9614280235832158
0.5602927565691748 0.9044522707377081 0.9606384851646872


 52%|█████▏    | 52/100 [04:00<03:33,  4.44s/it]

0.5635136438351052 0.9006919601271386 0.9613670829371993
0.5681080032425452 0.9012213284127206 0.9598227513005435


 53%|█████▎    | 53/100 [04:04<03:28,  4.45s/it]

0.5642415578458824 0.9010990588397157 0.9613552918044095
0.562664526501901 0.9032605843526827 0.9606833866704778


 54%|█████▍    | 54/100 [04:09<03:23,  4.42s/it]

0.5633670878176595 0.9010766228799533 0.961989451160602
0.5606065140726055 0.9043623824688828 0.9607996419722633


 55%|█████▌    | 55/100 [04:13<03:19,  4.44s/it]

0.5681764887828453 0.9008642862803374 0.9619831399355692
0.5650847407830019 0.9016377562096836 0.9605033891576991


 56%|█████▌    | 56/100 [04:18<03:18,  4.51s/it]

0.5611168412601246 0.9030479136334734 0.9618170291713304
0.5605411979210634 0.9047852308564764 0.9605578746715863


 57%|█████▋    | 57/100 [04:22<03:13,  4.51s/it]

0.559895534842622 0.9024356503009616 0.961796161411141
0.5605812405907569 0.9043884724060607 0.9606130126681676


 58%|█████▊    | 58/100 [04:27<03:07,  4.47s/it]

0.5589081612287783 0.901722602880794 0.961876035732256
0.5589988666105364 0.90479814942533 0.9608833223429576


 59%|█████▉    | 59/100 [04:31<03:04,  4.50s/it]

0.5620978469942131 0.9013938836604529 0.9616641261723822
0.5602669591168289 0.9041159597361641 0.9608808429087202


 60%|██████    | 60/100 [04:36<02:58,  4.47s/it]

0.5591407661344491 0.9027832404007861 0.9616123130507414
0.5603081371437589 0.9045304244678521 0.9607389824745887


 61%|██████    | 61/100 [04:40<02:53,  4.44s/it]

0.5621840015345929 0.9028448190601529 0.9616445732507413
0.5591193429379192 0.9049695319887292 0.9607601828139275


 62%|██████▏   | 62/100 [04:44<02:49,  4.46s/it]

0.5589486355875053 0.9027911128507866 0.9616181237888428
0.5605067584978339 0.9042518549290118 0.9607424138917733


 63%|██████▎   | 63/100 [04:49<02:43,  4.42s/it]

0.556164799601424 0.9028670145599553 0.9620734396539494
0.5591345308805964 0.9049922065534703 0.9607525049110126


 64%|██████▍   | 64/100 [04:53<02:38,  4.39s/it]

0.5649104133540509 0.9028755979350553 0.9620264362965474
0.5607856483965351 0.9039704762205467 0.9607871506724247


 65%|██████▌   | 65/100 [04:57<02:33,  4.39s/it]

0.5641400524214202 0.9031893877283037 0.9616443357315199
0.5599755846798303 0.9049915745547944 0.9604358368756298


 66%|██████▌   | 66/100 [05:02<02:31,  4.45s/it]

0.560880774493311 0.903560563813787 0.9618971070803497
0.5589679365425072 0.904822810271624 0.9608169595047504


 67%|██████▋   | 67/100 [05:07<02:28,  4.49s/it]

0.5602930362318076 0.9032759324044053 0.9619076427543964
0.5592525823879804 0.9047395979724864 0.9607514181529517


 68%|██████▊   | 68/100 [05:11<02:22,  4.46s/it]

0.558261386085959 0.9025958802726858 0.9619021967779569
0.559698363886835 0.9047453323931256 0.9606690940905459


 69%|██████▉   | 69/100 [05:15<02:17,  4.43s/it]

0.5575677996756984 0.9035483317201606 0.9621324207663869
0.5602657743427749 0.9044830993751622 0.9606025216021616


 70%|███████   | 70/100 [05:20<02:12,  4.42s/it]

0.5589913601968802 0.9024010136119157 0.9620559226113515
0.559882496218316 0.9044470599731145 0.9606884653390537


 71%|███████   | 71/100 [05:25<02:13,  4.62s/it]

0.5592384101129045 0.902207809261305 0.9620094876035155
0.55977333113112 0.9047127986490398 0.9607521583464342


 72%|███████▏  | 72/100 [05:30<02:11,  4.69s/it]

0.5617773389115053 0.9013450807433182 0.9619454592076175
0.558678841719693 0.9050497958205748 0.960916209610027


 73%|███████▎  | 73/100 [05:35<02:11,  4.85s/it]

0.5587806830219194 0.9034464394257314 0.9619045550045148
0.5588423502585743 0.9049817953426289 0.9608116112859465


 74%|███████▍  | 74/100 [05:40<02:09,  4.97s/it]

0.5643516483260136 0.9031542387549945 0.9618448613344122
0.5588300333510213 0.9048039070622472 0.9609648890975806


 75%|███████▌  | 75/100 [05:46<02:08,  5.14s/it]

0.558157370370977 0.9030879136251095 0.9617208508522139
0.5595419117650255 0.9048305232350573 0.9606580361133469


 76%|███████▌  | 76/100 [05:50<01:59,  4.98s/it]

0.5623120795277988 0.9016420801584778 0.9615968149215276
0.5596990167274924 0.9043685193050057 0.9609716791961742


 77%|███████▋  | 77/100 [05:56<01:56,  5.08s/it]

0.5604218982014002 0.9029307050850591 0.9609933125105886
0.5596163114185183 0.9047093755378439 0.9607154075261007


 78%|███████▊  | 78/100 [06:01<01:53,  5.14s/it]

0.5568553395131055 0.9034107990777209 0.9619751830416434
0.5583838869054566 0.9050466693618187 0.9609737200764698


 79%|███████▉  | 79/100 [06:06<01:46,  5.09s/it]

0.5565889847044851 0.9034260212386783 0.9614647033373045
0.5584530178713658 0.9049935608363474 0.9611393693878413


 80%|████████  | 80/100 [06:11<01:43,  5.16s/it]

0.5555622532087214 0.9033407938649675 0.9618044406525821
0.5595848514663682 0.904781601378366 0.9607685709602997


 81%|████████  | 81/100 [06:18<01:45,  5.56s/it]

0.5566548569529665 0.9025740820545569 0.9618617336819799
0.5595857177244422 0.9044815309866113 0.9607745780796602


 82%|████████▏ | 82/100 [06:24<01:45,  5.85s/it]

0.5604666579003428 0.9027396648740901 0.9617961614111412
0.5600363469311203 0.9041489732833266 0.960842588169259


 83%|████████▎ | 83/100 [06:29<01:32,  5.43s/it]

0.5611180631553425 0.9034730154787214 0.9614601735064343
0.5585275399544384 0.9049024034109976 0.9609760176712678


 84%|████████▍ | 84/100 [06:33<01:21,  5.12s/it]

0.5595686821376576 0.9031135487478205 0.9612607506713565
0.5580561712003644 0.9053222104661869 0.9610071464439945


 85%|████████▌ | 85/100 [06:38<01:13,  4.91s/it]

0.5574103808870503 0.9033853835054086 0.9616760106162948
0.5594389811597311 0.9046440733073065 0.9607803177880807


 86%|████████▌ | 86/100 [06:42<01:06,  4.76s/it]

0.5581415682446723 0.9024867009959852 0.9616542097448777
0.5585225037132593 0.9053460948569649 0.9607282068833423


 87%|████████▋ | 87/100 [06:46<01:00,  4.66s/it]

0.5618340064497555 0.9019807418754331 0.9615971881660188
0.5585005156886133 0.9053079143981784 0.9608383994442917


 88%|████████▊ | 88/100 [06:51<00:54,  4.57s/it]

0.5582096257630517 0.902281275006529 0.961454099800623
0.5600241700769173 0.9043366588329785 0.9608007073374489


 89%|████████▉ | 89/100 [06:55<00:49,  4.51s/it]

0.5610808945169636 0.9037466484586203 0.961872693497494
0.5592338145364713 0.9050501621218482 0.96080427994761


 90%|█████████ | 90/100 [06:59<00:44,  4.47s/it]

0.5600126084159402 0.9025278091943945 0.9614157913432997
0.5588892019684985 0.9051075811362481 0.9608481674311152


 91%|█████████ | 91/100 [07:04<00:40,  4.45s/it]

0.5613987506604662 0.9013617456640365 0.9617969927284169
0.5595833186082146 0.9049113081433622 0.9608405558461136


 92%|█████████▏| 92/100 [07:09<00:36,  4.53s/it]

0.5618715912688013 0.9029535069621523 0.9618489246096685
0.5591291035790809 0.9050572817804017 0.9607212905667852


 93%|█████████▎| 93/100 [07:13<00:31,  4.47s/it]

0.561071980233286 0.9022442337178811 0.9614887945726454
0.5587623250507887 0.9048974635029793 0.9609492487665103


 94%|█████████▍| 94/100 [07:17<00:26,  4.43s/it]

0.5554987598867978 0.9030441499123576 0.9620925005714883
0.5579435339962335 0.9052646031419777 0.9610365573688414


 95%|█████████▌| 95/100 [07:22<00:22,  4.45s/it]

0.5592061816477308 0.9018864502067014 0.9612531500562633
0.5597017447817302 0.9047107581961715 0.9607791433192314


 96%|█████████▌| 96/100 [07:26<00:17,  4.43s/it]

0.5616295591288922 0.9032216720916524 0.9608474248115089
0.5583269606053478 0.9053212843946578 0.9608543842006532


 97%|█████████▋| 97/100 [07:31<00:13,  4.43s/it]

0.5574330589350532 0.9027741238318613 0.9622134826836103
0.5598420275334993 0.9047125019965999 0.9607896793102751


 98%|█████████▊| 98/100 [07:36<00:09,  4.76s/it]

0.5588142169456856 0.9016752009042968 0.961525185910536
0.559430333392325 0.9046375959657743 0.9608220681233521


 99%|█████████▉| 99/100 [07:41<00:04,  4.80s/it]

0.5568529555610582 0.9035976469215585 0.9618692749172679
0.5579143247342063 0.9055014814048964 0.960877721688226


100%|██████████| 100/100 [07:47<00:00,  4.67s/it]

0.5570190098940158 0.9029221844386441 0.9618039656141386





# 知识点

## BCELoss和BCEWithLogitsLoss的区别
[BCELoss和BCEWithLogitsLoss](https://blog.csdn.net/qq_22210253/article/details/85222093)

$$
BCELoss = − n/1∑(y_{n}×lnx_{n}+(1−y_{n})×ln(1−x_{n}))
$$

In [None]:
import torch
from torch import nn

In [None]:
input_ = torch.rand(3, 3)
input_

In [None]:
target = torch.FloatTensor([[0, 1, 1]
                           , [0, 0, 1]
                           , [1, 0, 1]])
activation = nn.Sigmoid()
predict = activation(input_)
predict

In [None]:
loss = nn.BCELoss()
loss(predict, target)

In [None]:
# Sigmoid + BCEloss
loss = nn.BCEWithLogitsLoss()
loss(input_, target)

## einsum函数
[一文学会 Pytorch 中的 einsum](https://zhuanlan.zhihu.com/p/361209187)

In [None]:
import numpy as np
A = torch.tensor([[5],[3]])

B = torch.tensor([[[0, 1, 0],
              [1, 1, 0],

              [1, 1, 1]]])

In [None]:
A.shape, B.shape

In [None]:
torch.einsum('ij,jkl->ikl', A, B)

In [None]:
C = B.permute(1, 0, 2)
C

In [None]:
torch.matmul(A, C).permute(1,0, 2)

## nn.Parameter的使用
[pytorch学习笔记（十六）：Parameters](https://blog.csdn.net/qq_43328040/article/details/107761093?utm_medium=distribute.pc_relevant.none-task-blog-2~default~baidujs_title~default-1.no_search_link&spm=1001.2101.3001.4242.2)