-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_mrl.py
64 lines (58 loc) · 2.12 KB
/
train_mrl.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
from sentence_transformers import SentenceTransformer, InputExample, losses
#data function
def read_chatbot_csv(chatbot_csv):
samples = []
with open(chatbot_csv, 'r') as f:
for line in f:
line = line.strip('\n')
line_s = line.split("\t")
if len(line_s) != 3:
continue
query, key, value = line_s
if value not in ['0', '1']:
continue
value = float(value)
samples.append(InputExample(texts=[query, key], label=value))
return samples
# load data
path = "./data/"
train_csv = path + 'train.tsv'
dev_csv = path + 'dev.tsv'
test_csv = path + 'test.tsv'
train_samples = []
dev_samples = []
test_samples = []
train_samples = read_chatbot_csv(train_csv)
dev_samples = read_chatbot_csv(dev_csv)
test_samples = read_chatbot_csv(test_csv)
print(len(train_samples))
print(len(dev_samples))
print(len(test_samples))
#build model
from sentence_transformers import SentenceTransformer
from sentence_transformers.losses import CoSENTLoss, MatryoshkaLoss
from torch.utils.data import DataLoader
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator
from datetime import datetime
import math
#model = SentenceTransformer("./PairSupCon-roberta-wwm-ext", device="cuda:1")
#hfl/chinese-roberta-wwm-ext
model = SentenceTransformer("./chinese-roberta-wwm-ext", device="cuda:1")
base_loss = CoSENTLoss(model=model)
train_loss = MatryoshkaLoss(
model=model,
loss=base_loss,
matryoshka_dims=[768, 512, 256, 128, 64],
matryoshka_weights=[1, 1, 1, 1, 1],
)
evaluator = EmbeddingSimilarityEvaluator.from_input_examples(dev_samples, name='dev')
model_save_path = 'mrl_model/mrl_'+datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
num_epochs = 5
train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=16)
warmup_steps = warmup_steps = math.ceil(len(train_dataloader) * num_epochs * 0.1)
model.fit(train_objectives=[(train_dataloader, train_loss)],
evaluator=evaluator,
epochs=num_epochs,
evaluation_steps=1000,
warmup_steps=warmup_steps,
output_path=model_save_path)