# LightGBM Training

##### Importing Dependencies

In [None]:
!pip install lightgbm

import re, random, pickle, glob, os, difflib, itertools, logging, warnings, collections
warnings.simplefilter(action='ignore')
from tensorflow import keras
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.layers import *
from tensorflow.keras.models import Sequential, load_model, Model
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from tensorflow.keras import optimizers
import numpy as np

from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import *
from sklearn.svm import SVC

import seaborn as sns
from matplotlib import pyplot as plt
import pandas as pd

from Bio import SeqIO
from Bio.SeqRecord import SeqRecord
from Bio.Seq import Seq
import lightgbm as lgb
import os

In [None]:
root = 'path\to\datafolder' #path to folder containing data files
model_path = 'path\to\model_folder' #path to saved transformer model
model_name = 'name_of_transformer_model' #Find the name of the transformer model from the 'Train Transformer.ipynb'
lineages_path = 'name_of_lineages_file' # The lineages file contains information about what subfamily each viral strain belongs to. 
#Place the lineages file inside the datafolder.

## 1) Preprocessing

#### 1.1) Loading in Training Set Genomes 

In [None]:
viral_genomes = open(os.path.join(datapath, viral_refs)).readlines() ### insert appropriate path to training set FASTA
strains = [i[1:-1][:-5] for i in viral_genomes[0::2]]
genomes = np.array([re.sub('[^ATCGN]+' ,'', i.replace('\n', '').upper()) for i in viral_genomes[1::2]])

#### 1.2) Create Dictionary (Viral Strain &rarr; Class)

In [None]:
### insert appropriate path to classification TXT
genome_cls = open(os.path.join(datapath, 'lineages.txt')).readlines()  

strain = 'N\A'
strain_cls = []
for i in genome_cls:
    if((strain+' ') in i):
        cls = i[i.index('>')+2:-1]
        try:
            cls = cls[:cls.index('pa')]
            strain_cls.append((strain, cls))
        except:
            continue
    if('vir_name' in i):
        strain = i[10:-1]

def Convert(tup, di):
    di = dict(tup)
    return di
      
strain_cls = Convert(strain_cls, dict())

#### 1.3) Create Training Dataframe

In [None]:
temp = pd.DataFrame()
temp['strains'] = strain_cls.keys()
temp['cls'] = strain_cls.values()

training_df = pd.DataFrame()
training_df['strains'] = strains
training_df['genomes'] = genomes
training_df = training_df.merge(temp, on='strains')
training_df

#### 1.4) Gather Genomes by Class

In [None]:
alpha_genomes = []; beta_genomes = []; gamma_genomes = []; other_genomes= []

for i in df[df.cls=='Alpha'].iterrows():
    alpha_genomes.append(i[1].genomes)
    alpha_genomes.append(str(Seq(i[1].genomes).reverse_complement()))
    
for i in df[df.cls=='Beta'].iterrows():
    beta_genomes.append(i[1].genomes)
    beta_genomes.append(str(Seq(i[1].genomes).reverse_complement()))
    
for i in df[df.cls=='Gamma'].iterrows():
    gamma_genomes.append(i[1].genomes)
    gamma_genomes.append(str(Seq(i[1].genomes).reverse_complement()))
    
for index, row in df.iterrows():
    if row['cls'] != 'Alpha' and row['cls'] != 'Beta' and row['cls'] != 'Gamma':
        other_genomes.append(row['genomes'])
        other_genomes.append(str(Seq(row['genomes']).reverse_complement()))
        
print(f'# Alpha Genomes: {len(alpha_genomes)} | # Beta Genomes: {len(beta_genomes)} | # Gamma Genomes: {len(gamma_genomes)} | # Other Genomes: {len(other_genomes)}')

#### 1.5) Generate Reads from the Viral Genomes

In [None]:
maxlen = 150 # length of each read
alpha_reads = []
num_reads = int(7000/150)
for i in alpha_genomes:
    read_locs = np.random.randint(0,len(i)-maxlen-1, num_reads)
    for j in read_locs:
        alpha_reads.append(list(i[j : j+maxlen]))
        
beta_reads = []
num_reads = int(7000/150)
for i in beta_genomes:
    read_locs = np.random.randint(0,len(i)-maxlen-1, num_reads)
    for j in read_locs:
        beta_reads.append(list(i[j : j+maxlen]))
        
gamma_reads = []
num_reads = int(7000/150)
for i in gamma_genomes:
    read_locs = np.random.randint(0,len(i)-maxlen-1, num_reads)
    for j in read_locs:
        gamma_reads.append(list(i[j : j+maxlen]))
        
other_reads = []
num_reads = int(7000/150)
for i in other_genomes:
    read_locs = np.random.randint(0,len(i)-maxlen-1, num_reads)
    for j in read_locs:
        other_reads.append(list(i[j : j+maxlen]))

print(f'# Alpha Reads:  {len(alpha_reads)} | # Beta Reads: {len(beta_reads)} | # Gamma Reads:  {len(gamma_reads)} | # Other Reads: {len(other_reads)}')
print(f'# Total Reads: {sum([len(alpha_reads),len(beta_reads),len(gamma_reads),len(other_reads)])}')

#### 1.7) Tokenize Viral Reads

In [None]:
tokens = "ACGTN"
mapping = dict(zip(tokens, range(1,len(tokens)+1)))

def seqs2cat(seqs, mapping):
    def categorical_encode(seq):
        seq = [mapping.get(i,5) for i in seq]
        return np.array(seq)
    vecs = []
    for i in seqs:
        vecs.append(np.array(categorical_encode(i)))
    return np.array(vecs)

alpha_tokenized = seqs2cat(alpha_reads, mapping)
beta_tokenized = seqs2cat(beta_reads, mapping)
gamma_tokenized = seqs2cat(gamma_reads, mapping)
other_tokenized = seqs2cat(other_reads,mapping)

#### 1.8) Load in Transformer

In [None]:
print(tf.__version__)
transformer = load_model(os.path.join(model_path, model_name), compile=False)
transformer = Model(transformer.input, transformer.layers[-2].output)
transformer.summary()

#### 1.9) Encode Tokenized Reads through Transformer

In [None]:
alpha_pred = np.mean(transformer.predict(alpha_tokenized, verbose =1), axis=1)
beta_pred = np.mean(transformer.predict(beta_tokenized, verbose =1), axis=1)

In [None]:
gamma_pred = np.mean(transformer.predict(gamma_tokenized, verbose =1), axis=1)
other_pred = np.mean(transformer.predict(other_tokenized, verbose =1), axis=1)


## 2) Training

#### 2.1) Organize Viral Encodings and Corresponding Classifications

In [None]:
classDict = {0:"Alpha",1:"Beta",2:"Gamma",3:"Delta"}
pred = np.concatenate((alpha_pred, beta_pred, gamma_pred, other_pred))
gt = np.concatenate(([0]*len(alpha_pred), [1]*len(beta_pred), [2]*len(gamma_pred),[3]*len(other_pred)))
labelgt = np.concatenate((["Alpha"]*len(alpha_pred), ["Beta"]*len(beta_pred), ["Gamma"]*len(gamma_pred),["Other"]*len(other_pred)))

#### 2.2) Creating Training/Validation Split

In [None]:
# Default setting is a 70/30 split
X_train, X_validate, y_train, y_validate = train_test_split(pred, gt, test_size=0.3, random_state=42)
train_data = lgb.Dataset(X_train,y_train)
validate_data = lgb.Dataset(X_validate,y_validate)
print(f'Training Set Size: {len(X_train)}')
print(f'Validation Set Size: {len(X_validate)}')
print(f'Total Size: {len(X_train)+len(X_validate)} (Sanity Check)')

#### 2.3) Fitting LightGBM Model to Training Set

In [None]:
classifier = lgb.LGBMClassifier()
classifier.fit(X_train, y_train)

## 3) Validation

#### 3.1) Model Performance on Validation Set

In [None]:
y_pred = classifier.predict(X_validate)

from sklearn.metrics import accuracy_score
from IPython.display import display, HTML

val_alpha_accuracy = sum(classifier.predict(X_validate[y_validate==0]) == y_validate[y_validate==0]) / len(X_validate[y_validate==0])
val_beta_accuracy = sum(classifier.predict(X_validate[y_validate==1]) == y_validate[y_validate==1]) / len(X_validate[y_validate==1])
val_gamma_accuracy = sum(classifier.predict(X_validate[y_validate==2]) == y_validate[y_validate==2]) / len(X_validate[y_validate==2])
val_other_accuracy = sum(classifier.predict(X_validate[y_validate==3]) == y_validate[y_validate==3]) / len(X_validate[y_validate==3])

accuracy_scores=np.array([[val_alpha_accuracy,val_beta_accuracy,val_gamma_accuracy,val_other_accuracy]])
accuracies = pd.DataFrame(accuracy_scores, columns=["Alpha","Beta","Gamma","Other"])
accuracies.index = ["Accuracy"]
display(accuracies)

sns.set_style("darkgrid")
sns.set_palette("rocket")
sns.barplot(data=accuracies)
sns.set(rc={'figure.figsize':(7,5)})
plt.show()

## 4) Saving Model

In [None]:
from sklearn.externals import joblib

joblib.dump(classifier, os.path.join(model_path, 'LightGBM_Model.pkl'))