# Training Hebrew Text Encoder

In [1]:
!nvidia-smi

Wed Oct  9 17:14:38 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.14              Driver Version: 550.54.14      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100-SXM4-80GB          Off |   00000000:0E:00.0 Off |                    0 |
| N/A   35C    P0             68W /  400W |    1210MiB /  81920MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA A100-SXM4-80GB          Off |   00

In [3]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "4"

## Setup the environment

In [4]:
!pip install -q -U torch transformers bitsandbytes datasets huggingface_hub accelerate tqdm

In [5]:
from huggingface_hub import notebook_login
import os
import sys

In [6]:
os.environ["HF_TOKEN"] = "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"
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [7]:
def is_running_in_colab():
    return 'COLAB_GPU' in os.environ

if is_running_in_colab():
    # Load the Drive helper and mount
    print("Mounting google drive...")
    from google.colab import drive
    drive.mount('/content/drive')
else:
    print("Not running in Google Colab!")

Not running in Google Colab!


In [8]:
project_dir = '/home/nlp/achimoa/projects/hebrew_text_encoder'
src_dir = os.path.join(project_dir, 'src')

os.chdir(project_dir)
print(f"Current working directory set to: {os.getcwd()}")


if src_dir not in sys.path:
    sys.path.insert(0, src_dir)  # Add it to the front of PYTHONPATH
    print(f"PYTHONPATH updated with: {src_dir}")
else:
    print(f"PYTHONPATH already includes: {src_dir}")

Current working directory set to: /home/nlp/achimoa/projects/hebrew_text_encoder
PYTHONPATH updated with: /home/nlp/achimoa/projects/hebrew_text_encoder/src


In [9]:
%reload_ext autoreload
%autoreload 2
from transformers import AutoModel, AutoTokenizer
from datasets import concatenate_datasets
import torch
from torch.optim import AdamW
from datetime import datetime
from data import *
from loss import *
from trainings import *
from utils import *

### Wiki40b

In [10]:
from datetime import datetime
from torch.optim import AdamW

In [11]:
MODEL_NAME = 'onlplab/alephbert-base'
BATCH_SIZE = 64
LEARNING_RATE = 5e-5
WEIGHT_DECAY = 1e-4
CLIP_VALUE = 1.0
INFONCE_TEMPERATURE = 0.07
EPOCHS = 10

In [12]:
model_name_slug = MODEL_NAME.replace('/', '_').replace('-', '_')
log_file = f"./logs/hte_training_{model_name_slug}_01_wiki40b.log"
logger = setup_logger(log_file)

In [13]:
%%time

# device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")

# Define model
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(MODEL_NAME)
model = model.to(device)
logger.info(f"Start train base model: {MODEL_NAME}")

# Initialize the InfoNCE loss and the optimizer
criterion = InfoNCELoss(temperature=INFONCE_TEMPERATURE)
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

# Datasets to train on
dataset_names = ['wiki40b']

# Iterate over datasets and train
for dataset_name in dataset_names:
    start_datetime = datetime.now()

    logger.info(f"Switching to new dataset: {dataset_name}")
    dataset = transform_dataset(dataset_name, subsets=['train', 'validation'])

    # Tokenize the train dataset
    logger.info(f"Tokenizing train dataset")
    anchor_inputs_train = tokenizer(dataset['train']['anchor_text'], return_tensors='pt', padding=True, truncation=True)
    positive_inputs_train = tokenizer(dataset['train']['positive_text'], return_tensors='pt', padding=True, truncation=True)

    # Create DataLoader for training
    logger.info(f"Creating train dataloader")
    train_dataset = TensorDataset(anchor_inputs_train['input_ids'], anchor_inputs_train['attention_mask'],
                                  positive_inputs_train['input_ids'], positive_inputs_train['attention_mask'])
    train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

    # Tokenize the validation dataset
    logger.info(f"Tokenizing validation dataset")
    anchor_inputs_val = tokenizer(dataset['validation']['anchor_text'], return_tensors='pt', padding=True, truncation=True)
    positive_inputs_val = tokenizer(dataset['validation']['positive_text'], return_tensors='pt', padding=True, truncation=True)

    # Create DataLoader for validation
    logger.info(f"Creating validation dataloader")
    val_dataset = TensorDataset(anchor_inputs_val['input_ids'], anchor_inputs_val['attention_mask'],
                                positive_inputs_val['input_ids'], positive_inputs_val['attention_mask'])
    val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

    # Load the latest checkpoint if available and resume training
    logger.info(f"Loading checkpoint")
    checkpoint_dir = f"checkpoints/{model_name_slug}/checkpoints_01_wiki40b"
    start_epoch = load_checkpoint(model, optimizer, checkpoint_dir=checkpoint_dir, device=device)

    # Train the model for this dataset
    train(
        model=model,
        optimizer=optimizer,
        criterion=criterion,
        train_dataloader=train_dataloader,
        val_dataloader=val_dataloader,
        device=device,
        epochs=EPOCHS,
        start_epoch=start_epoch,
        checkpoint_dir=checkpoint_dir,
        clip_value=CLIP_VALUE
    )

    end_datetime = datetime.now()
    logger.info(f"Total training on {dataset_name} elapsed time is {(end_datetime - start_datetime).total_seconds()} seconds")

logger.info(f"End train base model: {MODEL_NAME}")

2024-10-03 19:11:16,772 - default - INFO - Using device: cuda
Some weights of BertModel were not initialized from the model checkpoint at onlplab/alephbert-base and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
2024-10-03 19:11:19,722 - default - INFO - Start train base model: onlplab/alephbert-base
2024-10-03 19:11:20,257 - default - INFO - Switching to new dataset: wiki40b
2024-10-03 19:11:20,258 - default - INFO - Transforming Wiki40B dataset
2024-10-03 19:11:29,090 - default - INFO - Transforming train subset
2024-10-03 19:11:29,098 - default - INFO - Transforming validation subset
2024-10-03 19:11:29,102 - default - INFO - Done transforming Wiki40B dataset
2024-10-03 19:11:29,109 - default - INFO - Tokenizing train dataset


### Synthesized data

In [10]:
MODEL_NAME = 'onlplab/alephbert-base'
BATCH_SIZE = 32
LEARNING_RATE = 5e-5
WEIGHT_DECAY = 1e-4
CLIP_VALUE = 1.0
INFONCE_TEMPERATURE = 0.07
EPOCHS = 20
TRAIN_ID = '02_synthesized'

model_name_slug = MODEL_NAME.replace('/', '_').replace('-', '_')

SOURCE_CHECKPOINT_DIR = f'checkpoints/{model_name_slug}/checkpoints_01_wiki40b'
CHECKPOINT_DIR = f'checkpoints/{model_name_slug}/checkpoints_{TRAIN_ID}'

In [11]:
# Logger
log_file = f"./logs/{model_name_slug}/{TRAIN_ID}.log"
logger = setup_logger(log_file)

# Get device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")

# Define model
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(MODEL_NAME)
model = model.to(device)
logger.info(f"Start train base model: {MODEL_NAME}")

# Initialize the InfoNCE loss and the optimizer
criterion = InfoNCELoss(temperature=INFONCE_TEMPERATURE)
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

# Load model at checkpoint
start_epoch = load_checkpoint(model=model, optimizer=optimizer, checkpoint_dir=SOURCE_CHECKPOINT_DIR, device=device)
logger.info(f"Loaded model from epoch {start_epoch}")

2024-10-05 08:27:49,777 - default - INFO - Using device: cuda
Some weights of BertModel were not initialized from the model checkpoint at onlplab/alephbert-base and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
2024-10-05 08:27:53,701 - default - INFO - Start train base model: onlplab/alephbert-base
2024-10-05 08:27:54,418 - default - INFO - Loading checkpoint checkpoint_epoch_7.pth
  checkpoint = torch.load(checkpoint_path, map_location=device)
2024-10-05 08:28:21,620 - default - INFO - Loaded model from epoch 7


In [12]:
dataset_name = 'synthesized_query_document'
logger.info(f"Switching to new dataset: {dataset_name}")
dataset = transform_dataset(dataset_name, data_folder_path='./data/synthetic_data_202409')

# Tokenize the train dataset and creating the dataloader
logger.info("Tokenize the train dataset and creating the dataloader")
train_dataloader = tokenize_inputs_and_create_dataloader(tokenizer, dataset['train'], batch_size=BATCH_SIZE, shuffle=True)

# Tokenize the validation dataset and creating the dataloader
logger.info("Tokenize the validation dataset and creating the dataloader")
val_dataloader = tokenize_inputs_and_create_dataloader(tokenizer, dataset['validation'], batch_size=BATCH_SIZE, shuffle=False)

2024-10-05 08:28:21,834 - default - INFO - Switching to new dataset: synthesized_query_document
2024-10-05 08:28:21,836 - default - INFO - Transforming synthesized dataset
2024-10-05 08:28:21,837 - default - INFO - Loading synthesize query document dataset from {data_folder_path}
Loading data files:   0%|                                                                                                   | 0/13 [00:00<?, ?it/s]2024-10-05 08:28:21,847 - default - DEBUG - Loading data from synthetic_data_20240906_0018.pkl
2024-10-05 08:28:21,922 - default - DEBUG - Loading data from synthetic_data_20240920_0557.pkl
2024-10-05 08:28:21,924 - default - DEBUG - Loading data from synthetic_data_20240924_1958.pkl
Loading data files:  23%|█████████████████████                                                                      | 3/13 [00:00<00:00, 23.31it/s]2024-10-05 08:28:21,976 - default - DEBUG - Loading data from synthetic_data_20240924_1959.pkl
2024-10-05 08:28:22,077 - default - DEBUG - L

In [13]:
%%time

start_datetime = datetime.now()
# Train the model for this dataset
train(
    model=model,
    optimizer=optimizer,
    criterion=criterion,
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    device=device,
    epochs=EPOCHS,
    start_epoch=start_epoch,
    checkpoint_dir=CHECKPOINT_DIR,
    clip_value=CLIP_VALUE
)

end_datetime = datetime.now()
logger.info(f"Total training on {dataset_name} elapsed time is {(end_datetime - start_datetime).total_seconds()} seconds")

logger.info(f"End train base model: {MODEL_NAME}")

2024-10-05 08:29:35,785 - default - INFO - Start training
2024-10-05 09:24:59,483 - default - INFO - Epoch 8, Train Loss: 0.5895706873030766                                                                 
2024-10-05 09:27:10,531 - default - INFO - Epoch 8, Validation Loss: 0.4877923820943882                                                            
2024-10-05 09:27:12,876 - default - INFO - Checkpoint saved at checkpoints/onlplab_alephbert_base/checkpoints_02_synthesized/checkpoint_epoch_7.pth
2024-10-05 10:21:27,240 - default - INFO - Epoch 9, Train Loss: 0.3431768832647282                                                                 
2024-10-05 10:23:38,519 - default - INFO - Epoch 9, Validation Loss: 0.46702060711363125                                                           
2024-10-05 10:23:42,275 - default - INFO - Checkpoint saved at checkpoints/onlplab_alephbert_base/checkpoints_02_synthesized/checkpoint_epoch_8.pth
2024-10-05 11:17:59,340 - default - INFO - Epoch 10, T

CPU times: user 7h 23min 15s, sys: 4h 36min 39s, total: 11h 59min 54s
Wall time: 12h 16min


## Evaluation

In [None]:
from datasets import load_dataset, Dataset, DatasetDict
from transformers import AutoModel, AutoTokenizer
import json
import pandas as pd
import torch
import torch.nn as nn
from torch.optim import AdamW
import logging
import os
from tqdm import tqdm
from torch.utils.data import DataLoader, TensorDataset

### HebNLI

In [None]:
# !git clone https://github.com/NNLP-IL/HebNLI.git eval/HebNLI # do not uncomment -- train jsonl file was manually added

In [None]:
label_mapping = {
    'entailment': 0,
    'neutral': 1,
    'contradiction': 2,
}

def load_jsonl_data(file_path):
    data = []

    with open(file_path, 'r', encoding='utf-8') as file:
        for line in file:
            record = json.loads(line)
            if 'translation1' in record and 'translation2' in record and 'original_label' in record and record['original_label'] in label_mapping.keys():
                data.append({
                    'premise': record['translation1'],
                    'hypothesis': record['translation2'],
                    'label': label_mapping[record['original_label']]
                })
    return data

# Load the data for each subset
train_data = load_jsonl_data('eval/HebNLI/HebNLI_train.jsonl')
val_data = load_jsonl_data('eval/HebNLI/HebNLI_val.jsonl')
test_data = load_jsonl_data('eval/HebNLI/HebNLI_test.jsonl')

# Create a DatasetDict containing all the subsets
nli_dataset = DatasetDict({
    'train': Dataset.from_pandas(pd.DataFrame(train_data)),
    'validation': Dataset.from_pandas(pd.DataFrame(val_data)),
    'test': Dataset.from_pandas(pd.DataFrame(test_data))
})
nli_dataset

DatasetDict({
    train: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 300067
    })
    validation: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 1999
    })
    test: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 579
    })
})

In [None]:
class CustomNLIModel(nn.Module):
    def __init__(self, pretrained_model, hidden_size=768, num_labels=3, freeze_backbone_weights=False):
        super(CustomNLIModel, self).__init__()
        # Use a pre-trained model like BERT or any custom backbone
        self.backbone = pretrained_model

        # Freeze the base model's weights
        if freeze_backbone_weights:
            for param in self.backbone.parameters():
                param.requires_grad = False

        self.classifier = nn.Linear(hidden_size, num_labels)  # Final classifier layer for entailment/contradiction/neutral

    def forward(self, input_ids, attention_mask=None):
        # Forward pass through the backbone model (e.g., BERT, RoBERTa)
        outputs = self.backbone(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.last_hidden_state[:, 0]  # Use [CLS] token representation for classification

        # Pass the pooled output through the classifier
        logits = self.classifier(pooled_output)
        return logits


def tokenize_dataset(dataset, tokenizer):
    # Tokenize the dataset
    def preprocess_function(examples):
        return tokenizer(examples['premise'], examples['hypothesis'], truncation=True, padding='max_length')

    # Tokenize the entire dataset
    tokenized_dataset = dataset.map(preprocess_function, batched=True)
    return tokenized_dataset


def create_dataloader(tokenized_dataset, batch_size=16):
    # Convert the dataset to PyTorch tensors
    tokenized_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])

    # Create a DataLoader
    dataloader = DataLoader(tokenized_dataset, batch_size=batch_size)
    return dataloader


def train_nli(model, train_dataloader, val_dataloader, device, epochs=3, lr=2e-5):
    # Define loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = AdamW(model.parameters(), lr=lr)

    for epoch in range(epochs):
        total_train_loss = 0.0
        model.to(device)
        model.train()

        # Track progress in the training loop using tqdm
        train_progress = tqdm(enumerate(train_dataloader), desc=f"Epoch {epoch + 1}/{epochs} [Train]", leave=False, total=len(train_dataloader))

        # Training loop
        for batch_idx, batch in train_progress:
            total_loss = 0.0
            correct = 0
            total = 0

            # Move batch to device
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)

            # Forward pass
            outputs = model(input_ids, attention_mask=attention_mask)

            # Calculate loss
            loss = criterion(outputs, labels)
            total_loss += loss.item()

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_train_loss += loss.item()

            # Calculate accuracy
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            # Update tqdm progress bar with the current batch number and average loss
            train_progress.set_postfix({
                "Batch": batch_idx + 1,
                "Train Loss": total_train_loss / (batch_idx + 1)
            })

        # Print epoch loss and accuracy
        avg_train_loss = total_train_loss / len(train_dataloader)
        train_accuracy = correct / total
        logger.info(f"Epoch {epoch+1}/{epochs}, Loss: {avg_train_loss:.4f}, Accuracy: {train_accuracy:.4f}")

        # Compute validation loss after each epoch
        avg_val_loss, val_accuracy = validate_nli(model, val_dataloader, criterion, device, epoch, epochs)


def validate_nli(model, val_dataloader, criterion, device, epoch, epochs):
    model.eval()
    total_val_loss = 0.0
    correct = 0
    total = 0

    # Track progress in the validation loop using tqdm
    val_progress = tqdm(enumerate(val_dataloader), desc=f"Epoch {epoch + 1}/{epochs} [Val]", leave=False, total=len(val_dataloader))

    with torch.no_grad():
        for batch_idx, batch in val_progress:
            # Move batch to device
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)

            # Forward pass
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)

            # Calculate loss
            loss = criterion(outputs, labels)
            total_val_loss += loss.item()

            # Calculate accuracy
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    # Calculate average validation loss and accuracy
    avg_val_loss = total_val_loss / len(val_dataloader)
    val_accuracy = correct / total

    logger.info(f"Validation Loss: {avg_val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}")
    return avg_val_loss, val_accuracy

In [None]:
MODEL_NAME = 'intfloat/multilingual-e5-base'
model_name_slug = MODEL_NAME.replace('/', '_').replace('-', '_')

BATCH_SIZE = 32
LR = 2e-5
EPOCHS = 3

In [None]:
# Load and tokenize the train and validation dataset
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
train_tokenized_dataset = tokenize_dataset(nli_dataset['train'], tokenizer)
val_tokenized_dataset = tokenize_dataset(nli_dataset['validation'], tokenizer)

Map:   0%|          | 0/300067 [00:00<?, ? examples/s]

Map:   0%|          | 0/1999 [00:00<?, ? examples/s]

### Base model

In [None]:
log_file = f"./logs/hte_eval_{model_name_slug}_00_base_hebnli.log"
logger = setup_logger(log_file)

# device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")

# Define backbone model
model = AutoModel.from_pretrained(MODEL_NAME)
model = model.to(device)
logger.info(f"Start train base model: {MODEL_NAME}")

# Define NLI model
nli_model = CustomNLIModel(model)

# Create the train and validation DataLoader
train_dataloader = create_dataloader(train_tokenized_dataset, batch_size=BATCH_SIZE)
val_dataloader = create_dataloader(val_tokenized_dataset, batch_size=BATCH_SIZE)

In [None]:
train_nli(
    model=nli_model,
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    device=device,
    epochs=EPOCHS,
    lr=LR
)

### 1st pass (Wiki40b) trained model

In [None]:
log_file = f"./logs/hte_eval_{model_name_slug}_01_wiki40b_hebnli.log"
logger = setup_logger(log_file)

# device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")

# Define backbone model
model = AutoModel.from_pretrained(MODEL_NAME)
model = model.to(device)
logger.info(f"Start train base model: {MODEL_NAME}")

# Load model at checkpoint
optimizer = AdamW(model.parameters(), lr=LR)
epoch = load_checkpoint(model=model, optimizer=optimizer, checkpoint_dir='checkpoints/checkpoints_01_wiki40b', device=device)
logger.info(f"Loaded model from epoch {epoch}")

# Define NLI model
nli_model = CustomNLIModel(model)

# Create the train and validation DataLoader
train_dataloader = create_dataloader(train_tokenized_dataset, batch_size=BATCH_SIZE)
val_dataloader = create_dataloader(val_tokenized_dataset, batch_size=BATCH_SIZE)

2024-09-14 23:31:18,550 - default - INFO - Using device: cuda
2024-09-14 23:31:19,245 - default - INFO - Start train base model: intfloat/multilingual-e5-base
2024-09-14 23:31:19,249 - default - INFO - Loading checkpoint checkpoint_epoch_2.pth
  checkpoint = torch.load(checkpoint_path, map_location=device)
2024-09-14 23:31:23,325 - default - INFO - Loaded model from epoch 2


In [None]:
train_nli(
    model=nli_model,
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    device=device,
    epochs=EPOCHS,
    lr=LR
)

2024-09-15 01:26:32,504 - default - INFO - Epoch 1/3, Loss: 0.7150, Accuracy: 0.6667
2024-09-15 01:26:47,004 - default - INFO - Validation Loss: 0.6193, Validation Accuracy: 0.7394
2024-09-15 03:21:59,500 - default - INFO - Epoch 2/3, Loss: 0.5688, Accuracy: 1.0000
2024-09-15 03:22:14,044 - default - INFO - Validation Loss: 0.6165, Validation Accuracy: 0.7674
2024-09-15 05:17:30,848 - default - INFO - Epoch 3/3, Loss: 0.4737, Accuracy: 1.0000
2024-09-15 05:17:45,449 - default - INFO - Validation Loss: 0.6666, Validation Accuracy: 0.7559


### 2nd pass (Synthesized data) trained model

In [None]:
log_file = f"./logs/hte_eval_{model_name_slug}_02_synthesized_hebnli.log"
logger = setup_logger(log_file)

# device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")

# Define backbone model
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(MODEL_NAME)
model = model.to(device)
logger.info(f"Start train base model: {MODEL_NAME}")

# Load model at checkpoint
optimizer = AdamW(model.parameters(), lr=LR)
epoch = load_checkpoint(model=model, optimizer=optimizer, checkpoint_dir='checkpoints/checkpoints_02_synthesized', device=device)
logger.info(f"Loaded model from epoch {epoch}")

# Define NLI model
nli_model = CustomNLIModel(model)

# Create the train and validation DataLoader
train_dataloader = create_dataloader(train_tokenized_dataset, batch_size=BATCH_SIZE)
val_dataloader = create_dataloader(val_tokenized_dataset, batch_size=BATCH_SIZE)

2024-09-15 05:17:45,487 - default - INFO - Using device: cuda
2024-09-15 05:17:47,875 - default - INFO - Start train base model: intfloat/multilingual-e5-base
2024-09-15 05:17:49,426 - default - INFO - Loading checkpoint checkpoint_epoch_2.pth
  checkpoint = torch.load(checkpoint_path, map_location=device)
2024-09-15 05:18:34,025 - default - INFO - Loaded model from epoch 2


In [None]:
train_nli(
    model=nli_model,
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    device=device,
    epochs=EPOCHS,
    lr=LR
)

Epoch 1/3 [Train]:  78%|███████▊  | 14648/18755 [1:30:05<25:15,  2.71it/s, Batch=14648, Train Loss=0.741]

In [None]:
from google.colab import runtime
runtime.unassign()