### Setup

### Data preprocessing
- remove all columns except: 'modified_sequence', 'precursor_charge', 'precursor_intensity'
- filter out unwanted charge states
- filter for most abundant charge state per sequence by precursor_intensity after normalizing the values
- filter sequence length according to occurance in dataset (currently less than 100 sequences of a certain length get removed)
- search for occurences of UNIMOD modifications and add them to the vocabulary
- generate continous sequence encoding // first layer - embedding layer
- generate precursor_charge one-hot encoding

### Preprocessing functions

# Dataset-Class

In [1]:
import re
import os
import random
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from tensorflow.keras.callbacks import ModelCheckpoint
from sklearn.utils import class_weight
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns
import wandb
from wandb.keras import WandbCallback
import timeit
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score, precision_score, recall_score, f1_score, ConfusionMatrixDisplay
import seaborn as sn

In [2]:
'''
File import for .parquet, .tsv and .csv files
At the moment a mix of .parquet, .tsv and .csv files will also be combined into one dataframe.
'''
def combine_files_into_df(dir_path='data/', file_types=['.parquet', '.tsv', '.csv'], columns_to_keep=['modified_sequence', 'precursor_charge', 'precursor_intensity']):
    dfs = []
    
    for file in os.listdir(dir_path):
        if any(file.endswith(file_type) for file_type in file_types):
            file_path = os.path.join(dir_path, file)
            
            if file.endswith('.parquet'):
                df = pd.read_parquet(file_path, engine='fastparquet')
            elif file.endswith('.tsv'):
                df = pd.read_csv(file_path, sep='\t')
            elif file.endswith('.csv'):
                df = pd.read_csv(file_path)
            else:
                continue  # Skip unsupported file types
            
            df = df[columns_to_keep]
            dfs.append(df)

    df = pd.concat(dfs, ignore_index=True)
    return df

In [148]:
class PrecursorChargeStateDataset:
    def __init__(self, classification_type="multi_class", model_type="embedding", charge_states=[1, 2, 3, 4, 5, 6], dir_path='data/', file_type='.parquet', columns_to_keep=['modified_sequence','precursor_charge','precursor_intensity'], test_ratio=0.1):
        
        
        ''' CHECK ALL INPUTS '''
       
        # check if classification_type is valid
        if isinstance(classification_type, str):
            if classification_type not in ["multi_class", "multi_label"]:
                raise ValueError("classification_type must be either 'multi_class' or 'multi_label'.")
            else:
                classification_type = classification_type.lower()
        else:
            raise TypeError("classification_type must be a string.")
        
        # check if model_type is valid
        if isinstance(model_type, str):
            if model_type not in ["embedding", "conv2d", "prosit"]:
                raise ValueError("model_type must be 'embedding', 'conv2d', 'prosit'.")
            else:
                model_type = model_type.lower()
        else:
            raise TypeError("model_type must be a string.")
        
        # check if classification_type and model_type are compatible
        if classification_type == "multi_class" and not model_type in ["embedding", "conv2d", "prosit"]:
            raise ValueError("classification_type and model_type are not compatible.")
        elif classification_type in ["multi_label", "multi_head"] and not model_type in ["embedding"]:
            raise ValueError("classification_type and model_type are not compatible.")
            
        
        
        # check if charge states correct:
        if isinstance(charge_states, list):
            if not all(isinstance(item, int) for item in charge_states):
                raise ValueError("charge_states must be a list of integers.")
        else:
            raise TypeError("charge_states must be a list.")
        
        # check dir_path
        if isinstance(dir_path, str):
            if not os.path.isdir(dir_path):
                raise ValueError("dir_path must be a valid directory. Is not: {}".format(dir_path))
        else:
            raise TypeError("dir_path must be a string.")
        
        # check file_type
        if isinstance(file_type, str):
            if not file_type.startswith("."):
                file_type = "." + file_type
        else:
            raise TypeError("file_type must be a string.")
        
        # check columns_to_keep
        if isinstance(columns_to_keep, list):
            if not all(isinstance(item, str) for item in columns_to_keep):
                raise ValueError("columns_to_keep must be a list of strings. In Order: 'modified_sequence', 'precursor_charge', 'precursor_intensity'.")
        else:
            raise TypeError("columns_to_keep must be a list.")
        
        
        '''
        Combine files into one dataframe and keep only desired columns
        Default: 
        - dir_path = 'data/'
        - file_type = '.parquet'
        Default: drop everything except: modified_sequence, precursor_charge and precursor_intensity
        - columns_to_keep = ['modified_sequence','precursor_charge','precursor_intensity']
        '''
        def combine_parquet_into_df(dir_path='data/', file_type='.parquet', columns_to_keep=['modified_sequence','precursor_charge','precursor_intensity']):
            dfs = [] 
            for file in os.listdir(dir_path):
                if file.endswith(file_type):
                    file_path = os.path.join(dir_path, file)
                    df = pd.read_parquet(file_path, engine='fastparquet')
                    df = df[columns_to_keep]
                    dfs.append(df)
        
            df = pd.concat(dfs, ignore_index=True)
            print(f"Step 1/12 complete. Combined {len(dfs)} files into one DataFrame.")
            return df
        
        '''
        Drop all rows with NaN values in a specific column
        Default: drop na from precursor_intensity column
        '''
        def drop_na(df, column='precursor_intensity'):
            df = df[df[column].notna()]
            print(f"Step 2/12 complete. Dropped rows with NaN for intensities.")
            return df
        
        '''
        Keep only desired charge entires
        Default: keep charges 1-6
        '''
        def keep_desired_charges(df, charge_list=[1, 2, 3, 4, 5, 6]):
            df = df[df['precursor_charge'].isin(charge_list)]
            print(f"Step 3/12 complete. Removed charge states not in {charge_list}.")
            return df
        
        '''
        Find all UNIMOD annotations and add them to the vocabulary
        (The length of the vocabulary +1 is used later for the embedding layer)
        '''
        def complete_vocabulary(df):
            """
            Completes the vocabulary with all the possible amino acids and modifications
            :return: list
            """
            vocabulary = []
            vocabulary+=list('XACDEFGHIKLMNPQRSTVWY')
            annotations = re.findall(r'(\w\[UNIMOD:\d+\])', ' '.join(df['modified_sequence']))
            for item in annotations:
                if item not in vocabulary:
                        vocabulary.append(item)
            vocab_len = len(vocabulary)
            print(f"Step 6/12 complete. Completed vocabulary with {vocab_len} entries.")
            return vocabulary, vocab_len
            
        '''
        Combine unique sequences and aggregate their precursor_charges and intensity in order to later select the most abundant charge state per sequence.
        '''
        def aggregate_sequences(df):
            df = df.groupby("modified_sequence")[["precursor_charge", "precursor_intensity"]].agg(list).reset_index()
            print(f"Step 4/12 complete. Aggregated all sequences to unique sequences.")
            return df
        
        # TODO: description
        '''
        Normalize precursor intensities for aggregated sequences
        '''
        def normalize_precursor_intensities(df_charge_list, df_intensity_list):
            # Get the index of the most abundant precursor intensity
            charge_dict = dict()
            for index, i in enumerate(df_charge_list):
                charge_dict[i] = []
                charge_dict[i].append(df_intensity_list[index])
        
            # Normalize the precursor intensity based on the most abundant precursor intensity
            for key, value in charge_dict.items():
                if len(value) > 1:
                    charge_dict[key] = sum(value) - min(value) / (max(value) - min(value))
        
            # convert list of one float to float values
            charge_dict = {key: value[0] for key, value in charge_dict.items()}
            return charge_dict
        
        # TODO: description
        '''
        Select most abundand charge state per unique sequence according to the normalized precursor intensity
        '''
        def get_most_abundant(df_charge_list, df_intensity_list, distributions=False):
            charge_dict = dict()
            for index, i in enumerate(df_charge_list):
                if i not in charge_dict:
                    charge_dict[i] = df_intensity_list[index]
                else:
                    charge_dict[i] += df_intensity_list[index]
            if distributions:
                return charge_dict
            else:
                return max(charge_dict, key=charge_dict.get)
            
            
        '''
        One-Hot encode most abundand charge state
        input: df with "most_abundance_charge" column
        output: new column "most_abundant_charge_vector" containing one-hot encoded vector
        '''
        def one_hot_encode_charge(df, charge_list=[1, 2, 3, 4, 5, 6]):
            df['most_abundant_charge_vector'] = df['most_abundant_charge'].apply(lambda x: [1 if x == i else 0 for i in charge_list])
            return df
        
        # TODO: description
        '''
        Applying normalization, selecting most abundant charge state and one-hot encoding
        '''
        def normalize_and_select_most_abundant(df):
            df['normalized'] = df.apply(lambda x: normalize_precursor_intensities(x["precursor_charge"], x["precursor_intensity"]), axis=1)
            df['pre_normalization'] = df.apply(lambda x: get_most_abundant(x["precursor_charge"], x["precursor_intensity"], True), axis=1)
            df['most_abundant_charge'] = df['normalized'].apply(lambda x: max(x, key=x.get))
            df = one_hot_encode_charge(df)
            print(f"Step 8/12 complete. Applied normalization, selected most abundant charge state and one-hot encoded it.")
            return df
        
        '''
        get topK charge states for each sequence according to the normalized precursor intensity
        
        input: df with "normalized" column
        output: new column "topK_charge_states" containing list of topK charge states
        
        default: k=2
        '''
        def get_topK_charge_states(df, k=2):
            def get_topK(label_dict):
                allowed_keys = list()
                sorted_values = sorted(label_dict.values(), reverse=True)
                for i in sorted_values:
                    for key, value in label_dict.items():
                        if i == value and len(allowed_keys) <= k-1:
                            allowed_keys.append(key)
                return allowed_keys
        
            df[f'top_{k}_charge_states'] = df['normalized'].apply(get_topK)
            print(f"Step 11/12 complete. Selected top {k} charge states per sequence.")
            return df
        
        '''
        Remove sequences of specific length represented less than a certain number of times
        
        input: df containig "modified_sequence" column, representation_threshold
        output: 
        - df containing only sequence legths represented more than representation_threshold times
        - padding_length
        default: representation_threshold = 100
        
        Calculate the sequence lengths and their counts
        Filter out sequences with counts below the threshold
        Filter the original DataFrame based on sequence length
        Drop the temporary column
        '''
        def remove_rare_sequence_lengths(df, representation_threshold=100):
            before_len = len(df)
            df['sequence_length_prepadding'] = df['modified_sequence'].apply(len)
            len_counts = df['sequence_length_prepadding'].value_counts().reset_index()
            len_counts.columns = ['seq_len', 'count']
            filtered_lengths = len_counts[len_counts['count'] >= representation_threshold]['seq_len']
            df = df[df['sequence_length_prepadding'].isin(filtered_lengths)].copy()
            padding_length = df['sequence_length_prepadding'].max()
            df = df[df['sequence_length_prepadding'].isin(filtered_lengths)]
            after_len = len(df)
            print(f"Step 5/12 complete. Removed {before_len - after_len} of {before_len} sequences if sequence-length is represented less than {representation_threshold} times.")
            return df, padding_length
        
        '''
        Encode all occuring charge states per unique sequence in a binary vector
        
        input: df containing "precursor_charge" column
        output: df containing an additional "charge_state_vector" column encoding all occuring charge states per unique sequence in a binary vector
        '''
        def encode_charge_states(df):
            df['charge_state_vector'] = df['precursor_charge'].apply(lambda x: [1 if i in x else 0 for i in range(1,7)])
            print(f"Step 9/12 complete. Encoded all occuring charge states per unique sequence in a binary vector.")
            return df
        
        '''
        Checks if a vector contains only continous charge states e.g. [1,1,1,0,0,0]
        Flase if a vector contains skipped charges e.g. [1,0,0,0,0,1]
        
        input: charge_state_vector
        output: True if no charge state is skipped, False if a charge state is skipped
        '''
        def has_skipped_charges(charge_state_vector):
            was_found = False
            was_concluded = False
            for i in charge_state_vector:
                if i == 1 and not was_found:
                    was_found = True
                if i == 0 and was_found:
                    was_concluded = True
                if i == 1 and was_concluded:
                    return True
            return False
        
        '''
        Filter out all sequences where has_skipped_charges() returns True
        
        input: df containing "charge_state_vector" column
        output: df containing only sequences where has_skipped_charges() returns False
        '''
        def filter_skipped_charges(df):
            return df[df['charge_state_vector'].apply(lambda x: not has_skipped_charges(x))]
        
        '''
        Removes sequences with skipped charges that occur less than a certain number of times
        
        input: df containing "charge_state_vector" column, cutoff
        output: df containing only sequences with skipped charges that occur more than cutoff times
        default: cutoff = 1000
        '''
        def skip_charges_for_occurrences(df, cutoff = 1000):
            list_k = []
            list_v = []
            drop_out_index = []
            for index, i in enumerate(df['charge_state_vector'].value_counts()):
                list_k.append(df['charge_state_vector'].value_counts().index[index])
                list_v.append(i)
                if  has_skipped_charges(df['charge_state_vector'].value_counts().index[index]) and list_v[index] < cutoff:
                    drop_out_index.append(index)
                    
            drop_out_list = []
            for i in drop_out_index:
                drop_out_list.append(list_k[i])
            df_out = df[~df['charge_state_vector'].isin(drop_out_list)]
            print(f"Step 10/12 complete. Removed {len(df) - len(df_out)} of {len(df)} sequences if unique charge state distribution is represented less than {cutoff} times.")
            return df_out    
                    
        """
        Encodes the 'modified_sequence' column in a DataFrame and adds a new column 'modified_sequence_vector'.
        
        input: df containing "modified_sequence" column, vocabulary, padding_length
        output: df containing "modified_sequence_vector" column with padded and encoded sequences
        
        defaults: padding_length = 50
        """
        def sequence_encoder(df, padding_length=50, vocabulary=None):
            
            if 'modified_sequence' not in df.columns:
                raise ValueError("DataFrame must contain a 'modified_sequence' column.")
        
            aa_dictionary = {aa: index for index, aa in enumerate(vocabulary)}
        
            def encode_sequence(sequence):
                pattern = r'[A-Z]\[[^\]]*\]|.'
                result = [match for match in re.findall(pattern, sequence)]
                result += ['X'] * (padding_length - len(result))
                return [aa_dictionary.get(aa, aa_dictionary['X']) for aa in result]
        
            df['modified_sequence_vector'] = df['modified_sequence'].apply(encode_sequence)
            print(f"Step 7/12 complete. Encoded all sequences.")
            return df
        
        '''
        Generate overview plot for precursor_charge distribution in combined dataset
        '''
        def plot_most_abundant_charge_distribution(df):
            # plot the distirbution of precursor_charge for the whole dataset
            sns.set_theme(style="darkgrid")
            sns.set_context("paper")
            ax = sns.countplot(x='most_abundant_charge', data=df, palette="viridis")
            plt.xlabel('Precursor Charge')
            plt.ylabel('Count')
            plt.title('Distribution of Precursor Charge')
            # add percentage of each charge state to the plot
            total = len(df['most_abundant_charge'])
            for p in ax.patches:
                percentage = '{:.1f}%'.format(100 * p.get_height()/total)
                x = p.get_x() + p.get_width() / 2 - 0.05
                y = p.get_y() + p.get_height() + 5
                ax.annotate(percentage, (x, y))
            plt.show()
            
        def plot_topK_charge_distribution(df, column_name='top_2_charge_states'):
            charge_state_counter = {
                1: 0,
                2: 0,
                3: 0,
                4: 0,
                5: 0,
                6: 0
            }
            
            for row in df[column_name]:
                for k in row:
                    charge_state_counter[k] = charge_state_counter[k] + 1
            sns.set_theme(style="darkgrid")
            sns.set_context("paper")
            palette = sns.color_palette("viridis", len(charge_state_counter))
            plt.figure(figsize=(8, 6))
            plt.bar(range(len(charge_state_counter)), list(charge_state_counter.values()), align='center', color=palette)
            plt.xticks(range(len(charge_state_counter)), list(charge_state_counter.keys()))
            plt.xlabel('Charge State')
            plt.ylabel('Count')
            plt.title('Charge State Distribution')
            plt.tight_layout()
            
            total = len(df['most_abundant_charge'])
            for p in plt.gca().patches:
                height = p.get_height()
                plt.gca().text(p.get_x() + p.get_width()/2., height + 3, '{:.1f}%'.format(height/total*100), ha='center', fontsize=12)
            
            plt.show()
            
        self.dir_path = dir_path
        self.file_type = file_type
        
        self.charge_states = charge_states
        self.num_classes = len(self.charge_states)
        
        self.classification_types = ['multi_class', 'multi_label']
        self.classification_type = classification_type
        
        self.model_types = ['embedding', 'conv2d', 'prosit']
        self.model_type = model_type
        
        self.df = combine_parquet_into_df(dir_path, file_type)
        self.df = drop_na(self.df, 'precursor_intensity')
        self.df = keep_desired_charges(self.df)
        self.df = aggregate_sequences(self.df)
        self.df, self.padding_length = remove_rare_sequence_lengths(self.df)
        self.vocabulary, self.voc_len = complete_vocabulary(self.df)
        self.df = sequence_encoder(self.df, self.padding_length, self.vocabulary)
        self.df = normalize_and_select_most_abundant(self.df)
        self.df = encode_charge_states(self.df)
        self.df = skip_charges_for_occurrences(self.df)
        self.df = get_topK_charge_states(self.df)
        print(f"Step 12/12 complete. Generated dataset with {len(self.df)} sequences.")
        if self.classification_type == 'multi_class':
            self.df = self.df[['modified_sequence_vector', 'most_abundant_charge_vector', 'top_2_charge_states']]
        elif self.classification_type == 'multi_label':
            self.df = self.df[['modified_sequence_vector', 'charge_state_vector', 'top_2_charge_states']]
        else:
            raise ValueError("classification_type must be one of the following: 'multi_class', 'multi_label'")
        
        if self.classification_type == "multi_class":
            if model_type == "embedding":
                self.data_type = "tensor"
            elif model_type == "conv2d":
                self.data_type = "2d_tensor"
            elif model_type == "prosit":
                self.data_type = "tensor"
            else:
                raise ValueError("model_type must be one of the following: 'embedding', 'conv2d', 'prosit'")
        elif self.classification_type == "multi_label":
            self.data_type = "tensor_multi_label"
        else:
            raise ValueError("classification_type must be one of the following: 'multi_class', 'multi_label'")
        
        
        self.validation_ratio = 0.2
        self.test_mode = True
        if test_ratio > 0:
            self.test_ratio = test_ratio
            self.df_test = self.df.sample(frac = self.test_ratio)
            self.test_mode = True
        else:
            self.df_test = pd.DataFrame()
            self.test_mode = False
        self.training_validation_df = self.df.drop(self.df_test.index)
        self.training_validation_split = StratifiedShuffleSplit(n_splits=1, test_size=self.validation_ratio)
        
        def create_training_validation_split(df = self.training_validation_df, sssplit = self.training_validation_split):
            trainval_ds_embed = np.array(df['modified_sequence_vector']) # TODO
            if self.data_type == "tensor_multi_label":
                trainval_labels_embed = np.array(df['charge_state_vector'])
            else:
                trainval_labels_embed = np.array(df['most_abundant_charge_vector'])
            # Perform the split train and val
            train_indicies_embed, val_indicies_embed = next(sssplit.split(trainval_ds_embed, trainval_labels_embed))
            # Distribution
            train_ds_embed, train_labels_embed = trainval_ds_embed[train_indicies_embed], trainval_labels_embed[train_indicies_embed]
            val_ds_embed, val_labels_embed = trainval_ds_embed[val_indicies_embed], trainval_labels_embed[val_indicies_embed]
            # create two dataframes for training and validation
            if self.data_type == "tensor_multi_label":
                df_train = pd.DataFrame({'modified_sequence_vector': train_ds_embed, 'charge_state_vector': train_labels_embed})
                df_val = pd.DataFrame({'modified_sequence_vector': val_ds_embed, 'charge_state_vector': val_labels_embed})
            else:
                df_train = pd.DataFrame({'modified_sequence_vector': train_ds_embed, 'most_abundant_charge_vector': train_labels_embed})
                df_val = pd.DataFrame({'modified_sequence_vector': val_ds_embed, 'most_abundant_charge_vector': val_labels_embed})
            return df_train, df_val
        self.df_train, self.df_val = create_training_validation_split(self.training_validation_df, self.training_validation_split)
                    
        def to_array(df, multi_label=False): 
            #print(df.head(4))
            if multi_label:
                label = [np.array(x) for x in df['charge_state_vector']]
                data = [np.array(x) for x in df['modified_sequence_vector']]
            else:
                label = [np.array(x) for x in df['most_abundant_charge_vector']]
                data = [np.array(x) for x in df['modified_sequence_vector']]
            return label, data
        def to_tensor(df, multi_label=False):
            label, data = to_array(df, multi_label)
            label = tf.convert_to_tensor(label)
            data = tf.convert_to_tensor(data)
            return label, data
        def to_2d_tensor(df):
            label, data = to_array(df)
            label = tf.convert_to_tensor(label)
            data = [np.reshape(np.array(x), (1, self.padding_length, 1)) for x in data]
            return label, data
        
        if self.data_type == "array":
            self.test_label, self.test_data = to_array(self.df_test)
            self.train_label, self.train_data = to_array(self.df_train)
            self.val_label, self.val_data = to_array(self.df_val)
        elif self.data_type == "tensor":
            self.test_label, self.test_data = to_tensor(self.df_test)
            self.train_label, self.train_data = to_tensor(self.df_train)
            self.val_label, self.val_data = to_tensor(self.df_val)
        elif self.data_type == "tensor_multi_label":
            self.test_label, self.test_data = to_tensor(self.df_test, True)
            self.train_label, self.train_data = to_tensor(self.df_train, True)
            self.val_label, self.val_data = to_tensor(self.df_val, True)
        elif self.data_type == "2d_tensor":
            self.test_label, self.test_data = to_2d_tensor(self.df_test)
            self.train_label, self.train_data = to_2d_tensor(self.df_train)
            self.val_label, self.val_data = to_2d_tensor(self.df_val)                
            

In [149]:
my_dataset = PrecursorChargeStateDataset(classification_type="multi_class", model_type="embedding", charge_states=[1, 2, 3, 4, 5, 6], dir_path='data/', file_type='.parquet', columns_to_keep=['modified_sequence','precursor_charge','precursor_intensity'])
# test

Step 1/12 complete. Combined 12 files into one DataFrame.
Step 2/12 complete. Dropped rows with NaN for intensities.
Step 3/12 complete. Removed charge states not in [1, 2, 3, 4, 5, 6].
Step 4/12 complete. Aggregated all sequences to unique sequences.
Step 5/12 complete. Removed 857 of 831677 sequences if sequence-length is represented less than 100 times.
Step 6/12 complete. Completed vocabulary with 23 entries.
Step 7/12 complete. Encoded all sequences.
Step 8/12 complete. Applied normalization, selected most abundant charge state and one-hot encoded it.
Step 9/12 complete. Encoded all occuring charge states per unique sequence in a binary vector.
Step 10/12 complete. Removed 728 of 830820 sequences if unique charge state distribution is represented less than 1000 times.
Step 11/12 complete. Selected top 2 charge states per sequence.
Step 12/12 complete. Generated dataset with 830092 sequences.



# Models

General idea:
Input: mod_seq_encoded, precursor_charge // precursor_charge_onehot
Output: 5 nodes --> highest value == most probable charge for input sequence
Use: Softmax, Crossentropy loss

stratified split:
- PROSITE
- CCE
- SCCE // ?

evaluate models by:
- categorical accuracy
- f1 score // ?

In [188]:
# imports optimized
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.layers import Input, Dense, Embedding, Flatten, Dropout, Bidirectional, GRU, Conv2D, Lambda
from tensorflow.keras.models import Model
from dlomix.layers.attention import AttentionLayer, DecoderAttentionLayer
import subprocess
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score, precision_score, recall_score, f1_score, ConfusionMatrixDisplay
import seaborn as sn


class ModelClass:
    def __init__(self, dataset):
        
        self.predicted = None
        self.prediction = None
        self.evaluated = None
        self.evaluation = None
        self.metrics = None
        self.loss = None
        self.history = None
        self.dataset = dataset
        self.num_classes = dataset.num_classes
        self.voc_len = dataset.voc_len
        self.max_len_seq = dataset.padding_length
        self.model_type = dataset.model_type
        self.classification_type = dataset.classification_type
        self.shape = my_dataset.train_data[0].shape
        self.wandb = False
        self.compiled = False
        self.fitted = False
        self.pretrained = False
                       
        if self.model_type == "embedding":
            self.model = self.embedding_model()
        elif self.model_type == "conv2d":
            self.model = self.conv2d_model()
        elif self.model_type == "prosit":
            self.model = self.prosit_model()
        elif self.model_type == "multihead":
            self.model = self.multihead_model()
        elif self.model_type == "multilabel":
            self.model = self.multilabel_model()
        else:
            raise ValueError("model_type must be one of the following: 'embedding', 'conv2d', 'prosit', 'multihead', 'multilabel'")
       

        
    def prosit_model(self):
        input_prosit = Input(shape=self.shape)
        x = Model(inputs=input_prosit, outputs=input_prosit)
        # Embedding, no vocabulary
        y = Embedding(input_dim=self.voc_len, output_dim=self.max_len_seq, input_length=self.max_len_seq)(input_prosit)
        # Encoder
        y = Bidirectional(GRU(256, return_sequences=True))(y)
        y = Dropout(0.5)(y)
        y = GRU(512, return_sequences=True)(y)
        y = Dropout(0.5)(y)
        # Attention
        y = AttentionLayer()(y)
        # Regressor
        y = Dense(512, activation="relu")(y)
        y = Dropout(0.1)(y)
        # Output
        out = Dense(self.num_classes, activation="softmax")(y)
        model_prosit = Model(inputs=[x.input], outputs=out)
        return model_prosit
    
    def conv2d_model(self):
        input_convolution = Input(shape=self.shape)
        x = Model(inputs=input_convolution, outputs=input_convolution)
        y = Rescaling(scale=1./100)(input_convolution)
        y = Conv2D(filters=128, kernel_size=(1,3), strides=1, activation="relu", padding='same')(y)
        y = Flatten()(y)
        y = Dense(210, activation="relu")(y)
        z = Dense(self.num_classes, activation="softmax")(y)
        model_convolution = Model(inputs=[x.input], outputs=z)
        return model_convolution
    
    def embedding_model(self):
        input_embedding = Input(shape=self.shape)
        # the first branch operates on the first input
        x = Model(inputs=input_embedding, outputs=input_embedding)
        y = Embedding(input_dim=self.voc_len, output_dim=self.max_len_seq, input_length=self.max_len_seq)(input_embedding)
        y = Flatten()(y)
        y = Dense(self.max_len_seq, activation="relu")(y)
        z = Dense(self.num_classes, activation="softmax")(y)
        model_embed = Model(inputs=[x.input], outputs=z)
        return model_embed
    
    def multihead_model(self):
        input_multihead = Input(shape=self.shape)
        x = Model(inputs=input_multihead, outputs=input_multihead)
        y = Embedding(input_dim=self.voc_len, output_dim=self.max_len_seq, input_length=self.max_len_seq)(input_multihead)
        y = Flatten()(y)
        branch_outputs = []
        for i in range(6):
            out = Lambda(lambda x: x[:, i:i+1])(y)   
            out = Dense(2, activation="sigmoid")(out)
            branch_outputs.append(out)
        model_multihead = Model(inputs=[x.input], outputs=branch_outputs)
        return model_multihead
        
    def multilabel_model(self):
        input_multilabel = Input(shape=self.shape)
        x = Model(inputs=input_multilabel, outputs=input_multilabel)
        y = Embedding(input_dim=self.voc_len, output_dim=self.max_len_seq, input_length=self.max_len_seq)(input_multilabel)
        y = Flatten()(y)
        y = Dense(self.max_len_seq, activation="relu")(y)
        z = Dense(self.num_classes, activation="sigmoid")(y)
        model_multilabel = Model(inputs=[x.input], outputs=z)
        return model_multilabel
    
    def summary(self):
        self.model.summary()
        
    
    def wandb_init(self, api_key = "4e8d3dcb1584ad129b3b49ccc34f65b20116ae54", project_name = "precursor-charge-state-prediction" ): # TODO DELETE PRIVATE KEY
        subprocess.call(['wandb', 'login', api_key])
        wandb.init(project=project_name)
        config = wandb.config
        config.model_type= self.model_type
        config.classification_type= self.classification_type
        config.num_classes= self.num_classes
        config.voc_len= self.voc_len
        config.max_len_seq= self.max_len_seq
        self.wandb = True
        
    def compile(self, lr=0.0001):
        if self.classification_type == "multi_class":
            if self.model == "prosit":
                self.model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['categorical_accuracy'])
                self.loss = 'categorical_crossentropy'
                self.metrics = 'categorical_accuracy'
            else:
                self.model.compile(optimizer=keras.optimizers.Adam(learning_rate=lr), loss='categorical_crossentropy', metrics=['categorical_accuracy'])
                self.loss = 'categorical_crossentropy'
                self.metrics = 'categorical_accuracy'
                
        elif self.classification_type == "multi_label":
            self.model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['binary_accuracy'])
            self.loss = 'binary_crossentropy'
            self.metrics = 'binary_accuracy'
        else:
            raise ValueError("classification_type must be one of the following: 'multi_class', 'multi_label'")
        self.compiled = True
        
    def fit(self, batch_size=4096, callbacks=None, epochs=30, no_wandb = False):
        if not self.compiled:
            raise ValueError("Model must be compiled before fitting. Use model_class.compile().")
        elif not self.wandb and no_wandb == False:
                raise ValueError("You did not initialize weights&biases. Set model_class.init(no_wandb=True) or use model_class.wandb_init(api_key= '...', project_name = '...')")
        else:
            if callbacks is None:
                if no_wandb:
                    callbacks = []
                else:
                    callbacks = [WandbCallback()]
            #print(self.shape)
            #print(callbacks, len(self.dataset.train_data), len(self.dataset.train_label), len(self.dataset.val_data), len(self.dataset.val_label))
            self.history = self.model.fit(self.dataset.train_data, self.dataset.train_label, epochs=epochs, batch_size=batch_size, validation_data=(self.dataset.val_data, self.dataset.val_label), callbacks=callbacks, verbose=1)
            
            self.fitted = True
            
    def plot_training(self):
        if self.fitted:
            # Access the loss, validation loss, and accuracy from the history object
            loss = self.history['loss']
            val_loss = self.history['val_loss']
            accuracy = self.history[self.metrics]
            val_accuracy = history.history[self.loss]
            
            # Plot the loss, validation loss, and accuracy curves
            epochs = range(1, len(loss) + 1)
            
            # Create subplots
            fig2, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
            
            # Plot loss and validation loss
            ax1.plot(epochs, loss, 'b', label='Training Loss')
            ax1.plot(epochs, val_loss, 'r', label='Validation Loss')
            ax1.set_title('Training and Validation Loss')
            ax1.set_xlabel('Epochs')
            ax1.set_ylabel('Loss')
            ax1.legend()
            
            # Plot accuracy and validation accuracy
            ax2.plot(epochs, accuracy, 'b', label='Training Accuracy')
            ax2.plot(epochs, val_accuracy, 'r', label='Validation Accuracy')
            ax2.set_title('Training and Validation Accuracy')
            ax2.set_xlabel('Epochs')
            ax2.set_ylabel('Accuracy')
            ax2.legend()
            
            # Adjust spacing between subplots
            plt.tight_layout()
            
            # Show the plots
            plt.show()
        else:
            raise ValueError("Model was not trained. No data to plot. Use model_class.fit()")
    
    def load_weights(self, path):
        self.model = tf.keras.saving.load_model(path)
        self.pretrained = True
            
                
    def evaluate(self, test_data=None, test_label=None, test_mode=False):
        if not self.fitted:
            if not test_mode:
                if test_data is None or test_label is None:
                    raise ValueError("You did not provide test_data and test_label. Use model_class.evaluate(test_data, test_label) or set apply test_ratio>0 to model_class")
                else:
                    self.evaluation = self.model.evaluate(test_data, test_label)
                    self.evaluated = True
            else:
                self.evaluation = self.model.evaluate(self.dataset.test_data, self.dataset.test_label)
                self.evaluated = True
        else:
            if self.pretrained:
                if test_data is None or test_label is None:
                    raise ValueError("You did not provide test_data and test_label. Use model_class.evaluate(test_data, test_label) or set apply test_ratio>0 to model_class")
                else:
                    self.evaluation = self.model.evaluate(test_data, test_label)
                    self.evaluated = True
            else:
                self.evaluation = self.model.evaluate(self.dataset.test_data, self.dataset.test_label)
                self.evaluated = True
            
        print(f"test loss, test acc: {self.evaluation}")
            
    def predict(self, test_data=None, test_label=None, test_mode=False, no_verification = False):
        if not self.fitted:
            if not test_mode:
                if test_data is None:
                    raise ValueError("You did not provide test_data and test_label. Use model_class.evaluate(test_data, test_label) or set apply test_ratio>0 to model_class")
                else:
                    self.prediction = self.model.predict(test_data)
                    self.predicted = True
            else:
                self.prediction = self.model.predict(test_data)
                self.predicted = True
        else:
            if self.pretrained:
                if test_data is None:
                    raise ValueError("You did not provide test_data and test_label. Use model_class.predict(test_data, test_label) or set apply test_ratio>0 to model_class")
                else:
                    self.prediction = self.model.predict(test_data)
                    self.predicted = True
            else:
                test_data = self.dataset.test_data
                test_label = self.dataset.test_label
                self.prediction = self.model.predict(test_data)
                self.predicted = True
        
        if not no_verification:
            if self.classification_type == "multi_class":
                if test_label is None:
                    raise ValueError("You did not provide test_label for prediction-verification.")
                else:
                    predicted_labels = np.argmax(self.prediction, axis=1)
                    true_labels = np.argmax(test_label, axis=1)
                        
                    cm = confusion_matrix(true_labels, predicted_labels)
                    print("Accuracy: ", accuracy_score(true_labels, predicted_labels))
                    # TODO calculate for score for each class (and/or average)
                    # TODO lookup logic for weighted, macro etc. -> presentation
                    print("Precision_weighted: ", precision_score(true_labels, predicted_labels, average='weighted'))
                    print("Recall_weighted: ", recall_score(true_labels, predicted_labels, average='weighted'))
                    print("F1_weighted: ", f1_score(true_labels, predicted_labels, average='weighted'))
                    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=self.dataset.charge_states)
                    disp.plot(cmap=plt.cm.Blues, xticks_rotation=45)
                    # add legend title and axis labels
                    plt.xlabel('Predicted Label')
                    plt.ylabel('True Label')
                    plt.title('Confusion Matrix')
                    # plt.colorbar(label="Number of Samples")
                    plt.show()
                    
                    new_df = pd.DataFrame()
                    new_df['charge'] = [1,2,3,4,5,6]
                    new_df['precision'] = precision_score(true_labels, predicted_labels, average=None)
                    new_df['recall'] = recall_score(true_labels, predicted_labels, average=None)
                    new_df['f1'] = f1_score(true_labels, predicted_labels, average=None)
                    print(new_df)
                   
            else:
                raise ValueError("Not implemented for multi-class.")    
    

In [194]:
model_class = ModelClass(my_dataset)
model_class.summary()

In [195]:
model_class.compile() # TODO callback wandb etc.

In [175]:
model_class.wandb_init()

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016666666666666666, max=1.0…

Problem at: C:\Users\micro\OneDrive\Dokumente\GitHub\BachelorThesis\venv\Lib\site-packages\wandb\sdk\wandb_init.py 837 getcaller


KeyboardInterrupt: 

In [196]:
model_class.fit(epochs=1, no_wandb=True)

In [165]:
model_class.evaluate()

test loss, test acc: [0.9858239889144897, 0.6071991920471191]


In [197]:
model_class.predict()

### Anaylsis / Plots etc.

In [None]:
# eval

## Postprocessing

In [None]:
def generate_charge_prediction_text(charge_predictions):
    max_charge_index = np.argmax(charge_predictions)
    max_charge_value = round(charge_predictions[max_charge_index], 2)

    charge_text = f"The predicted charge state for the input sequence is {max_charge_index+1} [{round(max_charge_value*100,2)}%]."
    percentage_text = "Prediction percentages for other states:\n"

    for index, prediction in enumerate(charge_predictions):
        if index != max_charge_index:
            percentage = round(prediction * 100, 2)
            percentage_text += f"Charge state {index+1}: {percentage}%\n"

    full_text = charge_text + "\n" + percentage_text
    return full_text


# Beispiel
charge_predictions = np.array([0, 0.3, 0.53, 0.17, 0, 0])
output_text = generate_charge_prediction_text(charge_predictions)
print(output_text)