# 多目标学习

## 数据
### 人口普查数据简介


数据集信息：
```
Barry Becker 从 1994 年人口普查数据库中提取。使用以下条件提取了一组合理干净的记录： ((AAGE>16) && (AGI>100) && (AFNLWGT>1)&& (HRSWK>0))

预测任务是确定一个人是否赚了超过 50K年。

>50K，<=50K。
属性列表：
年龄：连续。
工作班级：私人、Self-emp-not-inc、Self-emp-inc、联邦政府、地方政府、州政府、无薪、从未工作。
fnlwgt：连续。
教育：学士，一些大学，11th，HS-grad，教授学校，Assoc-acdm，Assoc-voc，9th，7th-8th，12th，硕士，1st-4th，10th，博士，5th-6th，学前班。
教育编号：连续。
婚姻状况：已婚公民配偶、离婚、未婚、分居、丧偶、已婚配偶缺席、已婚 AF 配偶。
职业：技术支持、工艺维修、其他服务、销售、执行管理、专业教授、处理清洁工、机器操作检查、行政文员、农业-捕鱼、运输-搬家、私人住宅- serv，保护性服务，武装部队。
关系：妻子、自己的孩子、丈夫、非家庭成员、其他亲属、未婚。
种族：白人、亚太岛民、美洲印第安人-爱斯基摩人、其他、黑人。
性别：女，男。
资本收益：持续。
资本损失：持续。
每周小时数：连续。
祖国：美国、柬埔寨、英国、波多黎各、加拿大、德国、美国边远地区（关岛-USVI-etc）、印度、日本、希腊、南部、中国、古巴、伊朗、洪都拉斯、菲律宾、意大利、波兰、牙买加、越南、墨西哥、葡萄牙、爱尔兰、法国、多米尼加共和国、老挝、厄瓜多尔、台湾、海地、哥伦比亚、匈牙利、危地马拉、尼加拉瓜、苏格兰、泰国、南斯拉夫、萨尔瓦多、特立纳达和多巴哥、秘鲁、香港，荷兰-荷兰。
```


### EDA

In [1]:
import pandas as pd
from sklearn.preprocessing import MinMaxScaler
from sklearn.preprocessing import LabelEncoder

In [2]:
!ls data/

adult.data  adult.names  adult.test


In [3]:
column_names = ['age', 'workclass', 'fnlwgt', 'education', 'education_num', 'marital_status', 'occupation',
                'relationship', 'race', 'sex', 'capital_gain', 'capital_loss', 'hours_per_week', 'native_country',
                'income_50k']
train_df = pd.read_csv("./data/adult.data", header=None, names=column_names)
train_df.head()

Unnamed: 0,age,workclass,fnlwgt,education,education_num,marital_status,occupation,relationship,race,sex,capital_gain,capital_loss,hours_per_week,native_country,income_50k
0,39,State-gov,77516,Bachelors,13,Never-married,Adm-clerical,Not-in-family,White,Male,2174,0,40,United-States,<=50K
1,50,Self-emp-not-inc,83311,Bachelors,13,Married-civ-spouse,Exec-managerial,Husband,White,Male,0,0,13,United-States,<=50K
2,38,Private,215646,HS-grad,9,Divorced,Handlers-cleaners,Not-in-family,White,Male,0,0,40,United-States,<=50K
3,53,Private,234721,11th,7,Married-civ-spouse,Handlers-cleaners,Husband,Black,Male,0,0,40,United-States,<=50K
4,28,Private,338409,Bachelors,13,Married-civ-spouse,Prof-specialty,Wife,Black,Female,0,0,40,Cuba,<=50K


In [4]:
train_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 32561 entries, 0 to 32560
Data columns (total 15 columns):
 #   Column          Non-Null Count  Dtype 
---  ------          --------------  ----- 
 0   age             32561 non-null  int64 
 1   workclass       32561 non-null  object
 2   fnlwgt          32561 non-null  int64 
 3   education       32561 non-null  object
 4   education_num   32561 non-null  int64 
 5   marital_status  32561 non-null  object
 6   occupation      32561 non-null  object
 7   relationship    32561 non-null  object
 8   race            32561 non-null  object
 9   sex             32561 non-null  object
 10  capital_gain    32561 non-null  int64 
 11  capital_loss    32561 non-null  int64 
 12  hours_per_week  32561 non-null  int64 
 13  native_country  32561 non-null  object
 14  income_50k      32561 non-null  object
dtypes: int64(6), object(9)
memory usage: 3.7+ MB


In [5]:
test_df = pd.read_csv("./data/adult.test", delimiter=","
                      , names=column_names
                      , header=None
                     )
test_df.head()

Unnamed: 0,age,workclass,fnlwgt,education,education_num,marital_status,occupation,relationship,race,sex,capital_gain,capital_loss,hours_per_week,native_country,income_50k
0,|1x3 Cross validator,,,,,,,,,,,,,,
1,25,Private,226802.0,11th,7.0,Never-married,Machine-op-inspct,Own-child,Black,Male,0.0,0.0,40.0,United-States,<=50K.
2,38,Private,89814.0,HS-grad,9.0,Married-civ-spouse,Farming-fishing,Husband,White,Male,0.0,0.0,50.0,United-States,<=50K.
3,28,Local-gov,336951.0,Assoc-acdm,12.0,Married-civ-spouse,Protective-serv,Husband,White,Male,0.0,0.0,40.0,United-States,>50K.
4,44,Private,160323.0,Some-college,10.0,Married-civ-spouse,Machine-op-inspct,Husband,Black,Male,7688.0,0.0,40.0,United-States,>50K.


In [6]:
# 设置标记
train_df["tag"] = 1
test_df["tag"] = 0
# 划分数据
test_df.dropna(inplace=True)
# 规范化数据
test_df["income_50k"] = test_df["income_50k"].apply(lambda x: x[:-1])

In [7]:
# 合并数据
data = pd.concat([train_df, test_df])
data.reset_index(inplace=True, drop=True)

### 数据预处理

In [8]:
label_columns = ['income_50k', 'marital_status']

# categorical columns
categorical_columns = ['workclass', 'education', 'occupation', 'relationship', 'race', 'sex', 'native_country']
for col in label_columns:
    if col == 'income_50k':
        data[col] = data[col].apply(lambda x: 0 if x == ' <=50K' else 1)
    else:
        data[col] = data[col].apply(lambda x: 0 if x == ' Never-married' else 1)

In [9]:
# feature engine
for col in column_names:
    if col not in label_columns + ['tag']:
        if col in categorical_columns:
            le = LabelEncoder()
            data[col] = le.fit_transform(data[col])
        else:
            mm = MinMaxScaler()
            data[col] = mm.fit_transform(data[[col]]).reshape(-1)
            
data = data[['age', 'workclass', 'fnlwgt', 'education', 'education_num', 'occupation',
             'relationship', 'race', 'sex', 'capital_gain', 'capital_loss', 'hours_per_week', 'native_country',
             'income_50k', 'marital_status', 'tag']]

In [10]:
# 编码用户特征和物品特征
user_feat_dict, item_feat_dict = dict(), dict()

for idx, col in enumerate(data.columns):
    if col not in label_columns + ["tag"]:
        #　用户特征
        if idx < 7:
            if col in categorical_columns:
                user_feat_dict[col] = (data[col].nunique() + 1, idx)
            else:
                user_feat_dict[col] = (1, idx)
        # 物品特征
        else:
            if col in categorical_columns:
                item_feat_dict[col] = (data[col].nunique() + 1, idx)
            else:
                item_feat_dict[col] = (1, idx)

                user_feat_dict, item_feat_dict

In [11]:
# 重新划分数据集
train_data, test_data = data[data["tag"] == 1], data[data["tag"] == 0]

In [12]:
train_data.drop(columns="tag", inplace=True)
test_data.drop(columns="tag", inplace=True)

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  return super().drop(


## 算法

In [13]:
import torch
import torch.nn as nn
from torchsummary import summary
from torch.utils.tensorboard import SummaryWriter
from torch.optim import Adam
import numpy as np
import torchsnooper
from tqdm import tqdm
from sklearn.metrics import roc_auc_score
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

### 自定义数据格式

In [14]:
class TrainDateSet(Dataset):
    
    def __init__(self, data):
        
        self.features = data[0]
        self.label1 = data[1]
        self.label2 = data[2]
    
    def __getitem__(self, index):
        
        return self.features[index], self.label1[index], self.label2[index]
    
    def __len__(self):
        
        return len(self.features)

In [15]:
# 加载Dataloader
train_datasets = (train_data.iloc[:, :-2].values, train_data.iloc[:, -2].values, train_data.iloc[:, -1].values)
test_datasets = (test_data.iloc[:, :-2].values, test_data.iloc[:, -2].values, test_data.iloc[:, -1].values)
train_datasets = TrainDateSet(train_datasets)
test_datasets = TrainDateSet(test_datasets)

### ESMM模型
![](./imgs/ESMM.png)

In [16]:
class ESMM(nn.Module):

    def __init__(self, user_feature_dict, item_feature_dict, emb_dim=128, hidden_dim=[128, 64], dropouts=[0.5, 0.5],
                 output_size=1, task_name=["ctr", "cvr"]):
        """

        :param user_feature_dict: 用户特征
        :param item_feature_dict:　物品特征
        :param emb_dim: 128
        :param hidden_dim: [128, 64]
        :param dropout: 0.5
        :param output_size: 1
        :param num_tasks:2
        """
        super(ESMM, self).__init__()

        if user_feature_dict is None or item_feature_dict is None:
            Exception("用户特征和物品特征不能为空！")
        if isinstance(user_feature_dict, dict) is False or isinstance(item_feature_dict, dict):
            Exception("输入数据类型必须为字典类型！")

        self.user_feature_dict = user_feature_dict
        self.item_feature_dict = item_feature_dict
        self.num_tasks = len(task_name)
        self.task_name = task_name

        # 共享Embedding(Share bottom)
        user_cate_feature_nums, item_cate_feature_nums = 0, 0
        
        # 用户特征Embedding编码
        for user_cate, num in self.user_feature_dict.items():
            # 必须为Spase Feature
            if num[0] > 1:
                user_cate_feature_nums += 1
                setattr(self, user_cate, nn.Embedding(num[0], emb_dim))
                
        # 物品特征
        for item_cate, num in self.item_feature_dict.items():
            if num[0] > 1:
                item_cate_feature_nums += 1
                setattr(self, item_cate, nn.Embedding(num[0], emb_dim))

        # 构建独立任务（tower）
        # Spase feat + Dense feat
        hidden_size = emb_dim * (user_cate_feature_nums + item_cate_feature_nums) \
                      + (len(self.user_feature_dict) - user_cate_feature_nums) \
                      + (len(self.item_feature_dict) - item_cate_feature_nums)

        for i in range(self.num_tasks):
            setattr(self, 'task_{}_dnn'.format(self.task_name[i]), nn.ModuleList())
            hid_dim = [hidden_size] + hidden_dim
            for j in range(len(hid_dim) - 1):
                getattr(self, 'task_{}_dnn'.format(self.task_name[i])).add_module('hidden_{}'.format(j),
                                                                      nn.Linear(hid_dim[j], hid_dim[j + 1]))
                getattr(self, 'task_{}_dnn'.format(self.task_name[i])).add_module('batchnorm_{}'.format(j),
                                                                      nn.BatchNorm1d(hid_dim[j + 1]))
                getattr(self, "task_{}_dnn".format(self.task_name[i])).add_module("{}_activation".format(task_name[i])
                                                                             , nn.ReLU())
                getattr(self, 'task_{}_dnn'.format(self.task_name[i])).add_module('dropout_{}'.format(j),
                                                                      nn.Dropout(dropouts[j]))
            getattr(self, 'task_{}_dnn'.format(self.task_name[i])).add_module('task_{}_last_layer'.format(j),
                                                                  nn.Linear(hid_dim[-1], output_size))

    def forward(self, x):
#         assert x.size()[1] != len(self.item_feature_dict) + len(self.user_feature_dict)
        # 编码Embedding向量
        user_embed_list, item_embed_list = list(), list()
        for user_feature, num in self.user_feature_dict.items():
            if num[0] > 1:
                user_embed_list.append(getattr(self, user_feature)(x[:, num[1]].long()))
            else:
                user_embed_list.append(x[:, num[1]].unsqueeze(1))
        for item_feature, num in self.item_feature_dict.items():
            if num[0] > 1:
                item_embed_list.append(getattr(self, item_feature)(x[:, num[1]].long()))
            else:
                item_embed_list.append(x[:, num[1]].unsqueeze(1))
        # 拼接向量
        user_embed = torch.cat(user_embed_list, dim=1)
        item_embed = torch.cat(item_embed_list, dim=1)
        # hidden_input
        hidden = torch.cat([user_embed, item_embed], axis=1).float()

        # 子网络
        task_outputs = list()
        for i in range(self.num_tasks):
            x = hidden
            #　Module list
            for mod in getattr(self,'task_{}_dnn'.format(self.task_name[i])):
                x = mod(x)
            task_outputs.append(x)
        
        if self.num_tasks == 2:
            
            pCTCVR = torch.mul(task_outputs[0], task_outputs[1])
            pCVR = task_outputs[0]
            
            return pCTCVR, pCVR
        
        elif len(self.num_tasks) == 1:
            return task_outputs
        else:
            Exception("目标数目为：1或２!")

In [17]:
a = torch.from_numpy(np.array([[1, 2, 4, 2, 0.5, 0.1],
                               [4, 5, 3, 8, 0.6, 0.43],
                               [6, 3, 2, 9, 0.12, 0.32],
                               [9, 1, 1, 1, 0.12, 0.45],
                               [8, 3, 1, 4, 0.21, 0.67]]))

user_cate_dict = {'user_id': (11, 0), 'user_list': (12, 3), 'user_num': (1, 4)}
item_cate_dict = {'item_id': (8, 1), 'item_cate': (6, 2), 'item_num': (1, 5)}
esmm = ESMM(user_cate_dict, item_cate_dict)
esmm

ESMM(
  (user_id): Embedding(11, 128)
  (user_list): Embedding(12, 128)
  (item_id): Embedding(8, 128)
  (item_cate): Embedding(6, 128)
  (task_ctr_dnn): ModuleList(
    (hidden_0): Linear(in_features=514, out_features=128, bias=True)
    (batchnorm_0): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (ctr_activation): ReLU()
    (dropout_0): Dropout(p=0.5, inplace=False)
    (hidden_1): Linear(in_features=128, out_features=64, bias=True)
    (batchnorm_1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (dropout_1): Dropout(p=0.5, inplace=False)
    (task_1_last_layer): Linear(in_features=64, out_features=1, bias=True)
  )
  (task_cvr_dnn): ModuleList(
    (hidden_0): Linear(in_features=514, out_features=128, bias=True)
    (batchnorm_0): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (cvr_activation): ReLU()
    (dropout_0): Dropout(p=0.5, inplace=False)
    (hidden_1): Linea

In [18]:
tasks = esmm(a)
print(tasks)

(tensor([[-1.4044],
        [-0.4822],
        [ 0.0017],
        [ 0.0052],
        [-0.4014]], grad_fn=<MulBackward0>), tensor([[-1.1930],
        [ 0.6939],
        [-0.7318],
        [-0.2432],
        [ 0.5550]], grad_fn=<AddmmBackward>))


In [19]:
w = SummaryWriter(log_dir="./log", comment="model info")
w.add_graph(esmm, a)
w.close()

### MMoE
![](./imgs/mmoe.png)

In [20]:
# @torchsnooper.snoop()
class MMoE(nn.Module):

    def __init__(self, user_feature_dict, item_feature_dict, emb_dim=128, n_expert=3, mmoe_hidden_dim=128,
                 hidden_dim=[128, 64], output_size=1, num_tasks=2, expert_activation=None):
        """

        :param user_feature_dict:
        :param item_feature_dict:
        :param emb_dim:
        :param n_expert:
        :param mmoe_hidden_dim:
        :param hidden_dim:
        :param output_size:
        :param num_tasks:
        """
        super(MMoE, self).__init__()

        if user_feature_dict is None or item_feature_dict is None:
            Exception("用户特征和物品特征不能为空！")
        if isinstance(user_feature_dict, dict) is False or isinstance(item_feature_dict, dict):
            Exception("输入数据类型必须为字典类型！")

        self.user_feature_dict = user_feature_dict
        self.item_feature_dict = item_feature_dict
        self.num_tasks = num_tasks

        # 共享Embedding(Share bottom)
        user_cate_feature_nums, item_cate_feature_nums = 0, 0
        # 用户特征Embedding编码
        for user_cate, num in self.user_feature_dict.items():
            # 必须为Spase Feature
            if num[0] > 1:
                user_cate_feature_nums += 1
                setattr(self, user_cate, nn.Embedding(num[0], emb_dim))
        # 物品特征
        for item_cate, num in self.item_feature_dict.items():
            if num[0] > 1:
                item_cate_feature_nums += 1
                setattr(self, item_cate, nn.Embedding(num[0], emb_dim))

        # 构建独立任务（tower）
        # Spase feat + Dense feat
        hidden_size = emb_dim * (user_cate_feature_nums + item_cate_feature_nums) \
                      + (len(self.user_feature_dict) - user_cate_feature_nums) \
                      + (len(self.item_feature_dict) - item_cate_feature_nums)

        # 专家网络
        self.erperts = torch.nn.Parameter(torch.rand(hidden_size, mmoe_hidden_dim, n_expert), requires_grad=True)
        self.erperts.data.normal_(0, 1)
        self.erperts_bias = torch.nn.Parameter(torch.rand(mmoe_hidden_dim, n_expert), requires_grad=True)

        # 门控网络
        self.gates = torch.nn.ParameterList([torch.nn.Parameter(torch.rand(hidden_size, n_expert), requires_grad=True)
                      for _ in range(num_tasks)])
        for gate in self.gates:
            gate.data.normal_(0, 1,)

        self.gate_bias = torch.nn.ParameterList([torch.nn.Parameter(torch.rand(n_expert), requires_grad=True) for _ in range(num_tasks)])

        for i in range(self.num_tasks):
            setattr(self, 'task_{}_dnn'.format(i + 1), nn.ModuleList())
            # input: mmoe_hidden_dim + hidden_dim
            hid_dim = [mmoe_hidden_dim] + hidden_dim
            for j in range(len(hid_dim) - 1):
                getattr(self, 'task_{}_dnn'.format(i + 1)).add_module('hidden_{}'.format(j),
                                                                      nn.Linear(hid_dim[j], hid_dim[j + 1]))
                getattr(self, 'task_{}_dnn'.format(i + 1)).add_module('batchnorm_{}'.format(j),
                                                                      nn.BatchNorm1d(hid_dim[j + 1]))
            getattr(self, 'task_{}_dnn'.format(i + 1)).add_module('task_last_layer',
                                                                  nn.Linear(hid_dim[-1], output_size))
            
        self.Softmax = nn.Softmax(dim=-1)

    def forward(self, x):

        assert x.size()[1] == len(self.item_feature_dict) + len(self.user_feature_dict)
        # 编码Embedding向量
        user_embed_list, item_embed_list = list(), list()
        for user_feature, num in self.user_feature_dict.items():
            if num[0] > 1:
                user_embed_list.append(getattr(self, user_feature)(x[:, num[1]].long()))
            else:
                user_embed_list.append(x[:, num[1]].unsqueeze(1))
        for item_feature, num in self.item_feature_dict.items():
            if num[0] > 1:
                item_embed_list.append(getattr(self, item_feature)(x[:, num[1]].long()))
            else:
                item_embed_list.append(x[:, num[1]].unsqueeze(1))
        # 拼接向量
        user_embed = torch.cat(user_embed_list, dim=1)
        item_embed = torch.cat(item_embed_list, dim=1)
        # hidden_input
        # B*hidden
        hidden = torch.cat([user_embed, item_embed], dim=1).float()
        # MMoE
        expert_outs = torch.matmul(hidden, self.erperts.permute(1, 0, 2)).permute(1, 0, 2)  # B*mmoe_hidden_dim*experts
        expert_outs += self.erperts_bias
        # 门控单元
        gates_out = list()
        for idx, gate in enumerate(self.gates):
            gate_out = torch.mm(hidden, gate)  # B * num_experts
            if self.gate_bias:
                gate_out += self.gate_bias[idx]
            # 归一化
            gate_out = self.Softmax(gate_out)
            gates_out.append(gate_out)
        # 各个模块
        outs = list()
        for gate_out in gates_out:
            expand_gate_out = torch.unsqueeze(gate_out, dim=1)  # B * 1 * experts
            weighted_expert_output = expert_outs * expand_gate_out.expand_as(expert_outs)  # B * mmoe_hidden * expert
            outs.append(torch.sum(weighted_expert_output, 2))  # B * mmoe_hidden

        # task_tower
        task_outputs = list()
        for i in range(self.num_tasks):
            x = outs[i]
            for mod in getattr(self, 'task_{}_dnn'.format(i + 1)):
                x = mod(x)
            task_outputs.append(x)

        return task_outputs

In [21]:
mmoe = MMoE(user_cate_dict, item_cate_dict)
mmoe

MMoE(
  (user_id): Embedding(11, 128)
  (user_list): Embedding(12, 128)
  (item_id): Embedding(8, 128)
  (item_cate): Embedding(6, 128)
  (gates): ParameterList(
      (0): Parameter containing: [torch.FloatTensor of size 514x3]
      (1): Parameter containing: [torch.FloatTensor of size 514x3]
  )
  (gate_bias): ParameterList(
      (0): Parameter containing: [torch.FloatTensor of size 3]
      (1): Parameter containing: [torch.FloatTensor of size 3]
  )
  (task_1_dnn): ModuleList(
    (hidden_0): Linear(in_features=128, out_features=128, bias=True)
    (batchnorm_0): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (hidden_1): Linear(in_features=128, out_features=64, bias=True)
    (batchnorm_1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (task_last_layer): Linear(in_features=64, out_features=1, bias=True)
  )
  (task_2_dnn): ModuleList(
    (hidden_0): Linear(in_features=128, out_features=128, bias=True)
  

In [22]:
outs = mmoe(a)
outs

[tensor([[-0.1936],
         [-0.5239],
         [ 0.5940],
         [ 0.3665],
         [-0.0912]], grad_fn=<AddmmBackward>),
 tensor([[ 0.2057],
         [ 0.2599],
         [-0.3309],
         [-0.3445],
         [ 0.1629]], grad_fn=<AddmmBackward>)]

In [23]:
w = SummaryWriter(log_dir="./log", comment="model info")
w.add_graph(esmm, a)
w.close()

In [29]:
!tensorboard --logdir ./log/

2021-11-22 19:29:05.920514: W tensorflow/stream_executor/platform/default/dso_loader.cc:60] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/lcoal/cuda-10.1/lib64:
2021-11-22 19:29:05.920554: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.
W1122 19:29:16.926579 140120317765376 plugin_event_accumulator.py:320] Found more than one graph event per run, or there was a metagraph containing a graph_def, as well as one or more graph events.  Overwriting the graph with the newest event.
W1122 19:29:16.926889 140120317765376 plugin_event_accumulator.py:358] Found more than one "run metadata" event with tag step1. Overwriting it with the newest event.
W1122 19:29:16.941798 140120317765376 plugin_event_accumulator.py:320] Found more than one graph event per run, or there was a metagraph containing a graph_

## 训练与评估

In [24]:
# 定义超参数
learning_rate = 0.01
epochs = 100
count = 0
writer = SummaryWriter("./log", comment="mertics")
device = torch.device("cuda")

model = MMoE(user_feat_dict, item_feat_dict)
model.to(device)

optimizer = Adam(model.parameters(), lr = learning_rate)
loss_fun = nn.BCEWithLogitsLoss()

train_dataload = DataLoader(train_datasets, batch_size = 64, shuffle = True)
test_dataload = DataLoader(test_datasets, batch_size = 64)

In [25]:
for epoch in tqdm(range(epochs)):
    y_train_income_true = []
    y_train_income_predict = []
    y_train_marry_true = []
    y_train_marry_predict = []
    total_loss, count = 0, 0
    for x, y1, y2 in train_dataload:
        x, y1, y2 = x.to(device), y1.to(device), y2.to(device)
        predict = model(x)
        y_train_income_true += list(y1.squeeze().cpu().numpy())
        y_train_marry_true += list(y2.squeeze().cpu().numpy())
        y_train_income_predict += list(predict[0].squeeze().cpu().detach().numpy())
        y_train_marry_predict += list(predict[1].squeeze().cpu().detach().numpy())
        loss1 = loss_fun(predict[0], y1.unsqueeze(1).float())
        loss2 = loss_fun(predict[1], y2.unsqueeze(1).float())
        loss = loss1 + loss2
        # 梯度更新
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += float(loss)
        count += 1
    
    y1_auc = roc_auc_score(y_train_income_true, y_train_income_predict)
    y2_auc = roc_auc_score(y_train_marry_true, y_train_marry_predict)
    loss_value = total_loss / count
    print(loss_value, y1_auc, y2_auc)
    writer.add_scalar("Train loss", loss_value, global_step = epoch)
    writer.add_scalar("Train_y1_auc", y1_auc, global_step = epoch)
    writer.add_scalar("Train_y2_auc", y2_auc, global_step = epoch)
    
    # 验证
    total_eval_loss = 0
    model.eval()
    count_eval = 0
    y_val_income_true = []
    y_val_marry_true = []
    y_val_income_predict = []
    y_val_marry_predict = []
    for x, y1, y2 in test_dataload:
        x, y1, y2 = x.to(device), y1.to(device), y2.to(device)
        predict = model(x)
        y_val_income_true += list(y1.squeeze().cpu().numpy())
        y_val_marry_true += list(y2.squeeze().cpu().numpy())
        y_val_income_predict += list(predict[0].squeeze().cpu().detach().numpy())
        y_val_marry_predict += list(predict[1].squeeze().cpu().detach().numpy())
        loss_1 = loss_fun(predict[0], y1.unsqueeze(1).float())
        loss_2 = loss_fun(predict[1], y2.unsqueeze(1).float())
        loss = loss_1 + loss_2
        total_eval_loss += float(loss)
        count_eval += 1
        
    y1_val_auc = roc_auc_score(y_val_income_true, y_val_income_predict)
    y2_val_auc = roc_auc_score(y_val_marry_true, y_val_marry_predict)
    print(total_eval_loss / count_eval, y1_val_auc, y2_val_auc)
    writer.add_scalar("Val loss",   total_eval_loss / count_eval, global_step = epoch + 1)
    writer.add_scalar("Val_y1_auc", y1_val_auc, global_step = epoch + 1)
    writer.add_scalar("Val_y2_auc", y2_val_auc, global_step = epoch + 1)
    
writer.close()        

  0%|          | 0/100 [00:00<?, ?it/s]

0.7134879486148624 0.8589090067395307 0.9182721976730421


  1%|          | 1/100 [00:06<10:10,  6.16s/it]

0.6581966375603395 0.8767561784095812 0.9326415761314085
0.6719974691600837 0.8717420003929226 0.9291301322352661


  2%|▏         | 2/100 [00:11<09:27,  5.79s/it]

0.670862022451326 0.8756476266288288 0.9347652694236518
0.6594614575677396 0.8756502569887188 0.9330729396840634


  3%|▎         | 3/100 [00:16<08:56,  5.53s/it]

0.6431583658152935 0.8800620010325142 0.9407865343838423
0.6481030554106287 0.8770978822008122 0.9381252372630308


  4%|▍         | 4/100 [00:21<08:47,  5.49s/it]

0.6369938169039931 0.8805162612616286 0.9439524365722732
0.6327112291908451 0.878372615791444 0.9450717092172237


  5%|▌         | 5/100 [00:25<08:14,  5.21s/it]

0.6145372372047574 0.8841828992821957 0.9539998644104557
0.6129984452822822 0.8819975177155678 0.9520447211724097


  6%|▌         | 6/100 [00:31<08:12,  5.24s/it]

0.6020808457159529 0.8856642685490291 0.9559097310233933
0.6031981532723589 0.8846006531925581 0.9552584059516622


  7%|▋         | 7/100 [00:36<08:10,  5.27s/it]

0.6065269106743383 0.8871708442925573 0.9572458529907353
0.5966423655071521 0.8873194753822844 0.9562008455371342


  8%|▊         | 8/100 [00:41<07:52,  5.13s/it]

0.6009359171577529 0.8886856270390743 0.9584637836968095
0.5920291548276463 0.8885254320394541 0.9576879776756276


  9%|▉         | 9/100 [00:46<07:33,  4.98s/it]

0.5917320820630766 0.8903621347909806 0.9604472388543173
0.5871770969191328 0.890815965493801 0.9584202429585416


 10%|█         | 10/100 [00:50<07:14,  4.83s/it]

0.5918824527777877 0.8928483444122647 0.9604264898537387
0.597069631918006 0.8844444128015174 0.9578501185554638


 11%|█         | 11/100 [00:54<06:55,  4.67s/it]

0.5857271664282855 0.8879804207045643 0.9599228388270231
0.5882012526740727 0.8893357162520963 0.9588848191938792


 12%|█▏        | 12/100 [00:59<06:41,  4.56s/it]

0.5786383621832903 0.8926936659291864 0.9594545611986767
0.5801289377840189 0.8943497280496799 0.9588654158560586


 13%|█▎        | 13/100 [01:03<06:29,  4.48s/it]

0.5711258086503721 0.8966299536581382 0.9604757411609173
0.5765526722128592 0.8958253443541521 0.959559124759971


 14%|█▍        | 14/100 [01:07<06:20,  4.43s/it]

0.5742072131119522 0.8965570420163007 0.960804696800097
0.5758752908360044 0.8963457759168986 0.9594951001630248


 15%|█▌        | 15/100 [01:12<06:13,  4.39s/it]

0.5789870921303244 0.8972899012462308 0.9616642534148223
0.5720961389583783 0.8985082174273558 0.9599123660548228


 16%|█▌        | 16/100 [01:16<06:12,  4.43s/it]

0.5673243423302968 0.8981047677773617 0.9615191461360424
0.5767200903011914 0.8969289559196143 0.9590380456626129


 17%|█▋        | 17/100 [01:21<06:15,  4.52s/it]

0.5695473646416384 0.8985812339610594 0.9599789018461656
0.571458490693499 0.8986316145239382 0.9598926760524741


 18%|█▊        | 18/100 [01:26<06:18,  4.61s/it]

0.576927653130363 0.8991555359842058 0.9609567006191346
0.5696167560598002 0.899577399252704 0.960079031527767


 19%|█▉        | 19/100 [01:30<06:09,  4.56s/it]

0.5644997853858799 0.8985419867136462 0.9616886500320194
0.5688184966975439 0.9001408841130084 0.9599245813865709


 20%|██        | 20/100 [01:35<06:02,  4.53s/it]

0.5623557502148198 0.9005964870681679 0.9619821050303891
0.5657414086675363 0.9016163430627079 0.9603571752734611


 21%|██        | 21/100 [01:39<05:58,  4.53s/it]

0.5639102150412166 0.9003427599910592 0.9619802048766157
0.5652050966715297 0.9022331273378791 0.9603639418398917


 22%|██▏       | 22/100 [01:43<05:51,  4.50s/it]

0.5663874039462968 0.9014121586174263 0.9619516516730395
0.5636013382076046 0.9026389040224626 0.9604558263782315


 23%|██▎       | 23/100 [01:48<05:45,  4.49s/it]

0.5637088349052504 0.9015424670062798 0.9613051073860028
0.5643054437777851 0.9023661075980585 0.9604505508952031


 24%|██▍       | 24/100 [01:52<05:41,  4.49s/it]

0.5673334786704942 0.9010820802755714 0.9618347243533446
0.5678452048414114 0.9014125608937174 0.9597813689227259


 25%|██▌       | 25/100 [01:57<05:45,  4.60s/it]

0.5709808607896169 0.8979505806689847 0.9598668776018322
0.5664796483774316 0.9020793375539495 0.9598746889229925


 26%|██▌       | 26/100 [02:03<05:55,  4.80s/it]

0.5597550635244332 0.9017759536276102 0.9618067819134813
0.5625866053732065 0.9032146573958394 0.9603330355530675


 27%|██▋       | 27/100 [02:07<05:48,  4.77s/it]

0.560952916098576 0.9019047669827983 0.9615659204570521
0.5630473503897842 0.9033175184048333 0.9604169384096641


 28%|██▊       | 28/100 [02:12<05:42,  4.76s/it]

0.559250314913544 0.9020638260190641 0.9616622345014381
0.5616538259978379 0.9037758464242185 0.9606343285290324


 29%|██▉       | 29/100 [02:16<05:31,  4.66s/it]

0.5653258394961264 0.9013843279907312 0.9614761721225794
0.5625163927410816 0.9032063459683541 0.9604261694353198


 30%|███       | 30/100 [02:21<05:25,  4.65s/it]

0.5588048206824883 0.9021719911820196 0.9614413246596291
0.5658829324606125 0.9022384645022079 0.9602149276282881


 31%|███       | 31/100 [02:26<05:22,  4.67s/it]

0.5637386310334299 0.9019303707411667 0.9612244526445883
0.5628174183293738 0.9033843529097115 0.9602288928972291


 32%|███▏      | 32/100 [02:30<05:14,  4.63s/it]

0.5715660269353904 0.90178262377781 0.9605305572041464
0.5659629859484014 0.9029971956799975 0.9596826408036045


 33%|███▎      | 33/100 [02:35<05:15,  4.71s/it]

0.5611193614847519 0.9026428954222906 0.9611516954172938
0.5632050235285506 0.9032301529715392 0.9605891446372893


 34%|███▍      | 34/100 [02:39<05:01,  4.56s/it]

0.5628741932850257 0.901697574135374 0.9615892652034103
0.5614604462809085 0.9037683320889408 0.9606115066097524


 35%|███▌      | 35/100 [02:44<05:04,  4.68s/it]

0.5591508630444022 0.9018142181256209 0.9616825084635734
0.5605278467733874 0.9044653595592662 0.9608013705165807


 36%|███▌      | 36/100 [02:49<04:59,  4.69s/it]

0.5591440302484175 0.9023613168089248 0.9619334135899473
0.5642950139603118 0.9022915910847843 0.960230135823279


 37%|███▋      | 37/100 [02:54<04:59,  4.75s/it]

0.575007000507093 0.9030088232077735 0.9617669635125347
0.5613171691861742 0.9040523471347399 0.9607444932792443


 38%|███▊      | 38/100 [02:59<04:54,  4.76s/it]

0.5644969845519346 0.9024510501931938 0.9615841924714618
0.5605637835730269 0.9045726007060224 0.9604954331474063


 39%|███▉      | 39/100 [03:04<04:51,  4.78s/it]

0.5731830439146828 0.9033469935500275 0.9617310641787455
0.5612844776779354 0.9041460815669358 0.9606340525609424


 40%|████      | 40/100 [03:08<04:43,  4.73s/it]

0.5571377619808795 0.9032654358044044 0.9611226417267864
0.5609285147578871 0.9042100062983183 0.9604774182071868


 41%|████      | 41/100 [03:13<04:35,  4.67s/it]

0.5614533256081974 0.9022648610005517 0.9615115285552904
0.5612814498439519 0.9038790943711773 0.9606930626679377


 42%|████▏     | 42/100 [03:18<04:34,  4.73s/it]

0.5651972228405522 0.9025317924659085 0.9616566188684116
0.5596658104997722 0.9048875604706652 0.9607437145785864


 43%|████▎     | 43/100 [03:23<04:33,  4.79s/it]

0.5652201705119189 0.9027948661171215 0.9615519916512671


 43%|████▎     | 43/100 [03:24<04:31,  4.75s/it]


KeyboardInterrupt: 

# 知识点

## BCELoss和BCEWithLogitsLoss的区别
[BCELoss和BCEWithLogitsLoss](https://blog.csdn.net/qq_22210253/article/details/85222093)

$$
BCELoss = − n/1∑(y_{n}×lnx_{n}+(1−y_{n})×ln(1−x_{n}))
$$

In [None]:
import torch
from torch import nn

In [None]:
input_ = torch.rand(3, 3)
input_

In [None]:
target = torch.FloatTensor([[0, 1, 1]
                           , [0, 0, 1]
                           , [1, 0, 1]])
activation = nn.Sigmoid()
predict = activation(input_)
predict

In [None]:
loss = nn.BCELoss()
loss(predict, target)

In [None]:
# Sigmoid + BCEloss
loss = nn.BCEWithLogitsLoss()
loss(input_, target)

## einsum函数
[一文学会 Pytorch 中的 einsum](https://zhuanlan.zhihu.com/p/361209187)

In [None]:
import numpy as np
A = torch.tensor([[5],[3]])

B = torch.tensor([[[0, 1, 0],
              [1, 1, 0],

              [1, 1, 1]]])

In [None]:
A.shape, B.shape

In [None]:
torch.einsum('ij,jkl->ikl', A, B)

In [None]:
C = B.permute(1, 0, 2)
C

In [None]:
torch.matmul(A, C).permute(1,0, 2)

## nn.Parameter的使用
[pytorch学习笔记（十六）：Parameters](https://blog.csdn.net/qq_43328040/article/details/107761093?utm_medium=distribute.pc_relevant.none-task-blog-2~default~baidujs_title~default-1.no_search_link&spm=1001.2101.3001.4242.2)