In [13]:
import pandas as pd
import numpy as np
import scanpy as sc
from sklearn.linear_model import LinearRegression
from collections import Counter
from scipy import stats
from tqdm import tqdm
import seaborn as sns
import matplotlib.pyplot as plt
from collections import defaultdict
import math
from pathlib import Path
import os
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
# import torch
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.utils.class_weight import compute_class_weight
# from torch.utils.data import TensorDataset, DataLoader
# import torch.nn as nn
# import torch.nn.functional as F
from sklearn.preprocessing import label_binarize
from scipy.stats import pearsonr
import matplotlib.pyplot as plt
from sklearn.preprocessing import label_binarize
from sklearn.metrics import r2_score
import csv

In [14]:
## optional
## delete celltypes with cell counts lower than certain amount
## specifing an lower bond is a must
def filter_low_counts(celltype_df, age_df, celltype_col, threshold):
    print("Checking low count cell types...")
    
    celltype_count = Counter(celltype_df[celltype_col])
    for key in celltype_count:
        if threshold == None:
            unique_ages = np.unique(age_df)
            num_groups = (len(unique_ages) + 1) * 100
            if celltype_count[key] < num_groups:
                print(key, " has too low counts")
                celltype_df = celltype_df[celltype_df[celltype_col] != key]
        else:
            if celltype_count[key] < threshold:
                print(key, " has too low counts")
                celltype_df = celltype_df[celltype_df[celltype_col] != key]
    return celltype_df

def get_skewed_count_info(adata, class_col, age_col, age_threshold):
    print("Checking skewed count cell types...")
    
    # Compute the fraction of cells for each age group within each cell ontology class
    group_counts = adata.obs.groupby([class_col, age_col]).size()
    total_counts = adata.obs.groupby([class_col]).size()
    
    # Calculate the fraction of each age group within each class
    class_age_fraction = group_counts / total_counts
    
    # Find the cell classes to filter out based on age distribution
    classes_to_filter = class_age_fraction[class_age_fraction > age_threshold].index.get_level_values(0).unique()
    
    return classes_to_filter

## Read h5ad file 
## and do cell type filtering based on age distribution and cell count thresholds.
def read_and_filter_h5ad(filepath_1, filepath_2 = None, class_col="celltype", age_col="age", age_threshold=0.8, count_threshold=None):
    """Parameters:
    adata: AnnData object
        The Scanpy AnnData object containing single-cell data.
    class_col: str, optional (default: 'celltype')
        The column name in adata.obs representing the cell ontology class.
    age_col: str, optional (default: 'age')
        The column name in adata.obs representing the age of the cells.
    age_threshold: float, optional (default: 0.8)
        The threshold fraction for filtering based on age distribution. If one age group has more than this
        fraction of cells in a class, the class will be filtered out.
    count_threshold: list, optional (default: [100])
        Threshold for filtering cell types based on count. If a single value is provided,
        it filters out cell types with counts lower than this value. If a range is provided,
        it filters out cell types outside this range.
    
    Returns:
    filtered_adata: AnnData object
        The filtered AnnData object with specified cell ontology classes removed based on both criteria."""
    try:
        adata1 = sc.read_h5ad(filepath_1)
        if filepath_2 != None:
            adata2 = sc.read_h5ad(filepath_2)
            adata1 = adata1.concatenate(adata2)
        adata = adata1
        
        celltype_df = adata.obs[[class_col]].copy()
        age_df = adata.obs[[age_col]].copy()
        
        # Apply the cell count threshold filtering
        celltype_df = filter_low_counts(celltype_df, age_df, class_col, count_threshold)
    
        # Create a filtered AnnData object based on cell count filtering
        filtered_adata = adata[celltype_df.index].copy()
        
        # Identify the skewed classes to filter based on age distribution
        classes_to_filter = get_skewed_count_info(filtered_adata, class_col, age_col, age_threshold)
        
        if len(classes_to_filter):
            print(classes_to_filter[0], " has skewed cell counts")
        # Further filter the AnnData object based on age distribution
        final_filtered_adata = filtered_adata[~filtered_adata.obs[class_col].isin(classes_to_filter)].copy()
        
        return final_filtered_adata
    except Exception as e:
        raise(e)

In [15]:
# current_dir = Path.cwd()
# print(f"Current working directory: {current_dir}")
current_dir = Path.cwd()
file1 = current_dir / "tabula-muris-senis-facs-processed-official-annotations-Brain_Myeloid.h5ad"
file2 = current_dir / "tabula-muris-senis-facs-processed-official-annotations-Brain_Non-Myeloid.h5ad"
print(f"File 1 path: {file1}")
# assert file1.is_file(), f"File not found: {file1}"
adata = read_and_filter_h5ad(str(file1), str(file2), "cell_ontology_class", "age")

File 1 path: /home/hang/SC_Ageing_Prediction/tabula-muris-senis-facs-processed-official-annotations-Brain_Myeloid.h5ad



This is where adjacency matrices should go now.
  warn(

This is where adjacency matrices should go now.
  warn(

This is where adjacency matrices should go now.
  warn(

This is where adjacency matrices should go now.
  warn(
  adata1 = adata1.concatenate(adata2)


Checking low count cell types...
macrophage  has too low counts
CD8-positive, alpha-beta T cell  has too low counts
ependymal cell  has too low counts
interneuron  has too low counts
oligodendrocyte precursor cell  has too low counts
Bergmann glial cell  has too low counts
neuroepithelial cell  has too low counts
T cell  has too low counts
neuronal stem cell  has too low counts
mature NK T cell  has too low counts
medium spiny neuron  has too low counts
Checking skewed count cell types...


  group_counts = adata.obs.groupby([class_col, age_col]).size()
  total_counts = adata.obs.groupby([class_col]).size()


In [16]:
adata

AnnData object with n_obs × n_vars = 19154 × 22966
    obs: 'FACS.selection', 'age', 'cell', 'cell_ontology_class', 'cell_ontology_id', 'free_annotation', 'method', 'mouse.id', 'sex', 'subtissue', 'tissue', 'n_genes', 'n_counts', 'louvain', 'leiden', 'batch'
    var: 'n_cells', 'means-0', 'dispersions-0', 'dispersions_norm-0', 'highly_variable-0', 'means-1', 'dispersions-1', 'dispersions_norm-1', 'highly_variable-1'
    obsm: 'X_pca', 'X_tsne', 'X_umap'

In [17]:
#print out first 5 rows of the data
print(adata.obs["n_genes"].head())

index
A10_B001060_B009250_S214.mm10-plus-1-0-0    1505
A10_B001061_B009251_S298.mm10-plus-1-0-0    2384
A10_B002503_B009456_S10.mm10-plus-1-0-0      599
A10_B002702_B009296_S154.mm10-plus-1-0-0    1547
A10_D045853_B009304_S106.mm10-plus-1-0-0     931
Name: n_genes, dtype: int64


In [18]:
# X = adata.X.toarray() if not isinstance(adata.X, np.ndarray) else adata.X

# # Round values to nearest int and ensure non-negative
# X = np.round(X).astype(int)
# X[X < 0] = 0  # just in case

# # Get gene names
# genes = np.array(adata.var_names)

# # Prepare the list of gene strings
# gene_strings = []
# for row in X:
#     repeated_genes = np.repeat(genes, row)
#     gene_string = " ".join(repeated_genes)
#     gene_strings.append(gene_string)

# # Get relevant metadata
# df_meta = adata.obs[['age', 'sex', 'cell_ontology_class', 'tissue']].copy()
# df_meta.columns = ['age', 'gender', 'cell_ontology_class', 'tissue']

# # Create final DataFrame
# df_final = pd.DataFrame({
#     'genes': gene_strings,
#     'age': df_meta['age'].values,
#     'gender': df_meta['gender'].values,
#     'cell_ontology_class': df_meta['cell_ontology_class'].values,
#     'tissue': df_meta['tissue'].values
# })

# # Set index as 0 to n-1
# df_final.index = range(df_final.shape[0])

# # Optional: Save to CSV
# df_final.to_csv("processed_cells.csv", index_label="index")

# # Return final DataFrame
# df_final



## too memory intensive 
## write to file directly

In [19]:
# # Prepare output file
# output_file = "processed_cells_streamed.csv"

# # Get gene names
# genes = np.array(adata.var_names)

# # Get expression matrix (dense row-by-row)
# X = adata.X

# # Prepare metadata
# ages = adata.obs["age"].values
# genders = adata.obs["sex"].values
# classes = adata.obs["cell_ontology_class"].values
# tissues = adata.obs["tissue"].values

# # Open file and write line-by-line
# with open(output_file, mode='w', newline='', encoding='utf-8') as f:
#     writer = csv.writer(f)
#     # Write header
#     writer.writerow(["index", "genes", "age", "gender", "cell_ontology_class", "tissue"])

#     for i in range(adata.n_obs):
#         # Get row i as dense array
#         row = X[i].toarray().flatten() if not isinstance(X, np.ndarray) else X[i]
        
#         # Convert to int and clip negatives
#         row = np.maximum(np.round(row).astype(int), 0)

#         # Efficiently repeat gene names
#         repeated_genes = np.repeat(genes, row)
#         gene_string = " ".join(repeated_genes)

#         # Write row to file
#         writer.writerow([
#             i,
#             gene_string,
#             ages[i],
#             genders[i],
#             classes[i],
#             tissues[i]
#         ])
# f.close()

In [20]:
from collections import Counter
import numpy as np
import json, math, os, pathlib
from sklearn.model_selection import StratifiedKFold

# ─── Pull data from AnnData ─────────────────────────────────────────────────────
genes   = np.asarray(adata.var_names)
X       = adata.X
ages    = adata.obs["age"].values
genders = adata.obs["sex"].values
classes = adata.obs["cell_ontology_class"].values
tissues = adata.obs["tissue"].values

print(Counter(ages))

Counter({'3m': 7394, '18m': 6928, '24m': 4832})


In [21]:
# # ─── Split config ───────────────────────────────────────────────────────────────
# N_SPLITS   = 11
# SPLIT_SIZE = math.ceil(adata.n_obs / N_SPLITS)
# OUT_DIR    = "fine_tune_chunks"
# os.makedirs(OUT_DIR, exist_ok=True)

# INSTRUCTION = "Predict the age of a single cell from gene expression and metadata."

# def bag_of_words(counts_row: np.ndarray, gene_names: np.ndarray) -> str:
#     """Convert a vector of counts to space-separated tokens."""
#     counts_row = np.maximum(np.round(counts_row).astype(int), 0)
#     return " ".join(np.repeat(gene_names, counts_row))

# # ─── Build chunk files in **Alpaca** format ─────────────────────────────────────

# skf = StratifiedKFold(n_splits=N_SPLITS, shuffle=True, random_state=42)
# for split_idx, (_, split_indices) in enumerate(skf.split(np.zeros(len(ages)), ages)):
#     records = []
#     for i in split_indices:
#         row = X[i].toarray().ravel() if not isinstance(X, np.ndarray) else X[i]
#         cell_input = (
#             f"Genes: {bag_of_words(row, genes)}\n"
#             f"Gender: {genders[i]}\n"
#             f"Class: {classes[i]}\n"
#             f"Tissue: {tissues[i]}"
#         )

#         records.append({
#             "instruction": INSTRUCTION,
#             "input": cell_input,
#             "output": str(ages[i])
#         })
#         break
# #     fname = os.path.join(OUT_DIR, f"cell_data_part_{split_idx+1}.json")
# #     with open(fname, "w", encoding="utf-8") as f:
# #         json.dump(records, f, indent=2, ensure_ascii=False)

# #     print(f"Wrote {len(records):>5} samples → {fname}")


In [22]:
import inflect

p = inflect.engine()

In [23]:

##################################the following code will transform the gene symbols to a more readable format###################


###################but I just realized that I am stupid and I can just add all gene symbols to the tokenizer and it will work fine######################
















# from sklearn.model_selection import StratifiedKFold
# import inflect

# p = inflect.engine()
# global unseen_cnt
# global seen_cnt
# unseen_cnt = 0
# seen_cnt = 0
# # ─── Load gene symbol ➜ biotype_ID mapping ──────────────────────────────────
# GENEINFO_PATH = "geneInfo.tab"  # adjust if the file lives elsewhere

# gene_mapping: dict[str, str] = {}
# with open(GENEINFO_PATH, "r", encoding="utf-8") as fh:
#     for raw_line in fh:
#         line = raw_line.strip()
#         # The first line of geneInfo.tab is just the row‑count (e.g. "33696") – skip it
#         if not line or line.isdigit():
#             continue
#         try:
#             gene_id, symbol, biotype = line.split("\t")
#         except ValueError:  # line didn’t have three columns – ignore it
#             continue
#         # Keep the last 6 digits of the Ensembl ID and prepend the gene biotype
#         #   ENSMUSG00000051951  →  "051951"
#         numeric_id = gene_id[-6:]
#         #gene_mapping[symbol] = f"{biotype} {numeric_id}"
#         gene_mapping[symbol] = f"{numeric_id}"  # use only the numeric ID for simplicity


# def translate_gene(symbol: str) -> str:
#     """Return the canonical training token for a gene symbol.

#     Example
#     -------
#     >>> translate_gene("Xkr4")
#     'protein_coding_051951'
#     """
#     new_name = gene_mapping.get(symbol, symbol)  # fall back to the symbol if unseen
#     if new_name == symbol:
#         global unseen_cnt
#         # expand the unseen gene name into gene mapping
#         # new_gene = f"gene {str(unseen_cnt)}"
#         new_gene = f"{str(unseen_cnt)}"
#         new_name = new_gene
#         gene_mapping[symbol] = new_gene
#         unseen_cnt += 1

#     return new_name


# # ─── Prompt details ──────────────────────────────────────────────────────────
# INSTRUCTION = "Predict the age of a single cell from gene expression and metadata."


# def bag_of_words(counts_row: np.ndarray, gene_names: np.ndarray) -> str:
#     """Convert a vector of counts to "gene english(count)" tokens separated by space."""
#     counts_row = np.maximum(np.round(counts_row).astype(int), 0)
#     tokens = [
#         f"{translate_gene(gene)} {p.number_to_words(count)}"
#         for gene, count in zip(gene_names, counts_row)
#         if count > 0
#     ]
#     return " ".join(tokens)

# # ─── Build chunk files in **Alpaca** format ───────────────────────────────────
# N_SPLITS = 11
# SPLIT_SIZE = math.ceil(adata.n_obs / N_SPLITS)
# OUT_DIR = "fine_tune_chunks"
# os.makedirs(OUT_DIR, exist_ok=True)

# lengths = []
# skf = StratifiedKFold(n_splits=N_SPLITS, shuffle=True, random_state=42)
# for split_idx, (_, split_indices) in enumerate(skf.split(np.zeros(len(ages)), ages)):
#     records = []
#     for i in split_indices:
#         row = X[i].toarray().ravel() if not isinstance(X, np.ndarray) else X[i]
#         bow = bag_of_words(row, genes)
#         lengths.append(len(bow))
#         cell_input = (
#             f"Genes: {bow}\n"
#             f"Gender: {genders[i]}\n"
#             f"Class: {classes[i]}\n"
#             f"Tissue: {tissues[i]}"
#         )
        
#         records.append({
#             "instruction": INSTRUCTION,
#             "input": cell_input,
#             "output": str(ages[i])
#         })
#         # print("unseen:", unseen_cnt, "seen:", seen_cnt)

#     fname = os.path.join(OUT_DIR, f"cell_data_part_noname_{split_idx + 1}.json")
#     with open(fname, "w", encoding="utf-8") as f:
#         json.dump(records, f, indent=2, ensure_ascii=False)

#     print(f"Wrote {len(records):>5} samples → {fname}")


In [24]:
import os, json, math, numpy as np
from sklearn.model_selection import StratifiedKFold
import inflect
p = inflect.engine()

N_SPLITS   = 11
SPLIT_SIZE = math.ceil(adata.n_obs / N_SPLITS)
OUT_DIR    = "fine_tune_chunks"
os.makedirs(OUT_DIR, exist_ok=True)

INSTRUCTION = "Predict the age of a single cell from gene expression and metadata."

# ─── NEW ────
gene_vocab: dict[str, int] = {}          # {gene_name: token_id}

def bag_of_words(counts_row: np.ndarray,
                 gene_names: np.ndarray,
                 vocab: dict[str, int] = gene_vocab) -> str:
    """
    Convert a vector of counts to 'gene english(count)' tokens separated by space
    and populate `vocab` with any previously unseen gene names.
    """
    counts_row = np.maximum(np.round(counts_row).astype(int), 0)

    tokens = []
    for gene, count in zip(gene_names, counts_row):
        if count == 0:
            continue

        # keep track of first appearance
        if gene not in vocab:
            vocab[gene] = len(vocab)      # next free index

        tokens.append(f"{gene} {p.number_to_words(count)}")

    return " ".join(tokens)
# ──────────────────────────────────────────────────────────────

lengths = []
skf = StratifiedKFold(n_splits=N_SPLITS, shuffle=True, random_state=42)

for split_idx, (_, split_indices) in enumerate(skf.split(np.zeros(len(ages)), ages)):
    records = []

    for i in split_indices:
        row = X[i].toarray().ravel() if not isinstance(X, np.ndarray) else X[i]

        bow = bag_of_words(row, genes)        # <- now also updates vocab
        lengths.append(len(bow))

        cell_input = (
            f"Genes: {bow}\n"
            f"Gender: {genders[i]}\n"
            f"Class: {classes[i]}\n"
            f"Tissue: {tissues[i]}"
        )
        records.append({
            "instruction": INSTRUCTION,
            "input": cell_input,
            "output": str(ages[i])
        })

    fname = os.path.join(OUT_DIR, f"cell_data_part_{split_idx+1}.json")
    with open(fname, "w", encoding="utf-8") as f:
        json.dump(records, f, indent=2, ensure_ascii=False)

    print(f"Wrote {len(records):>5} samples → {fname}")
    # break

# ─── Optionally save the new-token list for later use ─────────
with open(os.path.join(OUT_DIR, "gene_tokens.json"), "w") as f:
    json.dump(list(gene_vocab.keys()), f, indent=2)

print(f"Discovered {len(gene_vocab)} unique genes.")


Wrote  1742 samples → fine_tune_chunks/cell_data_part_1.json
Wrote  1742 samples → fine_tune_chunks/cell_data_part_2.json
Wrote  1742 samples → fine_tune_chunks/cell_data_part_3.json
Wrote  1741 samples → fine_tune_chunks/cell_data_part_4.json
Wrote  1741 samples → fine_tune_chunks/cell_data_part_5.json
Wrote  1741 samples → fine_tune_chunks/cell_data_part_6.json
Wrote  1741 samples → fine_tune_chunks/cell_data_part_7.json
Wrote  1741 samples → fine_tune_chunks/cell_data_part_8.json
Wrote  1741 samples → fine_tune_chunks/cell_data_part_9.json
Wrote  1741 samples → fine_tune_chunks/cell_data_part_10.json
Wrote  1741 samples → fine_tune_chunks/cell_data_part_11.json
Discovered 22727 unique genes.


In [25]:
max(lengths), min(lengths)

(70201, 681)

In [26]:
import shutil, os; shutil.rmtree(os.path.expanduser("~/.hf_cache/qwen3_14b"), ignore_errors=True)
############### use this to clear the cache if needed ###############

import json, glob, itertools
from __future__ import annotations
import os
from pathlib import Path
from huggingface_hub import login, snapshot_download
from transformers import AutoTokenizer

REPO_ID  = "unsloth/Qwen3-14B-unsloth-bnb-4bit"
CACHE_DIR = Path.home() / ".hf_cache" / "qwen3_14b"  # persists across sessions
TOKENIZER_DIR = CACHE_DIR / "tokenizer_plus"           # where the enhanced files live

# Respect the user's HF token, if provided via env‑var. Safer than hard‑coding.
HF_TOKEN = os.getenv("HF_TOKEN")  # → None if not set
if HF_TOKEN:
    login(token=HF_TOKEN, add_to_git_credential=False)

# Speed up start‑up: no anonymous telemetry pings.
os.environ.setdefault("HF_HUB_DISABLE_TELEMETRY", "1")

# Ensure the cache path exists
CACHE_DIR.mkdir(parents=True, exist_ok=True)


if not (CACHE_DIR / "tokenizer.json").exists():
    print("Downloading tokenizer repo – this happens once…")
    snapshot_download(
        repo_id=REPO_ID,
        local_dir=str(CACHE_DIR),
        local_dir_use_symlinks=False,   # real files – avoids broken symlinks inside containers
        token=HF_TOKEN,                 # may be None (public repo)
        ignore_patterns=["*.safetensors", "*.bin"],  # skip model weights for now
    )
else:
    print("Repo already cached – skipping download.")


def build_extra_tokens() -> list[str]:
    # """Return the list of tokens we want to add exactly once."""
    # digits_padded = [f"{i:03d}" for i in range(1000)]  # 0000 … 999
    # digits_plain  = [str(i) for i in range(4_000)]        # 0 … 1999
    # keywords      = [" protein_coding ", " lncRNA ", " gene "]
    # Use dict.fromkeys to deduplicate *while* preserving order
    OUT_DIR   = Path("fine_tune_chunks")        # or whatever you used before
    GENE_FILE = OUT_DIR / "gene_tokens.json"    # this is now a Path, not str

    if GENE_FILE.exists():
        with open(GENE_FILE, "r", encoding="utf-8") as f:
            raw_genes: list[str] = json.load(f)
        gene_tokens = [f" {g} " for g in raw_genes]       # <space>GENE<space>
    else:
        print(f"⚠️  Gene token file {GENE_FILE} not found – proceeding without it.")
        gene_tokens = []

    return list(dict.fromkeys(gene_tokens))
   # return list(dict.fromkeys(digits_padded + digits_plain + keywords))

if TOKENIZER_DIR.exists():
    print("Loading enhanced tokenizer…")
    tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_DIR, trust_remote_code=True)
else:
    print("uilding enhanced tokenizer (one‑off)…")
    tokenizer = AutoTokenizer.from_pretrained(CACHE_DIR, trust_remote_code=True)
    print(f"   → Original vocab size: {len(tokenizer)}")
    extra_tokens = build_extra_tokens()
    added = tokenizer.add_tokens(extra_tokens, special_tokens=False)
    print(f"   → Added {added} new tokens (vocab now {len(tokenizer)})")
    tokenizer.save_pretrained(TOKENIZER_DIR)


print(f"Enhanced tokenizer saved to {TOKENIZER_DIR}")





print("Vocab size:", len(tokenizer))
# print("Tokenize sample:", tokenizer.tokenize(" protein_coding 000123"))


  from .autonotebook import tqdm as notebook_tqdm
For more details, check out https://huggingface.co/docs/huggingface_hub/main/en/guides/download#download-files-to-local-folder.


Downloading tokenizer repo – this happens once…


Fetching 12 files: 100%|██████████| 12/12 [00:00<00:00, 51.21it/s]


uilding enhanced tokenizer (one‑off)…
   → Original vocab size: 151669
   → Added 22727 new tokens (vocab now 174396)
Enhanced tokenizer saved to /home/hang/.hf_cache/qwen3_14b/tokenizer_plus
Vocab size: 174396


In [None]:
def examples_in(path):
    with open(path, encoding="utf-8") as f:
        first = next(itertools.dropwhile(str.isspace, f.read(1)), '')
        f.seek(0)
        if first == '[':          # normal JSON array
            yield from json.load(f)
        else:                     # JSON-Lines
            for line in f:
                line = line.strip()
                if line:
                    yield json.loads(line)

# ───────────────────────────────
# 3. Scan all files and track the maximum prompt length in *tokens*
max_tokens = 0
cnt = 0
total_items = 0
average_tokens = []
for path in glob.glob("fine_tune_chunks/cell_data_part_*.json"):
    cnt += 1
    print(f"Processing file #{cnt}: {path}")
    print("Current max token length:", max_tokens)

    for ex in examples_in(path):
        total_items += 1
        prompt = f"{ex.get('instruction','')} {ex.get('input','')}".strip()
        token_ids = tokenizer(prompt, add_special_tokens=False).input_ids
        n_tokens = len(token_ids)

       ##### # Print the first 10 *tokens* (not just their IDs)
        # first_10_token_ids = token_ids[:200]
        # first_10_tokens = tokenizer.convert_ids_to_tokens(first_10_token_ids)
        # print("Prompt:", prompt)
        # print("First 10 tokens:", first_10_tokens)
        # # print 10 a line
        # for i in range(0, len(first_10_tokens), 10):
        #     print(" ".join(first_10_tokens[i:i+10]))

       ############ # Track the longest prompt length observed
        max_tokens = max(max_tokens, n_tokens)
        average_tokens.append(n_tokens)
        # break
    # break
    print(f"Maximum prompt length (tokens): {max_tokens}")
    print(f"average tokens per item: {np.mean(average_tokens)}")

Processing file #1: fine_tune_chunks/cell_data_part_10.json
Current max token length: 0
Maximum prompt length (tokens): 9949
average tokens per item: 3063.382538770821
Processing file #2: fine_tune_chunks/cell_data_part_8.json
Current max token length: 9949
Maximum prompt length (tokens): 12031
average tokens per item: 3080.094485927628
Processing file #3: fine_tune_chunks/cell_data_part_9.json
Current max token length: 12031
Maximum prompt length (tokens): 12529
average tokens per item: 3081.0696917480377
Processing file #4: fine_tune_chunks/cell_data_part_3.json
Current max token length: 12529
Maximum prompt length (tokens): 12529
average tokens per item: 3080.2391959798997
Processing file #5: fine_tune_chunks/cell_data_part_4.json
Current max token length: 12529
Maximum prompt length (tokens): 12529
average tokens per item: 3084.909947162876
Processing file #6: fine_tune_chunks/cell_data_part_5.json
Current max token length: 12529
Maximum prompt length (tokens): 12529
average tokens

In [70]:
print(f"more than 6000 tokens: {len([x for x in average_tokens if x > 6000])}")
print(f"more than 7000 tokens: {len([x for x in average_tokens if x > 7000])}")
print(f"more than 8000 tokens: {len([x for x in average_tokens if x > 8000])}")
print(f"more than 9000 tokens: {len([x for x in average_tokens if x > 9000])}")
print(f"more than 10000 tokens: {len([x for x in average_tokens if x > 10000])}")
print(f"total items: {len(average_tokens)}")

more than 6000 tokens: 48
more than 7000 tokens: 13
more than 8000 tokens: 6
more than 9000 tokens: 3
more than 10000 tokens: 0
total items: 1742
