<a href="https://colab.research.google.com/github/ManVien/CBS_Minihack_F23/blob/main/Test_CBS.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [7]:
!pip install biopython
from Bio import SeqIO
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

# Function to extract sequences from fasta file
def extract_sequences(fasta_file):
    sequences = []
    for record in SeqIO.parse(fasta_file, "fasta"):
        sequences.append(str(record.seq))
    return sequences

# Function to generate k-mers from DNA sequences
def generate_kmers(sequences, k):
    vectorizer = CountVectorizer(analyzer='char', ngram_range=(k, k))
    kmers = vectorizer.fit_transform(sequences)
    return kmers

# Function to get top predictions indices
def get_top_predictions_indices(predictions, top_n):
    prediction_scores = clf.predict_proba(test_kmers)[:, 1]  # For binary classification
    top_indices = np.argsort(prediction_scores)[::-1][:top_n]
    return top_indices

# Function to write top predictions to a file
def write_top_predictions_to_file(indices, sequences_file, output_file):
    with open(sequences_file, 'r') as file:
        sequence_records = list(SeqIO.parse(file, 'fasta'))

    top_sequences = [sequence_records[i] for i in indices]

    with open(output_file, 'w') as outfile:
        for record in top_sequences:
            outfile.write(f"{record.id}\n")

# Load and preprocess data
accessible_sequences = extract_sequences('accessible.fasta')
non_accessible_sequences = extract_sequences('notaccessible.fasta')
test_sequences = extract_sequences('test.fasta')

# Take a smaller subset of the data for training (e.g., 10% of the total data)
train_fraction = 0.1

# Reduce the size of accessible and non-accessible sequences
num_accessible_samples = int(len(accessible_sequences) * train_fraction)
num_non_accessible_samples = int(len(non_accessible_sequences) * train_fraction)

accessible_sequences = accessible_sequences[:num_accessible_samples]
non_accessible_sequences = non_accessible_sequences[:num_non_accessible_samples]

# Filter sequences overlapping 'N' characters or repetitive elements
# Implement filtering logic based on 'N' characters and repetitive elements

# Generate labels for accessible and non-accessible sequences
labels_accessible = [1] * len(accessible_sequences)
labels_non_accessible = [0] * len(non_accessible_sequences)

# Combine sequences and labels
all_sequences = accessible_sequences + non_accessible_sequences
all_labels = labels_accessible + labels_non_accessible

# Generate k-mers for training data
k = 5  # Adjust k as needed
X = generate_kmers(all_sequences, k)

# Split data into train and validation sets
X_train, X_val, y_train, y_val = train_test_split(X, all_labels, test_size=0.2, random_state=42)

# Train a RandomForestClassifier (or any chosen classifier)
clf = RandomForestClassifier(n_estimators=100, random_state=42)  # Modify parameters as needed
clf.fit(X_train, y_train)

# Evaluate model on validation set
val_predictions = clf.predict(X_val)

# Calculate evaluation metrics
accuracy = accuracy_score(y_val, val_predictions)
precision = precision_score(y_val, val_predictions)
recall = recall_score(y_val, val_predictions)
f1 = f1_score(y_val, val_predictions)

print(f"Accuracy: {accuracy}")
print(f"Precision: {precision}")
print(f"Recall: {recall}")
print(f"F1-score: {f1}")

# Use trained model to predict on test data (test.fasta)
test_kmers = generate_kmers(test_sequences, k)
predictions = clf.predict(test_kmers)

# Get top 10,000 predictions of accessible sites
#top_predictions_indices = get_top_predictions_indices(predictions, 10000)

# Write sequence identifiers to output file
#write_top_predictions_to_file(top_predictions_indices, 'test.fasta', 'test_top_predictions.txt')


Accuracy: 0.9143998478219517
Precision: 0.676923076923077
Recall: 0.09302325581395349
F1-score: 0.16356877323420077
