In [4]:
from sentence_transformers import (
    InputExample,
    SentencesDataset,
    SentenceTransformer,
    evaluation,
    losses,
    models,
)
from torch.utils.data import DataLoader
import pandas as pd

In [5]:
# Define the base model for word embeddings
word_embedding_model = models.Transformer("distilroberta-base", max_seq_length=512)

# Define the pooling layer that aggregates word embeddings into a sentence embedding
pooling_model = models.Pooling(
    word_embedding_model.get_word_embedding_dimension(), pooling_mode="mean"
)

# Construct a sentence transformer
model = SentenceTransformer(
    modules=[word_embedding_model, pooling_model],
)


Downloading (…)lve/main/config.json:   0%|          | 0.00/480 [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


Downloading model.safetensors:   0%|          | 0.00/331M [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

In [6]:
data = pd.read_csv("853.csv")

data.head()

Unnamed: 0,caption,mean,precision,votes,not_funny,somewhat_funny,funny
0,Who knew the Swiss had a navy?,1.675661,0.018367,1816,946,513,357
1,Weapons down. It's the Swiss.,1.672884,0.020759,1394,721,408,265
2,"Wow, and that’s just the tip of the Jarlsberg!",1.643323,0.01838,1685,886,512,286
3,All we need now is to find a port,1.639823,0.012828,3837,2151,917,769
4,We must be directly over where Wisconsin used ...,1.638809,0.022989,1041,542,333,166


In [7]:
df = data.drop(columns=['mean', 'precision', 'votes', 'not_funny', 'somewhat_funny', 'funny'], axis=1)

df.head()

Unnamed: 0,caption
0,Who knew the Swiss had a navy?
1,Weapons down. It's the Swiss.
2,"Wow, and that’s just the tip of the Jarlsberg!"
3,All we need now is to find a port
4,We must be directly over where Wisconsin used ...


In [10]:
train_examples=[
  InputExample(texts=[row["caption"]]) for _, row in data.iterrows()
]
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=16)

In [11]:
train_loss = losses.BatchSemiHardTripletLoss(
    model=model,
    distance_metric=losses.BatchHardTripletLossDistanceFunction.cosine_distance,
)

In [17]:
pos_pairs = data.sample(frac = 0.1)
pos_pairs["score"] = 1 # group label

neg_pairs = data.copy()
neg_pairs["caption"] = neg_pairs["caption"].sample(frac = 1).values
neg_pairs = neg_pairs.sample(frac = 0.1)
neg_pairs["score"] = 0 # group label

eval_data = pd.concat([pos_pairs, neg_pairs])

evaluator = evaluation.EmbeddingSimilarityEvaluator(
    eval_data["caption"].values.tolist(), # sentence
    eval_data["caption"].values.tolist(), # sentence
    scores=eval_data["score"].values.tolist(), # similarity
    show_progress_bar=True,
)

In [18]:
# training parameters
num_epochs = 4
warmup_steps = 100
evaluation_steps = 1000
model_save_path = "model"

In [19]:
model.fit(
    train_objectives=[(train_dataloader, train_loss)],
    evaluator=evaluator,
    epochs=num_epochs,
    evaluation_steps=evaluation_steps,
    warmup_steps=warmup_steps,
    output_path=model_save_path,
)

Epoch:   0%|          | 0/4 [00:00<?, ?it/s]

Iteration:   0%|          | 0/293 [00:00<?, ?it/s]

Batches:   0%|          | 0/59 [00:00<?, ?it/s]

Batches:   0%|          | 0/59 [00:00<?, ?it/s]



Iteration:   0%|          | 0/293 [00:00<?, ?it/s]

Batches:   0%|          | 0/59 [00:00<?, ?it/s]

Batches:   0%|          | 0/59 [00:00<?, ?it/s]



Iteration:   0%|          | 0/293 [00:00<?, ?it/s]

Batches:   0%|          | 0/59 [00:00<?, ?it/s]

Batches:   0%|          | 0/59 [00:00<?, ?it/s]



Iteration:   0%|          | 0/293 [00:00<?, ?it/s]

Batches:   0%|          | 0/59 [00:00<?, ?it/s]

Batches:   0%|          | 0/59 [00:00<?, ?it/s]

