In [1]:
import pandas as pd
import numpy as np
import obonet, networkx as nx
from sklearn.preprocessing import MultiLabelBinarizer
import joblib, os

# ≈öcie≈ºki
PATH_TERMS = "../data/bronze/Train/train_terms.tsv"
PATH_OBO = "../data/bronze/Train/go-basic.obo"

# 1. Propagacja etykiet
go_graph = obonet.read_obo(PATH_OBO)
df_terms = pd.read_csv(PATH_TERMS, sep="\t", names=['Protein_ID', 'term', 'ontology'])

def get_ancestors(go_id, graph):
    return nx.descendants(graph, go_id) if go_id in graph else set()

protein_to_terms = df_terms.groupby('Protein_ID')['term'].apply(list).to_dict()
expanded_data = []

print("üîÑ Propagacja etykiet...")
for prot_id, terms in protein_to_terms.items():
    extended_set = set(terms)
    for t in terms:
        extended_set.update(get_ancestors(t, go_graph))
    for final_term in extended_set:
        expanded_data.append({'Protein_ID': prot_id, 'term': final_term})

df_silver = pd.DataFrame(expanded_data)

# 2. Wyb√≥r TOP 1500 i Binaryzacja
top_terms = df_silver['term'].value_counts().nlargest(1500).index.tolist()
df_filtered = df_silver[df_silver['term'].isin(top_terms)]
protein_labels = df_filtered.groupby('Protein_ID')['term'].apply(list)

mlb = MultiLabelBinarizer(classes=top_terms)
y = mlb.fit_transform(protein_labels)

# üÜï DODAJ: Statystyki przed splitem
print(f"\nüìä Statystyki datasetu:")
print(f"  Liczba bia≈Çek: {len(protein_labels):,}")
print(f"  Liczba GO terms: {len(top_terms)}")
print(f"  ≈örednia liczba labelek na bia≈Çko: {y.sum(axis=1).mean():.2f}")
print(f"  Min labelek: {y.sum(axis=1).min()}")
print(f"  Max labelek: {y.sum(axis=1).max()}")

# üÜï DODAJ: Train/Val split (85/15)
from sklearn.model_selection import train_test_split

protein_ids = protein_labels.index.values
train_idx, val_idx = train_test_split(
    range(len(protein_ids)), 
    test_size=0.15, 
    random_state=42,
    shuffle=True
)

y_train = y[train_idx]
y_val = y[val_idx]
train_ids = protein_ids[train_idx]
val_ids = protein_ids[val_idx]

print(f"\n‚úÇÔ∏è Split:")
print(f"  Train: {len(train_ids):,} bia≈Çek")
print(f"  Val: {len(val_ids):,} bia≈Çek")

# 3. Zapis
os.makedirs("../data/gold", exist_ok=True)
os.makedirs("../models", exist_ok=True)

# üÜï ZMIE≈É: Zapisz osobno train i val
np.save("../data/gold/y_train_labels.npy", y_train)
np.save("../data/gold/y_val_labels.npy", y_val)
np.save("../data/gold/train_protein_ids.npy", train_ids)
np.save("../data/gold/val_protein_ids.npy", val_ids)
joblib.dump(top_terms, "../models/top_terms_1500.pkl")

print(f"\n‚úÖ Zapisano:")
print(f"  y_train: {y_train.shape}")
print(f"  y_val: {y_val.shape}")
print(f"  top_terms: {len(top_terms)} termin√≥w")

üîÑ Propagacja etykiet...

üìä Statystyki datasetu:
  Liczba bia≈Çek: 82,404
  Liczba GO terms: 1500
  ≈örednia liczba labelek na bia≈Çko: 38.91
  Min labelek: 1
  Max labelek: 530

‚úÇÔ∏è Split:
  Train: 70,043 bia≈Çek
  Val: 12,361 bia≈Çek

‚úÖ Zapisano:
  y_train: (70043, 1500)
  y_val: (12361, 1500)
  top_terms: 1500 termin√≥w
