### Fine-tuning from scratch

In [1]:
import logging

logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s -  %(message)s",
    level=logging.WARNING
)

logging.getLogger("haystack").setLevel(logging.INFO)

In [None]:
from haystack.nodes import DensePassageRetriever
from haystack.utils import fetch_archive_from_http
from haystack.document_stores import InMemoryDocumentStore

In [3]:
train_filename = "train.json"

query_model = "facebook/dpr-question_encoder-single-nq-base"
passage_model = "facebook/dpr-ctx_encoder-single-nq-base"

save_dir = "save_folder"

In [4]:
import torch
import torch.distributed as dist
from haystack.modeling.training import Trainer
from haystack.modeling.data_handler.processor import TextSimilarityProcessor
from haystack.modeling.data_handler.data_silo import DataSilo
from transformers import DPRQuestionEncoder, DPRContextEncoder, DPRQuestionEncoderTokenizer, DPRContextEncoderTokenizer
import os

rank = 0
world_size = 1
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
os.environ['NCCL_BUFFSIZE'] = '2097152'
dist.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank)


In [None]:
retriever = DensePassageRetriever(
    document_store=InMemoryDocumentStore(),
    use_gpu= True,
    query_embedding_model=query_model,
    passage_embedding_model=passage_model,
    max_seq_len_query=64,
    max_seq_len_passage=512,
)

In [None]:
retriever.train(
    data_dir = "",
    train_filename=train_filename,
    n_epochs=1,
    batch_size=12,
    n_gpu = 8,
    grad_acc_steps=8,
    save_dir=save_dir,
    checkpoint_every=2000,
    checkpoint_root_dir="checkpoints",
    num_positives=1
)

### Loading the fine-tuned model

In [None]:
reloaded_retriever = DensePassageRetriever.load(load_dir="model", document_store=None)