# Inference Setup

In [2]:
import torch
import pandas as pd
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM
from tqdm import tqdm
import torch.nn.functional as F
import re

In [3]:
class Predictor:
    def __init__(self, model_path) -> None:
        self.device = torch.device(
            "cuda:0" if torch.cuda.is_available() else "cpu"
        )
        self.status = 0
    
        self.hparam = {
            'max_len': 1000,
        }
        self.model = AutoModelForCausalLM.from_pretrained(model_path)
        self.model.eval()

        self.model.to(self.device)
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        self.greek_names = [
            'ALPHA',
            'BETA',
            'GAMMA',
            'DELTA',
            'EPSILON',
            'ZETA',
            'ETA',
            'THETA',
            'IOTA',
            'KAPPA',
            'LAMBDA',
            'MI',
            'NI',
            'XI',
            'OMICRON',
            'PI',
            'RHO',
            'SIGMA',
            'TAU',
            'HYPSILON',
            'PHI',
            'CHI',
            'PSI',
            'OMEGA',
        ]
        self.effects_names = [
            'binding_to_antibodies',
            'binding_to_host_receptor',
            'ct_value',
            'disease_severity',
            'effectiveness_of_available_antiviral_drugs',
            'effectiveness_of_available_diagnostics',
            'effectiveness_of_available_vaccines',
            'entry_efficiency',
            'fatality_rate',
            'host_virus_interactions',
            'immune_escape',
            'infection_duration',
            'infectivity',
            'intermolecular_interactions',
            'protein_conformational_optimization',
            'protein_flexibility',
            'protein_functioning',
            'protein_stability',
            'risk_of_hospitalization',
            'risk_of_reinfection',
            'sensitivity_to_convalescent_sera',
            'sensitivity_to_antibodies',
            'sensitivity_to_vaccinated_sera',
            'viral_fitness',
            'viral_incubation_period',
            'viral_load',
            'viral_replication',
            'viral_transmission',
            'viral_virulence',
        ]
    
    def generate(self, input_ids):
        with torch.no_grad():
            generated_sequence = []
            distributions = []
            comma_id = self.tokenizer.encode(',')[0]
            eos_id = self.tokenizer.eos_token_id
            sep_id = self.tokenizer.sep_token_id
            sepo_id = self.tokenizer(' | ')['input_ids'][0]
            output_indexes = [0]
            current_index = 0
            # past = None
            ended_with_eos = False
            while(len(generated_sequence) < 100):

                if current_index > 0:
                    if (predicted_token_tensor == comma_id):
                        output_indexes.append(current_index)

                outputs = self.model(
                    input_ids,
                    # past_key_values=past,
                    # use_cache=True,
                    return_dict=True
                )
                # past = outputs.past_key_values
                next_token_logits = outputs.logits[:, -1, :]
                predicted_token_tensor = torch.argmax(next_token_logits)

                if (predicted_token_tensor == eos_id) or (predicted_token_tensor == sepo_id) or (predicted_token_tensor == sep_id) :
                    ended_with_eos = True
                    break

                distributions.append(
                    F.softmax(next_token_logits[0], 0).detach()
                )
                
                input_ids = torch.cat(
                    (input_ids, predicted_token_tensor.view(1, 1).detach()),
                    dim=-1
                )
                generated_sequence.append(predicted_token_tensor.detach())
                current_index += 1
            return generated_sequence, output_indexes, distributions, ended_with_eos

    def generateTable(self, inputs, output_attributes, pre_table_outputs=None):
        self.status = 0
        table_outputs = []
        fields = output_attributes
        with torch.no_grad():
            for it, (input_text, doi) in enumerate(tqdm(inputs)):
                # print(input_text)
                self.status = round((it + 1)/len(inputs), 2) * 100

                prefix_input_ids = self.tokenizer.encode(
                    input_text,
                    return_tensors='pt',
                    truncation=True,
                    max_length=self.hparam['max_len'] -100
                )
                if pre_table_outputs != None:
                    generated_outputs = pre_table_outputs[doi]
                    if generated_outputs == [[]]:
                        table_outputs.append(dict(doi=doi, outputs=[]))
                        continue

                else:
                    generated_outputs = [[]]
                    
                for field in fields:
                    tmp_generated_outputs = []
                    for istance_index, instance in enumerate(generated_outputs):
                        confidences = []
                        past_conditional = ''
                        if len(instance) > 0:
                            for output in instance:
                                past_conditional += output['attribute'] + ': ' + output['value'] + ' | '
                        tmp_field = field['value']
                        conditional_text = past_conditional + tmp_field + '_list:' if field['multiple'] else past_conditional + tmp_field + ':',
                        # print('conditional text:', conditional_text[0])
                        
                        conditional_ids = self.tokenizer.encode(
                            conditional_text[0],
                            return_tensors='pt',
                            truncation=True,
                            max_length=100
                        )


                        input_ids = torch.cat(
                            (
                                torch.tensor([[self.tokenizer.bos_token_id]]),
                                prefix_input_ids,
                                torch.tensor([[self.tokenizer.sep_token_id]]),
                                conditional_ids
                            ),
                            dim=-1
                        ).to(self.device)

                        generated_sequence, output_indexes, distributions, ended_with_eos\
                            = self.generate(input_ids)

                        distributions = [
                            distribution.cpu().numpy() for distribution in distributions
                        ]
                        # print(self.tokenizer.decode(generated_sequence))
                        outputs_text = self.tokenizer.decode(generated_sequence).split(',')

                        if not ended_with_eos:
                            outputs_text = outputs_text[:-1]
                            output_indexes = output_indexes[:-1]
                        # print('Abstract number:', it)
                        # print('generated sequence:', self.tokenizer.decode(generated_sequence))

                        try:
                            outputs_text_filtered = []
                            for i, output_index in enumerate(output_indexes):
                                # confidence 1st token
                                if outputs_text[i] not in outputs_text[:i]:
                                    if field['multiple']:
                                        if self.output_is_valid(outputs_text[i].strip(), field['value']):
                                            outputs_text_filtered.append(outputs_text[i].strip())
                                            out_prob = distributions[output_index]
                                            confidences.append(np.max(out_prob))
                                    else:
                                        outputs_text_filtered.append(outputs_text[i].strip())
                                        out_prob = distributions[output_index]
                                        confidences.append(np.max(out_prob))

                                    # start_index = output_index
                                    # end_index = output_indexes[i + 1] - 1 if i < (len(outputs_text)-1) else len(generated_sequence)

                            assert len(outputs_text_filtered) == len(confidences), \
                                f'n of outputs not correspond n of confidences: {outputs_text} {confidences}\n'\
                                +f'Len Input: {prefix_input_ids.shape}, Len Cond: {conditional_ids.shape}'

                            if len(outputs_text_filtered) == 0 and field['value'] == 'mutation_name' and field['multiple'] :
                                tmp_generated_outputs = []
                                break

                        except:
                            print(outputs_text)
                            print('an error occurs... skip and keep running...')
                            break

                        
                        
                        
                        for output_index, output in enumerate(outputs_text_filtered):
                            if output not in outputs_text[:output_index]:
                                tmp_generated_outputs.append(
                                    instance + [
                                        dict(
                                            attribute=field['value'],
                                            value=output.strip(),
                                            confidence=np.round(np.float64(confidences[output_index]), 2),
                                        )
                                    ]
                                )
                    
                    generated_outputs = tmp_generated_outputs
                
                table_outputs.append(dict(doi=doi, outputs=generated_outputs))
            return table_outputs


    def output_is_valid(self, output, attribute):
        if attribute == 'mutation_name':
            if re.search('_', output):
                if re.search('^([A-Z0-9]+_)[A-Z]\d{1,4}[A-Z]$', output):
                    return True
                else:
                    return False
            else:
                if re.search('\.', output):
                    if re.search('^([A-Z]{1,2}\.[0-9]{1,3})(\.[0-9]{1,3}){,2}$', output): 
                        return True
                    else:
                        return False
                else:
                    if output in self.greek_names:
                        return True
                    else:
                        return False
        if attribute == 'effect':
            if output in self.effects_names:
                return True
            else:
                return False
        

In [4]:
pred = Predictor('backend/api/checkpoints/model_0')

# Insert here your list of abstract and the doi list

In [None]:
abstract_list = [
    'The recently reported B.1.1.529 Omicron variant of SARS-CoV-2 includes 34 mutations in the spike protein relative to the Wuhan strain that initiated the COVID-19 pandemic, including 15 mutations in the receptor binding domain (RBD). Functional studies have shown omicron to substantially escape the activity of many SARS-CoV-2-neutralizing antibodies. Here we report a 3.1 Å resolution cryo-electron microscopy (cryo-EM) structure of the Omicron spike protein ectodomain. The structure depicts a spike that is exclusively in the 1-RBD-up conformation with increased mobility and inter-protomer asymmetry. Many mutations cause steric clashes and/or altered interactions at antibody binding surfaces, whereas others mediate changes of the spike structure in local regions to interfere with antibody recognition. Overall, the structure of the omicron spike reveals how mutations alter its conformation and explains its extraordinary ability to evade neutralizing antibodies. Highlights SARS-CoV-2 omicron spike exclusively adopts 1-RBD-up conformation Omicron substitutions alter conformation and mobility of RBD A subset of omicron mutations change the local conformation of spike The structure reveals the basis of antibody neutralization escape'
]
doi_list = [
    '10.1101/2021.12.21.473620'
]

# Compute the predictions

In [6]:

abstract_doi_list = list(zip(abstract_list, doi_list))
prediction_attributes = [
    { 'value': 'mutation_name', 'multiple': True},
    { 'value': 'effect', 'multiple': True},
    { 'value': 'level', 'multiple': False}
]
pre_table_outputs = None

In [9]:

prediction_results = pred.generateTable(abstract_doi_list, prediction_attributes, pre_table_outputs=pre_table_outputs)

  0%|          | 0/1 [00:00<?, ?it/s]100%|██████████| 1/1 [01:04<00:00, 64.46s/it]100%|██████████| 1/1 [01:04<00:00, 64.47s/it]


In [10]:
df_predictions = pd.DataFrame()
for result in prediction_results:
    if len(result['outputs']) > 0:
        for instance in result['outputs']:
            output_dict = {}
            output_dict['doi'] = [result['doi']]
            for output in instance:
                output_dict[output['attribute']] = [output['value']]
                
            df_prediction = pd.DataFrame(output_dict)
            df_predictions = pd.concat([df_predictions, df_prediction], ignore_index=True, axis=0)
    else:
        output_dict = dict(doi=[result['doi']], mutation_name=[''], effect=[''], level=[''])
        # print(output_dict)
        df_prediction = pd.DataFrame(output_dict)
        df_predictions = pd.concat([df_predictions, df_prediction], ignore_index=True, axis=0)


In [11]:
df_predictions

Unnamed: 0,doi,mutation_name,effect,level
0,empty_doi,OMICRON,binding_to_antibodies,lower
1,empty_doi,OMICRON,protein_conformational_optimization,no evidence
2,empty_doi,OMICRON,immune_escape,undefined
