### Load model

In [1]:
import torch
from torch.utils.data import DataLoader
import numpy as np
from utils import count_parameters, read
from config import hyena_config
from hyena_simp import Config, HyenaConfig, AuthenticHyenaBlock, FastaModel

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model = FastaModel(hyena_config, AuthenticHyenaBlock)
model.load_state_dict(torch.load('models/model_state_dict.pt'))
model.eval()

# model_2 = FastaModel(hyena_config, AuthenticHyenaBlock)
# model_2.load_state_dict(torch.load('models/model_state_dict.pt'))
# model_2.eval()

FastaModel(
  (tok_emb): Embedding(13, 10)
  (dropout): Dropout(p=0.2, inplace=False)
  (layers): Sequential(
    (0): AuthenticHyenaBlock(
      (proj_input): Projection(
        (linear): Linear(in_features=10, out_features=30, bias=True)
        (conv): Conv1d(30, 30, kernel_size=(3,), stride=(1,), padding=(2,), groups=30)
      )
      (proj_output): Linear(in_features=10, out_features=10, bias=True)
      (filter): AuthenticHyenaFilter(
        (pos_emb): PositionalEmbedding()
        (mlp): Sequential(
          (0): Linear(in_features=5, out_features=64, bias=True)
          (1): Sin()
          (2): Linear(in_features=64, out_features=64, bias=True)
          (3): Sin()
          (4): Linear(in_features=64, out_features=64, bias=True)
          (5): Sin()
          (6): Linear(in_features=64, out_features=20, bias=False)
        )
        (window): AuthenticWindow()
      )
      (dropout): Dropout(p=0.0, inplace=False)
      (fft_conv): FFTConv()
    )
    (1): AuthenticHyenaB

### Aggregate model weights

In [None]:
# # add weights of two state_dicts
# for name, param in model.named_parameters():
#     if name in model_2.state_dict():
#         param.data = model_2.state_dict()[name]

### Data

In [3]:
path = './data/all_genomes.fasta'

data = read(path)
# selsct the first 2250 sequences
data = data[:2250]
print('Number of genome sequences: ',len(data))

# preprocessing
CONTEXT_LENGTH = 30000

# cut sequences to CONTEXT_LENGTH
for i in range(len(data)):
    if len(data[i]) >= CONTEXT_LENGTH:
        data[i] = data[i][:CONTEXT_LENGTH]
# apply 'P' padding to the sequences
for i in range(len(data)):
    data[i] = data[i] + 'P' * (CONTEXT_LENGTH - len(data[i]))

min_length = min([len(x) for x in data])
max_length = max([len(x) for x in data])

print('Min and max length of genome sequences after padding:\nMin: ', min_length,'\nMax: ', max_length)

# Tokenize

chars = set()

for genome in data:
    for char in genome:
        chars.add(char)
vocabulary = list(chars)


tok2id = {ch: i for i, ch in enumerate(vocabulary)}
id2tok = {i: ch for i, ch in enumerate(vocabulary)}

encode = lambda s: [tok2id[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([id2tok[i] for i in l]) # decoder: take a list of integers, output a string

# Train and test splits

n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

Number of genome sequences:  2250
Min and max length of genome sequences after padding:
Min:  30000 
Max:  30000


In [4]:
class Embeddings_DS(torch.utils.data.Dataset):
    def __init__(self, data):
        self.data = data
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        idx = np.random.randint(len(self.data))
        data = self.data[idx]
        data = torch.tensor(encode(data))

        return data

# Datasets

# train_ds = Embeddings_DS(train_data)
val_ds = Embeddings_DS(val_data)

# Dataloader
# loader = DataLoader(train_ds, batch_size=hyena_config.batch_size, shuffle=True, num_workers=10)
val_loader = DataLoader(val_ds, batch_size=hyena_config.batch_size, shuffle=True, num_workers=10)

model = FastaModel(hyena_config, AuthenticHyenaBlock)
m = model.to('cuda')

In [15]:
val_loss_ = []
model.eval()
embeddings = []
for batch in val_loader:
    batch = batch.to('cuda')
    logits, genome_embedding = model(batch)
    loss = torch.nn.functional.cross_entropy(
        logits.transpose(1, 2), batch
    )
    val_loss_.append(loss.item())
    embeddings.append(genome_embedding.detach().cpu().numpy())

# accumulate val loss
tot_val_loss = sum(val_loss_) / len(val_loss_) 
print(f'Validation loss: {tot_val_loss}')

Validation loss: 5.948329902225071


In [17]:
for i in range(len(embeddings)):
    embeddings[i] = embeddings[i].squeeze(0)

print('Shape of embeddings', embeddings[0].shape)
print(embeddings[0])

Shape of embeddings (10, 1000)
[[-1.04405728e+08 -2.75099040e+07 -6.44010200e+07 ... -4.77720320e+07
  -4.02863480e+07 -1.53021344e+08]
 [-2.93668640e+07 -4.15380480e+07 -4.62524200e+07 ... -1.01571030e+07
  -7.29748800e+07 -9.51161360e+07]
 [-5.72926480e+07  5.68492800e+07  1.58430000e+08 ... -7.48765600e+07
  -1.17196960e+08 -8.06539840e+07]
 ...
 [-7.05785760e+07  1.13798616e+08  1.16629528e+08 ... -1.91570752e+08
  -1.89325888e+08 -1.18564888e+08]
 [ 8.38409920e+07 -1.32136528e+08 -4.35611520e+07 ...  1.66617376e+08
   1.91753056e+08  1.34499360e+08]
 [-3.86077280e+07  1.57334304e+08 -1.52868600e+07 ... -1.56067104e+08
  -2.83712000e+07 -6.36987680e+07]]


### Clustering

In [None]:
# k-means clustering
from sklearn.cluster import KMeans

embeddings = torch.cat(embeddings, dim=0)
print('Shape of embeddings', embeddings.shape)
embeddings = embeddings.cpu().detach().numpy()
print('Shape of embeddings', embeddings.shape)

kmeans = KMeans(n_clusters=5, random_state=0).fit(embeddings)
print('Centroids: ', kmeans.cluster_centers_)
print('Labels: ', kmeans.labels_)
print('Inertia: ', kmeans.inertia_)
print('Number of iterations: ', kmeans.n_iter_)
print('Predictions: ', kmeans.predict(embeddings))

# t-SNE visualization
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE

tsne = TSNE(n_components=2, random_state=0)
embeddings_2d = tsne.fit_transform(embeddings)

plt.scatter(embeddings_2d[:, 0], embeddings_2d[:, 1], c=kmeans.labels_)
plt.title('t-SNE visualization of genome embeddings')
plt.show()


aggregate genomic data in one file -> split in parts of 2500 -> train separate models on each data chunk.

validation data: either a fraction of each chunk or the aggragated validation data of all data chunks.

aggregate model weights -> only add weights, if it improves over all validation accuracy.

Kmenas clustering and write someting in overleaf.



generate new genomes, based on previous period -> generate embeddings for them -> cluster embeddings -> compare new embeddings and groubnd truth cluster