In [1]:
import argparse
import csv
import json
import math
import random
import re

import lexicon
import config
import utils
import torch

from collections import defaultdict
from semantic_memory import vsm, vsm_utils
from nltk.corpus import wordnet as wn
from minicons import scorer, cwe
from transformers import AutoTokenizer
from torch.utils.data import DataLoader
from tqdm import tqdm



In [2]:
def lemma2concept(entry):
    return lexicon.Concept(
        lemma=entry["lemma"],
        singular=entry["singular"],
        plural=entry["plural"],
        article=entry["article"],
        generic=entry["generic"],
        taxonomic_phrase=entry["taxonomic_phrase"],
    )

lemma_path = "../data/things/things-lemmas-annotated.csv"

# read in concepts
concepts = defaultdict(lexicon.Concept)
with open(lemma_path, "r") as f:
    reader = csv.DictReader(f)
    for row in reader:
        if row["remove"] != "1":
            concepts[row["lemma"]] = lemma2concept(row)

things_senses = set()
with open("../data/things/things-senses-annotated.csv", "r") as f:
    reader = csv.DictReader(f)
    for row in reader:
        if row["sense"] != "-":
            senses = row['sense'].split("&")
            for sense in senses:
                things_senses.add(sense)

triple_path = "../data/things/things-triples.csv"
triples = utils.read_csv_dict(triple_path)

anchors = set()
hyponyms = set()
anchor_children = defaultdict(set)
concept_universe = set()

for triple in triples:
    hypernym = triple["hypernym"]
    hyponym = triple["hyponym"]
    anchor = triple["anchor"]

    anchors.add(anchor)
    hyponyms.add(hyponym)
    anchor_children[anchor].add(hyponym)

In [3]:
# concept_words = 
prop = lexicon.Property("daxable", "is daxable", "are daxable")

# concepts['aardvark'].property_sentence(prop)

# concepts['aardvark']
concept_space = defaultdict(str)
for c, concept in concepts.items():
    if concept.generic == "s":
        concept_space[c] = re.split(r'^(a|an)', concept.article)[-1]
    else:
        concept_space[c] = concept.plural

concept_space = dict(concept_space)

# prepare batches

query_pairs = [(v,v) for v in concept_space.values()]

In [4]:
# lm = scorer.IncrementalLMScorer("mistralai/Mistral-7B-Instruct-v0.2", "cuda:0")
lm = cwe.CWE("mistralai/Mistral-7B-Instruct-v0.2", "cuda:0")
# embs = lm.model.model.embed_tokens.weight.detach().cpu()
# vocab = lm.tokenizer.convert_ids_to_tokens(range(len(embs)))

# comment this out if saving first
# embs, vocab = torch.load("../data/embeddings/Mistral-7B-Instruct-v0.2.pt")

# mistral_vsm = vsm.VectorSpaceModel("Mistral-Instruct-v2")
# mistral_vsm.load_vectors_from_tensor(embs, vocab)

# uncomment if saving
# torch.save((embs, vocab), "../data/embeddings/Mistral-7B-Instruct-v0.2.pt")

Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [5]:
# birds = lm.extract_representation([("bird", "bird")], layer = "all")
# len(birds)
# lm.encode_text(["bird"])

# embs = lm.extract_representation(query, layer = "all")

layerwise = defaultdict(list)

query_dl = DataLoader(query_pairs, batch_size = 16)

for batch in tqdm(query_dl):
    query = list(zip(*batch))

    embs = lm.extract_representation(query, layer = "all")
    for i, emb in enumerate(embs):
        layerwise[i].extend(emb)

layerwise = {k: torch.stack(v) for k,v in layerwise.items()}

100%|██████████| 84/84 [00:14<00:00,  5.90it/s]


In [6]:
layerwise_vsms = {k: vsm.VectorSpaceModel(f"Mistral-Instruct-v2-layer-{k}") for k,v in layerwise.items()}

for i, vsm in layerwise_vsms.items():
    vsm.load_vectors_from_tensor(layerwise[i], list(concept_space.keys()))

In [7]:
# layerwise_vsms
anchor_neighbors = defaultdict(list)
for anchor in anchors:
    space = hyponyms - anchor_children[anchor]
    space = [c for c in space if c in concepts.keys()]

    # sims = []
    # for i, vsm in layerwise_vsms.items():
    #     neighbors = 
    neighbors = layerwise_vsms[32].neighbor(anchor, k=len(space), space=space, names_only=True, ignore_first=False)
    anchor_neighbors[anchor].extend(neighbors[0][: math.ceil(len(anchor_children[anchor])/2)])

    non_neighbors = list(reversed(neighbors[0]))
    anchor_neighbors[anchor].extend(non_neighbors[: math.ceil(len(anchor_children[anchor])/2)])

anchor_neighbors = dict(anchor_neighbors)
random.seed(42)
anchor_neighbors = {k: random.sample(v, len(v)) for k, v in anchor_neighbors.items()}

In [8]:
negative_sample_triples = []
for anchor, negative_samples in anchor_neighbors.items():
    for ns in negative_samples:
        negative_sample_triples.append((anchor, ns))

In [9]:
with open("../data/things/things-mistralai_Mistral-7B-Instruct-v0.2_layer32_ns-triples.csv", "w") as f:
    writer = csv.writer(f)
    writer.writerow(["anchor", "hyponym"])
    for triple in negative_sample_triples:
        writer.writerow(triple)

In [10]:
# save anchor, every concept similarities
anchor_concept_sims = defaultdict(list)
for anchor in anchors:
    space = hyponyms
    space = [c for c in space if c in concepts.keys()]

    neighbor_sims = layerwise_vsms[32].neighbor(anchor, k = len(space), space=space, ignore_first=False)
    for concept, sim in neighbor_sims[0]:
        anchor_concept_sims[(anchor, concept)].append(sim)

anchor_concept_sims_csv = []
for (anchor, concept), sims in anchor_concept_sims.items():
    anchor_concept_sims_csv.append((anchor, concept, max(sims)))

In [11]:
len(anchor_concept_sims_csv)

56672

In [12]:
with open("../data/things/things-mistralai_Mistral-7B-Instruct-v0.2_layer32_anchor_sims.csv", "w") as f:
    writer = csv.writer(f)
    writer.writerow(["anchor", "concept", "similarity"])
    for row in anchor_concept_sims_csv:
        writer.writerow(row)