In [2]:
import pandas as pd
from transformers import DistilBertTokenizerFast
import webdataset as wds
import sys
from utils.utils import GEODataset
from sklearn.model_selection import train_test_split

In [3]:
df = pd.read_parquet("data/geo_data.parquet")

In [4]:
texts = df["text"].values.tolist()
labels = df[["lat",  "lon"]].astype(float).values.tolist()

In [5]:
train_ratio = 0.78
test_ratio = 0.17
validation_ratio = 0.5

# train is now 75% of the entire data set
# the _junk suffix means that we drop that variable completely
x_train, x_test, y_train, y_test = train_test_split(texts, labels, test_size=1 - train_ratio)

# test is now 10% of the initial data set
# validation is now 15% of the initial data set
x_test, x_val, y_test, y_val = train_test_split(x_test, y_test, test_size=test_ratio/(test_ratio + validation_ratio)) 


In [6]:
len(x_train)/len(texts), len(x_test)/len(texts), len(x_val)/len(texts)

(0.78, 0.16417903737304806, 0.05582096262695195)

In [7]:
len(x_train)/30

4857645

In [12]:
TOKEN_MODEL = 'distilbert-base-uncased'
tokenizer = DistilBertTokenizerFast.from_pretrained(TOKEN_MODEL)
tokenizer.model_max_length = 300

In [14]:
tok_train = tokenizer(x_train, truncation=True, padding=True)

In [15]:
tok_test = tokenizer(x_test, truncation=True, padding=True)

In [16]:
tok_val = tokenizer(x_val, truncation=True, padding=True)

In [17]:
labels_train = y_train
labels_test = y_test
labels_val = y_val

In [18]:
train_dataset = GEODataset(tok_train, labels_train)
test_dataset = GEODataset(tok_test, labels_test)
val_dataset = GEODataset(tok_val, labels_val)

In [19]:
dataset_map = {'train' : train_dataset, 'test': test_dataset, 'val': val_dataset}

In [20]:
for key in dataset_map.keys():
    dataset = dataset_map[key]
    sink = wds.TarWriter(f"data/{key}_geo_wds_wiki_names.tar")
    for index, enc in enumerate(dataset):
        if index%10000==0:
            print(f"{index:6d}", end="\r", flush=True, file=sys.stderr)
        sink.write({
            "__key__": "sample%06d" % index,
            "enc_dict.pyd": enc,
        })
    sink.close()

3400000

In [None]:
!ls -l data/geo_wds.tar
!tar tvf data/geo_wds.tar | head

In [None]:
webds = wds.WebDataset('data/train_geo_wds.tar').decode('torch')

In [None]:
import torch

In [None]:
dataloader = torch.utils.data.DataLoader(webds, num_workers=1, batch_size=16)

In [None]:
for index, batch in enumerate(dataloader):
    if index%10000==0:
        print(batch['enc_dict.pyd']['labels'])