In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [32]:
import sys
sys.path.append("../")
import os
import torch
import numpy as np
from tqdm import tqdm

import logging
from src.utils import logging_utils
from src.utils import env_utils, experiment_utils
from src import functional
import wandb

logger = logging.getLogger(__name__)

logging.basicConfig(
    level=logging.DEBUG,
    format=logging_utils.DEFAULT_FORMAT,
    datefmt=logging_utils.DEFAULT_DATEFMT,
    stream=sys.stdout,
)

logger.info(f"{torch.__version__=}, {torch.version.cuda=}")

import src.dataset_manager as dataset_manager

2024-10-27 19:56:42 __main__ INFO     torch.__version__='2.5.0+cu124', torch.version.cuda='12.4'


In [56]:
dataset_manager.DatasetManager.list_dataset_groups()

['relations', 'sst2', 'geometry_of_truth']

In [57]:
dataset_manager.DatasetManager.list_datasets_by_group()

{'geometry_of_truth': ['sp_en_trans',
  'neg_sp_en_trans',
  'cities',
  'neg_cities',
  'smaller_than',
  'larger_than',
  'common_claim_true_false',
  'companies_true_false',
  'counterfact_true_false'],
 'sst2': ['sst2'],
 'relations': ['commonsense/word_sentiment',
  'commonsense/fruit_outside_color',
  'commonsense/task_done_by_person',
  'commonsense/work_location',
  'commonsense/task_done_by_tool',
  'commonsense/substance_phase',
  'commonsense/object_superclass',
  'commonsense/fruit_inside_color',
  'factual/pokemon_evolutions',
  'factual/country_capital_city',
  'factual/person_plays_pro_sport',
  'factual/star_constellation',
  'factual/country_language',
  'factual/presidents_birth_year',
  'factual/landmark_on_continent',
  'factual/country_largest_city',
  'factual/company_hq',
  'factual/food_from_country',
  'factual/landmark_in_country',
  'factual/company_ceo',
  'factual/superhero_archnemesis',
  'factual/city_in_country',
  'factual/person_band_lead_singer',
  'f

In [58]:
custom_dataset = dataset_manager.DatasetManager.from_dataset_group(
    "geometry_of_truth",
    # "sst2",
    batch_size=4
)

In [59]:
len(custom_dataset)

11319

In [60]:
batch = next(iter(custom_dataset))
batch

[RawExample(feature='Eighty-eight is smaller than eighty-three.', label='0'),
 RawExample(feature='The headquarters of Bloc populaire is in Vancouver.', label='0'),
 RawExample(feature='The city of Mbuji-Mayi is in the Democratic Republic of the Congo.', label='1'),
 RawExample(feature='Craig Erickson plays as quarterback.', label='1')]

In [61]:
train, test = custom_dataset.split([0.7, 0.3])

In [62]:
print(len(train), len(test))

7923 3396


In [63]:
custom_dataset.examples[:10]

[RawExample(feature='Eighty-eight is smaller than eighty-three.', label='0'),
 RawExample(feature='The headquarters of Bloc populaire is in Vancouver.', label='0'),
 RawExample(feature='The city of Mbuji-Mayi is in the Democratic Republic of the Congo.', label='1'),
 RawExample(feature='Craig Erickson plays as quarterback.', label='1'),
 RawExample(feature='Lamplugh Island belongs to the continent of Antarctica.', label='1'),
 RawExample(feature='100 Questions was originally aired on CBS.', label='0'),
 RawExample(feature='Cats do not only like litter boxes, but also like getting in cardboard boxes or utility tubs.', label='1'),
 RawExample(feature='Joel Palmer lost their life at Dayton.', label='1'),
 RawExample(feature='In Haryana, they understand Hindi.', label='1'),
 RawExample(feature='The original language of Die Tageszeitung is German.', label='1')]

In [68]:
from src.dataset_manager import DatasetLoader, RawExample, RelationDatasetLoader
from src.utils.env_utils import DEFAULT_DATA_DIR
import os, json
from pathlib import Path
import random
    
loader = RelationDatasetLoader(group="relations", name="factual/country_capital_city")
loader.load()


2024-10-27 20:25:36 src.dataset_manager INFO     Loaded 48 examples from factual/country_capital_city.


[RawExample(feature='The capital of United States is Washington D.C..', label='1'),
 RawExample(feature='The capital of United States is Canberra.', label='0'),
 RawExample(feature='The capital city of Canada is Ottawa.', label='1'),
 RawExample(feature='The capital city of Canada is Riyadh.', label='0'),
 RawExample(feature='The capital of Mexico is Mexico City.', label='1'),
 RawExample(feature='The capital of Mexico is Tokyo.', label='0'),
 RawExample(feature='The capital city of Brazil is Bras\\u00edlia.', label='1'),
 RawExample(feature='The capital city of Brazil is Lima.', label='0'),
 RawExample(feature='The capital of Argentina is Buenos Aires.', label='1'),
 RawExample(feature='The capital of Argentina is Islamabad.', label='0'),
 RawExample(feature='The capital of Chile is Santiago.', label='1'),
 RawExample(feature='The capital of Chile is Canberra.', label='0'),
 RawExample(feature='The capital city of Peru is Lima.', label='1'),
 RawExample(feature='The capital city of Pe

In [69]:
from src.dataset_manager import GeometryOfTruthDatasetLoader

cities = GeometryOfTruthDatasetLoader(group="GoT", name="cities").load()

In [70]:
cities

[RawExample(feature='The city of Krasnodar is in Russia.', label='1'),
 RawExample(feature='The city of Krasnodar is in South Africa.', label='0'),
 RawExample(feature='The city of Lodz is in Poland.', label='1'),
 RawExample(feature='The city of Lodz is in the Dominican Republic.', label='0'),
 RawExample(feature='The city of Maracay is in Venezuela.', label='1'),
 RawExample(feature='The city of Maracay is in China.', label='0'),
 RawExample(feature='The city of Baku is in Azerbaijan.', label='1'),
 RawExample(feature='The city of Baku is in Ukraine.', label='0'),
 RawExample(feature='The city of Baoji is in China.', label='1'),
 RawExample(feature='The city of Baoji is in Guatemala.', label='0'),
 RawExample(feature='The city of Addis Ababa is in Ethiopia.', label='1'),
 RawExample(feature='The city of Addis Ababa is in Indonesia.', label='0'),
 RawExample(feature="The city of Abidjan is in Côte d'Ivoire.", label='1'),
 RawExample(feature='The city of Abidjan is in China.', label='0

In [74]:
relation_dataset = dataset_manager.DatasetManager.from_dataset_group(
    "relations",
    batch_size=32
)

2024-10-27 20:27:14 src.dataset_manager INFO     Loaded 120 examples from commonsense/word_sentiment.
2024-10-27 20:27:14 src.dataset_manager INFO     Loaded 60 examples from commonsense/fruit_outside_color.
2024-10-27 20:27:14 src.dataset_manager INFO     Loaded 64 examples from commonsense/task_done_by_person.
2024-10-27 20:27:14 src.dataset_manager INFO     Loaded 76 examples from commonsense/work_location.
2024-10-27 20:27:14 src.dataset_manager INFO     Loaded 104 examples from commonsense/task_done_by_tool.
2024-10-27 20:27:14 src.dataset_manager INFO     Loaded 100 examples from commonsense/substance_phase.
2024-10-27 20:27:14 src.dataset_manager INFO     Loaded 152 examples from commonsense/object_superclass.
2024-10-27 20:27:14 src.dataset_manager INFO     Loaded 72 examples from commonsense/fruit_inside_color.
2024-10-27 20:27:14 src.dataset_manager INFO     Loaded 88 examples from factual/pokemon_evolutions.
2024-10-27 20:27:14 src.dataset_manager INFO     Loaded 48 examples

In [75]:
len(relation_dataset)

680