In [None]:
import os
import random
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, LSTM, GRU
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.initializers import RandomNormal
from tensorflow.keras.callbacks import ModelCheckpoint
from sklearn.model_selection import KFold
from tqdm import tqdm
import matplotlib.pyplot as plt

# Assuming these are utility functions from LSTM_peptides.py
from sample import _onehotencode, _sample_with_temp, _save_flags


In [None]:
class SequenceHandler:
    def __init__(self, window=0, step=1, refs=True):
        self.window = window
        self.step = step
        self.refs = refs
        self.sequences = []
        self.vocab = None
        self.X = None
        self.y = None
        self.generated = None

    def load_sequences(self, infile):
        with open(infile, "r") as f:
            self.sequences = [line.strip() for line in f.readlines()]
        self.vocab = sorted(set("".join(self.sequences)))

    def analyze_training(self):
        print("Number of sequences loaded: %d" % len(self.sequences))
        print("Vocabulary: ", self.vocab)

    def pad_sequences(self, padlen=0):
        maxlen = max([len(seq) for seq in self.sequences])
        padlen = padlen if padlen > 0 else maxlen
        self.sequences = [seq.ljust(padlen) for seq in self.sequences]

    def one_hot_encode(self, target='all'):
        self.X = np.array([_onehotencode(seq)[0] for seq in self.sequences])
        self.y = np.array([_onehotencode(seq)[1] for seq in self.sequences])

    def analyze_generated(self, num, fname, plot=True):
        with open(fname, 'w') as f:
            f.write("\n".join(self.generated))

    def save_generated(self, logdir, filename):
        with open(filename, 'w') as f:
            f.write("\n".join(self.generated))


class Model:
    def __init__(self, n_vocab, outshape, session_name, n_units=64, batch=128, layers=2, cell="LSTM",
                 loss='categorical_crossentropy', lr=0.01, dropoutfract=0.1, l2_reg=None, ask=False, seed=42):
        self.n_vocab = n_vocab
        self.outshape = outshape
        self.session_name = session_name
        self.n_units = n_units
        self.batch = batch
        self.layers = layers
        self.cell = cell
        self.loss = loss
        self.lr = lr
        self.dropoutfract = dropoutfract
        self.l2_reg = l2_reg
        self.ask = ask
        self.seed = seed
        self.model = None
        self.logdir = "./" + self.session_name

        self._build_model()

    def _build_model(self):
        if self.cell == "LSTM":
            cell_type = LSTM
        else:
            cell_type = GRU

        model = Sequential()
        model.add(cell_type(self.n_units, input_shape=(None, self.n_vocab), return_sequences=(self.layers > 1)))

        for _ in range(1, self.layers):
            model.add(cell_type(self.n_units, return_sequences=(_ < self.layers - 1)))

        model.add(Dense(self.outshape, activation="softmax"))

        optimizer = Adam(learning_rate=self.lr)
        model.compile(optimizer=optimizer, loss=self.loss)

        self.model = model

    def train(self, X, y, epochs=100, valsplit=0.2, sample=0):
        checkpoint = ModelCheckpoint(filepath=self.logdir + '/checkpoint/model_epoch_{epoch:02d}.hdf5', save_best_only=True)
        self.model.fit(X, y, batch_size=self.batch, epochs=epochs, validation_split=valsplit, callbacks=[checkpoint])

    def plot_losses(self):
        # Assuming you want to plot the training loss
        plt.plot(self.model.history.history['loss'])
        plt.plot(self.model.history.history['val_loss'])
        plt.title('Model loss')
        plt.ylabel('Loss')
        plt.xlabel('Epoch')
        plt.legend(['Train', 'Validation'], loc='upper left')
        plt.show()

    def sample(self, num=100, minlen=7, maxlen=50, start=None, temp=2.5, show=False):
        sampled = []
        lcntr = 0
        for rs in tqdm(range(num)):
            random.seed(rs)
            if not maxlen:
                longest = np.random.randint(7, 50)
            else:
                longest = maxlen

            if start:
                start_aa = start
            else:
                start_aa = 'B'
            sequence = start_aa

            while sequence[-1] != ' ' and len(sequence) <= longest:
                x, _, _ = _onehotencode(sequence)
                preds = self.model.predict(x)[0][-1]
                next_aa = _sample_with_temp(preds, temp=temp)
                sequence += self.vocab[next_aa]

            if start_aa == 'B':
                sequence = sequence[1:].rstrip()
            else:
                sequence = sequence.rstrip()

            if len(sequence) < minlen:
                lcntr += 1
                continue

            sampled.append(sequence)
            if show:
                print(sequence)

        print("\t%i sequences were shorter than %i" % (lcntr, minlen))
        return sampled

    def load_model(self, filename):
        self.model.load_weights(filename)

    def finetuneinit(self, new_session_name):
        self.session_name = new_session_name
        self.logdir = "./" + self.session_name
        if not os.path.exists(self.logdir):
            os.makedirs(self.logdir)

In [None]:
# Parameters for pretraining
sessname = "train100"
infile = "new_sequences.csv"
neurons = 64 # 128, 512
layers = 2   #3,
epochs = 100 #150, 200, 300
batchsize = 128 # 100, 512, 1000
window = 0
step = 1
target = 'all'
valsplit = 0.2
dropout = 0.1
learningrate = 0.01
l2_rate = None

# Loading sequence data, analyze, pad and encode it
data = SequenceHandler(window=window, step=step)
print("Loading sequences...")
data.load_sequences(infile)
data.analyze_training()

# Pad sequences
print("\nPadding sequences...")
data.pad_sequences(padlen=0)

# One-hot encode padded sequences
print("One-hot encoding sequences...")
data.one_hot_encode(target=target)

# Building the LSTM model
print("\nBuilding model...")
model = Model(n_vocab=len(data.vocab), outshape=len(data.vocab), session_name=sessname, 
              n_units=neurons, batch=batchsize, layers=layers, cell="LSTM", 
              loss='categorical_crossentropy', lr=learningrate, dropoutfract=dropout, 
              l2_reg=l2_rate, ask=True, seed=42)
print("Model built!")

# Training model on data
print("\nTRAINING MODEL FOR {} EPOCHS...\n".format(epochs))
model.train(data.X, data.y, epochs=epochs, valsplit=valsplit, sample=0)
model.plot_losses()  # Plot loss

# Save the model instance
save_model_instance(model)

In [None]:
# Parameters for sampling
sample_name = "testsample"
modfile = "pretrained_model/checkpoint/model_epoch_67.hdf5"
sample_size = 100
temperature = 2.5
sample_length = 36

print("\nUSING PRETRAINED MODEL... ({})\n".format(modfile))
model = load_model_instance(modfile)
model.load_model(modfile)

# Generating new data through sampling
print("\nSAMPLING {} SEQUENCES...\n".format(sample_size))
data.generated = model.sample(sample_size, start='B', maxlen=sample_length, show=False, temp=temperature)
data.analyze_generated(sample_size, fname=model.logdir + '/analysis_temp' + str(temperature) + '.txt', plot=True)
data.save_generated(model.logdir, model.logdir + '/sampled_sequences_temp' + str(temperature) + '.csv')


In [None]:
# Parameters for finetuning
finetune_name = "finetune10"
finetune_file = "finetune_set.csv"
finetune_epochs = 10

print("\nUSING PRETRAINED MODEL FOR FINETUNING... ({})\n".format(modfile))
print("Loading model...")
model = load_model_instance(modfile)
model.load_model(modfile)
model.finetuneinit(finetune_name)  # Generate new session folders for finetuning run

print("Finetuning model...")
model.train(data.X, data.y, epochs=finetune_epochs, valsplit=valsplit, sample=0)
model.plot_losses()  # Plot loss

# Save the finetuned model instance
save_model_instance(model)