# Select sequences from Clades not used in training

In [7]:
from tqdm.notebook import tqdm
import pandas as pd
tqdm.pandas()

In [2]:
CLADES = ['GRY','O']
COLS = ["Virus name", "Accession ID", "Collection date", "Submission date","Clade", "Host", "Is complete?"]

In [3]:
path_metadata = "/home/disco/Github/GISAID/metadata_tsv_2021_11_11/metadata.tsv"
metadata = pd.read_csv(path_metadata, sep="\t", usecols=COLS)

In [4]:
# Remove NaN in Clades and not-complete sequences
metadata.dropna(axis="rows",
            how="any",
            subset=["Is complete?", "Clade"], 
            inplace=True,
            )

In [5]:
# Filter by Clades and Host
CLADES = tuple(clade for clade in CLADES)
metadata.query(f"`Clade` in {CLADES} and `Host`=='Human'", inplace=True)

In [8]:
# Generate id of sequences in fasta file: "Virus name|Accession ID|Collection date"
metadata["fasta_id"] = metadata.progress_apply(lambda row: "|".join([row["Virus name"],row["Collection date"],row["Submission date"]]), axis=1)

  0%|          | 0/966768 [00:00<?, ?it/s]

In [9]:
import random
from collections import namedtuple
# subsample 
SAMPLES_PER_CLADE = 1000
SampleClade = namedtuple("SampleClade", ["fasta_id","clade"])
list_fasta_selected = []
for clade in tqdm(CLADES):
    samples_clade = metadata.query(f"`Clade` == '{clade}'")["fasta_id"].tolist()
    random.shuffle(samples_clade)
    # select 'SAMPLES_PER_CLADE' samples for each clade, or all of them if available samples are less than required
    list_fasta_selected.extend([SampleClade(fasta_id, clade) for fasta_id in samples_clade[:SAMPLES_PER_CLADE]])

  0%|          | 0/2 [00:00<?, ?it/s]

In [10]:
from pathlib import Path
Path("nextclade-comparison").mkdir(exist_ok=True, parents=True)
pd.DataFrame(list_fasta_selected).to_csv("nextclade-comparison/outer_clades.csv")

In [11]:
pd.DataFrame(list_fasta_selected).groupby("clade").size()

clade
GRY    1000
O      1000
dtype: int64

___
## Inference and histogram of probabilities

In [28]:
import numpy as np 

from pathlib import Path
from Bio import SeqIO
from supervised_dna.fcgr import FCGR
from supervised_dna.utils import (
    # array2img,
    # find_matches,
    # fcgrpos2kmers,
    preprocess_seq
)
fcgr = FCGR(8)

In [24]:
list_fasta = list((Path("data-outer-comparison").rglob("*.fasta")))
len(list_fasta) 

1952

In [25]:
from supervised_dna import ModelLoader
# -1- Load model with best weights
loader = ModelLoader()
MODEL  = "resnet50_8mers"
CLADES = ['S','L','G','V','GR','GH','GV','GK']
WEIGHTS_PATH = "checkpoints/model-02-0.969.hdf5"
model  = loader(
            model_name=MODEL,
            n_outputs=len(CLADES),
            weights_path=WEIGHTS_PATH,
            ) # get compiled model from ./supervised_dna/models


 **load model weights_path** : checkpoints/model-02-0.969.hdf5

**Model created**


In [29]:
# from collections import namedtuple
# Results = namedtuple("Results",["path_fasta","prob","pred_class"])
# list_results = []
# for path_fasta in tqdm(list_test, desc="Computing Saliency Maps"):
#     prob, pred_class = compute_analysis(path_fasta)
#     list_results.append(Results(path_fasta, prob, pred_class))

# pd.DataFrame(list_results).to_csv("results_nextclade_comparison.csv")

In [48]:
from tqdm.notebook import tqdm
from collections import namedtuple
Results = namedtuple("Results",["path_fasta","prob","pred_class"])

list_results = []
for path_fasta in tqdm(list_fasta):
    fasta = next(SeqIO.parse(path_fasta, "fasta"))
    array_freq = fcgr(sequence=preprocess_seq(str(fasta.seq)))
    # preproceesing (divide by 10) and add channel axis
    input_model = np.expand_dims(array_freq/10. , axis=-1)
    input_model = np.expand_dims(input_model,axis=0)
    
    # make prediction and get predicted clade
    probs      = model.predict(input_model)[0]
    pred_class = CLADES[np.argmax(probs)]
    prob_class = probs.max()
    
    # save result
    list_results.append(Results(path_fasta, prob_class, pred_class))
    

  0%|          | 0/1952 [00:00<?, ?it/s]

In [46]:
prob.max(), CLADES[np.argmax(prob)]

(0.9999782, 'GR')

In [49]:
pd.DataFrame(list_results).to_csv("outer-comparison/results_pred.csv")