In [1]:
from tqdm.notebook import tqdm
import pandas as pd

In [2]:
DEV = False
# model_name = 'dunzhang/stella_en_400M_v5'
model_name = "avsolatorio/GIST-Embedding-v0" # train when I've got a spare two hours

In [3]:
def import_labelled_data(path="data/labelled/data.json", group_relevant=True):
    data = pd.read_json(path, encoding="latin-1")
    data["relevance"] = data["class"].apply(
        lambda x: "relevant" if x != "irrelevant" else x
    )
    return data

data = import_labelled_data(path='../../data/labelled/data.json',group_relevant=False)

# drop null classes
data = data.dropna(subset=["class"])

if DEV:
    data = data.sample(5000)

# train test split
from sklearn.model_selection import train_test_split
train_data, test_data = train_test_split(data, test_size=0.2, random_state=42)

data.head()

Unnamed: 0,url,text,class,relevance
0,https://www.wetlands.org/wp-content/uploads/20...,\n \n \nFlamingo\nFlamingo\nFlamingo\nFlaming...,Birds,relevant
1,https://www.wetlands.org/publications/flamingo...,\n\n \n \n \n \n \n \nABOUT THE GROUP \n \nThe...,Birds,relevant
2,https://www.wetlands.org/publications/the-stat...,\n\n\n\n\n\n(FIRST PAGE) \n \n \n \n \nTHE STA...,Birds,relevant
3,https://www.sciencedirect.com/science/article/...,\nPlease contact us via our\nsupport center fo...,Mammals,relevant
4,https://www.wetlands.org/publications/strategi...,Strategies for wise use of Wetlands:\nBest Pra...,Wetlands,relevant


In [4]:
from chunking import chunk_dataset_and_explode

# roughly 4 characters per token
max_len = 2048

train_data = chunk_dataset_and_explode(train_data, max_len=max_len, overlap=int(max_len * 0.2))
test_data = chunk_dataset_and_explode(test_data, max_len=max_len, overlap=int(max_len * 0.2))

In [5]:
from datasets import Dataset

train_dataset = Dataset.from_pandas(train_data, split="train")
test_dataset = Dataset.from_pandas(test_data, split="test")

train_dataset

Dataset({
    features: ['chunk_id', 'url', 'text', 'class', 'relevance'],
    num_rows: 122195
})

In [6]:
# embeddings
from sentence_transformers import SentenceTransformer


model = SentenceTransformer(model_name,trust_remote_code=True,similarity_fn_name='dot')

model.cuda()
# model.cpu()

SentenceTransformer(
  (0): Transformer({'max_seq_length': 512, 'do_lower_case': True}) with Transformer model: BertModel 
  (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': True, 'pooling_mode_mean_tokens': False, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
  (2): Normalize()
)

In [7]:
model.device


device(type='cuda', index=0)

In [8]:
# clear cuda cache
import torch
torch.cuda.empty_cache()

In [9]:
train_embeddings = model.encode(
    train_dataset["text"],
    show_progress_bar=True,
    normalize_embeddings=True,
)

Batches:   0%|          | 0/3819 [00:00<?, ?it/s]

  attn_output = torch.nn.functional.scaled_dot_product_attention(


In [10]:
test_embeddings = model.encode(
    test_dataset["text"],
    show_progress_bar=True,
    normalize_embeddings=True,
)

Batches:   0%|          | 0/842 [00:00<?, ?it/s]

In [11]:
train_embeddings.shape

(122195, 768)

In [12]:
train_embeddings[0]

array([-3.09738684e-02, -5.73649332e-02,  1.82388201e-02, -3.32309529e-02,
        4.77531515e-02, -2.00871639e-02,  2.39493307e-02,  2.52473261e-02,
       -7.57582858e-02, -6.51207045e-02,  2.09446419e-02, -1.14720818e-02,
       -4.41362150e-02,  2.30766200e-02,  9.02883895e-03,  5.82449995e-02,
        3.54298092e-02,  1.09698288e-02, -1.88599620e-02,  1.50688067e-02,
        1.46394931e-02, -6.05719164e-03, -6.56022877e-03,  2.88828518e-02,
        4.75077294e-02, -2.93905623e-02,  2.93811113e-02,  2.09692419e-02,
       -5.65504804e-02,  2.66036410e-02,  2.96469592e-02, -2.05935221e-02,
       -5.85723855e-02,  4.49135900e-02,  3.42251100e-02, -2.81485058e-02,
       -2.74868850e-02, -3.43200676e-02, -2.10785698e-02,  1.02992086e-02,
       -2.56073344e-02,  2.80118058e-03, -3.74934673e-02, -1.70398075e-02,
       -4.25287001e-02,  4.05347385e-02, -1.88313238e-02,  3.40081044e-02,
       -1.64272671e-03, -2.94003934e-02, -9.01349708e-02, -3.67292203e-03,
        1.39281042e-02,  

In [13]:
# add to pandas dataframe
train_data["embeddings"] = train_embeddings.tolist()
test_data["embeddings"] = test_embeddings.tolist()

In [14]:
train_data.head()

Unnamed: 0,chunk_id,url,text,class,relevance,embeddings
0,4176,https://www.conservationevidence.com/individua...,Cease or prohibit shipping(Summarised by: Anaë...,Subtidal Benthic Invertebrate Conservation,relevant,"[-0.03097386844456196, -0.05736493319272995, 0..."
1,4176,https://www.conservationevidence.com/individua...,"ed, and their biomass converted to energy valu...",Subtidal Benthic Invertebrate Conservation,relevant,"[-0.023276297375559807, -0.050277337431907654,..."
2,8508,https://budget.finance.go.ug/sites/default/fil...,LG Draft Budget Estimates 2024/25 VOTE: 718 Lu...,irrelevant,irrelevant,"[0.0002742694632615894, 0.009270773269236088, ..."
3,8508,https://budget.finance.go.ug/sites/default/fil...,"ies/Fees 1,100,213 1,034,087 Registration fees...",irrelevant,irrelevant,"[0.011356459930539131, -0.01759961247444153, -..."
4,8508,https://budget.finance.go.ug/sites/default/fil...,"0 0 0 0 Tourism Development 2,000 0 0 0 2,000 ...",irrelevant,irrelevant,"[0.014761500991880894, -0.024308089166879654, ..."


In [15]:
import os

# Create the folder if it doesn't exist
if not os.path.exists(f"embeddings/{model_name}/dev"):
    os.makedirs(f"embeddings/{model_name}/dev")

if DEV:
    train_data.to_json(f"embeddings/{model_name}/dev/train_embeddings.json", orient='records', indent=4)
    test_data.to_json(f"embeddings/{model_name}/dev/test_embeddings.json", orient='records', indent=4)
else:
    train_data.to_json(f"embeddings/{model_name}/train_embeddings.json", orient='records')
    test_data.to_json(f"embeddings/{model_name}/test_embeddings.json", orient='records')

    # also save dev data
    devTrain = train_data.sample(10000)
    devTest = test_data.sample(int(10000*0.2))
    devTrain.to_json(f"embeddings/{model_name}/dev/train_embeddings.json", orient='records')
    devTest.to_json(f"embeddings/{model_name}/dev/test_embeddings.json", orient='records')



In [16]:
import torch
print(torch.cuda.is_available())
print(torch.cuda.current_device())
print(torch.cuda.get_device_name(0))
# get version of cuda
print(torch.version.cuda)
# clear cache
torch.cuda.empty_cache()

True
0
NVIDIA GeForce RTX 3050 6GB Laptop GPU
12.1
