In [1]:
import numpy as np

import utils.general

In [2]:
train_val, test = utils.general.load_dti_data()

Found local copy...


In [3]:
x, y = utils.general.format_input(data=test, num_samples=10, seed=1)

In [4]:
client = utils.general.create_openai_client()

In [5]:
notools_assistant = client.beta.assistants.create(
    model='gpt-4o',
    name='No-Tools',
    instructions=(
'''You are an artificial super-intelligence capable of solving computational biology problems with very little information available to you.
You are tasked with computing the dissociation constant (Kd) between a drug molecule given by its SMILES string and a target protein given by its amino acid sequence.
Every query will consist of a number of rows, each of which has the SMILES string of a drug molecule, followed by a space, and then the amino acid sequence of the target protein, concluded by a newline.
For each row, you must predict the dissociation constant Kd.
The answers should be in units of nanomolar (nM).
Please preface each answer with three less than signs and a space, and finish each answer with a space and three greater than signs, with only a numerical value inside.
Here is an example of an input row:
O=C(O)CCC(=O)C(=O)O MANDSGGPGGPSPSERDRQYCELCGKMENLLRCSRCRSSFYCCKEHQRQDWKKHKLVCQGSEGALGHGVGPHQHSGPAPPAAVPPPRAGAREPRKAAARRDNASGDAAKGKVKAKPPADPAAAASPCRAAAGGQGSAVAAEAEPGKEEPPARSSLFQEKANLYPPSNTPGDALSPGGGLRPNGQTKPLPALKLALEYIVPCMNKHGICVVDDFLGKETGQQIGDEVRALHDTGKFTDGQLVSQKSDSSKDIRGDKITWIEGKEPGCETIGLLMSSMDDLIRHCNGKLGSYKINGRTKAMVACYPGNGTGYVRHVDNPNGDGRCVTCIYYLNKDWDAKVSGGILRIFPEGKAQFADIEPKFDRLLFFWSDRRNPHEVQPAYATRYAITVWYFDADERARAKVKYLTGEKGVRVELNKPSDSVGKDVF
If the query were to contain that input row, you would report the following:
>>> 900.0 <<<
You must make your best estimate of a numerical value for each row.
Make sure that each row has exactly one answer.
There should be exactly one answer for each molecule and protein pair.'''
    ),
    tools=[],
    temperature=0.2
)

max_tries = 5
tries = 0
finished = False
success = False
while not finished:
    try:
        thread = client.beta.threads.create()

        message = client.beta.threads.messages.create(
            thread_id=thread.id,
            role='user',
            content=x
        )

        run = client.beta.threads.runs.create_and_poll(
        thread_id=thread.id,
        assistant_id=notools_assistant.id
        )

        messages = client.beta.threads.messages.list(thread_id=thread.id)

        res = messages.model_dump()['data'][0]['content'][0]['text']['value']

        pred = utils.general.extract_predictions(res)

        score = utils.general.score_predictions(pred, y)
        log_score = utils.general.score_predictions(np.log(pred), y)
        finished = True
        success = True
    except AssertionError:
        tries += 1
        if tries == max_tries:
            finished = True
if finished and success:
    print('Pearson Correlation Coefficient: {0:.3}'.format(score))
    print('With log: {0:.3}\n'.format(log_score))
    print(np.vstack([np.log(pred), y]).T)

Pearson Correlation Coefficient: 0.0124
With log: 0.0201

[[ 5.01063529  2.30258509]
 [ 5.29831737  1.60943791]
 [ 5.19295685  9.21034037]
 [ 5.39362755  1.60943791]
 [ 5.13579844 -0.63487827]
 [ 5.24702407  6.44571982]
 [ 5.07517382  4.82831374]
 [ 5.34710753  2.89037176]
 [ 4.94164242  3.91202301]
 [ 5.43807931  5.29831737]]


In [6]:
x.split(' ')

['CCN1C(=O)N(Cc2cnc(cc12)-c1ccc(cc1)C1(CCC1)C#N)c1c(F)c(OC)cc(OC)c1F',
 'MGAPACALALCVAVAIVAGASSESLGTEQRVVGRAAEVPGPEPGQQEQLVFGSGDAVELSCPPPGGGPMGPTVWVKDGTGLVPSERVLVGPQRLQVLNASHEDSGAYSCRQRLTQRVLCHFSVRVTDAPSSGDDEDGEDEAEDTGVDTGAPYWTRPERMDKKLLAVPAANTVRFRCPAAGNPTPSISWLKNGREFRGEHRIGGIKLRHQQWSLVMESVVPSDRGNYTCVVENKFGSIRQTYTLDVLERSPHRPILQAGLPANQTAVLGSDVEFHCKVYSDAQPHIQWLKHVEVNGSKVGPDGTPYVTVLKTAGANTTDKELEVLSLHNVTFEDAGEYTCLAGNSIGFSHHSAWLVVLPAEEELVEADEAGSVYAGILSYGVGFFLFILVVAAVTLCRLRSPPKKGLGSPTVHKISRFPLKRQVSLESNASMSSNTPLVRIARLSSGEGPTLANVSELELPADPKWELSRARLTLGKPLGEGCFGQVVMAEAIGIDKDRAAKPVTVAVKMLKDDATDKDLSDLVSEMEMMKMIGKHKNIINLLGACTQGGPLYVLVEYAAKGNLREFLRARRPPGLDYSFDTCKPPEEQLTFKDLVSCAYQVARGMEYLASQKCIHRDLAARNVLVTEDNVMKIADFGLARDVHNLDYYKKTTNGRLPVKWMAPEALFDRVYTHQSDVWSFGVLLWEIFTLGGSPYPGIPVEELFKLLKEGHRMDKPANCTHDLYMIMRECWHAAPSQRPTFKQLVEDLDRVLTVTSTDEYLDLSAPFEQYSPGGQDTPSSSSSGDDSVFAHDLLPPAPPSSGGSRT',
 '\nCOC(=O)N1CC(C1)n1cc(nn1)-c1cnc(cc1NC(C)C)-n1ncc2cc(cnc12)C#N',
 'MNKPITPSTYVRCLNVGLIRKLSDFIDPQEGWKKLAVAIKKPSGDDRYN