# Klasyfikator klasy białka na podstawie cech z InterPLM.

In [None]:
!git clone https://github.com/ElanaPearl/interPLM.git
%cd interPLM
!pip install -e .
!pip install biopython

## Co chcemy osiągnąć?
<img src="https://raw.githubusercontent.com/DavidJones1102/ML_project/refs/heads/main/pipeline.svg">


## Pobieranie danych
### Zbiór Danych: CATH Protein Domains


W bazie CATH istnieje ścisła hierarchia.

| Poziom | Oznaczenie | Nazwa | Liczba Klas* |
| --- | --- | --- | --- |
| **C** | `1` | **Class** | ~4 |
| **A** | `1.10` | **Architecture** | ~40 |
| **T** | `1.10.8` | **Topology** | **~1,272**  |
| **H** | `1.10.8.10` | **Homology** | >6,000 |

**Liczby klas są przybliżone i zależą od wersji bazy CATH.*

Będziemy trenować model na poziomie C.A.T, ponieważ niższe poziomy są zbyt ogólne, a poziom C.A.T.H jest zbyt rozdrobniony.


In [3]:
%%bash
mkdir -p cath_data
cd cath_data

echo "Pobieranie etykiet..."
wget -q -nc ftp://orengoftp.biochem.ucl.ac.uk/cath/releases/latest-release/cath-classification-data/cath-domain-list.txt

echo "Pobieranie sekwencji..."
wget -q -nc ftp://orengoftp.biochem.ucl.ac.uk/cath/releases/latest-release/sequence-data/cath-domain-seqs.fa

ls -lh

Pobieranie etykiet...
Pobieranie sekwencji...
Pliki w katalogu:
total 154M
-rw-r--r-- 1 root root  43M Jan 20 10:01 cath-domain-list.txt
-rw-r--r-- 1 root root 112M Jan 20 10:01 cath-domain-seqs.fa


In [5]:
!head -20 cath_data/cath-domain-list.txt

#---------------------------------------------------------------------
# FILE NAME:    CathDomainList.v4.4.0
# FILE DATE:    16.12.2024
#
# CATH VERSION: v4.4.0
# VERSION DATE: 16.12.2024
#
# FILE FORMAT:  Cath List File (CLF) Format 2.0
#
# FILE DESCRIPTION:
# Contains all classified protein domains in CATH
# for class 1 (mainly alpha), class 2 (mainly beta),
# class 3 (alpha and beta) and class 4 (few secondary structures).
#
# See 'README.file_formats' for file format information
#---------------------------------------------------------------------
1oaiA00     1    10     8    10     1     1     1     1     1    59 1.000
1go5A00     1    10     8    10     1     1     1     1     2    69 999.000
3frhA01     1    10     8    10     2     1     1     1     1    58 1.200
3friA01     1    10     8    10     2     1     1     1     2    54 1.800


In [15]:
import pandas as pd
from Bio import SeqIO

LABEL_FILE = 'cath_data/cath-domain-list.txt'
SEQ_FILE = 'cath_data/cath-domain-seqs.fa'

column_names = [
    'domain_id', 'class_C', 'arch_A', 'top_T', 'hom_H',
    's35', 's60', 's95', 's100', 's100_count', 'domain_len', 'resolution'
]

df_labels = pd.read_csv(
    LABEL_FILE,
    sep=r'\s+',
    comment='#',
    header=None,
    names=column_names,
    usecols=['domain_id', 'class_C', 'arch_A', 'top_T', 'hom_H']
)

df_labels['target_label'] = (
    df_labels['class_C'].astype(str) + "." +
    df_labels['arch_A'].astype(str) + "." +
    df_labels['top_T'].astype(str)
)
df_labels.drop(columns=['class_C', 'arch_A', 'top_T', 'hom_H'], inplace=True, errors='ignore')


print(f"Wczytywanie sekwencji z {SEQ_FILE}...")
seq_data = []

with open(SEQ_FILE, "r") as handle:
    for record in SeqIO.parse(handle, "fasta"):
        original_id = record.id
        # original_id = cath|4_4_0|3avrA01/886-901_1246-1312_1380-1395
        try:
            part_with_id = original_id.split('|')[2]
            # part_with_id = 3avrA01/886-901_1246-1312_1380-1395

            clean_id = part_with_id.split('/')[0]
            # clean_id = 3avrA01
            seq_data.append({
                'domain_id': clean_id,
                'sequence': str(record.seq)
            })
        except IndexError:
            print(f"Pominięto nietypowy nagłówek: {original_id}")
            continue

df_seqs = pd.DataFrame(seq_data)


print("Łączenie sekwencji z etykietami")
full_dataset = pd.merge(df_labels, df_seqs, on='domain_id', how='inner')

print("-" * 50)
print(f"GOTOWY ZBIÓR DANYCH: {len(full_dataset)} próbek.")
print("-" * 50)
pd.set_option('display.max_colwidth', 50)
print(full_dataset.head())

Wczytywanie sekwencji z cath_data/cath-domain-seqs.fa...
Łączenie sekwencji z etykietami
--------------------------------------------------
GOTOWY ZBIÓR DANYCH: 601328 próbek.
--------------------------------------------------
  domain_id target_label                                           sequence
0   1oaiA00       1.10.8  PTLSPEQQEMLQAFSTQSGMNLEWSQKCLQDNNWDYTRSAQAFTHL...
1   1go5A00       1.10.8  PAPTPSSSPVPTLSPEQQEMLQAFSTQSGMNLEWSQKCLQDNNWDY...
2   3frhA01       1.10.8  YPMNINDALTSILASKKYRALCPDTVRRILTEEWGRHKSPKQTVEA...
3   3friA01       1.10.8  YPMNINDALTSILASKKYRALCPDTVRRILTEEWGRHKSPKQTVEA...
4   3b89A01       1.10.8  SLNINDALTSILASKKYRALCPDTVRRILTEEWGRHKSPKQTVEAA...


In [None]:
import torch
from transformers import AutoModel, AutoTokenizer
from interplm.sae.inference import load_sae_from_hf

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Używam urządzenia: {DEVICE}")

MODEL_NAME = "esm2-8m"
HF_MODEL_NAME = "facebook/esm2_t6_8M_UR50D"
LAYER_ID = 6 

print("Ładowanie modelu ESM-2")
tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_NAME)
base_model = AutoModel.from_pretrained(HF_MODEL_NAME).to(DEVICE)
base_model.eval()

print(f"Ładowanie SAE dla warstwy {LAYER_ID}...")
sae = load_sae_from_hf(plm_model=MODEL_NAME, plm_layer=LAYER_ID)
sae = sae.to(DEVICE)
sae.eval()

In [None]:
import numpy as np
from tqdm import tqdm

# --- OGRANICZENIE DANYCH DO TESTÓW ---
SAMPLE_SIZE = 2000 

if SAMPLE_SIZE:
    df_subset = full_dataset.sample(n=min(SAMPLE_SIZE, len(full_dataset)), random_state=42).copy()
else:
    df_subset = full_dataset.copy()

print(f"Przetwarzanie {len(df_subset)} próbek...")

def extract_sae_features(sequences, batch_size=32):
    all_features = []
    
    for i in tqdm(range(0, len(sequences), batch_size)):
        batch_seqs = sequences[i:i+batch_size]
        
        inputs = tokenizer(
            batch_seqs, 
            return_tensors="pt", 
            padding=True, 
            truncation=True, 
            max_length=512
        ).to(DEVICE)
        
        with torch.no_grad():
            # Przepuszczamy sekwencje przez ESM-2
            outputs = base_model(**inputs, output_hidden_states=True)
            
            # Pobieramy hidden state z wybranej warstwy
            dense_acts = outputs.hidden_states[LAYER_ID] 
            
            # Przepuszczamy pobraną warstwe InterPLM
            sae_acts = sae.encode(dense_acts)
            
            # SAE zwraca osobne cechy dla każdego aminokwasu
            # uśredniamy, aby uzyskać jeden wektor opisujący całe białko dla klasyfikatora.
            mask = inputs['attention_mask'].unsqueeze(-1).float()
            sum_features = torch.sum(sae_acts * mask, dim=1)
            count_tokens = torch.clamp(mask.sum(dim=1), min=1e-9)
            mean_features = sum_features / count_tokens
            
            all_features.append(mean_features.cpu().numpy())
            
    return np.vstack(all_features)

X = extract_sae_features(df_subset['sequence'].tolist())
y = df_subset['target_label'].values

print(f"\nWymiary macierzy cech X: {X.shape}")
print(f"Liczba etykiet y: {len(y)}")