/
default_fine_tuning_method.py
154 lines (141 loc) · 5.3 KB
/
default_fine_tuning_method.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
from typing import List
from sentence_transformers import SentenceTransformer
from embedding_studio.core.config import settings
from embedding_studio.core.plugin import FineTuningMethod
from embedding_studio.embeddings.data.clickstream.parsers.s3_parser import (
AWSS3ClickstreamParser,
)
from embedding_studio.embeddings.data.clickstream.search_event import (
DummyEventType,
SearchResult,
)
from embedding_studio.embeddings.data.clickstream.splitter import (
ClickstreamSessionsSplitter,
)
from embedding_studio.embeddings.data.clickstream.text_query_item import (
TextQueryItem,
)
from embedding_studio.embeddings.data.clickstream.text_query_retriever import (
TextQueryRetriever,
)
from embedding_studio.embeddings.data.loaders.s3.s3_loader import (
AWSS3DataLoader,
)
from embedding_studio.embeddings.data.storages.producers.clip import (
CLIPItemStorageProducer,
)
from embedding_studio.embeddings.data.utils.fields_normalizer import (
DatasetFieldsNormalizer,
)
from embedding_studio.embeddings.losses.prob_cosine_margin_ranking_loss import (
CosineProbMarginRankingLoss,
)
from embedding_studio.embeddings.models.text_to_image.clip import (
TextToImageCLIPModel,
)
from embedding_studio.models.clickstream.sessions import SessionWithEvents
from embedding_studio.models.plugin import FineTuningBuilder, PluginMeta
from embedding_studio.workers.fine_tuning.data.prepare_data import prepare_data
from embedding_studio.workers.fine_tuning.experiments.experiments_tracker import (
ExperimentsManager,
)
from embedding_studio.workers.fine_tuning.experiments.finetuning_settings import (
FineTuningSettings,
)
from embedding_studio.workers.fine_tuning.experiments.initial_params.clip import (
INITIAL_PARAMS,
)
from embedding_studio.workers.fine_tuning.experiments.metrics_accumulator import (
MetricsAccumulator,
)
class DefaultFineTuningMethod(FineTuningMethod):
meta = PluginMeta(
name="Default Fine Tuning Method",
version="0.0.1",
description="A default fine-tuning plugin",
)
def __init__(self):
# uncomment and pass your credentials to use your own s3 bucket
# creds = {
# "role_arn": "arn:aws:iam::123456789012:role/some_data"
# "aws_access_key_id": "TESTACCESSKEIDTEST11",
# "aws_secret_access_key": "QWERTY1232qdsadfasfg5349BBdf30ekp23odk03",
# }
# self.data_loader = AWSS3DataLoader(**creds)
# with empty creds, use anonymous session
creds = {
}
self.data_loader = AWSS3DataLoader(**creds)
self.retriever = TextQueryRetriever()
self.parser = AWSS3ClickstreamParser(
TextQueryItem, SearchResult, DummyEventType
)
self.splitter = ClickstreamSessionsSplitter()
self.normalizer = DatasetFieldsNormalizer("item", "item_id")
self.storage_producer = CLIPItemStorageProducer(self.normalizer)
self.accumulators = [
MetricsAccumulator("train_loss", True, True, True, True),
MetricsAccumulator(
"train_not_irrelevant_dist_shift", True, True, True, True
),
MetricsAccumulator(
"train_irrelevant_dist_shift", True, True, True, True
),
MetricsAccumulator("test_loss"),
MetricsAccumulator("test_not_irrelevant_dist_shift"),
MetricsAccumulator("test_irrelevant_dist_shift"),
]
self.manager = ExperimentsManager(
tracking_uri=settings.MLFLOW_TRACKING_URI,
main_metric="test_not_irrelevant_dist_shift",
accumulators=self.accumulators,
)
self.initial_params = INITIAL_PARAMS
self.initial_params.update(
{
"not_irrelevant_only": [True],
"negative_downsampling": [
0.5,
],
"examples_order": [
[
11,
]
],
}
)
self.settings = FineTuningSettings(
loss_func=CosineProbMarginRankingLoss(),
step_size=35,
test_each_n_sessions=0.5,
num_epochs=3,
)
def upload_initial_model(self) -> None:
model = TextToImageCLIPModel(SentenceTransformer("clip-ViT-B-32"))
self.manager.upload_initial_model(model)
def get_fine_tuning_builder(
self, clickstream: List[SessionWithEvents]
) -> FineTuningBuilder:
ranking_dataset = prepare_data(
clickstream,
self.parser,
self.splitter,
self.retriever,
self.data_loader,
self.storage_producer,
)
fine_tuning_builder = FineTuningBuilder(
data_loader=self.data_loader,
query_retriever=self.retriever,
clickstream_parser=self.parser,
clickstream_sessions_splitter=self.splitter,
dataset_fields_normalizer=self.normalizer,
item_storage_producer=self.storage_producer,
accumulators=self.accumulators,
experiments_manager=self.manager,
fine_tuning_settings=self.settings,
initial_params=self.initial_params,
ranking_data=ranking_dataset,
initial_max_evals=2,
)
return fine_tuning_builder