<a href="https://colab.research.google.com/github/DIPANJAN001/Andrew-Ng-Machine-Learning-Notes/blob/master/trainer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertModel, AdamW, get_linear_schedule_with_warmup
from transformers import Trainer, TrainingArguments

# Define the custom dataset for triplets
class TripletDataset(Dataset):
    def __init__(self, anchor_sentences, positive_sentences, negative_sentences, tokenizer, max_length=128):
        self.anchor_sentences = anchor_sentences
        self.positive_sentences = positive_sentences
        self.negative_sentences = negative_sentences
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.anchor_sentences)

    def __getitem__(self, idx):
        anchor = self.anchor_sentences[idx]
        positive = self.positive_sentences[idx]
        negative = self.negative_sentences[idx]

        anchor_encoding = self.tokenizer(anchor, truncation=True, padding='max_length', max_length=self.max_length, return_tensors='pt')
        positive_encoding = self.tokenizer(positive, truncation=True, padding='max_length', max_length=self.max_length, return_tensors='pt')
        negative_encoding = self.tokenizer(negative, truncation=True, padding='max_length', max_length=self.max_length, return_tensors='pt')

        return {
            'anchor_input_ids': anchor_encoding['input_ids'].squeeze(),
            'anchor_attention_mask': anchor_encoding['attention_mask'].squeeze(),
            'positive_input_ids': positive_encoding['input_ids'].squeeze(),
            'positive_attention_mask': positive_encoding['attention_mask'].squeeze(),
            'negative_input_ids': negative_encoding['input_ids'].squeeze(),
            'negative_attention_mask': negative_encoding['attention_mask'].squeeze(),
        }

# Define the triple loss function
def triplet_loss(anchor_output, positive_output, negative_output, margin=1.0):
    pos_distance = torch.norm(anchor_output - positive_output, dim=1)
    neg_distance = torch.norm(anchor_output - negative_output, dim=1)
    loss = torch.relu(pos_distance - neg_distance + margin)
    return loss.mean()

# Create your Sentence-BERT model (e.g., using the BertModel)
model = BertModel.from_pretrained('bert-base-uncased')

# Define the training data (anchor, positive, negative triplets)
anchor_sentences = ["This is an example sentence.", "Another anchor sentence.", ...]
positive_sentences = ["Similar sentence to anchor.", "Positive example.", ...]
negative_sentences = ["Dissimilar sentence.", "Negative example.", ...]

# Initialize the tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Create the custom dataset
dataset = TripletDataset(anchor_sentences, positive_sentences, negative_sentences, tokenizer)

# Define training arguments
training_args = TrainingArguments(
    output_dir='./output',
    num_train_epochs=3,
    per_device_train_batch_size=16,
    logging_dir='./logs',
    save_steps=1000,
    evaluation_strategy="steps",
    eval_steps=1000,
)

# Create a Trainer instance
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    data_collator=None,  # You can define a custom data collator if needed
    compute_metrics=None,  # Define metrics for evaluation if needed
)

# Train the model
trainer.train()

# Save the fine-tuned model
trainer.save_model()

# You can now use the fine-tuned model for generating sentence embeddings and other downstream tasks.
