In [None]:
import os 
import pandas as pd 

high_df = pd.DataFrame()
low_df = pd.DataFrame()

data_dir = 'data'
for f in os.listdir(data_dir):
    for file in os.listdir(os.path.join(data_dir, f)):
        if file.endswith('.csv') and f.endswith('high.alt'):
            high_df = pd.read_csv(os.path.join(data_dir, f, file))
        if file.endswith('.csv') and f.endswith('low.alt'):
            low_df = pd.read_csv(os.path.join(data_dir, f, file))

high_df['label'] = 1
low_df['label'] = 0

df = pd.concat([high_df, low_df], ignore_index=True)
df 

In [None]:
import pandas as pd
import numpy as np
import os
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report
from openai import OpenAI
# Set your OpenAI API key
# client = OpenAI()

def format_example(row):
    return (f"In the AF type of alternative splicing, the splicing site strength of the A2_5SS site at the 5' end of the alternatively spliced exon is {row.A2_5SS}, "
            f"In the AF type of alternative splicing, the splicing site strength of the A2_5SS site at the 5' end of the alternatively spliced exon is {row.A4_5SS}. ")

# Function to create prompt with examples
def create_prompt(examples, new_event):
    prompt = """You are an expert in genomics and epigenetics, specializing in splicing site strength patterns and their effects on alternative splicing events.
Context:
- The data represents splicing site strengths around splice sites in genomic regions.
- Splicing site strength refers to the effectiveness or probability that a particular splicing site will be recognised and used during RNA splicing.
- Each numerical feature represents splicing site strength values in a specific splice site.
Task:
- Classify each event as 'high' or 'low' based on the Splicing site strength.
- 'High' means that the PSI values for all variable shear events in the sample are in the highest 25% range, and 'Low' means that the PSI values for all variable shear events in the sample are in the lowest 25% range.
- Consider the relative differences between values, as well as the overall magnitude of Splicing site strength.

Here are a few Examples:
"""
    for _, example in examples.iterrows():
        prompt += f"{format_example(example)}the event is: {'High' if example.label == 1 else 'Low'}\n\n"
    
    prompt += f"Now reply in one word whether:\n{new_event} is high or low, then provide your reasoning in detail:"
    return prompt

# Function to get LLM prediction
def get_llm_prediction(prompt):
    response = client.chat.completions.create(
        model="gpt-4o-mini",
        messages=[{"role": "user", "content": prompt}],
        max_tokens=10,
        n=1,
        stop=None,
        temperature=0.0,
    )
    return response.choices[0].message.content.strip()

# Split the data
train, test = train_test_split(df, test_size=0.01, stratify=df['label'], random_state=42)

# Get 32 examples from each class
high_examples = train[train['label'] == 1].sample(4, random_state=42)
low_examples = train[train['label'] == 0].sample(4, random_state=42)
examples = pd.concat([high_examples, low_examples])

# Make predictions on test set
predictions = []
for _, row in test.iterrows():
    new_event = format_example(row)
    prompt = create_prompt(examples, new_event)
    prediction = get_llm_prediction(prompt)
    predictions.append(1 if prediction.lower() == 'high' else 0)

# Calculate accuracy
accuracy = accuracy_score(test['label'], predictions)
print(f"Accuracy: {accuracy}")

# Print classification report
print(classification_report(test['label'], predictions))