In [1]:
import time
import socket
import shutil
import gc
from importlib import reload
from tqdm.auto import tqdm

from pathlib import Path

import torch
from torch.utils.data import DataLoader
from datasets import load_from_disk

import src.setup_embds_from_api
reload(src.setup_embds_from_api)
from src.setup_embds_from_api import path_to_galactica_folder, path_to_orig_data, path_to_raw_data, output_dir
from src.setup_embds_from_api import checkpoint_model
from src.setup_embds_from_api import device
from src.setup_embds_from_api import make_tensors, make_metadata

import openai

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
datasets = load_from_disk(str(path_to_raw_data))

In [3]:
#Define dirs
current_time = time.strftime("%b%d_%H-%M-%S")
host = socket.gethostname()
_output_dir = Path(path_to_galactica_folder,output_dir,current_time + '_' + host)
tensors_file = Path(_output_dir, 'tensors.tsv')
metadata_file = Path(_output_dir, 'metadata.tsv')

if not _output_dir.exists():
    _output_dir.mkdir(parents=True)

dataset = datasets['test'].select(range(2048))

# copy dataset
path_to_dataset = Path(_output_dir, 'dataset.json')
dataset.save_to_disk(str(path_to_dataset))

#copy metadata
if make_metadata:
    int2str = {v: k for k,v in (dataset.features['labels']._str2int).items()}
    metadata = dataset.map(lambda seq: {'id': str(seq['id']) }, batched=False)
    metadata = metadata.map(lambda seq: {'_labels': int2str[seq['labels']] }, batched=False)
    metadata = metadata.map(lambda seq: {'text': seq['text'].replace('\t', ' ') }, batched=False)
    metadata = metadata.map(lambda seq: {'text': seq['text'].replace('\n', ' ') }, batched=False)
    metadata.to_csv(str(metadata_file), sep='\t', index=False)

Creating CSV from Arrow format: 100%|██████████| 3/3 [00:00<00:00, 24.41ba/s]                  


In [4]:
# Tensors
if make_tensors:

    dataset = dataset.remove_columns(['id', 'title', 'labels'])

    batch_size = 32

    num_steps = len(dataset) // batch_size
    progress_bar = tqdm(range(num_steps))

    # Cleaning
    gc.collect()
    torch.cuda.empty_cache()

    tensors_f = open(tensors_file, 'w')

    for i in range(num_steps):

        print("Processing batch {}/{}".format(i, num_steps))

        batch = dataset[i*batch_size:(i+1)*batch_size]
        batch = [str(e) for e in batch['text']]
        embeddings = openai.Embedding.create(input=batch, model=checkpoint_model)['data']

        for i in range(len(embeddings)):
            embd = [str(e) for e in embeddings[i]['embedding']]
            embd = '\t'.join(embd)
            tensors_f.write(embd + '\n')

        progress_bar.update(1)

    tensors_f.close()

  0%|          | 0/64 [00:00<?, ?it/s]

Processing batch 0/64


  2%|▏         | 1/64 [00:01<01:54,  1.82s/it]

Processing batch 1/64


  3%|▎         | 2/64 [00:02<01:20,  1.30s/it]

Processing batch 2/64


  5%|▍         | 3/64 [00:04<01:18,  1.28s/it]

Processing batch 3/64


  6%|▋         | 4/64 [00:05<01:16,  1.28s/it]

Processing batch 4/64


  8%|▊         | 5/64 [00:06<01:14,  1.26s/it]

Processing batch 5/64


  9%|▉         | 6/64 [00:08<01:18,  1.36s/it]

Processing batch 6/64


 11%|█         | 7/64 [00:09<01:18,  1.38s/it]

Processing batch 7/64


 12%|█▎        | 8/64 [00:10<01:08,  1.23s/it]

Processing batch 8/64


 14%|█▍        | 9/64 [00:11<01:12,  1.32s/it]

Processing batch 9/64


 16%|█▌        | 10/64 [00:12<01:04,  1.19s/it]

Processing batch 10/64


 17%|█▋        | 11/64 [00:13<00:58,  1.11s/it]

Processing batch 11/64


 19%|█▉        | 12/64 [00:14<00:54,  1.05s/it]

Processing batch 12/64


 20%|██        | 13/64 [00:15<00:51,  1.02s/it]

Processing batch 13/64


 22%|██▏       | 14/64 [00:16<00:52,  1.05s/it]

Processing batch 14/64


 23%|██▎       | 15/64 [00:18<01:00,  1.24s/it]

Processing batch 15/64


 25%|██▌       | 16/64 [00:19<00:58,  1.22s/it]

Processing batch 16/64


 27%|██▋       | 17/64 [00:23<01:38,  2.09s/it]

Processing batch 17/64


 28%|██▊       | 18/64 [00:25<01:38,  2.13s/it]

Processing batch 18/64


 30%|██▉       | 19/64 [00:26<01:19,  1.77s/it]

Processing batch 19/64


 31%|███▏      | 20/64 [00:27<01:03,  1.43s/it]

Processing batch 20/64


 33%|███▎      | 21/64 [00:28<00:54,  1.27s/it]

Processing batch 21/64


 34%|███▍      | 22/64 [00:32<01:24,  2.00s/it]

Processing batch 22/64


 36%|███▌      | 23/64 [00:33<01:12,  1.76s/it]

Processing batch 23/64


 38%|███▊      | 24/64 [00:35<01:21,  2.03s/it]

Processing batch 24/64


 39%|███▉      | 25/64 [00:37<01:08,  1.76s/it]

Processing batch 25/64


 41%|████      | 26/64 [00:38<00:57,  1.52s/it]

Processing batch 26/64


 42%|████▏     | 27/64 [00:38<00:49,  1.33s/it]

Processing batch 27/64


 44%|████▍     | 28/64 [00:40<00:52,  1.45s/it]

Processing batch 28/64


 45%|████▌     | 29/64 [00:41<00:43,  1.25s/it]

Processing batch 29/64


 47%|████▋     | 30/64 [00:46<01:21,  2.38s/it]

Processing batch 30/64


 48%|████▊     | 31/64 [00:47<01:03,  1.93s/it]

Processing batch 31/64


 50%|█████     | 32/64 [00:49<00:59,  1.87s/it]

Processing batch 32/64


 52%|█████▏    | 33/64 [00:49<00:49,  1.58s/it]

Processing batch 33/64


 53%|█████▎    | 34/64 [00:50<00:42,  1.40s/it]

Processing batch 34/64


 55%|█████▍    | 35/64 [00:51<00:36,  1.26s/it]

Processing batch 35/64


 56%|█████▋    | 36/64 [00:53<00:35,  1.27s/it]

Processing batch 36/64


 58%|█████▊    | 37/64 [00:54<00:31,  1.16s/it]

Processing batch 37/64


 59%|█████▉    | 38/64 [00:55<00:28,  1.09s/it]

Processing batch 38/64


 61%|██████    | 39/64 [00:56<00:28,  1.16s/it]

Processing batch 39/64


 62%|██████▎   | 40/64 [00:57<00:27,  1.16s/it]

Processing batch 40/64


 64%|██████▍   | 41/64 [00:58<00:27,  1.18s/it]

Processing batch 41/64


 66%|██████▌   | 42/64 [01:00<00:27,  1.23s/it]

Processing batch 42/64


 67%|██████▋   | 43/64 [01:01<00:27,  1.33s/it]

Processing batch 43/64


 69%|██████▉   | 44/64 [01:02<00:25,  1.27s/it]

Processing batch 44/64


 70%|███████   | 45/64 [01:03<00:22,  1.17s/it]

Processing batch 45/64


 72%|███████▏  | 46/64 [01:04<00:19,  1.08s/it]

Processing batch 46/64


 73%|███████▎  | 47/64 [01:05<00:17,  1.02s/it]

Processing batch 47/64


 75%|███████▌  | 48/64 [01:06<00:15,  1.01it/s]

Processing batch 48/64


 77%|███████▋  | 49/64 [01:07<00:15,  1.04s/it]

Processing batch 49/64


 78%|███████▊  | 50/64 [01:08<00:13,  1.01it/s]

Processing batch 50/64


 80%|███████▉  | 51/64 [01:09<00:12,  1.05it/s]

Processing batch 51/64


 81%|████████▏ | 52/64 [01:10<00:11,  1.06it/s]

Processing batch 52/64


 83%|████████▎ | 53/64 [01:12<00:13,  1.27s/it]

Processing batch 53/64


 84%|████████▍ | 54/64 [01:13<00:11,  1.15s/it]

Processing batch 54/64


 86%|████████▌ | 55/64 [01:14<00:11,  1.25s/it]

Processing batch 55/64


 88%|████████▊ | 56/64 [01:16<00:11,  1.46s/it]

Processing batch 56/64


 89%|████████▉ | 57/64 [01:17<00:08,  1.28s/it]

Processing batch 57/64


 91%|█████████ | 58/64 [01:18<00:06,  1.12s/it]

Processing batch 58/64


 92%|█████████▏| 59/64 [01:18<00:04,  1.02it/s]

Processing batch 59/64


 94%|█████████▍| 60/64 [01:23<00:08,  2.11s/it]

Processing batch 60/64


 95%|█████████▌| 61/64 [01:24<00:05,  1.84s/it]

Processing batch 61/64


 97%|█████████▋| 62/64 [01:25<00:03,  1.56s/it]

Processing batch 62/64


 98%|█████████▊| 63/64 [01:26<00:01,  1.32s/it]

Processing batch 63/64


100%|██████████| 64/64 [01:27<00:00,  1.14s/it]