In [None]:
import os
import argparse
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.preprocessing import QuantileTransformer
import numpy as np
import seaborn as sns
from scipy import stats
import matplotlib.pyplot as plt
from scipy.stats import spearmanr
from scipy.stats import pearsonr
from sklearn.metrics import r2_score
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score



import torch.nn.functional as F
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv, global_mean_pool
from torch.nn import Linear, Dropout
from torch_geometric.data import Dataset, DataLoader
from models import SelfAttention, HybridModel
from tqdm import tqdm 
import wandb


# Load graphs and data features

In [None]:


# Initialize W&B
wandb.init(project='Immunogenicity', entity='kevingivechian')



# Argument parser setup
parser = argparse.ArgumentParser(description='Train a model on protein graph data')
parser.add_argument('--batch_size', type=int, default=60, help='Input batch size for training (default: 60)')
parser.add_argument('--lr', type=float, default=0.00001, help='Learning rate (default: 0.00001)')
parser.add_argument('--epochs', type=int, default=5, help='epochs (default: 5)')
args = parser.parse_args()

# Now using args to configure your training
batch_size = args.batch_size
lr = args.lr
epochs = args.epochs


wandb.config.update({"batch_size": batch_size})
wandb.config.update({"learning_rate": lr})
wandb.config.update({"epochs": epochs})













 # Import tqdm for the progress bar

# Set the directory where your files are located
directory = '/gpfs/gibbs/project/krishnaswamy_smita/kbg32/extracted_folder_FULL2/PyGs'

# Get a list of all .pt files in the directory
files = [f for f in os.listdir(directory) if f.endswith('.pt')]

# Initialize an empty list to store the graphs
graphs = []

# Loop through the files and load each graph, showing a progress bar
for file in tqdm(files, desc="Loading graphs"):
    file_path = os.path.join(directory, file)
    graph = torch.load(file_path)
    graphs.append(graph)

# Now `graphs` contains all the loaded graph objects
print(f"Loaded {len(graphs)} graphs.")


graphs = [x for x in graphs if 'NXVPMVATV' not in x.name]


expanded_df = pd.read_csv('complete_score.csv')
expanded_df = pd.read_table('complete_score_Mprops_1_2.csv')
expanded_df = expanded_df.dropna(subset='Foreignness_Score')
expanded_df['pep_pair'] = expanded_df['peptide'] + expanded_df['allele']



data = np.array(expanded_df['Foreignness_Score'].tolist()) # Your U-shaped data
data = data.reshape(-1, 1)  # Reshape if the data is 1D

qt = QuantileTransformer(n_quantiles=50, random_state=0, output_distribution='normal')
transformed_data_q = qt.fit_transform(data)
expanded_df['quant_foreign'] = transformed_data_q
expanded_df.dropna(subset='quant_foreign')

f_dict = dict(zip(expanded_df['peptide'],expanded_df['smoothed_foreign']))
fp2_dict = dict(zip(expanded_df['peptide'],expanded_df['Mprop1']))
new_imm_dict = dict(zip(expanded_df['peptide'],expanded_df['immunogenicity']))
new_imm_dict_pair = dict(zip(expanded_df['pep_pair'],expanded_df['immunogenicity']))

expanded_pep_pair = expanded_df['pep_pair'].tolist()




#cut off h-bonding features for now 

for data in graphs:  # Assuming data_list is the list containing your graph data
    data.x = data.x[:, :-2]



hla_df = pd.read_csv('HLA_27_seqs_csv.csv')
hla_dict_true = dict(zip(hla_df['allele'], hla_df['seqs']))
strings = [x.name for x in graphs]
strings = [x for x in strings if 'NXVPMVATV' not in x]




extracted_substrings = []

for s in strings:
    start = s.find("LPKPLTLR")
    if start != -1:  # Check if the substring exists in the string
        end = s.find("_", start)
        if end != -1:
            substring = s[start:end]
            extracted_substrings.append(substring)
        else:
            print(f"Underscore not found in string: {s}")
    else:
        print(f"'LPKPLTLR' not found in string: {s}")

        
new_values_fp2_values = []
new_values_f_values = []

new_imm_values = []
peptide_order = []


# Print the extracted substrings
for substring in extracted_substrings:
    new_name = substring[8:]
    new_values_fp2_values.append(fp2_dict[new_name])
    new_values_f_values.append(f_dict[new_name])
    peptide_order.append(new_name)
    
    
    
for substring in extracted_substrings:
    new_name = substring[8:]
    new_imm_values.append(new_imm_dict[new_name])



hla_name = []

for x in graphs:
    start_index = 24
    end_sequence = "LPKPLTLR"
    end_index = x.name.find(end_sequence, start_index) + len(end_sequence)
    sub_hla = x.name[start_index:end_index] if end_index != -1 else ""
    hla_name.append(sub_hla)


short_strings = hla_name

# Update dictionary values with shorter strings
for key, value in hla_dict_true.items():
    for short_string in short_strings:
        if short_string in value:
            hla_dict_true[key] = short_string
            break  # Stop checking after the first match is found


hla_dict = hla_dict_true
inverted_hla_dict = {}
for key, value in hla_dict.items():
    if value in inverted_hla_dict:
        inverted_hla_dict[value].append(key)
    else:
        inverted_hla_dict[value] = [key]

        
        

corresponding_alleles = []

# Loop through each string in hla_name
for string in hla_name:
    string = string[2:]
    found = False
    for key in inverted_hla_dict:
        if string in key:
            corresponding_alleles.append(inverted_hla_dict[key])
            found = True
            break
    if not found:
        corresponding_alleles.append(["Allele not found"])


        
scores = new_values_fp2_values



for data, score in zip(graphs, scores):
    data.y = torch.tensor([score], dtype=torch.float)  # We use a one-element tensor for each graph-level label
    data.x = torch.cat([data.x, data.coords], dim=-1)

    data.x = data.x.to(dtype=torch.float32)
    data.y = data.y.to(dtype=torch.float32)



    
import torch
from torch_geometric.data import Data

def pad_graph(graph, max_nodes, feature_size, coord_size):
    num_nodes_to_add = max_nodes - graph.num_nodes
    if num_nodes_to_add > 0:
        # Pad node features
        zero_features = torch.zeros(num_nodes_to_add, feature_size)
        padded_features = torch.cat([graph.x, zero_features], dim=0)

        # Pad coordinates
        zero_coords = torch.zeros(num_nodes_to_add, coord_size)
        padded_coords = torch.cat([graph.coords, zero_coords], dim=0)

        # Update the graph
        graph.x = padded_features
        graph.coords = padded_coords
        graph.num_nodes = max_nodes
    return graph

# Example usage
graphs_to_pad = graphs  # Replace with your list of graphs
max_nodes = max(graph.num_nodes for graph in graphs_to_pad)
feature_size = 23  # Replace with the size of your feature vectors
coord_size = 3     # Replace with the size of your coordinate vectors

padded_graphs = [pad_graph(graph, max_nodes, feature_size, coord_size) for graph in graphs_to_pad]



graphs = padded_graphs
    

# Load and process sequence data 

In [None]:
####LOAD and PREPROCESS SEQUENCE DATA     
    
    
seq_df = pd.DataFrame(peptide_order)
seq_df['protfp2'] = new_values_fp2_values
seq_df['f'] = new_values_f_values

seq_df.columns = ['peptide','protfp2','f']
seq_df['allele'] = corresponding_alleles   
    
    
import pandas as pd

# Assuming df is your existing DataFrame
# Assuming expanded_pep_pair is your list of expanded peptide pairs

# Function to find the matching allele
def find_matching_allele(peptide, alleles, expanded_pep_pair):
    for allele in alleles:
        combo = peptide + allele
        if combo in expanded_pep_pair:
            return combo
    return 0  # Return None if no match is found

# Apply the function to each row
seq_df['combo2'] = seq_df.apply(lambda row: find_matching_allele(row['peptide'], row['allele'], expanded_pep_pair), axis=1)






sums = []
z_peps = []

for x in seq_df[seq_df['combo2'] == 0]['peptide'].tolist():
    ndf = expanded_df[expanded_df['peptide'] == x]
    sums.append(ndf['immunogenicity'].sum())
    if ndf['immunogenicity'].sum() == 0:
        z_peps.append(x)
        
        
        
        
        
final_imm = []

for x in seq_df['combo2'].tolist():
    if x in new_imm_dict_pair.keys():
        final_imm.append(new_imm_dict_pair[x])
    else:
        final_imm.append(0)
        
seq_df['final_immuno'] = final_imm
seq_df['length'] = [len(x) for x in seq_df['peptide'].tolist()]





# Assuming seq_df is your DataFrame and graphs is your list of graphs
# z_peps is your list of peptides to keep

# Create a mask for rows to keep
keep_mask = (seq_df['combo2'] != 0) | (seq_df['peptide'].isin(z_peps))

# Filter the DataFrame based on the mask
filtered_df = seq_df[keep_mask]

# Get the indices of the remaining rows
remaining_indices = filtered_df.index.tolist()

# Filter the graphs list using these indices
filtered_graphs = [graphs[i] for i in remaining_indices]

    
    
    

    
    
#Convert PYTORCH GRAPHS TO DGL GRAPHS


# Assuming 'graphs' is your list of graph data objects
for graph in filtered_graphs:
    # The number of edges is half the size of the second dimension of edge_index
    num_edges = graph.edge_index.size(1)
    
    # Create a tensor of ones with the size equal to the number of edges
    # Assuming all edges have a single feature, which is set to 1
    graph.edge_attr = torch.ones((num_edges, 1))

    # Now, each graph object has an edge_attr tensor filled with ones
    
    
    
    
    
import torch
import dgl
#from egnn_pytorch import EGNNConv
from torch_geometric.data import Data

dgl_filtered_graphs = []
for g in filtered_graphs:
    # Example PyTorch Geometric graph (replace with your actual graph)
    #print(g.name)
    pt_geometric_graph = g

    # Convert to DGL graph
    src, dst = pt_geometric_graph.edge_index
    dgl_graph = dgl.graph((src, dst), num_nodes=pt_geometric_graph.num_nodes)
    dgl_graph.ndata['x'] = pt_geometric_graph.x  # Node features
    dgl_graph.edata['edge_attr'] = pt_geometric_graph.edge_attr  # Edge attributes

    # Extracting node features and coordinate features from dgl_graph
    node_feat = dgl_graph.ndata['x'][:, :20]  # First 20 features are node features
    coord_feat = dgl_graph.ndata['x'][:, 20:]  # Last 3 features are coordinate features

    # Using EGNNConv
    #conv = EGNNConv(20, 200, 20)  # Adjust dimensions according to your model

    # Forward pass through EGNNConv
    #h, x = conv(dgl_graph, node_feat, coord_feat)
    dgl_filtered_graphs.append(dgl_graph)
    

    
    
    
    
    
    
import numpy as np


# Function to pad peptide sequences
def pad_peptide_sequence(sequence, max_length=11, padding_char='J'):
    # Pad the sequence with the padding character to reach the max length
    padded_sequence = sequence.ljust(max_length, padding_char)
    return padded_sequence

# Example usage
peptides2 = filtered_df['peptide'].tolist() # Add your list of peptides here
padded_peptides = [pad_peptide_sequence(pep) for pep in peptides2]
filtered_df['peptide_padded'] = padded_peptides


def one_hot_encode_sequence(sequence, amino_acids, padding_char='J'):
    # Create a dictionary mapping each amino acid and padding character to an integer
    char_to_int = dict((c, i) for i, c in enumerate(amino_acids + padding_char))

    # Initialize the one-hot encoded matrix for the sequence
    one_hot_encoded = np.zeros((len(sequence), len(char_to_int)))

    # Fill the one-hot encoded matrix with appropriate values
    for i, char in enumerate(sequence):
        if char in char_to_int:  # Only encode known characters
            one_hot_encoded[i, char_to_int[char]] = 1
    
    return one_hot_encoded

# Define the amino acids and padding character
amino_acids = 'ACDEFGHIKLMNPQRSTVWY'  # 20 standard amino acids
padding_char = 'J'



filtered_df
filtered_df['length'] = [len(x) for x in filtered_df['peptide'].tolist()]



protein_sequences = filtered_df['peptide_padded'].tolist()
protein_reg_values = filtered_df['protfp2'].tolist()
protein_immuno_values = filtered_df['final_immuno'].tolist()
protein_sequences_non_padded = filtered_df['peptide'].tolist()

protein_reg_values_f = filtered_df['f'].tolist()




encoded_sequences = [one_hot_encode_sequence(seq, amino_acids, padding_char) for seq in protein_sequences]

# Build Data Loaders

In [None]:
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import DataLoader, TensorDataset
from torch_geometric.data import Dataset, DataLoader


# Split the graphs into a training set and a validation set
#train_graphs, val_graphs = train_test_split(filtered_graphs, test_size=0.20, random_state=42)





class ProteinDataset(Dataset):
    def __init__(self, list_of_graphs, transform=None, pre_transform=None):
        super(ProteinDataset, self).__init__('.', transform, pre_transform)
        self.list_of_graphs = list_of_graphs

    def len(self):
        return len(self.list_of_graphs)

    def get(self, idx):
        # Here we would return the idx-th graph in the list and its corresponding label
        data = self.list_of_graphs[idx]
        return data

    
    
    
# Convert the list of NumPy arrays into a single NumPy array
encoded_sequences_array = np.array(encoded_sequences)
regression_values_array = np.array(protein_reg_values)
binary_values_array = np.array(protein_immuno_values)
regression_values_array_f = np.array(protein_reg_values_f)





from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import DataLoader, TensorDataset
from torch_geometric.data import Dataset, DataLoader

# First, split the graphs and sequences into training and non-training sets
train_graphs, non_train_graphs = train_test_split(dgl_filtered_graphs, test_size=0.30, random_state=42)  # 30% for non-training


X_train, X_non_train, y_train, y_non_train, y_trainf, y_non_trainf = train_test_split(
    encoded_sequences_array, regression_values_array, regression_values_array_f, test_size=0.30, random_state=42
)




# only run this if 9-mers only is neeeded 

#train_graphs, non_train_graphs = train_test_split(filtered_graphs9, test_size=0.30, random_state=42)  # 30% for non-training
#X_train, X_non_train, y_train, y_non_train = train_test_split(
    #encoded_sequences_array, regression_values_array, test_size=0.30, random_state=42
#)


y_train_b, y_non_train_b = train_test_split(
    binary_values_array, test_size=0.30, random_state=42
)

# Now, split the non-training set into validation and test sets (50% each)
val_graphs, test_graphs = train_test_split(non_train_graphs, test_size=0.5, random_state=42)  # 50% of non-training for test


X_val, X_test, y_val, y_test = train_test_split(
    X_non_train, y_non_train, test_size=0.5, random_state=42
)


X_val, X_test, y_val_f, y_test_f = train_test_split(
    X_non_train,y_non_trainf, test_size=0.5, random_state=42
)




y_val_b, y_test_b = train_test_split(
    y_non_train_b, test_size=0.5, random_state=42
)

# Convert the NumPy arrays to PyTorch tensors
X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
X_val_tensor = torch.tensor(X_val, dtype=torch.float32)
X_test_tensor = torch.tensor(X_test, dtype=torch.float32)

y_train_tensor = torch.tensor(y_train, dtype=torch.float32)
y_val_tensor = torch.tensor(y_val, dtype=torch.float32)
y_test_tensor = torch.tensor(y_test, dtype=torch.float32)

y_train_tensor_f = torch.tensor(y_trainf, dtype=torch.float32)
y_val_tensor_f = torch.tensor(y_val_f, dtype=torch.float32)
y_test_tensor_f = torch.tensor(y_test_f, dtype=torch.float32)


y_train_tensor_b = torch.tensor(y_train_b, dtype=torch.float32)
y_val_tensor_b = torch.tensor(y_val_b, dtype=torch.float32)
y_test_tensor_b = torch.tensor(y_test_b, dtype=torch.float32)

# Create TensorDatasets
train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
val_dataset = TensorDataset(X_val_tensor, y_val_tensor)
test_dataset = TensorDataset(X_test_tensor, y_test_tensor)

# Create TensorDatasets (binary)
train_dataset_b = TensorDataset(X_train_tensor, y_train_tensor_b)
val_dataset_b = TensorDataset(X_val_tensor, y_val_tensor_b)
test_dataset_b = TensorDataset(X_test_tensor, y_test_tensor_b)

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Create DataLoaders (binary)
train_loader_b = DataLoader(train_dataset_b, batch_size=batch_size, shuffle=True)
val_loader_b = DataLoader(val_dataset_b, batch_size=batch_size, shuffle=False)
test_loader_b = DataLoader(test_dataset_b, batch_size=batch_size, shuffle=False)

# For Graphs
train_dataset_g = ProteinDataset(train_graphs)
val_dataset_g = ProteinDataset(val_graphs)
test_dataset_g = ProteinDataset(test_graphs)

train_loader_g = DataLoader(train_dataset_g, batch_size=batch_size, shuffle=True)
val_loader_g = DataLoader(val_dataset_g, batch_size=batch_size, shuffle=False)
test_loader_g = DataLoader(test_dataset_g, batch_size=batch_size, shuffle=False)



device = 'cuda'



from torch.utils.data import Dataset

class CustomGraphDataset(Dataset):
    def __init__(self, graphs, sequences, labels, labelsf):
        self.graphs = graphs
        self.sequences = sequences
        self.labels = labels
        self.labelsf = labelsf

    def __len__(self):
        return len(self.graphs)

    def __getitem__(self, idx):
        return self.graphs[idx], self.sequences[idx], self.labels[idx], self.labelsf[idx]


def collate(samples):
    graphs, seq_data, labels, labelsf = map(list, zip(*samples))
    batched_graph = dgl.batch(graphs)
    seq_data = torch.stack(seq_data, dim=0)
    labels = torch.stack(labels, dim=0)
    labelsf = torch.stack(labelsf, dim=0)
    return batched_graph, seq_data, labels, labelsf


    
train_dataset = CustomGraphDataset(train_graphs, X_train_tensor, y_train_tensor, y_train_tensor_f)
val_dataset = CustomGraphDataset(val_graphs, X_val_tensor, y_val_tensor,y_val_tensor_f)
test_dataset = CustomGraphDataset(test_graphs, X_test_tensor, y_test_tensor,y_test_tensor_f)

train_dataset_b = CustomGraphDataset(train_graphs, X_train_tensor, y_train_tensor_b,y_train_tensor_f)
val_dataset_b = CustomGraphDataset(val_graphs, X_val_tensor, y_val_tensor_b,y_val_tensor_f)
test_dataset_b = CustomGraphDataset(test_graphs, X_test_tensor, y_test_tensor_b,y_test_tensor_f)


from dgl.dataloading import GraphDataLoader

train_loader_g = GraphDataLoader(train_dataset, batch_size=batch_size, collate_fn=collate, shuffle=True)
val_loader_g = GraphDataLoader(val_dataset, batch_size=batch_size,collate_fn=collate, shuffle=False)
test_loader_g = GraphDataLoader(test_dataset, batch_size=batch_size,collate_fn=collate, shuffle=False)
#train_loader_g = GraphDataLoader(train_dataset, batch_size=16, shuffle=True)


train_loader_gb = GraphDataLoader(train_dataset_b, batch_size=batch_size, collate_fn=collate, shuffle=True)
val_loader_gb = GraphDataLoader(val_dataset_b, batch_size=batch_size,collate_fn=collate, shuffle=False)
test_loader_gb = GraphDataLoader(test_dataset_b, batch_size=batch_size,collate_fn=collate, shuffle=False)

# Run training loops and log performance

In [None]:



# Define the model
device = 'cuda'
num_epochs = wandb.config.epochs
hybrid_model = HybridModel(gat_hidden_channels=100, vae_input_dim=11*21, vae_hidden_dim=400, vae_latent_dim=30, final_output_dim=1)
hybrid_model.to(device)

# Define the optimizer
optimizer = torch.optim.Adam(hybrid_model.parameters(), lr=wandb.config.learning_rate)

# Define the loss function
def hybrid_loss_function(recon_x, x, mu, logvar, final_output, y):
    #MSE = F.mse_loss(recon_x, x, reduction='sum') 
    MSE = F.mse_loss(recon_x, x.view(-1, 11*21), reduction='sum')

    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    regression_loss = F.mse_loss(final_output.squeeze(), y, reduction='sum')
    return 0.5*MSE + 0.5*KLD + 2.0*regression_loss

# Training loop
def train_model():
    train_losses = []
    val_losses = []
    
    for epoch in range(num_epochs):
        hybrid_model.train()
        train_loss = 0
        for graph_data, sequence_data, target, peptide_property in train_loader_g:
            #single_graph, sequence_data, target = batch[0][0], batch[1][0], batch[2][0]
            #print(graph_data)
            #print(target)
            graph_data = graph_data.to(device)
            sequence_data, target, peptide_property = sequence_data.to(device), target.to(device),peptide_property.to(device)
            #print(target)
            #print(graph_data)
            optimizer.zero_grad()
            recon_batch, mu, logvar, final_output = hybrid_model(graph_data, sequence_data, target, peptide_property)
            #print('finaloutput')
            #print(final_output)
            loss = hybrid_loss_function(recon_batch, sequence_data, mu, logvar, final_output, target)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        train_loss /= len(train_loader_g.dataset)
        train_losses.append(train_loss)

        # Validation loop
        hybrid_model.eval()
        val_loss = 0
        with torch.no_grad():
            for graph_data, sequence_data, target, peptide_property in val_loader_g:
                graph_data = graph_data.to(device)
                sequence_data, target, peptide_property = sequence_data.to(device), target.to(device),peptide_property.to(device)

                recon_batch, mu, logvar, final_output = hybrid_model(graph_data, sequence_data,target, peptide_property)
                loss = hybrid_loss_function(recon_batch,sequence_data, mu, logvar, final_output, target)
                val_loss += loss.item()

        val_loss /= len(val_loader_g.dataset)
        val_losses.append(val_loss)

        print(f"Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")

    return train_losses, val_losses

# Plot the loss curves
def plot_losses(train_losses, val_losses):
    plt.figure(figsize=(10, 5))
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.title('Training and Validation Losses')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.show()

# Optionally, run the training loop
if __name__ == '__main__':
    train_losses, val_losses = train_model()
    plot_losses(train_losses, val_losses)
    print("DONE PRE-TRAINING")



pre_trained = hybrid_model
fine_tuned = hybrid_model









# Hyperparameters for loss terms
ALPHA = 5.0  # Weight for the BCE loss (1 or 0 immunogenicity prediciton)
BETA = 0.10   # Weight for the MSE loss (reconstruction)
THETA = 0.10  # Weight for the KLD loss  



# Calculate class weights
positive_weight = 13429 /3745
class_weights = torch.tensor([1.0, positive_weight], device=device)

# Define the loss function with class-weighted BCE
def hybrid_loss_function(recon_x, x, mu, logvar, final_output, y, alpha=ALPHA, beta=BETA, theta=THETA):
    MSE = F.mse_loss(recon_x, x.view(-1, 11*21), reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    
    # Apply class weights in BCE
    #BCE = F.binary_cross_entropy_with_logits(final_output.squeeze(), y, weight=class_weights[y.long()], reduction='sum')
    BCE = F.binary_cross_entropy_with_logits(final_output.view(-1), y.view(-1), weight=class_weights[y.long()], reduction='sum')

    return beta * MSE + theta * KLD + alpha * BCE



optimizer = torch.optim.Adam(fine_tuned.parameters(), lr=wandb.config.learning_rate, weight_decay = 1e-6)







# Training loop
# Training loop


device = 'cuda'

# Training loop
def train_model():
    train_losses = []
    val_losses = []
    
    for epoch in range(num_epochs):
        hybrid_model.train()
        train_loss = 0
        for graph_data, sequence_data, target, peptide_property in train_loader_gb:
            #single_graph, sequence_data, target = batch[0][0], batch[1][0], batch[2][0]
            #print(graph_data)
            #print(target)
            graph_data = graph_data.to(device)
            sequence_data, target, peptide_property = sequence_data.to(device), target.to(device),peptide_property.to(device)
            #print(target)
            #print(graph_data)
            optimizer.zero_grad()
            recon_batch, mu, logvar, final_output = hybrid_model(graph_data, sequence_data, target, peptide_property)
            #print('finaloutput')
            #print(final_output)
            loss = hybrid_loss_function(recon_batch, sequence_data, mu, logvar, final_output, target)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        train_loss /= len(train_loader_gb.dataset)
        train_losses.append(train_loss)

        # Validation loop
        hybrid_model.eval()
        val_loss = 0
        with torch.no_grad():
            for graph_data, sequence_data, target, peptide_property in val_loader_gb:
                graph_data = graph_data.to(device)
                sequence_data, target, peptide_property = sequence_data.to(device), target.to(device),peptide_property.to(device)

                recon_batch, mu, logvar, final_output = hybrid_model(graph_data, sequence_data,target, peptide_property)
                loss = hybrid_loss_function(recon_batch,sequence_data, mu, logvar, final_output, target)
                val_loss += loss.item()

        val_loss /= len(val_loader_gb.dataset)
        val_losses.append(val_loss)

        print(f"Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
        
        wandb.log({"train_loss": train_loss})
        wandb.log({"val_loss": val_loss})

    return train_losses, val_losses
# Plot the loss curves
def plot_losses(train_losses, val_losses):
    plt.figure(figsize=(10, 5))
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.title('Training and Validation Losses')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.show()



# Run the training
num_epochs = epochs  # Set the number of epochs
#train_losses, val_losses = train_model()




# Optionally, run the training loop
if __name__ == '__main__':
    train_losses, val_losses = train_model()
    plot_losses(train_losses, val_losses)
    print("DONE FINE-TUNING")
    
    
    
    
    
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score

fine_tuned.eval()  # Set the model to evaluation mode

true_targets = []
predicted_probs = []  # Store raw probabilities for ROC AUC calculation
predicted_labels = []  # Store binary predictions for other metrics

with torch.no_grad():
    for graph_data, sequence_data, target,peptide_property in test_loader_gb:
        graph_data = graph_data.to(device)
        sequence_data, target, peptide_property = sequence_data.to(device), target.to(device), peptide_property.to(device)

        _, _, _, final_output = fine_tuned(graph_data, sequence_data, target,peptide_property)

        # Convert to probabilities
        probs = torch.sigmoid(final_output).squeeze()

        # Handle the case where probs is a scalar
        if probs.ndim == 0:
            probs = probs.unsqueeze(0)  # Make it a 1-element tensor

        probs = probs.cpu().numpy()

        # Convert probabilities to binary predictions
        predicted = np.round(probs)

        true_targets.extend(target.cpu().numpy())
        predicted_probs.extend(probs.tolist())  # Convert to list before extending
        predicted_labels.extend(predicted)

# Calculate metrics
true_targets = np.array(true_targets)
predicted_probs = np.array(predicted_probs)
predicted_labels = np.array(predicted_labels)

accuracy = accuracy_score(true_targets, predicted_labels)
precision = precision_score(true_targets, predicted_labels)
recall = recall_score(true_targets, predicted_labels)
f1 = f1_score(true_targets, predicted_labels)
roc_auc = roc_auc_score(true_targets, predicted_probs)



# Log metrics to Weights & Biases
wandb.log({
    'Test Accuracy': accuracy,
    'Test Precision': precision,
    'Test Recall': recall,
    'Test F1 Score': f1,
    'Test ROC AUC': roc_auc
})


# Print the metrics
print('test_metrics')
print(f'Accuracy: {accuracy:.4f}')
print(f'Precision: {precision:.4f}')
print(f'Recall: {recall:.4f}')
print(f'F1 Score: {f1:.4f}')
print(f'ROC AUC: {roc_auc:.4f}')














true_targets = []
predicted_probs = []  # Store raw probabilities for ROC AUC calculation
predicted_labels = []  # Store binary predictions for other metrics

with torch.no_grad():
    for graph_data, sequence_data, target,peptide_property in train_loader_gb:
        graph_data = graph_data.to(device)
        sequence_data, target, peptide_property = sequence_data.to(device), target.to(device), peptide_property.to(device)

        _, _, _, final_output = fine_tuned(graph_data, sequence_data, target,peptide_property)

        # Convert to probabilities
        probs = torch.sigmoid(final_output).squeeze()

        # Handle the case where probs is a scalar
        if probs.ndim == 0:
            probs = probs.unsqueeze(0)  # Make it a 1-element tensor

        probs = probs.cpu().numpy()

        # Convert probabilities to binary predictions
        predicted = np.round(probs)

        true_targets.extend(target.cpu().numpy())
        predicted_probs.extend(probs.tolist())  # Convert to list before extending
        predicted_labels.extend(predicted)

# Calculate metrics
true_targets = np.array(true_targets)
predicted_probs = np.array(predicted_probs)
predicted_labels = np.array(predicted_labels)

accuracy = accuracy_score(true_targets, predicted_labels)
precision = precision_score(true_targets, predicted_labels)
recall = recall_score(true_targets, predicted_labels)
f1 = f1_score(true_targets, predicted_labels)
roc_auc = roc_auc_score(true_targets, predicted_probs)




# Log metrics to Weights & Biases
wandb.log({
    'Train Accuracy': accuracy,
    'Train Precision': precision,
    'Train Recall': recall,
    'Train F1 Score': f1,
    'Train ROC AUC': roc_auc
})


# Print the metrics
print('train_metrics')
print(f'Accuracy: {accuracy:.4f}')
print(f'Precision: {precision:.4f}')
print(f'Recall: {recall:.4f}')
print(f'F1 Score: {f1:.4f}')
print(f'ROC AUC: {roc_auc:.4f}')



