In [None]:
import time
import socket
import shutil
import gc
from importlib import reload
from tqdm.auto import tqdm

from pathlib import Path

import torch
from torch.utils.data import DataLoader
from datasets import load_from_disk
from transformers import OPTForCausalLM, AutoTokenizer, DataCollatorWithPadding

import src.setup_embds
reload(src.setup_embds)
from src.setup_embds import path_to_galactica_folder, path_to_orig_data, path_to_raw_data, path_to_tokenized_data, output_dir
from src.setup_embds import ModelClass, checkpoint
from src.setup_embds import device
from src.setup_embds import make_tensors, make_metadata

In [None]:
datasets = load_from_disk(str(path_to_tokenized_data))

model = ModelClass.from_pretrained(str(checkpoint), device_map="auto").base_model
model.to(device)

tokenizer = AutoTokenizer.from_pretrained(str(checkpoint))

In [None]:
#Define dirs
current_time = time.strftime("%b%d_%H-%M-%S")
host = socket.gethostname()
_output_dir = Path(path_to_galactica_folder,output_dir,current_time + '_' + host)
tensors_file = Path(_output_dir, 'tensors.tsv')
metadata_file = Path(_output_dir, 'metadata.tsv')

if not _output_dir.exists():
    _output_dir.mkdir(parents=True)

# copy checkpoint
checkpoint_dir = Path(_output_dir, 'checkpoint')
if checkpoint_dir.exists():
    shutil.rmtree(str(checkpoint_dir))
if Path(checkpoint).exists():
    shutil.copytree(str(checkpoint), str(checkpoint_dir))
else:
    model.save_pretrained(str(checkpoint_dir))
    tokenizer.save_pretrained(str(checkpoint_dir))

dataset = datasets['test'] #.select(range(64))

# copy dataset
path_to_dataset = Path(_output_dir, 'dataset.json')
dataset.save_to_disk(str(path_to_dataset))

#copy metadata
int2str = {v: k for k,v in (dataset.features['labels']._str2int).items()}
metadata = dataset.remove_columns(['input_ids', 'attention_mask', 'token_type_ids'])
metadata = metadata.map(lambda seq: {'id': str(seq['id']) }, batched=False)
metadata = metadata.map(lambda seq: {'_labels': int2str[seq['labels']] }, batched=False)
metadata = metadata.map(lambda seq: {'text': seq['text'].replace('\t', ' ') }, batched=False)
metadata = metadata.map(lambda seq: {'text': seq['text'].replace('\n', ' ') }, batched=False)
metadata.to_csv(str(metadata_file), sep='\t', index=False)

In [None]:
# Tensors
if make_tensors:

    dataset = dataset.remove_columns(['id', 'title', 'text', 'token_type_ids', 'labels'])
    dataset.set_format("torch")

    batch_size = 1
    data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
    dataloader = DataLoader(dataset, shuffle=False, batch_size=batch_size, collate_fn=data_collator)

    num_steps = len(dataset) // batch_size
    progress_bar = tqdm(range(num_steps))

    # Cleaning
    gc.collect()
    torch.cuda.empty_cache()

    tensors_f = open(tensors_file, 'w')

    for i, batch in enumerate(dataloader):

        _batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**_batch)
        #outputs = outputs['last_hidden_state'].mean(dim=1)
        outputs = outputs['last_hidden_state'][:,-1,:]

        for embd in outputs:
            embd = [str(e) for e in embd.tolist()]
            embd = '\t'.join(embd)
            tensors_f.write(embd + '\n')

    progress_bar.update(1)

    tensors_f.close()