# Training Hebrew Text Encoder

In [2]:
!nvidia-smi

Wed Oct  2 08:33:00 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.14              Driver Version: 550.54.14      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100-SXM4-80GB          Off |   00000000:0E:00.0 Off |                    0 |
| N/A   31C    P0             67W /  400W |    1210MiB /  81920MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA A100-SXM4-80GB          Off |   00

In [3]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "4"

## Setup the environment

In [4]:
!pip install -q -U torch transformers bitsandbytes datasets huggingface_hub accelerate tqdm

In [5]:
from huggingface_hub import notebook_login
import os
import sys
from datasets import load_dataset

In [6]:
os.environ["HF_TOKEN"] = "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"
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [7]:
def is_running_in_colab():
    return 'COLAB_GPU' in os.environ

if is_running_in_colab():
    # Load the Drive helper and mount
    print("Mounting google drive...")
    from google.colab import drive
    drive.mount('/content/drive')
else:
    print("Not running in Google Colab!")

Not running in Google Colab!


In [8]:
!pwd

/home/nlp/achimoa/projects/hebrew_text_encoder


In [9]:
project_dir = '/home/nlp/achimoa/projects/hebrew_text_encoder'

os.chdir(project_dir)
print(f"Current working directory set to: {os.getcwd()}")


if project_dir not in sys.path:
    sys.path.insert(0, project_dir)  # Add it to the front of PYTHONPATH
    print(f"PYTHONPATH updated with: {project_dir}")

Current working directory set to: /home/nlp/achimoa/projects/hebrew_text_encoder


In [10]:
from transformers import AutoModel, AutoTokenizer
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import logging
import os
from datasets import DatasetDict, Dataset
from tqdm import tqdm
import pickle

## Generic Helper Code

In [11]:
def setup_logger(file_path: str):
    os.environ["PYTHONUNBUFFERED"] = "1"
    
    # Create or retrieve the logger
    logger = logging.getLogger('default')

    # Remove all existing handlers
    if logger.hasHandlers():
        for handler in logger.handlers[:]:
            logger.removeHandler(handler)

    logger.setLevel(logging.DEBUG)
    logger.propagate = False

    # Stream Handler (for console output)
    stream_handler = logging.StreamHandler()
    stream_handler.setLevel(logging.DEBUG)
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    stream_handler.setFormatter(formatter)
    logger.addHandler(stream_handler)

    # File Handler (for file output)
    os.makedirs(os.path.dirname(file_path), exist_ok=True)
    file_handler = logging.FileHandler(file_path, delay=False)  # Log file name (you can specify the path)
    file_handler.setLevel(logging.DEBUG) # Set the log level for file handler
    file_handler.setFormatter(formatter) # Use the same formatter
    logger.addHandler(file_handler)

    return logger

In [12]:
def load_checkpoint(model, optimizer, checkpoint_dir, device):
    logger = logging.getLogger('default')

    latest_checkpoint = None
    if os.path.exists(checkpoint_dir):
        checkpoint_files = [f for f in os.listdir(checkpoint_dir) if f.endswith(".pth")]
        if checkpoint_files:
            latest_checkpoint = sorted(checkpoint_files)[-1]  # Get the latest checkpoint

    if latest_checkpoint:
        logger.info(f"Loading checkpoint {latest_checkpoint}")
        checkpoint_path = os.path.join(checkpoint_dir, latest_checkpoint)
        checkpoint = torch.load(checkpoint_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        return checkpoint['epoch']  # return the epoch to resume from

    logger.info("No checkpoint found. Starting from scratch.")
    return 0  # Start from the first epoch if no checkpoint found


# Save model and optimizer state
def save_checkpoint(model, optimizer, epoch, checkpoint_dir):
    logger = logging.getLogger('default')
    
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)
    checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_epoch_{epoch}.pth")
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }, checkpoint_path)
    logger.info(f"Checkpoint saved at {checkpoint_path}")

## Train the model

### Helper functions

In [13]:
def transform_dataset_wiki40b(subsets=['train', 'validation', 'test']):
    logger = logging.getLogger('default')
    logger.info("Transforming Wiki40B dataset")

    dataset = load_dataset("wiki40b", "he")
    decoded_dataset = dataset.map(lambda x: {'text': decode_text(x['text'])})

    def transform_entry(entry):
        # Process the 'text' using parse_wiki_article
        article = parse_wiki_article(entry['text'])

        # Extract anchor_text and positive_text based on the parsed output
        anchor_text = article['title']
        if 'sections' in article and len(article['sections']) > 0:
            anchor_text += " " + article['sections'][0]['section']
            positive_text = article['sections'][0]['paragraphs'][0]
        else:
            positive_text = article['abstract'][0]

        # Return the transformed data
        return {
            'anchor_text': 'query : ' + anchor_text,
            'positive_text': 'document: ' + positive_text
        }

    # Apply the transformation to the train, validation, and test subsets
    transformed_dataset = {}
    for subset in subsets:
        # Transform each subset of the dataset using map (this processes each 'text' entry)
        logger.info(f"Transforming {subset} subset")
        transformed_subset = decoded_dataset[subset].map(transform_entry)
        transformed_dataset[subset] = transformed_subset

    # Return the transformed dataset as a DatasetDict
    logger.info("Done transforming Wiki40B dataset")
    return DatasetDict(transformed_dataset)


def decode_text(text):
    decoded_text = bytes(text, "utf-8").decode("unicode_escape").encode("latin1").decode("utf-8")
    return decoded_text


def parse_wiki_article(text):
    lines = text.strip().split('\n')

    PARAGRAPH_DIVIDER = '_NEWLINE_'

    # Initialize variables
    article_dict = {'title': '', 'abstract': '', 'sections': []}
    current_section = None
    abstract_parsed = False

    i = 0
    while i < len(lines):
        line = lines[i].strip()

        if line == "_START_ARTICLE_":
            # The next line is the title
            article_dict['title'] = lines[i + 1].strip()
            i += 2  # Move to the next relevant line
        elif line == "_START_PARAGRAPH_":
            # If the abstract has not been parsed and the current section is None, this is the abstract
            paragraph = lines[i + 1].strip()
            if not abstract_parsed and not current_section:
                article_dict['abstract'] = paragraph.split(PARAGRAPH_DIVIDER)
                abstract_parsed = True
            elif current_section:
                current_section['paragraphs'] = paragraph.split(PARAGRAPH_DIVIDER)
            i += 2
        elif line == "_START_SECTION_":
            # The next line is the section name
            section_name = lines[i + 1].strip()
            current_section = {'section': section_name, 'paragraphs': ''}
            article_dict['sections'].append(current_section)
            i += 2
        else:
            i += 1  # Move to the next line if none of the cases match

    return article_dict


def transform_dataset_synthesized(file_path, test_size=0.2):
    logger = logging.getLogger('default')
    logger.info("Transforming synthesized dataset")
    with open(file_path, 'rb') as f:
        data = pickle.load(f)

    def transform_entry(entry):
        # Return the transformed data
        return {
            'anchor_text': 'query: ' + entry['user_query'],
            'positive_text': 'document: ' + entry['positive_document'],
            'negative_text': 'document: ' + entry['hard_negative_document'],
        }

    # Apply the transformation to each entry
    transformed_data = list(map(transform_entry, data))

    # Convert the list of dictionaries to a Hugging Face Dataset
    dataset = Dataset.from_list(transformed_data)

    # Split the dataset into train and test sets
    train_test_dataset = dataset.train_test_split(test_size=test_size)
    train_validation_dataset = DatasetDict({
        'train': train_test_dataset['train'],  # Keep the 'train' split
        'validation': train_test_dataset['test']  # Rename 'test' to 'validation'
    })

    logger.info("Done transforming synthesized dataset")
    return train_validation_dataset


def transform_dataset(dataset_name, **kwargs):
    if dataset_name == 'wiki40b':
        return transform_dataset_wiki40b(**kwargs)
    elif dataset_name == 'synthesized_dataset':
        return transform_dataset_synthesized(**kwargs)
    else:
        raise ValueError(f"Unknown dataset name: {dataset_name}")

In [14]:
class InfoNCELoss(torch.nn.Module):
    def __init__(self, temperature=0.07):
        """
        Parameters:
        - temperature: Scaling factor applied to the logits before applying the softmax function.
        """
        super(InfoNCELoss, self).__init__()
        self.temperature = temperature

    def forward(self, anchor, positive, negatives):
        """
        Compute the InfoNCE loss.

        Parameters:
        - anchor: Tensor of shape (batch_size, embedding_dim) - anchor samples
        - positive: Tensor of shape (batch_size, embedding_dim) - positive samples corresponding to each anchor
        - negatives: Tensor of shape (batch_size, num_negatives, embedding_dim) - negative samples

        Returns:
        - loss: Computed InfoNCE loss
        """
        batch_size = anchor.size(0)
        num_negatives = negatives.size(1)

        # Normalize embeddings to unit vectors
        anchor = F.normalize(anchor, dim=-1)
        positive = F.normalize(positive, dim=-1)
        negatives = F.normalize(negatives, dim=-1)

        # Calculate the positive logits (similarity between anchor and positive)
        positive_logits = torch.sum(anchor * positive, dim=-1, keepdim=True)  # Shape: (batch_size, 1)

        # Calculate the negative logits (similarity between anchor and negatives)
        negative_logits = torch.bmm(negatives, anchor.unsqueeze(2)).squeeze(2)  # Shape: (batch_size, num_negatives)

        # Concatenate positive and negative logits
        logits = torch.cat([positive_logits, negative_logits], dim=1)  # Shape: (batch_size, 1 + num_negatives)

        # Apply temperature scaling
        logits = logits / self.temperature

        # Create labels - 0 for the positive samples, as it is the first in the concatenated logits
        labels = torch.zeros(batch_size, dtype=torch.long, device=logits.device)

        # Compute the InfoNCE loss using cross-entropy
        loss = F.cross_entropy(logits, labels)

        return loss

In [15]:
def validate(model, val_dataloader, criterion, device, epoch, epochs):
    model.eval()
    total_val_loss = 0.0

    # Track progress in the validation loop using tqdm
    val_progress = tqdm(enumerate(val_dataloader), desc=f"Epoch {epoch + 1}/{epochs} [Val]", leave=False, total=len(val_dataloader))

    with torch.no_grad():
        for batch_idx, batch in val_progress:

            if len(batch) == 4:
                anchor_ids, anchor_mask, positive_ids, positive_mask = [x.to(device) for x in batch]

                # Forward pass to get the embeddings
                anchor_outputs = model(input_ids=anchor_ids, attention_mask=anchor_mask)
                anchor_embeds = anchor_outputs.last_hidden_state[:, 0, :]  # CLS token embeddings

                positive_outputs = model(input_ids=positive_ids, attention_mask=positive_mask)
                positive_embeds = positive_outputs.last_hidden_state[:, 0, :]  # CLS token embeddings

                # Set negatives as the other positives in the batch
                # Create a matrix where the negatives are shifted versions of positives
                batch_size = positive_embeds.size(0)
    #             negatives_embeds = torch.stack([positive_embeds[i:] + positive_embeds[:i] for i in range(1, batch_size)], dim=0)
                # Create the negatives for each index `i` by excluding the positive embedding at index `i`
                negatives_embeds_list = []

                for i in range(batch_size):
                    # Exclude the current index `i` using slicing
                    negatives_embeds = torch.cat([positive_embeds[:i], positive_embeds[i+1:]], dim=0)

                    # Append the result to the list
                    negatives_embeds_list.append(negatives_embeds)

                # Stack the negatives for each sample in the batch
                # Each entry in the batch now has (batch_size - 1) negative embeddings
                negatives_embeds = torch.stack(negatives_embeds_list)

            elif len(batch) == 6:
                anchor_ids, anchor_mask, positive_ids, positive_mask, negative_ids, negative_mask = [x.to(device) for x in batch]

                # Forward pass to get the embeddings
                anchor_outputs = model(input_ids=anchor_ids, attention_mask=anchor_mask)
                anchor_embeds = anchor_outputs.last_hidden_state[:, 0, :]  # CLS token embeddings

                positive_outputs = model(input_ids=positive_ids, attention_mask=positive_mask)
                positive_embeds = positive_outputs.last_hidden_state[:, 0, :]  # CLS token embeddings

                negative_outputs = model(input_ids=negative_ids, attention_mask=negative_mask)
                negative_embeds = negative_outputs.last_hidden_state[:, 0, :]  # CLS token embeddings

                # Set negatives as the other positives in the batch
                # Create a matrix where the negatives are shifted versions of positives
                batch_size = positive_embeds.size(0)
    #             negatives_embeds = torch.stack([positive_embeds[i:] + positive_embeds[:i] for i in range(1, batch_size)], dim=0)
                # Create the negatives for each index `i` by excluding the positive embedding at index `i`
                negatives_embeds_list = []

                for i in range(batch_size):
                    # Exclude the current index `i` using slicing
                    negatives_embeds = torch.cat([positive_embeds[:i], negative_embeds, positive_embeds[i+1:]], dim=0)

                    # Append the result to the list
                    negatives_embeds_list.append(negatives_embeds)

                # Stack the negatives for each sample in the batch
                # Each entry in the batch now has (batch_size - 1) negative embeddings
                negatives_embeds = torch.stack(negatives_embeds_list)

            # Compute the validation loss
            val_loss = criterion(anchor_embeds, positive_embeds, negatives_embeds)
            total_val_loss += val_loss.item()

            # Update tqdm progress bar with the current batch number and average loss
            val_progress.set_postfix({
                "Batch": batch_idx + 1,
                "Val Loss": total_val_loss / (batch_idx + 1)
            })

    avg_val_loss = total_val_loss / len(val_dataloader)
    return avg_val_loss

In [16]:
def train(
    model,
    optimizer,
    criterion,
    train_dataloader,
    val_dataloader,
    device,
    epochs,
    start_epoch=0,
    checkpoint_dir='checkpoints',
    clip_value = None
):
    logger = logging.getLogger('default')
    logger.info("Start training")

    model.to(device)
    model.train()
    best_val_loss = float('inf')  # Initialize best validation loss to infinity

    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)

    for epoch in range(start_epoch, epochs):
        total_train_loss = 0.0
        model.to(device)
        model.train()

        # Track progress in the training loop using tqdm
        train_progress = tqdm(enumerate(train_dataloader), desc=f"Epoch {epoch + 1}/{epochs} [Train]", leave=False, total=len(train_dataloader))

        for batch_idx, batch in train_progress:

            if len(batch) == 4: # in case we have only (anchor, positive)
                anchor_ids, anchor_mask, positive_ids, positive_mask = [x.to(device) for x in batch]

                # Forward pass to get the embeddings
                anchor_outputs = model(input_ids=anchor_ids, attention_mask=anchor_mask)
                anchor_embeds = anchor_outputs.last_hidden_state[:, 0, :]  # CLS token embeddings

                positive_outputs = model(input_ids=positive_ids, attention_mask=positive_mask)
                positive_embeds = positive_outputs.last_hidden_state[:, 0, :]  # CLS token embeddings

                # Set negatives as the other positives in the batch
                # Create a matrix where the negatives are shifted versions of positives
                batch_size = positive_embeds.size(0)
    #             negatives_embeds = torch.stack([positive_embeds[i:] + positive_embeds[:i] for i in range(1, batch_size)], dim=0)
                # Create the negatives for each index `i` by excluding the positive embedding at index `i`
                negatives_embeds_list = []

                for i in range(batch_size):
                    # Exclude the current index `i` using slicing
                    negatives_embeds = torch.cat([positive_embeds[:i], positive_embeds[i+1:]], dim=0)

                    # Append the result to the list
                    negatives_embeds_list.append(negatives_embeds)

                # Stack the negatives for each sample in the batch
                # Each entry in the batch now has (batch_size - 1) negative embeddings
                negatives_embeds = torch.stack(negatives_embeds_list)

            elif len(batch) == 6: # in case we have only (anchor, positive, negative)
                anchor_ids, anchor_mask, positive_ids, positive_mask, negative_ids, negative_mask = [x.to(device) for x in batch]

                # Forward pass to get the embeddings
                anchor_outputs = model(input_ids=anchor_ids, attention_mask=anchor_mask)
                anchor_embeds = anchor_outputs.last_hidden_state[:, 0, :]  # CLS token embeddings

                positive_outputs = model(input_ids=positive_ids, attention_mask=positive_mask)
                positive_embeds = positive_outputs.last_hidden_state[:, 0, :]  # CLS token embeddings

                negative_outputs = model(input_ids=negative_ids, attention_mask=negative_mask)
                negative_embeds = negative_outputs.last_hidden_state[:, 0, :]  # CLS token embeddings

                # Set negatives as the other positives in the batch
                # Create a matrix where the negatives are shifted versions of positives
                batch_size = positive_embeds.size(0)
    #             negatives_embeds = torch.stack([positive_embeds[i:] + positive_embeds[:i] for i in range(1, batch_size)], dim=0)
                # Create the negatives for each index `i` by excluding the positive embedding at index `i`
                negatives_embeds_list = []

                for i in range(batch_size):
                    # Exclude the current index `i` using slicing
                    negatives_embeds = torch.cat([positive_embeds[:i], negative_embeds, positive_embeds[i+1:]], dim=0)

                    # Append the result to the list
                    negatives_embeds_list.append(negatives_embeds)

                # Stack the negatives for each sample in the batch
                # Each entry in the batch now has (batch_size - 1) negative embeddings
                negatives_embeds = torch.stack(negatives_embeds_list)

            # Compute the InfoNCE loss
            loss = criterion(anchor_embeds, positive_embeds, negatives_embeds)

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            if clip_value is not None:
                torch.nn.utils.clip_grad_norm_(model.parameters(), clip_value)
            optimizer.step()

            total_train_loss += loss.item()

            # Update tqdm progress bar with the current batch number and average loss
            train_progress.set_postfix({
                "Batch": batch_idx + 1,
                "Train Loss": total_train_loss / (batch_idx + 1)
            })

        avg_train_loss = total_train_loss / len(train_dataloader)
        logger.info(f"Epoch {epoch + 1}, Train Loss: {avg_train_loss}")

        # Compute validation loss after each epoch
        avg_val_loss = validate(model, val_dataloader, criterion, device, epoch, epochs)
        logger.info(f"Epoch {epoch + 1}, Validation Loss: {avg_val_loss}")

        # Save checkpoint after each epoch
        if loss < best_val_loss:
            best_val_loss = loss
            save_checkpoint(model, optimizer, epoch, checkpoint_dir)

### Wiki40b

In [17]:
from datetime import datetime
from torch.optim import AdamW

In [18]:
MODEL_NAME = 'intfloat/multilingual-e5-large'
BATCH_SIZE = 64
LEARNING_RATE = 5e-5
WEIGHT_DECAY = 1e-4
CLIP_VALUE = 1.0
INFONCE_TEMPERATURE = 0.07
EPOCHS = 10

In [19]:
model_name_slug = MODEL_NAME.replace('/', '_').replace('-', '_')
log_file = f"./logs/hte_training_{model_name_slug}_01_wiki40b.log"
logger = setup_logger(log_file)

In [20]:
%%time

# device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")

# Define model
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(MODEL_NAME)
model = model.to(device)
logger.info(f"Start train base model: {MODEL_NAME}")

# Initialize the InfoNCE loss and the optimizer
criterion = InfoNCELoss(temperature=INFONCE_TEMPERATURE)
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

# Datasets to train on
dataset_names = ['wiki40b']

# Iterate over datasets and train
for dataset_name in dataset_names:
    start_datetime = datetime.now()

    logger.info(f"Switching to new dataset: {dataset_name}")
    dataset = transform_dataset(dataset_name, subsets=['train', 'validation'])

    # Tokenize the train dataset
    logger.info(f"Tokenizing train dataset")
    anchor_inputs_train = tokenizer(dataset['train']['anchor_text'], return_tensors='pt', padding=True, truncation=True)
    positive_inputs_train = tokenizer(dataset['train']['positive_text'], return_tensors='pt', padding=True, truncation=True)

    # Create DataLoader for training
    logger.info(f"Creating train dataloader")
    train_dataset = TensorDataset(anchor_inputs_train['input_ids'], anchor_inputs_train['attention_mask'],
                                  positive_inputs_train['input_ids'], positive_inputs_train['attention_mask'])
    train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

    # Tokenize the validation dataset
    logger.info(f"Tokenizing validation dataset")
    anchor_inputs_val = tokenizer(dataset['validation']['anchor_text'], return_tensors='pt', padding=True, truncation=True)
    positive_inputs_val = tokenizer(dataset['validation']['positive_text'], return_tensors='pt', padding=True, truncation=True)

    # Create DataLoader for validation
    logger.info(f"Creating validation dataloader")
    val_dataset = TensorDataset(anchor_inputs_val['input_ids'], anchor_inputs_val['attention_mask'],
                                positive_inputs_val['input_ids'], positive_inputs_val['attention_mask'])
    val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

    # Load the latest checkpoint if available and resume training
    logger.info(f"Loading checkpoint")
    checkpoint_dir = f"checkpoints/{model_name_slug}/checkpoints_01_wiki40b"
    start_epoch = load_checkpoint(model, optimizer, checkpoint_dir=checkpoint_dir, device=device)

    # Train the model for this dataset
    train(
        model=model,
        optimizer=optimizer,
        criterion=criterion,
        train_dataloader=train_dataloader,
        val_dataloader=val_dataloader,
        device=device,
        epochs=EPOCHS,
        start_epoch=start_epoch,
        checkpoint_dir=checkpoint_dir,
        clip_value=CLIP_VALUE
    )

    end_datetime = datetime.now()
    logger.info(f"Total training on {dataset_name} elapsed time is {(end_datetime - start_datetime).total_seconds()} seconds")

logger.info(f"End train base model: {MODEL_NAME}")

2024-10-02 08:34:51,804 - default - INFO - Using device: cuda
2024-10-02 08:34:56,570 - default - INFO - Start train base model: intfloat/multilingual-e5-large
2024-10-02 08:34:57,070 - default - INFO - Switching to new dataset: wiki40b
2024-10-02 08:34:57,071 - default - INFO - Transforming Wiki40B dataset
2024-10-02 08:35:07,541 - default - INFO - Transforming train subset
2024-10-02 08:35:07,551 - default - INFO - Transforming validation subset
2024-10-02 08:35:07,556 - default - INFO - Done transforming Wiki40B dataset
2024-10-02 08:35:07,561 - default - INFO - Tokenizing train dataset
2024-10-02 08:36:05,951 - default - INFO - Creating train dataloader
2024-10-02 08:36:05,952 - default - INFO - Tokenizing validation dataset
2024-10-02 08:36:08,663 - default - INFO - Creating validation dataloader
2024-10-02 08:36:08,664 - default - INFO - Loading checkpoint
2024-10-02 08:36:08,665 - default - INFO - No checkpoint found. Starting from scratch.
2024-10-02 08:36:08,665 - default - IN

### Synthesized data

In [None]:
MODEL_NAME = 'intfloat/multilingual-e5-base'
BATCH_SIZE = 16
LEARNING_RATE = 5e-5
INFONCE_TEMPERATURE = 0.07

In [None]:
model_name_slug = MODEL_NAME.replace('/', '_').replace('-', '_')
log_file = f"./logs/hte_training_{model_name_slug}_02_synthesized.log"
logger = setup_logger(log_file)

In [None]:
# device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")

# Define model
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(MODEL_NAME)
model = model.to(device)
logger.info(f"Start train base model: {MODEL_NAME}")

# Initialize the InfoNCE loss and the optimizer
criterion = InfoNCELoss(temperature=INFONCE_TEMPERATURE)
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)

# Load model at checkpoint
epoch = load_checkpoint(model=model, optimizer=optimizer, checkpoint_dir='checkpoints/checkpoints_01_wiki40b', device=device)
logger.info(f"Loaded model from epoch {epoch}")

2024-09-12 08:46:38,985 - default - INFO - Using device: cuda


tokenizer_config.json:   0%|          | 0.00/418 [00:00<?, ?B/s]

sentencepiece.bpe.model:   0%|          | 0.00/5.07M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.1M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/280 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/694 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.11G [00:00<?, ?B/s]

2024-09-12 08:46:49,938 - default - INFO - Start train base model: intfloat/multilingual-e5-base
2024-09-12 08:46:50,945 - default - INFO - Loading checkpoint checkpoint_epoch_2.pth
  checkpoint = torch.load(checkpoint_path, map_location=device)
2024-09-12 08:47:33,963 - default - INFO - Loaded model from epoch 2


In [None]:
%%time

# Datasets to train on
dataset_names = ['synthesized_dataset']

# Iterate over datasets and train
for dataset_name in dataset_names:
    start_datetime = datetime.now()

    logger.info(f"Switching to new dataset: {dataset_name}")
    dataset = transform_dataset(dataset_name, tokenizer=tokenizer, file_path='data/synthetic_data_20240906_0018.pkl')

    # Tokenize the train dataset
    anchor_inputs_train = tokenizer(dataset['train']['anchor_text'], return_tensors='pt', padding=True, truncation=True)
    positive_inputs_train = tokenizer(dataset['train']['positive_text'], return_tensors='pt', padding=True, truncation=True)
    negative_inputs_train = tokenizer(dataset['train']['negative_text'], return_tensors='pt', padding=True, truncation=True)

    # Create DataLoader for training
    train_dataset = TensorDataset(anchor_inputs_train['input_ids'], anchor_inputs_train['attention_mask'],
                                  positive_inputs_train['input_ids'], positive_inputs_train['attention_mask'],
                                  negative_inputs_train['input_ids'], negative_inputs_train['attention_mask'])
    train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

    # Tokenize the validation dataset
    anchor_inputs_val = tokenizer(dataset['validation']['anchor_text'], return_tensors='pt', padding=True, truncation=True)
    positive_inputs_val = tokenizer(dataset['validation']['positive_text'], return_tensors='pt', padding=True, truncation=True)
    negative_inputs_val = tokenizer(dataset['validation']['negative_text'], return_tensors='pt', padding=True, truncation=True)

    # Create DataLoader for validation
    val_dataset = TensorDataset(anchor_inputs_val['input_ids'], anchor_inputs_val['attention_mask'],
                                positive_inputs_val['input_ids'], positive_inputs_val['attention_mask'],
                                negative_inputs_val['input_ids'], negative_inputs_val['attention_mask'])
    val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

    # Load the latest checkpoint if available and resume training
    checkpoint_dir = "checkpoints/checkpoints_02_synthesized"
    start_epoch = load_checkpoint(model, optimizer, checkpoint_dir=checkpoint_dir, device=device)

    # Train the model for this dataset
    train(
        model=model,
        optimizer=optimizer,
        criterion=criterion,
        train_dataloader=train_dataloader,
        val_dataloader=val_dataloader,
        device=device,
        epochs=3,
        start_epoch=start_epoch,
        checkpoint_dir=checkpoint_dir
    )

    end_datetime = datetime.now()
    logger.info(f"Total training on {dataset_name} elapsed time is {(end_datetime - start_datetime).total_seconds()} seconds")

logger.info(f"End train base model: {MODEL_NAME}")

2024-09-12 08:47:34,022 - default - INFO - Switching to new dataset: synthesized_dataset
2024-09-12 08:47:43,775 - default - INFO - No checkpoint found. Starting from scratch.
2024-09-12 08:54:49,910 - default - INFO - Epoch 1, Train Loss: 0.4075799766257405
2024-09-12 08:55:21,333 - default - INFO - Epoch 1, Validation Loss: 0.29725874787569045
2024-09-12 08:55:30,123 - default - INFO - Checkpoint saved at checkpoints/checkpoints_02_synthesized/checkpoint_epoch_0.pth
2024-09-12 09:02:35,619 - default - INFO - Epoch 2, Train Loss: 0.1710497047379613
2024-09-12 09:03:07,053 - default - INFO - Epoch 2, Validation Loss: 0.2627878464460373
2024-09-12 09:03:19,624 - default - INFO - Checkpoint saved at checkpoints/checkpoints_02_synthesized/checkpoint_epoch_1.pth
2024-09-12 09:10:25,313 - default - INFO - Epoch 3, Train Loss: 0.11907066453620792
2024-09-12 09:10:56,767 - default - INFO - Epoch 3, Validation Loss: 0.26005379532277584
2024-09-12 09:11:10,023 - default - INFO - Checkpoint save

CPU times: user 14min 18s, sys: 9min 20s, total: 23min 39s
Wall time: 23min 36s


## Evaluation

In [None]:
from datasets import load_dataset, Dataset, DatasetDict
from transformers import AutoModel, AutoTokenizer
import json
import pandas as pd
import torch
import torch.nn as nn
from torch.optim import AdamW
import logging
import os
from tqdm import tqdm
from torch.utils.data import DataLoader, TensorDataset

### HebNLI

In [None]:
# !git clone https://github.com/NNLP-IL/HebNLI.git eval/HebNLI # do not uncomment -- train jsonl file was manually added

In [None]:
label_mapping = {
    'entailment': 0,
    'neutral': 1,
    'contradiction': 2,
}

def load_jsonl_data(file_path):
    data = []

    with open(file_path, 'r', encoding='utf-8') as file:
        for line in file:
            record = json.loads(line)
            if 'translation1' in record and 'translation2' in record and 'original_label' in record and record['original_label'] in label_mapping.keys():
                data.append({
                    'premise': record['translation1'],
                    'hypothesis': record['translation2'],
                    'label': label_mapping[record['original_label']]
                })
    return data

# Load the data for each subset
train_data = load_jsonl_data('eval/HebNLI/HebNLI_train.jsonl')
val_data = load_jsonl_data('eval/HebNLI/HebNLI_val.jsonl')
test_data = load_jsonl_data('eval/HebNLI/HebNLI_test.jsonl')

# Create a DatasetDict containing all the subsets
nli_dataset = DatasetDict({
    'train': Dataset.from_pandas(pd.DataFrame(train_data)),
    'validation': Dataset.from_pandas(pd.DataFrame(val_data)),
    'test': Dataset.from_pandas(pd.DataFrame(test_data))
})
nli_dataset

DatasetDict({
    train: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 300067
    })
    validation: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 1999
    })
    test: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 579
    })
})

In [None]:
class CustomNLIModel(nn.Module):
    def __init__(self, pretrained_model, hidden_size=768, num_labels=3, freeze_backbone_weights=False):
        super(CustomNLIModel, self).__init__()
        # Use a pre-trained model like BERT or any custom backbone
        self.backbone = pretrained_model

        # Freeze the base model's weights
        if freeze_backbone_weights:
            for param in self.backbone.parameters():
                param.requires_grad = False

        self.classifier = nn.Linear(hidden_size, num_labels)  # Final classifier layer for entailment/contradiction/neutral

    def forward(self, input_ids, attention_mask=None):
        # Forward pass through the backbone model (e.g., BERT, RoBERTa)
        outputs = self.backbone(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.last_hidden_state[:, 0]  # Use [CLS] token representation for classification

        # Pass the pooled output through the classifier
        logits = self.classifier(pooled_output)
        return logits


def tokenize_dataset(dataset, tokenizer):
    # Tokenize the dataset
    def preprocess_function(examples):
        return tokenizer(examples['premise'], examples['hypothesis'], truncation=True, padding='max_length')

    # Tokenize the entire dataset
    tokenized_dataset = dataset.map(preprocess_function, batched=True)
    return tokenized_dataset


def create_dataloader(tokenized_dataset, batch_size=16):
    # Convert the dataset to PyTorch tensors
    tokenized_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])

    # Create a DataLoader
    dataloader = DataLoader(tokenized_dataset, batch_size=batch_size)
    return dataloader


def train_nli(model, train_dataloader, val_dataloader, device, epochs=3, lr=2e-5):
    # Define loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = AdamW(model.parameters(), lr=lr)

    for epoch in range(epochs):
        total_train_loss = 0.0
        model.to(device)
        model.train()

        # Track progress in the training loop using tqdm
        train_progress = tqdm(enumerate(train_dataloader), desc=f"Epoch {epoch + 1}/{epochs} [Train]", leave=False, total=len(train_dataloader))

        # Training loop
        for batch_idx, batch in train_progress:
            total_loss = 0.0
            correct = 0
            total = 0

            # Move batch to device
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)

            # Forward pass
            outputs = model(input_ids, attention_mask=attention_mask)

            # Calculate loss
            loss = criterion(outputs, labels)
            total_loss += loss.item()

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_train_loss += loss.item()

            # Calculate accuracy
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            # Update tqdm progress bar with the current batch number and average loss
            train_progress.set_postfix({
                "Batch": batch_idx + 1,
                "Train Loss": total_train_loss / (batch_idx + 1)
            })

        # Print epoch loss and accuracy
        avg_train_loss = total_train_loss / len(train_dataloader)
        train_accuracy = correct / total
        logger.info(f"Epoch {epoch+1}/{epochs}, Loss: {avg_train_loss:.4f}, Accuracy: {train_accuracy:.4f}")

        # Compute validation loss after each epoch
        avg_val_loss, val_accuracy = validate_nli(model, val_dataloader, criterion, device, epoch, epochs)


def validate_nli(model, val_dataloader, criterion, device, epoch, epochs):
    model.eval()
    total_val_loss = 0.0
    correct = 0
    total = 0

    # Track progress in the validation loop using tqdm
    val_progress = tqdm(enumerate(val_dataloader), desc=f"Epoch {epoch + 1}/{epochs} [Val]", leave=False, total=len(val_dataloader))

    with torch.no_grad():
        for batch_idx, batch in val_progress:
            # Move batch to device
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)

            # Forward pass
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)

            # Calculate loss
            loss = criterion(outputs, labels)
            total_val_loss += loss.item()

            # Calculate accuracy
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    # Calculate average validation loss and accuracy
    avg_val_loss = total_val_loss / len(val_dataloader)
    val_accuracy = correct / total

    logger.info(f"Validation Loss: {avg_val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}")
    return avg_val_loss, val_accuracy

In [None]:
MODEL_NAME = 'intfloat/multilingual-e5-base'
model_name_slug = MODEL_NAME.replace('/', '_').replace('-', '_')

BATCH_SIZE = 32
LR = 2e-5
EPOCHS = 3

In [None]:
# Load and tokenize the train and validation dataset
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
train_tokenized_dataset = tokenize_dataset(nli_dataset['train'], tokenizer)
val_tokenized_dataset = tokenize_dataset(nli_dataset['validation'], tokenizer)

Map:   0%|          | 0/300067 [00:00<?, ? examples/s]

Map:   0%|          | 0/1999 [00:00<?, ? examples/s]

### Base model

In [None]:
log_file = f"./logs/hte_eval_{model_name_slug}_00_base_hebnli.log"
logger = setup_logger(log_file)

# device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")

# Define backbone model
model = AutoModel.from_pretrained(MODEL_NAME)
model = model.to(device)
logger.info(f"Start train base model: {MODEL_NAME}")

# Define NLI model
nli_model = CustomNLIModel(model)

# Create the train and validation DataLoader
train_dataloader = create_dataloader(train_tokenized_dataset, batch_size=BATCH_SIZE)
val_dataloader = create_dataloader(val_tokenized_dataset, batch_size=BATCH_SIZE)

In [None]:
train_nli(
    model=nli_model,
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    device=device,
    epochs=EPOCHS,
    lr=LR
)

### 1st pass (Wiki40b) trained model

In [None]:
log_file = f"./logs/hte_eval_{model_name_slug}_01_wiki40b_hebnli.log"
logger = setup_logger(log_file)

# device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")

# Define backbone model
model = AutoModel.from_pretrained(MODEL_NAME)
model = model.to(device)
logger.info(f"Start train base model: {MODEL_NAME}")

# Load model at checkpoint
optimizer = AdamW(model.parameters(), lr=LR)
epoch = load_checkpoint(model=model, optimizer=optimizer, checkpoint_dir='checkpoints/checkpoints_01_wiki40b', device=device)
logger.info(f"Loaded model from epoch {epoch}")

# Define NLI model
nli_model = CustomNLIModel(model)

# Create the train and validation DataLoader
train_dataloader = create_dataloader(train_tokenized_dataset, batch_size=BATCH_SIZE)
val_dataloader = create_dataloader(val_tokenized_dataset, batch_size=BATCH_SIZE)

2024-09-14 23:31:18,550 - default - INFO - Using device: cuda
2024-09-14 23:31:19,245 - default - INFO - Start train base model: intfloat/multilingual-e5-base
2024-09-14 23:31:19,249 - default - INFO - Loading checkpoint checkpoint_epoch_2.pth
  checkpoint = torch.load(checkpoint_path, map_location=device)
2024-09-14 23:31:23,325 - default - INFO - Loaded model from epoch 2


In [None]:
train_nli(
    model=nli_model,
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    device=device,
    epochs=EPOCHS,
    lr=LR
)

2024-09-15 01:26:32,504 - default - INFO - Epoch 1/3, Loss: 0.7150, Accuracy: 0.6667
2024-09-15 01:26:47,004 - default - INFO - Validation Loss: 0.6193, Validation Accuracy: 0.7394
2024-09-15 03:21:59,500 - default - INFO - Epoch 2/3, Loss: 0.5688, Accuracy: 1.0000
2024-09-15 03:22:14,044 - default - INFO - Validation Loss: 0.6165, Validation Accuracy: 0.7674
2024-09-15 05:17:30,848 - default - INFO - Epoch 3/3, Loss: 0.4737, Accuracy: 1.0000
2024-09-15 05:17:45,449 - default - INFO - Validation Loss: 0.6666, Validation Accuracy: 0.7559


### 2nd pass (Synthesized data) trained model

In [None]:
log_file = f"./logs/hte_eval_{model_name_slug}_02_synthesized_hebnli.log"
logger = setup_logger(log_file)

# device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")

# Define backbone model
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(MODEL_NAME)
model = model.to(device)
logger.info(f"Start train base model: {MODEL_NAME}")

# Load model at checkpoint
optimizer = AdamW(model.parameters(), lr=LR)
epoch = load_checkpoint(model=model, optimizer=optimizer, checkpoint_dir='checkpoints/checkpoints_02_synthesized', device=device)
logger.info(f"Loaded model from epoch {epoch}")

# Define NLI model
nli_model = CustomNLIModel(model)

# Create the train and validation DataLoader
train_dataloader = create_dataloader(train_tokenized_dataset, batch_size=BATCH_SIZE)
val_dataloader = create_dataloader(val_tokenized_dataset, batch_size=BATCH_SIZE)

2024-09-15 05:17:45,487 - default - INFO - Using device: cuda
2024-09-15 05:17:47,875 - default - INFO - Start train base model: intfloat/multilingual-e5-base
2024-09-15 05:17:49,426 - default - INFO - Loading checkpoint checkpoint_epoch_2.pth
  checkpoint = torch.load(checkpoint_path, map_location=device)
2024-09-15 05:18:34,025 - default - INFO - Loaded model from epoch 2


In [None]:
train_nli(
    model=nli_model,
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    device=device,
    epochs=EPOCHS,
    lr=LR
)

Epoch 1/3 [Train]:  78%|███████▊  | 14648/18755 [1:30:05<25:15,  2.71it/s, Batch=14648, Train Loss=0.741]

In [None]:
from google.colab import runtime
runtime.unassign()