## Onehot encoded training

In [None]:
%load_ext autoreload
%autoreload 2

from embeddings.tokenization import load_labels, check_id_and_labels_exist
from embeddings.tokenization import KmerTokenizer
from embeddings.integer_embeddings import IntegerEmbeddings, OneHotEmbeddings
from embeddings.esmc_embeddings import ESMcEmbeddings
import numpy as np


In [None]:
tokenizer = KmerTokenizer(
							input_path = "../downloads",
							genome_col="genome_name",
							dna_sequence_col="dna_sequence",
							kmer_prefix="CACATG",
							kmer_suffix_size=12,
							file_type="parquet",
							reverse_complement=False,
							kmer_offset = 0,
							)
token_collection = tokenizer.run_tokenizer(nr_of_cores=2)

In [None]:
embedder = OneHotEmbeddings(token_collection=token_collection)
embeddings = embedder.run_embedder()

In [None]:
print(embeddings["GCF_000164865.1"]["forward"])


In [None]:
gid_and_strand_id = [[gid, strand_id] for gid, strands in embeddings.items() for strand_id in strands]

X = [embeddings[gid][strand_id] for gid, strand_id in gid_and_strand_id]

ids = [strand_id for _, strand_id in gid_and_strand_id]
groups = [gid for gid, _ in gid_and_strand_id]

In [None]:
labels = load_labels(file_path="../downloads/labels.csv", id = "genome_name", label = "madin_categorical_motility_binary", sep = ",", freq_others=None)
label_dict_literal, label_dict, int2label = labels["label_dict"], labels["label_dict_int"], labels["int2label"] 

### ESM-C embeddings

In [None]:
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
import torch
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
esmc_embedder = ESMcEmbeddings(token_collection=token_collection, 
                               kmer_suffix_size=tokenizer.kmer_suffix_size,
                               hidden_state=None,
                               pooling = "mean_per_token")
esmc_embeddings = esmc_embedder.run_embedder()

In [None]:
labels = load_labels(file_path="../downloads/labels.csv", id = "genome_name", label = "madin_categorical_motility_binary", sep = ",", freq_others=None)
label_dict_literal, label_dict, int2label = labels["label_dict"], labels["label_dict_int"], labels["int2label"] 

gid_and_strand_id = [[gid, strand_id] for gid, strands in esmc_embeddings.items() for strand_id in strands]

X = [esmc_embeddings[gid][strand_id][0] for gid, strand_id in gid_and_strand_id]

X = (torch.stack(X).to("cpu").numpy())
X

ids = [strand_id for _, strand_id in gid_and_strand_id]
groups = [gid for gid, _ in gid_and_strand_id]
X = np.array(

			[
				(x.detach().cpu() if isinstance(x, torch.Tensor) else torch.as_tensor(x, dtype=torch.float32))
				for gid, x in zip(groups, X) if gid in label_dict
			],
			dtype=np.float32
		)	
y = np.array([label_dict[gid] for gid in groups if gid in label_dict])
print(len(X), len(y))



In [None]:
pca = PCA(n_components=2)
pca.fit(X)
projected_mean_embeddings = pca.transform(X)


# plot the clusters
plt.figure(figsize=(4, 4))
sns.scatterplot(
    x=projected_mean_embeddings[:, 0],
    y=projected_mean_embeddings[:, 1],
    hue=y,
)
plt.title(
    f"PCA of mean embeddings "
)
plt.xlabel("PC 1")
plt.ylabel("PC 2")
plt.show()

In [None]:
from utilities.RF_classification import hist_gradient_boosting_classifier, pca_plot, model_context


ctx = model_context(X=X,
              y=y,
              output_directory="../results",
              phenotype="madin_categorical_motility_binary",
              kmer_prefix=tokenizer.kmer_prefix,
              kmer_suffix_size=tokenizer.kmer_suffix_size,
              model_type="HistGradientBoosting",
              int2label=int2label,
              k_folds=5
            )
hist_gradient_boosting_classifier(context=ctx)


In [None]:
pca_plot(ctx, save=False)