# Analysing text transformer dataset 

In [None]:
%load_ext autoreload

## Librairies

In [None]:
import sys, os
sys.path.append("..")
import pandas as pd 
from researchpkg.industry_classification.dataset.sec_datamodule import DatasetType


## Loading dataset

In [None]:
from researchpkg.industry_classification.config import ROOT_DIR
DATASET_DIR =os.path.join(ROOT_DIR,"data/sec_data_v2/count30_sic1agg_including_is_2023")

In [None]:
MODEL  = "meta-llama/Llama-2-7b-hf"
TOKEN  ="" #Put your hugingface token here

In [None]:
import huggingface_hub
huggingface_hub.login(token=TOKEN)
#Create tokenizer form bert
from transformers import  AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(MODEL)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"

## 1. Descriptive template

In [None]:
%autoreload
from researchpkg.industry_classification.dataset.sec_transformer_datamodule import SecTextTransformerDataset, TextTransformerTemplateType
desc_dataset  = SecTextTransformerDataset(
            dataset_dir= DATASET_DIR,
            type=DatasetType.TEST,
            tokenizer=tokenizer,
            sic_digits=1,
            seq_max_length=1750,
            balance_sampling=False,
            load_in_memory=False,
            template_type=TextTransformerTemplateType.RAW  ,
            max_tag_depth=5
        )

desc_dataset_relative  = SecTextTransformerDataset(
            dataset_dir= DATASET_DIR,
            type=DatasetType.TEST,
            tokenizer=tokenizer,
            sic_digits=1,
            seq_max_length=1750,
            balance_sampling=False,
            load_in_memory=False,
            template_type=TextTransformerTemplateType.DESCRIPTIVE  ,
            max_tag_depth=5
        )

### Distribution of dataset size (number of tags)

In [None]:
import tqdm
import matplotlib.pyplot as plt
dataset_size = [desc_dataset[i]["n_tags"] for i in tqdm.tqdm(range(len(desc_dataset)))]


In [None]:
import seaborn as sns
sns.histplot(dataset_size,bins=50, kde=True)

### Displaying single samples

In [None]:
i_small_sample = 1500 # I is chosen with the code below a sample having not too many tags.
import random
# while True:
#     i = random.choice(list(range(len(desc_dataset))))
#     sample = desc_dataset.__getitem__(i,verbose=False)
#     if sample["length"]>00:
#         continue
#     i_small_sample = i
#     break

sample = desc_dataset.__getitem__(i_small_sample,verbose=True)
print("file ", desc_dataset.data_files[i_small_sample])


In [None]:
df=sample["df"]
df["amount"]= df["net_change"]
df[["tag","amount"]]

In [None]:
# Same sample using relative template("Descriptive")
sample = sample = desc_dataset_relative.__getitem__(i_small_sample,verbose=True)

### Taxonomy tree tags

In [None]:
taxonomy_tree = desc_dataset.bs_taxonomy_tree.root_trees[1]
# nodes= [node for node in taxonomy_tree]
# nodes = sorted(nodes, key=lambda x: x.number)
# for node in nodes:
#     print(node.concept_name, end=", ")
taxonomy_tree.display()

In [None]:
income_statement_calculation_tree = desc_dataset.is_taxonomy_tree
income_statement_calculation_tree.max

In [None]:
from researchpkg.industry_classification.preprocessing.gaap_taxonomy_parser import (
    CalculationTree,
    CalculationTreeType,
)


from researchpkg.industry_classification.config import (
  
    SEC_TAX_VERSION,
    SEC_TAX,
    SEC_TAX_DATA_DIR,
    SEC_TAX_MAX_TAGS_DEPTH,
    SEC_TAX_MIN_TAGS_DEPTH,
)

# income Statement tag calculation tree
income_statement_calculation_tree = CalculationTree.build_taxonomy_tree(
    SEC_TAX_DATA_DIR,
    SEC_TAX,
    SEC_TAX_VERSION,
    type = CalculationTreeType.INCOME_STATEMENT
)

In [None]:
income_statement_calculation_tree.display()

## 2. Comparative template

In [None]:
DATASET_DIR =os.path.join(ROOT_DIR,"data/sec_data2/count30_sic1agg_including_is")
comp_dataset  = SecTextTransformerDataset(
            dataset_dir= DATASET_DIR,
            type=DatasetType.VAL,
            tokenizer=tokenizer,
            sic_digits=1,
            seq_max_length=1536,
            balance_sampling=False,
            max_tag_depth=10,
            max_comparative_pair_depth_gap=2,
            template_type=TextTransformerTemplateType.COMPARATIVE
        )

In [None]:
print(comp_dataset.bs_taxonomy_tree.are_in_same_branch("CashAndCashEquivalentsAtCarryingValue","AssetsCurrent"))

In [None]:
import random
i = random.choice(list(range(len(comp_dataset))))
# i = 76387
print("Sample",i,"\n----------------")
sample = comp_dataset.__getitem__(i,verbose=True)

In [None]:
%timeit
for i in range(1):
    print("Sample",i,"\n----------------")
    sample = comp_dataset.__getitem__(i,verbose=True)

In [None]:
desc_dataset.data_files[sample['sample_idx']]

# 3. Dataset statistics

## 3.1. Train dataset target distribution

In [None]:
import tqdm
from collections import Counter
import numpy as np
train_dataset  = SecTextTransformerDataset(
            dataset_dir= DATASET_DIR,
            type=DatasetType.TRAIN,
            tokenizer=tokenizer,
            sic_digits=1,
            seq_max_length=1750,
            balance_sampling=False,
            load_in_memory=False,
            template_type=TextTransformerTemplateType.RAW  ,
            max_tag_depth=None
        )


from researchpkg.industry_classification.utils.sics_loader import load_sic_codes
def get_file_label(filepath):
    df = pd.read_csv(filepath, nrows=1, usecols=["cik"])
    cik = df["cik"][0]
    sic = train_dataset.registrants_index_dict[cik]
    target = train_dataset.sic_id_index[sic]
    return target



target_list = [
    get_file_label(train_dataset.data_files[i])
    for i in tqdm.tqdm(range(train_dataset.__len__()), desc="Target list")
]

sic_code_df = load_sic_codes()[["sic", "industry_title"]]
sic_to_text = sic_code_df.set_index("sic").to_dict()["industry_title"]

accounts_index = train_dataset.accounts_index
sic_id_index = train_dataset.sic_id_index
sic_reverse_index = {v: k for k, v in train_dataset.sic_id_index.items()}

#Labels
labels = [sic_to_text[sic_reverse_index[t]] for t in sorted(target_list)]
labels_count = Counter(labels)

import seaborn as sns
import matplotlib.pyplot as plt
plt.figure(figsize=(12, 6))
sns.barplot(x=list(labels_count.keys()), y=list(labels_count.values()))

## 3.2. Test dataset : target distribution

In [None]:
import tqdm
from collections import Counter
import numpy as np
test_dataset  = SecTextTransformerDataset(
            dataset_dir= DATASET_DIR,
            type=DatasetType.TEST,
            tokenizer=tokenizer,
            sic_digits=1,
            seq_max_length=1750,
            balance_sampling=False,
            load_in_memory=False,
            template_type=TextTransformerTemplateType.RAW  ,
            max_tag_depth=None
        )


from pattern_recognition.industry_classification.utils.sics_loader import load_sic_codes
def get_file_label(filepath):
    df = pd.read_csv(filepath, nrows=1, usecols=["cik"])
    cik = df["cik"][0]
    sic = test_dataset.registrants_index_dict[cik]
    target = test_dataset.sic_id_index[sic]
    return target



target_list = [
    get_file_label(test_dataset.data_files[i])
    for i in tqdm.tqdm(range(test_dataset.__len__()), desc="Target list")
]

sic_code_df = load_sic_codes()[["sic", "industry_title"]]
sic_to_text = sic_code_df.set_index("sic").to_dict()["industry_title"]

accounts_index = test_dataset.accounts_index
sic_id_index = test_dataset.sic_id_index
sic_reverse_index = {v: k for k, v in test_dataset.sic_id_index.items()}

#Labels
labels = [sic_to_text[sic_reverse_index[t]] for t in sorted(target_list)]
labels_count = Counter(labels)

import seaborn as sns
import matplotlib.pyplot as plt
plt.figure(figsize=(12, 6))
sns.barplot(x=list(labels_count.keys()), y=list(labels_count.values()))







# Text transformer Dataset length distributions

In [None]:

%load_ext autoreload
%autoreload

import os
from researchpkg.industry_classification.config import ROOT_DIR
from researchpkg.industry_classification.dataset.sec_transformer_datamodule import SecTextTransformerDataset, TextTransformerTemplateType
from researchpkg.industry_classification.dataset.sec_datamodule import DatasetType
DATASET_DIR =os.path.join(ROOT_DIR,"data/sec_data_v2/count30_sic1agg_including_is_2023")
from transformers import  AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("TechxGenus/gemma-2b-GPTQ")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
dataset_train  = SecTextTransformerDataset(
            dataset_dir= DATASET_DIR,
            type=DatasetType.TRAIN,
            tokenizer=tokenizer,
            sic_digits=1,
            seq_max_length=2000,
            min_tag_depth=None,
            max_tag_depth=5,
            balance_sampling=False,
            load_in_memory=False,
            template_type=TextTransformerTemplateType.COMPARATIVE,
        )

import seaborn as sns
# seq_length_list = [x["length"].item() for x in dataset_train]

#Histogram of sequence length
# sns.histplot(seq_length_list, bins=100, kde=True)

In [None]:
out=dataset_train.__getitem__(5,verbose=1)


# SFT Dataset : complete instruction length distribution

In [None]:

labels_subtext = f"""- Mining
    - Construction
    - Manufacturing
    - Transportation & Public Utilities
    - Wholesale Trade
    - Retail Trade
    - Finance
    - Services
    """
def partial_instruction_formatter(prompt) -> str:

    
    return (
        "<start_of_turn>user You are asked to predict the industry sector "
        "of a company based on its balance sheet and income statement.\n"
        "The value of the accounts are normalized by the total assets and given in percentage of totals assets.\n "
        "Given the provided informations about the balance sheet and income statement, "
        "you should predict the most probable industry sector of the "
        "related company.\n"
        "You should answer on a single line with only the name of the predicted "
        "industry sector and  nothing else.\n"
        "Here are the possible industry sectors: \n\n"
        f"{labels_subtext}\n"
        "You must strictly respect the spelling of the predicted industry sector.\n"
        "\n<end_of_turn>\n"
        "<start_of_turn> user \n"
        f"{prompt}\n<end_of_turn>\n"
        "<start_of_turn>model \n"
        "Based on the information provided, the most probable industry sector of the company is: \n"
    )

def instruction_formatter(prompt, label) -> str:

    return (
        partial_instruction_formatter(prompt)+f"{label} <end_of_turn> \n"
    )


In [None]:
%load_ext autoreload
%autoreload

import os
from researchpkg.industry_classification.config import ROOT_DIR
from researchpkg.industry_classification.dataset.sec_transformer_datamodule import TextTransformerTemplateType
from researchpkg.industry_classification.dataset.sec_transformer_sft_dataset import SecTextTransformerDatasetSFT
from researchpkg.industry_classification.dataset.sec_datamodule import DatasetType
DATASET_DIR =os.path.join(ROOT_DIR,"data/sec_data_v2/count30_sic1agg_including_is_2023")
from transformers import  AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("TechxGenus/gemma-2b-GPTQ")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
dataset_train  = SecTextTransformerDatasetSFT(
            dataset_dir= DATASET_DIR,
            type=DatasetType.TRAIN,
            tokenizer=tokenizer,
            sic_digits=1,
            seq_max_length=2000,
            max_tag_depth=5,
            balance_sampling=False,
            load_in_memory=False,
            instruction_formatter=instruction_formatter,
            partial_instruction_formatter=partial_instruction_formatter,
            template_type=TextTransformerTemplateType.DESCRIPTIVE,
            bread_first_tree_exploration=False
        )

import seaborn as sns
seq_length_list = [x["length"].item() for x in dataset_train]

#Histogram of sequence length
sns.histplot(seq_length_list, bins=100, kde=True)

In [None]:
print(dataset_train[8]["complete_instruction"])


In [None]:
dataset_train.bs_taxonomy_tree.get_node_by_concept_name("AssetsCurrent").number

## Descriptive Relative template

In [None]:
labels_subtext = f"""
- Mining
- Construction
- Manufacturing
- Transportation & Public Utilities
- Wholesale Trade
- Retail Trade
- Finance
- Services
"""
def partial_instruction_formatter(prompt) -> str:

    return (
        "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n"
        "You are asked to predict the industry sector "
        "of a company based on its balance sheet and income statement.\n"
        "The value of the accounts are normalized by the total assets and given in percentage of totals assets.\n "
        "Given the provided informations about the balance sheet and income statement, "
        "you should predict the most probable industry sector of the "
        "related company.\n"
        "You should answer on a single line with the name of the predicted "
        "industry sector and \n"
        "Here are the possible industry sectors: \n"
        f"{labels_subtext}\n"
        "\n\n You must strictly respect the spelling of the predicted industry sector.\n"
        "<|eot_id|>"
        f"<|start_header_id|>user<|end_header_id|> {prompt} <|eot_id|>\n"
        "<|start_header_id|>assistant<|end_header_id|> Based on the information provided, the most probable industry sector of the company is: \n"
    )
def instruction_formatter(prompt, label) -> str:

    return (
        partial_instruction_formatter(prompt)+f"{label} <end_of_turn> \n"
    )


In [None]:
%load_ext autoreload
%autoreload

import os
from researchpkg.industry_classification.config import ROOT_DIR
from researchpkg.industry_classification.dataset.sec_transformer_sft_dataset import TextTransformerTemplateType
from researchpkg.industry_classification.dataset.sec_transformer_sft_dataset import SecTextTransformerDatasetSFT
from researchpkg.industry_classification.dataset.sec_datamodule import DatasetType
DATASET_DIR =os.path.join(ROOT_DIR,"data/sec_data_v2/count30_sic1agg_including_is_2023")
from transformers import  AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("hugging-quants/Meta-Llama-3.1-8B-Instruct-GPTQ-INT4")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
dataset_train  = SecTextTransformerDatasetSFT(
            dataset_dir= DATASET_DIR,
            type=DatasetType.TRAIN,
            tokenizer=tokenizer,
            sic_digits=1,
            seq_max_length=2000,
            max_tag_depth=5,
            balance_sampling=False,
            load_in_memory=False,
            instruction_formatter=instruction_formatter,
            partial_instruction_formatter=partial_instruction_formatter,
            template_type=TextTransformerTemplateType.RAW,
            bread_first_tree_exploration=False
        )

In [None]:
print(dataset_train[10]["complete_instruction"])

In [None]:
import seaborn as sns
seq_length_list = [x["length"].item() for x in dataset_train]

#Histogram of sequence length
sns.histplot(seq_length_list, bins=100, kde=True)

# Dataset NO CHANGES

In [None]:
labels_subtext = f"""
- Mining
- Construction
- Manufacturing
- Transportation & Public Utilities
- Wholesale Trade
- Retail Trade
- Finance
- Services
"""

def partial_instruction_formatter(prompt) -> str:
    return (
        "<start_of_turn>user You are asked to predict the industry sector "
        "of a company based on its balance sheet and income statement.\n"
        "You are given the list of all accounts name in the balance sheet and income statement.\n"
        "Based on that list you should indicate the most probable industry sector of the "
        "related company.\n"
        "Here are the possible industry sectors: \n\n"
        f"{labels_subtext}\n"
        "You must strictly respect the spelling of the predicted industry sector.\n"
        "\n<end_of_turn>\n"
        "<start_of_turn> user \n"
        f"{prompt}\n<end_of_turn>\n"
        "<start_of_turn>model \n"
        "Based on the information provided, the most probable industry sector of the company is: \n"
    )

def instruction_formatter(prompt, label) -> str:

    return partial_instruction_formatter(prompt) + f"{label} <end_of_turn> \n"

In [None]:
%load_ext autoreload
%autoreload

import os
from researchpkg.industry_classification.config import ROOT_DIR
from researchpkg.industry_classification.dataset.sec_transformer_sft_dataset import TextTransformerTemplateType
from researchpkg.industry_classification.dataset.sec_transformer_sft_dataset import SecTextTransformerDatasetSFT
from researchpkg.industry_classification.dataset.sec_datamodule import DatasetType
DATASET_DIR =os.path.join(ROOT_DIR,"data/sec_data_v2/count30_sic1agg_including_is_2023")
from transformers import  AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("TechxGenus/gemma-2b-GPTQ")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
dataset_train  = SecTextTransformerDatasetSFT(
            dataset_dir= DATASET_DIR,
            type=DatasetType.TRAIN,
            tokenizer=tokenizer,
            sic_digits=1,
            seq_max_length=2000,
            max_tag_depth=None,
            balance_sampling=False,
            load_in_memory=False,
            instruction_formatter=instruction_formatter,
            partial_instruction_formatter=partial_instruction_formatter,
            template_type=TextTransformerTemplateType.NO_CHANGE,
            bread_first_tree_exploration=False
        )



In [None]:
print(dataset_train[10]["complete_instruction"])

In [None]:
import seaborn as sns
seq_length_list = [x["length"].item() for x in dataset_train]

#Histogram of sequence length
sns.histplot(seq_length_list, bins=100, kde=True)

## Dataset with explanation prompt

In [None]:
labels_subtext = f"""
- Mining
- Construction
- Manufacturing
- Transportation & Public Utilities
- Wholesale Trade
- Retail Trade
- Finance
- Services
"""
def partial_instruction_formatter(prompt) -> str:

    return (
        "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n"
        "You are asked to predict the industry sector "
        "of a company based on its balance sheet and income statement.\n"
        "The value of the accounts are normalized by the total assets and given in percentage of totals assets.\n "
        "Given the provided informations about the balance sheet and income statement, "
        "you should predict the most probable industry sector of the "
        "related company.\n"
        "You should answer on a single line with the name of the predicted "
        "industry sector and \n"
        "Here are the possible industry sectors: \n"
        f"{labels_subtext}\n"
        "\n\n You must strictly respect the spelling of the predicted industry sector.\n"
        "<|eot_id|>"
        f"<|start_header_id|>user<|end_header_id|> {prompt} <|eot_id|>\n"
        "<|start_header_id|>assistant<|end_header_id|> Based on the information provided, the most probable industry sector of the company is: \n"
    )
def instruction_formatter(prompt, label) -> str:

    return (
        partial_instruction_formatter(prompt)+f"{label} <end_of_turn> \n"
    )


In [None]:
%load_ext autoreload
%autoreload

import os
from researchpkg.industry_classification.config import ROOT_DIR
from researchpkg.industry_classification.dataset.sec_transformer_sft_dataset import TextTransformerTemplateType
from researchpkg.industry_classification.dataset.sec_transformer_sft_dataset import SecTextTransformerDatasetSFT
from researchpkg.industry_classification.dataset.sec_datamodule import DatasetType
DATASET_DIR =os.path.join(ROOT_DIR,"data/sec_data_v2/count30_sic1agg_including_is_2023")
from transformers import  AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("hugging-quants/Meta-Llama-3.1-8B-Instruct-GPTQ-INT4")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
dataset_train  = SecTextTransformerDatasetSFT(
            dataset_dir= DATASET_DIR,
            type=DatasetType.TEST,
            tokenizer=tokenizer,
            sic_digits=1,
            seq_max_length=2000,
            max_tag_depth=5,
            balance_sampling=False,
            load_in_memory=False,
            instruction_formatter=instruction_formatter,
            partial_instruction_formatter=partial_instruction_formatter,
            template_type=TextTransformerTemplateType.DESCRIPTIVE,
            bread_first_tree_exploration=False)

In [None]:
explanation_instruction = "<|start_header_id|>user<|end_header_id|> Please provide a justification of  your answer. <|eot_id|>\n"
explanation_instruction += "<|start_header_id|>assistant<|end_header_id|>"

dataset_exp = dataset_train.get_sft_dataset_with_explanation_prompt(
    explanation_prompt=explanation_instruction,
    y_pred_list=["Wholesale Trade,Finance"]*len(dataset_train)
)

In [None]:
len(tokenizer.encode(dataset_exp["text_with_explanation_prompt"][0]))