In [1]:
import os
import time
from collections import defaultdict

import numpy as np
import torch
import torch.cuda
import torch.distributed as dist

from src import dist_utils, slurm, util
from src.index_io import load_or_initialize_index, save_embeddings_and_index
from src.model_io import create_checkpoint_directories, load_or_initialize_atlas_model
from src.options import get_options
from src.tasks import get_task
from src.index import DistributedFAISSIndex, DistributedIndex

os.environ["TOKENIZERS_PARALLELISM"] = "true"


In [2]:
options = get_options()

opt = options.parse([
    '--name', 'eval-test',
    '--generation_max_length', '16',
    '--gold_score_mode', 'pdist',
    '--precision', 'fp32',
    '--reader_model_type', 'google/t5-base-lm-adapt',
    '--text_maxlength', '128',
    '--model_path', '/atlas-data/models/atlas_nq/base/',
    '--per_gpu_batch_size', '2',
    '--n_context', '5',
    '--retriever_n_context', '5',
    '--checkpoint_dir', '/atlas-data/checkpoints',
    '--main_port', '12345',
    '--index_mode', 'flat',
    '--task', 'qa',
    '--write_results',
])

In [3]:
def qa_data_iterator(questions): 
    for question in questions:
        yield question

@torch.no_grad()
def evaluate(model, index, opt, questions, step=None):
    model.eval()
    metrics = defaultdict(lambda: [])
    dataset_wpred = []
    unwrapped_model = util.get_unwrapped_model_if_wrapped(model)
    reader_tokenizer = unwrapped_model.reader_tokenizer

    task = get_task(opt, reader_tokenizer)
    data_iterator = qa_data_iterator(questions)
    data_iterator = filter(None, map(task.process, data_iterator))
    data_iterator = list(task.batch_iterator(data_iterator, opt.per_gpu_batch_size))

    for i, batch in enumerate(data_iterator):
        query = batch.get("query", [""])
        answers = batch.get("target", [""])
        batch_metadata = batch.get("metadata")
        target_tokens = batch.get("target_tokens")

        query_enc, labels, decoder_input_ids = unwrapped_model.tokenize(query, answers, target_tokens=target_tokens)

        query_ids_retriever = query_enc["input_ids"].cuda()
        query_mask_retriever = query_enc["attention_mask"].cuda()
        retrieved_passages, _ = unwrapped_model.retrieve(
            index,
            opt.n_context,
            query,
            query_ids_retriever,
            query_mask_retriever,
            batch_metadata=batch_metadata,
            filtering_fun=task.filter,
        )

        reader_tokens, _ = unwrapped_model.tokenize_passages(query, retrieved_passages)


        if "eval_loss" in task.metrics:
            eval_loss, logits = unwrapped_model.compute_reader_loss_and_logits(reader_tokens, decoder_input_ids, labels)
            metrics["eval_loss"].append(eval_loss)

        generation = unwrapped_model.generate(
            reader_tokens, query, choices=batch["choices"] if "choices" in batch else None
        )

        for k, g in enumerate(generation):
            if opt.decoder_prompt_format is not None:
                query_ids = reader_tokenizer.encode(
                    opt.decoder_prompt_format.format_map({"query": query[k]}), add_special_tokens=False
                )
                g = g[len(query_ids) + 1 :]
            pred = reader_tokenizer.decode(g, skip_special_tokens=True)
            gold = [answers[k]] if not "answers" in batch else batch["answers"][k]
            sample_metrics = task.evaluation(pred, gold)
            for key, value in sample_metrics.items():
                metrics[key].append(value)

            ex = {"query": query[k], "answers": gold, "generation": pred}
            if not opt.dont_write_passages:
                ex["passages"] = retrieved_passages[k]
            if batch_metadata is not None:
                ex["metadata"] = batch_metadata[k]
            if opt.task == "multiple_choice":
                ex["choice_logits"] = task.get_choice_logits(logits[k])
            if "id" in batch:
                ex["id"] = batch["id"][k]
            dataset_wpred.append(ex)

    metrics, dataset_wpred = task.evaluation_postprocessing(metrics, dataset_wpred)
    metrics = util.avg_dist_dict(task.metrics, metrics)
    metrics = {key: value if key == "eval_loss" else 100 * value for key, value in metrics.items()}

    print('results', dataset_wpred)

    return metrics


In [4]:
passegesEmpty = [{
    'title': '',
    'text': ''
}] * 5

passagesCorrect = [{
    'title': 'My favourite number',
    'text': 'My favourite number is 3455'
}, {
    'title': 'The secret word',
    'text': 'The secret word is FROG'
}] * 5

passagesIncorrect = [{
    'title': 'My favourite number',
    'text': 'My favourite number is 290'
}, {
    'title': 'The secret word',
    'text': 'The secret word is CAR'
}] * 5

questions = [{
    'question': 'What is my favourite number?',
    'target': '3455',  
}, {
    'question': 'What is the secret word?',
    'target': 'FROG',
}]


logger = util.init_logger()
torch.manual_seed(opt.seed)

opt.device = 'cuda'
opt.global_rank = 0
opt.world_size = 1
opt.is_distributed = False
opt.is_main = True

index = DistributedIndex()
passages = passagesCorrect

index.init_embeddings(passages)
model, _, _, _, _, opt, step = load_or_initialize_atlas_model(opt, eval_only=True)

model.build_index(index, passages, opt.per_gpu_embedder_batch_size, logger)
metrics = evaluate(model, index, opt, questions, step)
print(metrics['eval_loss'])




[11/30/2022 12:48:10] {model_io.py:130} INFO - Loading /atlas-data/models/atlas_nq/base
[11/30/2022 12:48:10] {model_io.py:131} INFO - loading checkpoint /atlas-data/models/atlas_nq/base/model.pth.tar


Some weights of the model checkpoint at facebook/contriever were not used when initializing Contriever: ['pooler.dense.bias', 'pooler.dense.weight']
- This IS expected if you are initializing Contriever from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Contriever from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


[11/30/2022 12:48:16] {atlas.py:53} INFO - Atlas Init
[11/30/2022 12:48:16] {model_io.py:194} INFO - Model loaded from /atlas-data/models/atlas_nq/base/
[11/30/2022 12:48:16] {atlas.py:84} INFO - 10 passages encoded on process: 0
hey torch.Size([2, 640]) False
True torch.Size([10, 128, 768]) torch.Size([2, 640, 768])
hey torch.Size([2, 512]) True
decoder tensor([[    0, 32099,  6154,  ...,     0,     0,     0],
        [    0, 32099,     3,  ...,     0,     0,     0]], device='cuda:0')
hey torch.Size([2, 640]) False
True torch.Size([10, 128, 768]) torch.Size([2, 640, 768])
hey torch.Size([2, 1]) True
decoder tensor([[0],
        [0]], device='cuda:0')
hey torch.Size([2, 1]) True
decoder tensor([[32099],
        [32099]], device='cuda:0')
hey torch.Size([2, 1]) True
decoder tensor([[6154],
        [   3]], device='cuda:0')
hey torch.Size([2, 1]) True
decoder tensor([[3769],
        [7422]], device='cuda:0')
hey torch.Size([2, 1]) True
decoder tensor([[    1],
        [15927]], device='c

In [7]:
                # last_hidden_state = output.last_hidden_state
                # output.last_hidden_state = last_hidden_state.view(self.config.bsz, -1, last_hidden_state.size(-1))
                # print(return_dict, last_hidden_state.size(), output.last_hidden_state.size())
# True torch.Size([10, 128, 768]) torch.Size([2, 640, 768])


#                                    (batch_size * num_passages, seq_len, hidden_size)
# output.last_hidden_state starts as [10, 128, 768]
# then we go ahead and *unstack* it.
# gets converted to [2, 640, 768]
#             (batch_size, seq_len * num_passages, hidden_size)
a = torch.rand(3,4,5)
a

tensor([[[0.7026, 0.0875, 0.1418, 0.8254, 0.6902],
         [0.3329, 0.5169, 0.5192, 0.7989, 0.4210],
         [0.8796, 0.5450, 0.0854, 0.7404, 0.3570],
         [0.8371, 0.2315, 0.3826, 0.0085, 0.6328]],

        [[0.8045, 0.1771, 0.8830, 0.7865, 0.6945],
         [0.7434, 0.3773, 0.0839, 0.5978, 0.9541],
         [0.3975, 0.5133, 0.0864, 0.3684, 0.9206],
         [0.7616, 0.2692, 0.6412, 0.1603, 0.8845]],

        [[0.4705, 0.6577, 0.2136, 0.7203, 0.5471],
         [0.5106, 0.9789, 0.9993, 0.2279, 0.9662],
         [0.8593, 0.8149, 0.6101, 0.6662, 0.4376],
         [0.6804, 0.7272, 0.1327, 0.2583, 0.0055]]])

In [6]:










#  (batch_size, seq_len * num_passages) to (batch_size * num_passages, seq_len)
# a: torch.Size([2, 640]) -> b torch.Size([10, 128])
# a: torch.Size([2, 15]) -> b torch.Size([10, 3])

# random "a"
a = torch.rand(2, 15)
print(a.size())
b = a.view(a.size(0) * 5, -1)
# b = a.view(bs * num_passages, -1)
print(b.size())
print(a)
print(b)

torch.Size([2, 15])
torch.Size([10, 3])
tensor([[0.7006, 0.2428, 0.0318, 0.5096, 0.2905, 0.6130, 0.6844, 0.3915, 0.3460,
         0.1144, 0.5269, 0.6296, 0.4549, 0.4846, 0.6101],
        [0.0440, 0.7096, 0.7897, 0.5727, 0.0889, 0.1923, 0.4750, 0.7143, 0.3561,
         0.0122, 0.7269, 0.5587, 0.8398, 0.2020, 0.4995]])
tensor([[0.7006, 0.2428, 0.0318],
        [0.5096, 0.2905, 0.6130],
        [0.6844, 0.3915, 0.3460],
        [0.1144, 0.5269, 0.6296],
        [0.4549, 0.4846, 0.6101],
        [0.0440, 0.7096, 0.7897],
        [0.5727, 0.0889, 0.1923],
        [0.4750, 0.7143, 0.3561],
        [0.0122, 0.7269, 0.5587],
        [0.8398, 0.2020, 0.4995]])
