In [2]:
import pandas as pd
from transformers import DistilBertTokenizerFast
import webdataset as wds
import sys
from utils.utils import GEODataset, StreamTokenizedDataset
from sklearn.model_selection import train_test_split
import torch
MAX_SEQ_LENGTH = 200

In [22]:
df = pd.read_csv("data/wiki_exploded.gz").dropna()

In [23]:
texts = df["text"].values.tolist()
labels = df[["lat",  "lon"]].astype(float).values.tolist()

In [24]:
train_ratio = 0.78
test_ratio = 0.17
validation_ratio = 0.5

x_train, x_test, y_train, y_test = train_test_split(texts, labels, test_size=1 - train_ratio)

x_test, x_val, y_test, y_val = train_test_split(x_test, y_test, test_size=test_ratio/(test_ratio + validation_ratio)) 


In [25]:
len(x_train)/len(texts), len(x_test)/len(texts), len(x_val)/len(texts)

(0.7799999384400924, 0.16417907640302723, 0.0558209851568803)

In [26]:
len(x_train)/100

91228.2

In [None]:
TOKEN_MODEL = 'distilbert-base-uncased'
tokenizer = DistilBertTokenizerFast.from_pretrained(TOKEN_MODEL)
# tokenizer.model_max_length = 200

In [None]:
TEXTBATCHES = 10000
train_dataset = StreamTokenizedDataset(x_train, y_train, tokenizer, TEXTBATCHES, MAX_SEQ_LENGTH)
test_dataset = StreamTokenizedDataset(x_test, y_test, tokenizer, TEXTBATCHES, MAX_SEQ_LENGTH)
val_dataset = StreamTokenizedDataset(x_val, y_val, tokenizer, TEXTBATCHES, MAX_SEQ_LENGTH)

In [None]:
dataset_map = {'train' : train_dataset, 'test': test_dataset, 'val': val_dataset}

In [32]:
for key in ['train', 'test', 'val']:
    dataset = dataset_map[key]
    sink = wds.TarWriter(f"data/{key}_wiki_exploded.tar")
    for index, enc in enumerate(dataset):
        if index%10000==0:
            print(f"{index:6d}", end="\r", flush=True, file=sys.stderr)
        sink.write({
            "__key__": "sample%06d" % index,
            "data.pyd": enc,
        })
    sink.close()

5320000

KeyboardInterrupt: 

In [None]:
!ls -l data/geo_wds.tar
!tar tvf data/geo_wds.tar | head

In [None]:
webds = wds.WebDataset('data/train_geo_wds.tar').decode('torch')

In [None]:
import torch

In [None]:
dataloader = torch.utils.data.DataLoader(webds, num_workers=1, batch_size=16)

In [None]:
for index, batch in enumerate(dataloader):
    if index%10000==0:
        print(batch['enc_dict.pyd']['labels'])