In [1]:
%load_ext autoreload
%autoreload 2

In [3]:
import os, time, json
import pandas as pd
from openai import OpenAI
from tqdm.auto import tqdm
import spacy

import sys
sys.path.append("../../")
import os

import logging
from src.utils import logging_utils
from src.utils import env_utils
from src import functional

logger = logging.getLogger(__name__)

logging.basicConfig(
    level=logging.DEBUG,
    format=logging_utils.DEFAULT_FORMAT,
    datefmt=logging_utils.DEFAULT_DATEFMT,
    stream=sys.stdout,
)

import torch
import transformers

logger.info(f"{torch.__version__=}, {torch.version.cuda=}")
logger.info(f"{torch.cuda.is_available()=}, {torch.cuda.device_count()=}, {torch.cuda.get_device_name()=}")
logger.info(f"{transformers.__version__=}")

2024-07-26 15:38:01 __main__ INFO     torch.__version__='2.3.1', torch.version.cuda='12.1'
2024-07-26 15:38:01 __main__ INFO     torch.cuda.is_available()=True, torch.cuda.device_count()=1, torch.cuda.get_device_name()='NVIDIA RTX A6000'
2024-07-26 15:38:01 __main__ INFO     transformers.__version__='4.42.4'


In [45]:
from dataclasses_json import DataClassJsonMixin
from dataclasses import dataclass, field, fields
from typing import Optional
import random

@dataclass(frozen=False)
class BridgeSample(DataClassJsonMixin):
    bridge: str
    entity_pair: list[str]
    description: Optional[str] = None

    def __post_init__(self):
        assert len(self.entity_pair) == 2, f"entity_pair must have length 2, got {len(self.entity)} - {self.entity}"
    
    def __str__(self):
        return self.description if self.description is not None else f"{self.bridge} is a common link between {self.entity[0]} and {self.entity[1]}." 

    
@dataclass(frozen=False)
class BridgeRelation(DataClassJsonMixin):
    name: str
    answer_template: str
    swappable: bool
    examples: list[BridgeSample] = field(default_factory=list)

    def __post_init__(self):
        assert "<bridge>" in self.answer_template
        assert "<entity1>" in self.answer_template
        assert "<entity2>" in self.answer_template
        for example in self.examples:
            example.description = self.answer_template.replace("<bridge>", example.bridge).replace("<entity1>", example.entity_pair[0]).replace("<entity2>", example.entity_pair[1])
        logger.info(f"initialized bridge relation {self.name} with {len(self.examples)} examples")
    
    def __len__(self):
        return len(self.examples)
    
    def __getitem__(self, idx):
        return self.examples[idx]

class BridgeDataset(DataClassJsonMixin):
    relations: list[BridgeRelation]
    examples: list[BridgeSample]

    def __init__(self, relations: list[BridgeRelation]):
        self.relations = relations
        self.examples = []
        for relation in relations:
            self.examples.extend(relation.examples)
        
        random.shuffle(self.examples)

        logger.info(f"initialized bridge dataset with {len(self.relations)} relations and {len(self)} examples")

    def __len__(self):
        return len(self.examples)
    
    def __getitem__(self, idx):
        return self.examples[idx]


def load_bridge_relation(file_name: str) -> BridgeRelation:
    with open(file_name, "r") as f:
        data = json.load(f)
    return BridgeRelation.from_dict(data)

def load_bridge_relations() -> list[BridgeRelation]:
    bridge_data_dir = os.path.join(env_utils.DEFAULT_DATA_DIR, "bridge_dataset", "cleaned")
    relations = []
    for file_name in os.listdir(bridge_data_dir):
        if file_name.endswith(".json"):
            relations.append(load_bridge_relation(os.path.join(bridge_data_dir, file_name)))
    return relations
        

In [46]:
relations = load_bridge_relations()

2024-07-26 16:55:39 __main__ INFO     initialized bridge relation superpower_characters with 27 examples
2024-07-26 16:55:39 __main__ INFO     initialized bridge relation sport_players with 23 examples
2024-07-26 16:55:39 __main__ INFO     initialized bridge relation movie_actor with 52 examples
2024-07-26 16:55:39 __main__ INFO     initialized bridge relation architect_building with 21 examples
