In [None]:
import os
import numpy as np
from Bio import SeqIO
from collections import Counter
from itertools import product
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split, cross_val_score, cross_val_predict, GridSearchCV
from sklearn.metrics import classification_report, matthews_corrcoef, accuracy_score, roc_auc_score, average_precision_score, f1_score, precision_score, recall_score, confusion_matrix, roc_curve, precision_recall_curve
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.feature_extraction.text import TfidfTransformer
from imblearn.over_sampling import RandomOverSampler
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import joblib
from sklearn.svm import SVC
from sklearn.preprocessing import label_binarize
from torch.nn import functional as F
import pandas as pd
from tqdm import tqdm
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

# تنظیم دستگاه (GPU یا CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"دستگاه: {device}")

# تابع برای Data Augmentation (تغییر تصادفی نوکلئوتیدها)
def augment_sequence(seq, mutation_rate=0.03):
    nucleotides = ['A', 'C', 'G', 'T']
    seq = list(seq)
    for i in range(len(seq)):
        if np.random.rand() < mutation_rate:
            seq[i] = np.random.choice([n for n in nucleotides if n != seq[i]])
    return ''.join(seq)

# تابع برای Mixup
def mixup_data(tfidf, tfidf_perm, distance, distance_perm, ssf, ssf_perm, pcp, pcp_perm, labels, alpha=0.4):
    lam = np.random.beta(alpha, alpha)
    tfidf_mixed = lam * tfidf + (1 - lam) * tfidf_perm
    distance_mixed = lam * distance + (1 - lam) * distance_perm
    ssf_mixed = lam * ssf + (1 - lam) * ssf_perm
    pcp_mixed = lam * pcp + (1 - lam) * pcp_perm
    labels_mixed = labels
    return tfidf_mixed, distance_mixed, ssf_mixed, pcp_mixed, labels_mixed

# تابع برای استخراج ویژگی‌های TF (k-mer)
def extract_tf_features(fasta_files, k=5, augment=True):
    kmer_dict = {''.join(n): i for i, n in enumerate(product('ACGT', repeat=k))}
    features = []
    labels = []
    sequences = []
    location_to_label = {
        'Cytoplasm': 0,
        'Endoplasmic_reticulum': 1,
        'Extracellular_region': 2,
        'Mitochondria': 3,
        'Nucleus': 4
    }

    for fasta_file in tqdm(fasta_files, desc="استخراج ویژگی‌های TF"):
        file_name = os.path.basename(fasta_file)
        parts = file_name.split('_')
        if 'reticulum' in parts or 'region' in parts:
            location = parts[0] + '_' + parts[1]
        else:
            location = parts[0]
        label = location_to_label[location]

        for record in SeqIO.parse(fasta_file, "fasta"):
            seq = str(record.seq)
            kmers = [seq[i:i+k] for i in range(len(seq)-k+1)]
            kmer_count = Counter(kmers)
            feature_vector = np.zeros(len(kmer_dict))
            for kmer, count in kmer_count.items():
                if kmer in kmer_dict:
                    feature_vector[kmer_dict[kmer]] = count / len(kmers)
            features.append(feature_vector)
            labels.append(label)
            sequences.append(seq)

            if augment:
                augmented_seq = augment_sequence(seq)
                kmers = [augmented_seq[i:i+k] for i in range(len(augmented_seq)-k+1)]
                kmer_count = Counter(kmers)
                feature_vector = np.zeros(len(kmer_dict))
                for kmer, count in kmer_count.items():
                    if kmer in kmer_dict:
                        feature_vector[kmer_dict[kmer]] = count / len(kmers)
                features.append(feature_vector)
                labels.append(label)
                sequences.append(augmented_seq)
    
    return np.array(features), np.array(labels), sequences

# تابع برای استخراج Distance-based Subsequence Profiles
def extract_distance_based_features(sequences, max_k=12):
    patterns = [''.join(p) for p in product('ACGT', repeat=2)]
    features = []
    
    for seq in tqdm(sequences, desc="استخراج ویژگی‌های Distance-based"):
        feature_vector = np.zeros((max_k + 1) * len(patterns))
        for k in range(max_k + 1):
            for i, pattern in enumerate(patterns):
                count = 0
                for j in range(len(seq) - k - 2):
                    if seq[j] == pattern[0] and seq[j + k + 1] == pattern[1]:
                        count += 1
                feature_vector[k * len(patterns) + i] = count / (len(seq) - k - 1) if len(seq) - k - 1 > 0 else 0
        features.append(feature_vector)
    
    return np.array(features)

# تابع برای استخراج ویژگی‌های ساختار ثانویه
def extract_ssf_features(sequences):
    features = []
    
    for seq in tqdm(sequences, desc="استخراج ویژگی‌های SSF"):
        seq_len = len(seq)
        gc_content = (seq.count('G') + seq.count('C')) / seq_len if seq_len > 0 else 0
        au_pairs = sum(1 for i in range(seq_len-1) if (seq[i] == 'A' and seq[i+1] == 'U') or (seq[i] == 'U' and seq[i+1] == 'A'))
        gc_pairs = sum(1 for i in range(seq_len-1) if (seq[i] == 'G' and seq[i+1] == 'C') or (seq[i] == 'C' and seq[i+1] == 'G'))
        loop_count = sum(1 for i in range(seq_len-2) if seq[i:i+3] in ['AAA', 'UUU', 'AAU', 'UUA'])
        hairpin_count = sum(1 for i in range(seq_len-4) if seq[i:i+5] in ['AUGCA', 'UACGU'])
        free_energy = -1.0 * gc_pairs - 0.5 * au_pairs
        
        feature_vector = np.array([
            gc_content,
            au_pairs / seq_len if seq_len > 0 else 0,
            gc_pairs / seq_len if seq_len > 0 else 0,
            loop_count / seq_len if seq_len > 0 else 0,
            hairpin_count / seq_len if seq_len > 0 else 0,
            free_energy / seq_len if seq_len > 0 else 0
        ])
        features.append(feature_vector)
    
    return np.array(features)

# تابع برای استخراج ویژگی‌های فیزیکوشیمیایی
def extract_pcp_features(sequences, max_length=1000):
    properties = {
        'A': [0.62, 0.3, -0.5],
        'C': [0.29, 0.4, -0.2],
        'G': [0.48, 0.5, -0.3],
        'T': [0.58, 0.2, -0.1],
        'U': [0.58, 0.2, -0.1],
        'N': [0.0, 0.0, 0.0]
    }
    
    features = []
    for seq in tqdm(sequences, desc="استخراج ویژگی‌های PCP"):
        prop_matrix = np.zeros((max_length, 3))
        for i in range(min(len(seq), max_length)):
            prop_matrix[i] = properties.get(seq[i], [0.0, 0.0, 0.0])
        mean_props = prop_matrix.mean(axis=0)
        std_props = prop_matrix.std(axis=0)
        feature_vector = np.concatenate([mean_props, std_props])
        features.append(feature_vector)
    
    return np.array(features)

# تعریف شاخه ویژگی با Residual Connections
class FeatureBranch(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(FeatureBranch, self).__init__()
        self.conv1d_3 = nn.Conv1d(in_channels=input_dim, out_channels=64, kernel_size=3, padding=1)
        self.conv1d_5 = nn.Conv1d(in_channels=input_dim, out_channels=64, kernel_size=5, padding=2)
        self.bn = nn.BatchNorm1d(128)
        self.dropout1 = nn.Dropout(0.5)
        self.mha = nn.MultiheadAttention(embed_dim=128, num_heads=8)
        self.norm = nn.LayerNorm(128)
        self.linear = nn.Linear(128, output_dim)
        self.residual_fc = nn.Linear(input_dim, 128)
    
    def forward(self, x):
        residual = self.residual_fc(x).unsqueeze(2)
        x = x.unsqueeze(2)
        x1 = self.conv1d_3(x)
        x2 = self.conv1d_5(x)
        x = torch.cat((x1, x2), dim=1)
        x = self.bn(x)
        x = self.dropout1(x)
        x = x + residual
        x = x.permute(2, 0, 1)
        x, _ = self.mha(x, x, x)
        x = self.norm(x)
        x = x.permute(1, 0, 2)
        x = x.squeeze(1)
        x = self.linear(x)
        return x

# تعریف مدل کامل
class RNALocateModel(nn.Module):
    def __init__(self, tfidf_dim=512, distance_dim=208, ssf_dim=6, pcp_dim=6, linear_dim=256, num_classes=5):
        super(RNALocateModel, self).__init__()
        self.tfidf_branch = FeatureBranch(tfidf_dim, linear_dim)
        self.distance_branch = FeatureBranch(distance_dim, linear_dim)
        self.ssf_branch = FeatureBranch(ssf_dim, linear_dim)
        self.pcp_branch = FeatureBranch(pcp_dim, linear_dim)
        self.fc = nn.Linear(linear_dim * 4, 256)
        self.dropout = nn.Dropout(0.5)
        self.output = nn.Linear(256, num_classes)
    
    def forward(self, tfidf, distance, ssf, pcp):
        tfidf_out = self.tfidf_branch(tfidf)
        distance_out = self.distance_branch(distance)
        ssf_out = self.ssf_branch(ssf)
        pcp_out = self.pcp_branch(pcp)
        combined = torch.cat((tfidf_out, distance_out, ssf_out, pcp_out), dim=1)
        fc_out = F.relu(self.fc(combined))
        fc_out = self.dropout(fc_out)
        final_out = self.output(fc_out)
        return final_out, fc_out

# تعریف Generator برای WGAN-GP
class Generator(nn.Module):
    def __init__(self, noise_dim=128, label_dim=5, output_dim=256+208+6+6):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(noise_dim + label_dim, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, output_dim),
            nn.Tanh()
        )
    
    def forward(self, noise, labels):
        input = torch.cat((noise, labels), dim=1)
        return self.model(input)

# تعریف Discriminator برای WGAN-GP
class Discriminator(nn.Module):
    def __init__(self, input_dim=256+208+6+6, label_dim=5):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim + label_dim, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1)
        )
    
    def forward(self, features, labels):
        input = torch.cat((features, labels), dim=1)
        return self.model(input)

# محاسبه گرادیان پنالتی برای WGAN-GP
def compute_gradient_penalty(discriminator, real_samples, fake_samples, labels, device):
    alpha = torch.rand(real_samples.size(0), 1).to(device)
    alpha = alpha.expand_as(real_samples)
    interpolates = alpha * real_samples + (1 - alpha) * fake_samples
    interpolates = interpolates.requires_grad_(True)
    d_interpolates = discriminator(interpolates, labels)
    gradients = torch.autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=torch.ones_like(d_interpolates).to(device),
        create_graph=True,
        retain_graph=True
    )[0]
    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty

# محاسبه وزن کلاس‌ها با Label Smoothing
def compute_class_weights(labels, smoothing=0.05):
    class_counts = np.bincount(labels)
    n_classes = len(class_counts)
    n_samples = len(labels)
    weights = n_samples / (n_classes * class_counts)
    weights = (1 - smoothing) * weights + smoothing / n_classes
    return torch.tensor(weights, dtype=torch.float32).to(device)

# آموزش WGAN-GP
def train_wgan_gp(generator, discriminator, train_loader, num_epochs_gan, device, noise_dim=100, n_critic=10, lambda_gp=10):
    g_optimizer = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    
    for epoch in tqdm(range(num_epochs_gan), desc="آموزش WGAN-GP"):
        for i, (tfidf, distance, ssf, pcp, labels) in enumerate(train_loader):
            tfidf, distance, ssf, pcp, labels = tfidf.to(device), distance.to(device), ssf.to(device), pcp.to(device), labels.to(device)
            features = torch.cat([tfidf, distance, ssf, pcp], dim=1)
            batch_size = tfidf.size(0)
            
            # آموزش Discriminator
            for _ in range(n_critic):
                d_optimizer.zero_grad()
                noise = torch.randn(batch_size, noise_dim).to(device)
                fake_features = generator(noise, F.one_hot(labels, num_classes=5).float())
                real_output = discriminator(features, F.one_hot(labels, num_classes=5).float())
                fake_output = discriminator(fake_features.detach(), F.one_hot(labels, num_classes=5).float())
                gradient_penalty = compute_gradient_penalty(
                    discriminator, features, fake_features.detach(), F.one_hot(labels, num_classes=5).float(), device
                )
                d_loss = -torch.mean(real_output) + torch.mean(fake_output) + lambda_gp * gradient_penalty
                d_loss.backward()
                d_optimizer.step()
            
            # آموزش Generator
            g_optimizer.zero_grad()
            noise = torch.randn(batch_size, noise_dim).to(device)
            fake_features = generator(noise, F.one_hot(labels, num_classes=5).float())
            fake_output = discriminator(fake_features, F.one_hot(labels, num_classes=5).float())
            g_loss = -torch.mean(fake_output)
            g_loss.backward()
            g_optimizer.step()
        
        print(f"Epoch {epoch+1}/{num_epochs_gan}, D Loss: {d_loss.item():.4f}, G Loss: {g_loss.item():.4f}")
    
    return generator, discriminator

# تولید داده‌های مصنوعی با WGAN-GP
def generate_synthetic_data(generator, num_samples, labels, noise_dim=100):
    noise = torch.randn(num_samples, noise_dim).to(device)
    labels_one_hot = F.one_hot(torch.tensor(labels), num_classes=5).float().to(device)
    synthetic_features = generator(noise, labels_one_hot)
    synthetic_features = synthetic_features.detach().cpu().numpy()
    synthetic_tfidf = synthetic_features[:, :256]
    synthetic_distance = synthetic_features[:, 256:256+208]
    synthetic_ssf = synthetic_features[:, 256+208:256+208+6]
    synthetic_pcp = synthetic_features[:, 256+208+6:]

    return synthetic_tfidf, synthetic_distance, synthetic_ssf, synthetic_pcp, labels

# تابع آموزش مدل PyTorch
def train_model(model, train_loader, val_loader, num_epochs, device, checkpoint_path, class_weights, patience=7):
    criterion = nn.CrossEntropyLoss(weight=class_weights, label_smoothing=0.05)
    optimizer = optim.Adam(model.parameters(), lr=0.0003, weight_decay=1e-4)
    best_val_acc = 0.0
    patience_counter = 0
    train_losses = []
    train_accuracies = []
    val_accuracies = []

    for epoch in tqdm(range(num_epochs), desc="آموزش مدل PyTorch"):
        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0
        for tfidf, distance, ssf, pcp, labels in train_loader:
            tfidf, distance, ssf, pcp, labels = tfidf.to(device), distance.to(device), ssf.to(device), pcp.to(device), labels.to(device)
            if np.random.rand() < 0.7:
                idx = torch.randperm(tfidf.size(0))
                tfidf_mixed, distance_mixed, ssf_mixed, pcp_mixed, labels_mixed = mixup_data(
                    tfidf, tfidf[idx], distance, distance[idx], ssf, ssf[idx], pcp, pcp[idx], labels, alpha=0.6
                )
                tfidf, distance, ssf, pcp, labels = tfidf_mixed, distance_mixed, ssf_mixed, pcp_mixed, labels_mixed
            optimizer.zero_grad()
            outputs, _ = model(tfidf, distance, ssf, pcp)
            labels = labels.long()
            if labels.ndim > 1:
                labels = labels.argmax(dim=1)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            train_total += labels.size(0)
            train_correct += (predicted == labels).sum().item()
        
        train_loss_epoch = train_loss / len(train_loader)
        train_acc_epoch = train_correct / train_total
        train_losses.append(train_loss_epoch)
        train_accuracies.append(train_acc_epoch)
        
        model.eval()
        val_correct = 0
        val_total = 0
        with torch.no_grad():
            for tfidf, distance, ssf, pcp, labels in val_loader:
                tfidf, distance, ssf, pcp, labels = tfidf.to(device), distance.to(device), ssf.to(device), pcp.to(device), labels.to(device)
                outputs, _ = model(tfidf, distance, ssf, pcp)
                _, predicted = torch.max(outputs, 1)
                val_total += labels.size(0)
                if labels.ndim > 1:
                    labels = labels.argmax(dim=1)
                val_correct += (predicted == labels).sum().item()
        
        val_acc = val_correct / val_total
        val_accuracies.append(val_acc)
        print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss_epoch:.4f}, Train Accuracy: {train_acc_epoch:.4f}, Val Accuracy: {val_acc:.4f}")
        
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), checkpoint_path)
            print(f"بهترین مدل ذخیره شد: {checkpoint_path}")
            patience_counter = 0
        else:
            patience_counter += 1
            print(f"صبر: {patience_counter}/{patience}")
        
        if patience_counter >= patience:
            print("Early Stopping: توقف آموزش به دلیل عدم بهبود در دقت اعتبارسنجی")
            break
    
    metrics_df = pd.DataFrame({
        'epoch': range(1, len(train_losses) + 1),
        'train_loss': train_losses,
        'train_accuracy': train_accuracies,
        'val_accuracy': val_accuracies
    })
    metrics_df.to_csv(os.path.join(os.path.dirname(checkpoint_path), 'training_metrics.csv'), index=False)
    print(f"متریک‌های آموزش ذخیره شد: {os.path.join(os.path.dirname(checkpoint_path), 'training_metrics.csv')}")
    
    return train_losses, train_accuracies, val_accuracies

# تابع برای استخراج ویژگی‌ها از مدل PyTorch
def extract_features(model, data_loader, device):
    model.eval()
    features = []
    labels_list = []
    
    for tfidf, distance, ssf, pcp, labels in tqdm(data_loader, desc="استخراج ویژگی‌ها"):
        tfidf, distance, ssf, pcp, labels = tfidf.to(device), distance.to(device), ssf.to(device), pcp.to(device), labels.to(device)
        _, fc_out = model(tfidf, distance, ssf, pcp)
        features.append(fc_out.detach().cpu().numpy())
        labels_list.append(labels.cpu().numpy())
    
    features = np.concatenate(features, axis=0)
    labels = np.concatenate(labels_list, axis=0)
    return features, labels

# مسیر ذخیره‌سازی فایل‌ها
output_dir = "F:\\payan-nameh\\Optimized_version2"
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

# فایل‌های FASTA
fasta_files = [
    "F:\\New Version\\Data\\rnalocate\\Cytoplasm_train.fasta",
    "F:\\New Version\\Data\\rnalocate\\Endoplasmic_reticulum_train.fasta",
    "F:\\New Version\\Data\\rnalocate\\Extracellular_region_train.fasta",
    "F:\\New Version\\Data\\rnalocate\\Mitochondria_train.fasta",
    "F:\\New Version\\Data\\rnalocate\\Nucleus_train.fasta"
]

# استخراج ویژگی‌ها
tf_features, y, sequences = extract_tf_features(fasta_files, k=5, augment=True)
tfidf = TfidfTransformer()
tfidf_features = tfidf.fit_transform(tf_features).toarray()
distance_features = extract_distance_based_features(sequences, max_k=12)
ssf_features = extract_ssf_features(sequences)
pcp_features = extract_pcp_features(sequences, max_length=1000)

print("تعداد نمونه‌ها در TF-IDF:", tfidf_features.shape[0])
print("تعداد نمونه‌ها در Distance-based:", distance_features.shape[0])
print("تعداد نمونه‌ها در SSF:", ssf_features.shape[0])
print("تعداد نمونه‌ها در PCP:", pcp_features.shape[0])

# نرمال‌سازی و کاهش ابعاد TF-IDF
scaler_tfidf = StandardScaler()
scaler_distance = StandardScaler()
scaler_ssf = StandardScaler()
scaler_pcp = StandardScaler()
tfidf_features_scaled = scaler_tfidf.fit_transform(tfidf_features)
distance_features_scaled = scaler_distance.fit_transform(distance_features)
ssf_features_scaled = scaler_ssf.fit_transform(ssf_features)
pcp_features_scaled = scaler_pcp.fit_transform(pcp_features)

pca = PCA(n_components=256, random_state=42)
tfidf_features_scaled = pca.fit_transform(tfidf_features_scaled)

joblib.dump(scaler_tfidf, os.path.join(output_dir, 'scaler_tfidf.pkl'))
joblib.dump(scaler_distance, os.path.join(output_dir, 'scaler_distance.pkl'))
joblib.dump(scaler_ssf, os.path.join(output_dir, 'scaler_ssf.pkl'))
joblib.dump(scaler_pcp, os.path.join(output_dir, 'scaler_pcp.pkl'))
joblib.dump(pca, os.path.join(output_dir, 'pca_tfidf.pkl'))
print("اسکیلرها و PCA ذخیره شدند.")

# تقسیم داده‌ها
X_tfidf_train, X_tfidf_test, X_distance_train, X_distance_test, X_ssf_train, X_ssf_test, X_pcp_train, X_pcp_test, y_train, y_test = train_test_split(
    tfidf_features_scaled, distance_features_scaled, ssf_features_scaled, pcp_features_scaled, y, test_size=0.2, random_state=42, stratify=y
)

print("تعداد نمونه‌ها در train:", X_tfidf_train.shape[0])
print("تعداد نمونه‌ها در test:", X_tfidf_test.shape[0])
print("تعداد نمونه‌ها در هر کلاس (train):", np.bincount(y_train))
print("تعداد نمونه‌ها در هر کلاس (test):", np.bincount(y_test))

# آموزش WGAN-GP
generator = Generator(noise_dim=100, label_dim=5, output_dim=256+208+6+6).to(device)
discriminator = Discriminator(input_dim=256+208+6+6, label_dim=5).to(device)
train_dataset_temp = TensorDataset(
    torch.tensor(X_tfidf_train, dtype=torch.float32),
    torch.tensor(X_distance_train, dtype=torch.float32),
    torch.tensor(X_ssf_train, dtype=torch.float32),
    torch.tensor(X_pcp_train, dtype=torch.float32),
    torch.tensor(y_train, dtype=torch.long)
)
train_loader_temp = DataLoader(train_dataset_temp, batch_size=32, shuffle=True)
generator, discriminator = train_wgan_gp(generator, discriminator, train_loader_temp, num_epochs_gan=200, device=device)

# تولید داده‌های مصنوعی
synthetic_tfidf, synthetic_distance, synthetic_ssf, synthetic_pcp, synthetic_labels = generate_synthetic_data(
    generator, num_samples=4000, labels=[1] * 2000 + [2] * 2000
)

# t-SNE برای چک کیفیت داده‌های مصنوعی
tsne = TSNE(n_components=2, random_state=42)
combined_features = np.concatenate([X_tfidf_train[:1000], synthetic_tfidf[:1000]], axis=0)
tsne_results = tsne.fit_transform(combined_features)
plt.scatter(tsne_results[:1000, 0], tsne_results[:1000, 1], c='blue', label='Real')
plt.scatter(tsne_results[1000:, 0], tsne_results[1000:, 1], c='red', label='Synthetic')
plt.legend()
plt.title('t-SNE of Real vs Synthetic Data')
plt.savefig(os.path.join(output_dir, 'tsne_synthetic_data.png'))
plt.close()
print(f"t-SNE ذخیره شد: {os.path.join(output_dir, 'tsne_synthetic_data.png')}")

# ترکیب داده‌های مصنوعی با داده‌های واقعی
X_tfidf_train = np.concatenate([X_tfidf_train, synthetic_tfidf], axis=0)
X_distance_train = np.concatenate([X_distance_train, synthetic_distance], axis=0)
X_ssf_train = np.concatenate([X_ssf_train, synthetic_ssf], axis=0)
X_pcp_train = np.concatenate([X_pcp_train, synthetic_pcp], axis=0)
y_train = np.concatenate([y_train, synthetic_labels], axis=0)

# اعمال RandomOverSampler
sampling_strategy = {0: 8496, 1: 5000, 2: 4000, 3: 2500, 4: 7768}
ros = RandomOverSampler(sampling_strategy=sampling_strategy, random_state=42)
X_train_combined = np.concatenate([X_tfidf_train, X_distance_train, X_ssf_train, X_pcp_train], axis=1)
X_train_combined_balanced, y_train_balanced = ros.fit_resample(X_train_combined, y_train)
X_tfidf_train_balanced = X_train_combined_balanced[:, :X_tfidf_train.shape[1]]
X_distance_train_balanced = X_train_combined_balanced[:, X_tfidf_train.shape[1]:X_tfidf_train.shape[1]+X_distance_train.shape[1]]
X_ssf_train_balanced = X_train_combined_balanced[:, X_tfidf_train.shape[1]+X_distance_train.shape[1]:X_tfidf_train.shape[1]+X_distance_train.shape[1]+X_ssf_train.shape[1]]
X_pcp_train_balanced = X_train_combined_balanced[:, X_tfidf_train.shape[1]+X_distance_train.shape[1]+X_ssf_train.shape[1]:]

print("شکل y_train_balanced:", y_train_balanced.shape)
if y_train_balanced.ndim > 1:
    y_train_balanced = np.argmax(y_train_balanced, axis=1)
    print("y_train_balanced به بردار یک‌بعدی تبدیل شد:", y_train_balanced.shape)

print("تعداد نمونه‌ها در هر کلاس بعد از RandomOverSampler (train):", np.bincount(y_train_balanced))
print("شکل داده‌های متعادل‌شده TF-IDF (train):", X_tfidf_train_balanced.shape)
print("شکل داده‌های متعادل‌شده Distance-based (train):", X_distance_train_balanced.shape)
print("شکل داده‌های متعادل‌شده SSF (train):", X_ssf_train_balanced.shape)
print("شکل داده‌های متعادل‌شده PCP (train):", X_pcp_train_balanced.shape)

# تبدیل داده‌ها به Tensor
train_dataset = TensorDataset(
    torch.tensor(X_tfidf_train_balanced, dtype=torch.float32),
    torch.tensor(X_distance_train_balanced, dtype=torch.float32),
    torch.tensor(X_ssf_train_balanced, dtype=torch.float32),
    torch.tensor(X_pcp_train_balanced, dtype=torch.float32),
    torch.tensor(y_train_balanced, dtype=torch.long)
)
test_dataset = TensorDataset(
    torch.tensor(X_tfidf_test, dtype=torch.float32),
    torch.tensor(X_distance_test, dtype=torch.float32),
    torch.tensor(X_ssf_test, dtype=torch.float32),
    torch.tensor(X_pcp_test, dtype=torch.float32),
    torch.tensor(y_test, dtype=torch.long)
)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# محاسبه وزن کلاس‌ها
class_weights = torch.tensor([1.0, 2.8, 3.5, 1.0, 1.0]).to(device)

# آموزش مدل PyTorch
model = RNALocateModel(tfidf_dim=256).to(device)
train_losses, train_accuracies, val_accuracies = train_model(
    model, train_loader, test_loader, num_epochs=100, device=device,
    checkpoint_path=os.path.join(output_dir, 'best_pytorch_model.pth'), class_weights=class_weights
)

# بارگذاری بهترین مدل
model.load_state_dict(torch.load(os.path.join(output_dir, 'best_pytorch_model.pth')))
print("بهترین مدل PyTorch بارگذاری شد.")

# استخراج ویژگی‌ها
X_train_features, y_train_features = extract_features(model, train_loader, device)
X_test_features, y_test_features = extract_features(model, test_loader, device)

np.save(os.path.join(output_dir, 'X_test_features.npy'), X_test_features)
np.save(os.path.join(output_dir, 'y_test_features.npy'), y_test_features)
print("ویژگی‌های تست ذخیره شدند.")

print("شکل ویژگی‌های استخراج‌شده (train):", X_train_features.shape)
print("شکل ویژگی‌های استخراج‌شده (test):", X_test_features.shape)

# بهینه‌سازی SVM
param_grid = {
    'C': [10, 50, 100, 200],
    'kernel': ['rbf'],
    'gamma': [0.01, 0.05, 0.1]
}
svm_model = SVC(probability=True, random_state=42)
grid = GridSearchCV(svm_model, param_grid, cv=15, scoring='f1_macro', n_jobs=4)
grid.fit(X_train_features, y_train_features)

print("بهترین پارامترها:", grid.best_params_)
print("بهترین F1-score در GridSearch:", grid.best_score_)

best_svm = grid.best_estimator_

# ارزیابی مدل روی داده آموزش
y_train_pred = cross_val_predict(best_svm, X_train_features, y_train_features, cv=10)
f1_macro_train = cross_val_score(best_svm, X_train_features, y_train_features, cv=10, scoring='f1_macro').mean()
accuracy_train = accuracy_score(y_train_features, y_train_pred)
precision_train = precision_score(y_train_features, y_train_pred, average='macro')
recall_train = recall_score(y_train_features, y_train_pred, average='macro')
mcc_train = matthews_corrcoef(y_train_features, y_train_pred)

print("\nمعیارهای کلی روی داده آموزش (cross-validation):")
print(f"F1-score (ماکرو): {f1_macro_train:.4f}")
print(f"Accuracy: {accuracy_train:.4f}")
print(f"Precision: {precision_train:.4f}")
print(f"Recall: {recall_train:.4f}")
print(f"MCC: {mcc_train:.4f}")

class_names = ['Cytoplasm', 'Endoplasmic_reticulum', 'Extracellular_region', 'Mitochondria', 'Nucleus']
print("\nگزارش معیارها برای هر کلاس (train):")
print(classification_report(y_train_features, y_train_pred, target_names=class_names))

# ارزیابی مدل روی داده تست
y_test_pred = best_svm.predict(X_test_features)
y_test_pred_proba = best_svm.predict_proba(X_test_features)
f1_macro_test = f1_score(y_test_features, y_test_pred, average='macro')
f1_weighted_test = f1_score(y_test_features, y_test_pred, average='weighted')
accuracy_test = accuracy_score(y_test_features, y_test_pred)
precision_test = precision_score(y_test_features, y_test_pred, average='macro')
recall_test = recall_score(y_test_features, y_test_pred, average='macro')
mcc_test = matthews_corrcoef(y_test_features, y_test_pred)
auc_roc_test = roc_auc_score(label_binarize(y_test_features, classes=[0,1,2,3,4]), y_test_pred_proba, average='macro')
auc_pr_test = average_precision_score(label_binarize(y_test_features, classes=[0,1,2,3,4]), y_test_pred_proba, average='weighted')

test_metrics = {
    'f1_macro': f1_macro_test,
    'f1_weighted': f1_weighted_test,
    'accuracy': accuracy_test,
    'precision': precision_test,
    'recall': recall_test,
    'mcc': mcc_test,
    'auc_roc': auc_roc_test,
    'auc_pr': auc_pr_test
}
test_metrics_df = pd.DataFrame([test_metrics])
test_metrics_df.to_csv(os.path.join(output_dir, 'test_metrics.csv'), index=False)
print(f"متریک‌های تست ذخیره شد: {os.path.join(output_dir, 'test_metrics.csv')}")

print("\nمعیارهای کلی روی داده تست:")
print(f"F1-score (ماکرو): {f1_macro_test:.4f}")
print(f"F1-score (وزنی): {f1_weighted_test:.4f}")
print(f"Accuracy: {accuracy_test:.4f}")
print(f"Precision: {precision_test:.4f}")
print(f"Recall: {recall_test:.4f}")
print(f"MCC: {mcc_test:.4f}")
print(f"AUC-ROC: {auc_roc_test:.4f}")
print(f"AUC-PR: {auc_pr_test:.4f}")

print("\nگزارش معیارها برای هر کلاس (test):")
print(classification_report(y_test_features, y_test_pred, target_names=class_names))

# محاسبه و ذخیره Confusion Matrix
cm = confusion_matrix(y_test_features, y_test_pred)
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
plt.savefig(os.path.join(output_dir, 'confusion_matrix.png'))
plt.close()
print(f"Confusion Matrix ذخیره شد: {os.path.join(output_dir, 'confusion_matrix.png')}")

# محاسبه و ذخیره ROC و PR curves
plt.figure(figsize=(12, 5))
label_binarized = label_binarize(y_test_features, classes=[0, 1, 2, 3, 4])
for i, class_name in enumerate(class_names):
    fpr, tpr, _ = roc_curve(label_binarized[:, i], y_test_pred_proba[:, i])
    plt.subplot(1, 2, 1)
    plt.plot(fpr, tpr, label=f'{class_name} (AUC={roc_auc_score(label_binarized[:, i], y_test_pred_proba[:, i]):.2f})')
    precision, recall, _ = precision_recall_curve(label_binarized[:, i], y_test_pred_proba[:, i])
    plt.subplot(1, 2, 2)
    plt.plot(recall, precision, label=f'{class_name} (AUC-PR={average_precision_score(label_binarized[:, i], y_test_pred_proba[:, i]):.2f})')

plt.subplot(1, 2, 1)
plt.plot([0, 1], [0, 1], 'k--')
plt.xlabel('False Positive Rate')
plt.ylabel('True')
plt.title('True Positive Rate')
plt.title('ROC Curve')
plt.legend(loc='lower right')

plt.subplot(1, 2, 2)
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Precision-Recall Curve')
plt.legend(loc='lower left')
plt.tight_layout()
plt.savefig(os.path.join(output_dir, 'roc_pr_curves.png'))
plt.close()
print(f"ROC و PR Curves ذخیره شد: {os.path.join(output_dir, 'roc_pr_curves.png')}")

# ذخیره مدل
joblib.dump(best_svm, os.path.join(output_dir, 'best_svm_model_hybrid.pkl'))
print(f"مدل ذخیره شد: {os.path.join(output_dir, 'best_svm_model_hybrid.pkl')}")