In [1]:
import torch
import torch.nn.functional as F
import pandas as pd
import numpy as np

In [3]:
# ! pip install transformers

In [None]:
# pip install transformers
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import torch
from transformers import BertTokenizer, BertModel
import pandas as pd
import numpy as np
# MODEL_NAME = "hfl/chinese-pert-large"
MODEL_NAME = "hfl/chinese-roberta-wwm-ext"
# MODEL_NAME = "hfl/rbt3"
tokenizer = BertTokenizer.from_pretrained(MODEL_NAME)
model = BertModel.from_pretrained(MODEL_NAME,
                 hidden_dropout_prob=0,
                 attention_probs_dropout_prob=0,)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

获取训练集中的物料名称向量

In [45]:
trainData = pd.read_csv("./Data_train.csv", index_col=0)
train_mat_names = list(trainData["MAT_NAME（物料名称）"].drop_duplicates())
trainData[:5]

Unnamed: 0_level_0,id（供应商代码）,MAT_NAME（物料名称）
QUOTE_MAT_LINE_ID（报价单行号）,Unnamed: 1_level_1,Unnamed: 2_level_1
8e24b75ce1c811ebaf97005056b12bb8,6603,液压缸吊座
f3afd8f4e1dc11ebaf97005056b12bb8,20958,六角头螺栓-C级GB/T5780-2000
1b99daf8e1de11ebaf97005056b12bb8,41982,高压水银荧光灯泡
1b9aac07e1de11ebaf97005056b12bb8,41982,手电筒螺口小灯泡（圆头）
bff9cf95e1e511ebaf97005056b12bb8,20238,外接头


In [46]:
def get_matVec(mat_names):
    input_id = tokenizer(mat_names, padding=True, truncation=True, max_length=10, return_tensors="pt")
    input_ids = input_id["input_ids"]
    input_chunks = torch.chunk(input_ids, 10, 0)
    outputs = []
    for chunk in input_chunks:
        chunk = chunk.to(device)
        with torch.no_grad():  
            output = model(chunk)['pooler_output']
        output = output.cpu().detach()
        torch.cuda.empty_cache()
        outputs.append(output)
    matvec = torch.cat(outputs,0)
    return matvec

获取测试集中的物料名称向量

In [47]:
test_mat_names = ["太阳能电池", "石墨烯", "推进油缸保护套"]

In [48]:
testMatVec = get_matVec(test_mat_names)
# 测试数据中的物料名称向量
testMatVec.shape

torch.Size([3, 768])

In [49]:
trainMatVec = torch.load("trainMatVec"+".pt")
train_mat_index = pd.DataFrame(np.arange(len(train_mat_names)), index = train_mat_names)
test_mat_index = pd.DataFrame(np.arange(len(test_mat_names)), index = test_mat_names)

计算测试集中物料名称与训练集中的物料名称的相似度

In [50]:
def cos_similar(p, q):
    sim_matrix = p.matmul(q.transpose(-2, -1))
    a = torch.norm(p, p=2, dim=-1)
    b = torch.norm(q, p=2, dim=-1)
    sim_matrix /= a.unsqueeze(-1)
    sim_matrix /= b.unsqueeze(-2)
    return sim_matrix

torch.cuda.empty_cache()
with torch.no_grad(): 
    # sim_mat_score = torch.cdist(testMatVec,trainMatVec,p=2)
    # sim_mat_score = F.cosine_similarity(testMatVec.unsqueeze(1), trainMatVec.unsqueeze(0), dim=2)
    sim_mat_score = cos_similar(testMatVec,trainMatVec)
    sim_mat_score = sim_mat_score.sort(axis=1, descending=True)
    sim_mat_index = sim_mat_score.indices[:,:30]
    sim_mat_score = sim_mat_score.values[:,:30] 
    torch.cuda.empty_cache()
torch.cuda.empty_cache()
print(sim_mat_score.shape)
print(sim_mat_index.shape)

torch.Size([3, 30])
torch.Size([3, 30])


In [51]:
sim_mat_index[:3]

tensor([[17933,  2528, 22351, 18151, 18668, 22546, 23290, 18365, 26237, 23561,
         25963, 18353, 26884,  4287,  2176, 17428, 22131, 11322, 23663, 31221,
         26566, 25520, 10792, 23109,  4579,  2287, 19913,  2490, 11397,  9389],
        [14727,  9973, 26483, 25948, 27437,  2752, 19919, 11223, 16342,  6642,
         30095, 30681, 30273,  3677,  1705, 25940, 30656, 15740, 30248, 21109,
         23070,  9964, 17907, 26852, 19286, 11007, 11224, 19649, 25946, 21280],
        [   76, 27962, 22807, 18107, 17826, 18110, 20437,  4602, 18827, 22817,
         18329, 17810, 16253,  3748, 19187, 28002, 30461, 28825,  2918,  3632,
         27928, 24979, 30489, 27916, 22827, 22797, 17625, 17529, 11075,  7351]])

In [52]:
sim_mat_score[:3]

tensor([[0.9754, 0.9744, 0.9728, 0.9727, 0.9725, 0.9723, 0.9719, 0.9715, 0.9714,
         0.9701, 0.9700, 0.9693, 0.9688, 0.9688, 0.9687, 0.9682, 0.9682, 0.9681,
         0.9680, 0.9668, 0.9668, 0.9667, 0.9667, 0.9666, 0.9666, 0.9666, 0.9664,
         0.9663, 0.9660, 0.9659],
        [0.9694, 0.9693, 0.9689, 0.9686, 0.9684, 0.9683, 0.9681, 0.9680, 0.9679,
         0.9679, 0.9678, 0.9677, 0.9676, 0.9675, 0.9674, 0.9674, 0.9671, 0.9670,
         0.9668, 0.9668, 0.9668, 0.9667, 0.9667, 0.9664, 0.9663, 0.9662, 0.9662,
         0.9662, 0.9660, 0.9660],
        [0.9951, 0.9821, 0.9793, 0.9782, 0.9777, 0.9774, 0.9774, 0.9771, 0.9762,
         0.9762, 0.9758, 0.9757, 0.9756, 0.9756, 0.9754, 0.9754, 0.9753, 0.9753,
         0.9748, 0.9747, 0.9747, 0.9746, 0.9746, 0.9745, 0.9740, 0.9739, 0.9738,
         0.9736, 0.9736, 0.9733]])

In [53]:
test_mat_index[:3]

Unnamed: 0,0
太阳能电池,0
石墨烯,1
推进油缸保护套,2


获取历史每个物料对应的可提供供应商

In [16]:
train_id_mat = trainData.copy()
train_id_mat["mat_id"] = train_id_mat["MAT_NAME（物料名称）"].apply(lambda x:train_mat_index.loc[x])
train_id_mat[:2]

Unnamed: 0_level_0,id（供应商代码）,MAT_NAME（物料名称）,mat_id
QUOTE_MAT_LINE_ID（报价单行号）,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
8e24b75ce1c811ebaf97005056b12bb8,6603,液压缸吊座,0
f3afd8f4e1dc11ebaf97005056b12bb8,20958,六角头螺栓-C级GB/T5780-2000,1


In [17]:
mat_supply_ids = train_id_mat.groupby(by="mat_id")["id（供应商代码）"].apply(set)
# 提供物料的历史供应商
mat_supply_ids[:200]

mat_id
0                                                 {6603}
1                                                {20958}
2                                                {41982}
3                                                {41982}
4                                  {27177, 21484, 20238}
                             ...                        
195                         {43944, 10298, 22123, 33944}
196                         {43944, 10298, 22123, 33944}
197                  {13094, 43944, 22123, 33944, 10298}
198                         {43944, 10298, 22123, 33944}
199    {43944, 22123, 24431, 32919, 33944, 13945, 10298}
Name: id（供应商代码）, Length: 200, dtype: object

In [24]:
query_mat_index = test_mat_index.loc[test_mat_names]
query_mat_index = query_mat_index[0].to_list()
# 物料在相似度矩阵中的下标
query_mat_index[:5]

[0, 1, 2]

In [54]:
item_sim_mat = pd.DataFrame(sim_mat_index[query_mat_index].tolist())
item_sim_mat[:4]

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,20,21,22,23,24,25,26,27,28,29
0,17933,2528,22351,18151,18668,22546,23290,18365,26237,23561,...,26566,25520,10792,23109,4579,2287,19913,2490,11397,9389
1,14727,9973,26483,25948,27437,2752,19919,11223,16342,6642,...,23070,9964,17907,26852,19286,11007,11224,19649,25946,21280
2,76,27962,22807,18107,17826,18110,20437,4602,18827,22817,...,27928,24979,30489,27916,22827,22797,17625,17529,11075,7351


In [55]:
item_sim_score = pd.DataFrame(sim_mat_score[query_mat_index].tolist())
item_sim_score[:4]

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,20,21,22,23,24,25,26,27,28,29
0,0.975351,0.974373,0.972824,0.972731,0.97254,0.97232,0.971894,0.971539,0.971442,0.970059,...,0.966772,0.966697,0.966665,0.966642,0.966639,0.966576,0.966416,0.966314,0.966026,0.965949
1,0.969413,0.969256,0.968894,0.968589,0.96839,0.968343,0.968125,0.968029,0.967892,0.967856,...,0.966754,0.966745,0.966728,0.966437,0.966301,0.966225,0.96622,0.966178,0.966032,0.965999
2,0.995081,0.982073,0.979273,0.978168,0.977674,0.97744,0.977398,0.977059,0.976213,0.976181,...,0.974664,0.974613,0.974605,0.974502,0.974006,0.973928,0.973836,0.973586,0.973552,0.973256


In [27]:
item_sim_mat.shape
# 行数是测试集数量，列数对应在训练集中相关的前10个物料向量

(3, 10)

In [58]:
from operator import add
from functools import reduce
def simMat2supp(x, suppNum = 50):
    suplists = list(mat_supply_ids.loc[x].apply(list))
    scores = list(item_sim_score.loc[x.name])
    return pd.Series([suplists,scores])
item_sim_mat[:10].apply(simMat2supp, axis=1)

Unnamed: 0,0,1
0,"[[25798, 28555, 269, 8213, 30039], [14761], [3...","[0.9753511548042297, 0.9743733406066895, 0.972..."
1,"[[28357, 44758], [16410], [14436], [6119, 3146...","[0.9694129228591919, 0.9692555665969849, 0.968..."
2,"[[9600, 26309, 41896, 30351, 10681], [13094], ...","[0.9950808882713318, 0.9820733666419983, 0.979..."


In [59]:
result = item_sim_mat.apply(simMat2supp, axis=1)
result = result.rename(columns={0:'supps',1:'scores'})
result

Unnamed: 0,supps,scores
0,"[[25798, 28555, 269, 8213, 30039], [14761], [3...","[0.9753511548042297, 0.9743733406066895, 0.972..."
1,"[[28357, 44758], [16410], [14436], [6119, 3146...","[0.9694129228591919, 0.9692555665969849, 0.968..."
2,"[[9600, 26309, 41896, 30351, 10681], [13094], ...","[0.9950808882713318, 0.9820733666419983, 0.979..."


In [60]:
def scoreBroadCast(x, suppNum = 50):
    suppslist = []
    scoreslist = []
    for i in range(len(x['supps'])):
        suppslist.extend(x['supps'][i])
        scoreslist.extend([x['scores'][i]]*len(x['supps'][i]))
    return pd.Series([suppslist[:suppNum],scoreslist[:suppNum]])
broadResult = result.apply(scoreBroadCast, axis=1)
broadResult = broadResult.rename(columns={0:'supps',1:'scores'})
broadResult[:10]

Unnamed: 0,supps,scores
0,"[25798, 28555, 269, 8213, 30039, 14761, 3587, ...","[0.9753511548042297, 0.9753511548042297, 0.975..."
1,"[28357, 44758, 16410, 14436, 6119, 31465, 3060...","[0.9694129228591919, 0.9694129228591919, 0.969..."
2,"[9600, 26309, 41896, 30351, 10681, 13094, 525,...","[0.9950808882713318, 0.9950808882713318, 0.995..."


In [63]:
similar_mats = train_mat_index.index[sim_mat_index]

太阳能电池

In [67]:
for i, mat in enumerate(test_mat_names):
    print("物料名称：{}".format(mat))
    print("相似物料")
    print(list(similar_mats[i]))
    print("推荐供应商ID")
    print(broadResult["supps"][i])
    print("相关供应物料相似度")
    print(broadResult["scores"][i])
    print()

物料名称：太阳能电池
相似物料
['备用电池', '锂聚合物电池', '5号电池', '聚合物锂电池', '直流钻电池', '充电式镍氢电池', '锂亚硫酰氯电池', '特碱性电池', '锂亚电池', '三轮车电池', '锂离子充电电池', 'SAFT电池', '7号电池', '层迭电池', '积层电池', '太阳能户外灯', 'UPS电池包', '万用表电池', '积层电池 9V', '台达PLC电池', '铅酸蓄电池', 'PLC电池', '层叠电池', '碱性电池', 'ups电池', '蓄电池', '可充锂电池', '干电池', '充电电池', '充电式电镐']
推荐供应商ID
[25798, 28555, 269, 8213, 30039, 14761, 3587, 22188, 12814, 5683, 41493, 40247, 15755, 27804, 35695, 718, 27044, 35895, 31715, 27383, 9490, 9651, 1357, 16840, 19272, 8927, 18618, 13139, 269, 31877, 40147, 38027, 2093, 22188, 12814, 41937, 14051, 36452, 10469, 30539, 43883, 44925, 32848, 7289, 21565, 6129, 21137, 38548, 6908, 10]
相关供应物料相似度
[0.9753511548042297, 0.9753511548042297, 0.9753511548042297, 0.9753511548042297, 0.9753511548042297, 0.9743733406066895, 0.9728236794471741, 0.9728236794471741, 0.9728236794471741, 0.9728236794471741, 0.9728236794471741, 0.9728236794471741, 0.9727308750152588, 0.9727308750152588, 0.9727308750152588, 0.9725401401519775, 0.9723196625709534, 0.9723196625709534, 0