In [None]:
import requests
from requests.auth import HTTPBasicAuth
from requests.packages.urllib3.util.retry import Retry
from requests.packages.urllib3.util.timeout import Timeout

import pandas as pd
import numpy as np
import datetime

UMLS_API_KEY = ''

class UMLS_QUERY:
    TGT_URL = 'https://utslogin.nlm.nih.gov/cas/v1/api-key'
    SERVICE_TICKET_URL = 'https://utslogin.nlm.nih.gov/cas/v1/tickets/'
    UMLS_VERSION = '2021AB'
    
    ### Set retry strategy
    retry_strategy = Retry(
        total=20,
        backoff_factor=1,
        status_forcelist=[429, 502, 503, 504],
        allowed_methods=['GET', 'POST'])

    ### Set timeout strategy
    timeout_strategy = Timeout(connect=20, read=20)
    
    def __init__(self, key):
        self.umls_api_key = key
        self.session = requests.Session()
        self.session.mount('https://', requests.adapters.HTTPAdapter(max_retries=self.retry_strategy))
        self.refresh_tgt()
        
    ### Wrapper to refresh the TGT
    def refresh_tgt(self):
        self.tgt = self.get_umls_tgt()
        self.tgt_datetime = datetime.datetime.now()
        
    ### Get a UMLS TGT once every 8 hours 
    def get_umls_tgt(self):
        try: 
            r = self.session.post(url=self.TGT_URL, data={'apikey': self.umls_api_key}, timeout=self.timeout_strategy)
        except (ConnectionError, TypeError, ValueError, Exception) as e:
            print(e)
            return None
        else:
            if r.status_code == 201:
                return r.headers.get('Location').split('/')[-1]
    
    ### Updates the TGT if necessary
    def check_tgt(self):
        if (datetime.datetime.now() >= self.tgt_datetime+datetime.timedelta(hours=7, minutes=45)) or self.tgt is None:
            self.refresh_tgt()

    ### Get the details for a particular cui
    def cui_query(self, cui, detail_type=None, page_size=25):
        # Create the URL for getting CUI details
        url = 'https://uts-ws.nlm.nih.gov/rest/content/'+self.UMLS_VERSION+'/CUI/'+cui
        
        # Get different result types
        if not detail_type:
            pass
        elif detail_type.startswith('atom'):
            url = url+'/atoms/preferred'
        elif detail_type.startswith('pref'):
            url = url+'/atoms'
        elif detail_type.startswith('defin'):
            url = url+'/definitions'
        elif detail_type.startswith('relat'):
            url = url+'/relations'
        else:
            pass
        
        # Check the tgt and get a service ticket
        self.check_tgt()
        service_ticket = self.get_umls_service_ticket_()
        
        # return results as dataframe
        return pd.DataFrame(self.query_UMLS_(url=url, payload={'ticket': service_ticket, 'pageSize': page_size}))

        
    ### Get the CUIs that match a particular string
    def string_query(self, q, search_type='normalizedString', page_size=25):
        
        # Accepted string types: ‘exact’,‘words’,‘leftTruncation’, ‘rightTruncation’,‘approximate’, ‘normalizedString’, ‘normalizedWords’
        
        # Create the URL for searching UMLS
        url = 'https://uts-ws.nlm.nih.gov/rest/search/'+self.UMLS_VERSION
        
        # Check the tgt and get a service ticket
        self.check_tgt()
        service_ticket = self.get_umls_service_ticket_()
        
        # return results as dataframe
        return pd.DataFrame(self.query_UMLS_(url=url, payload={'ticket': service_ticket, 'string': q, 'searchType': search_type, 'pageSize': page_size}).get('results'))
    
    
    ### Get a UMLS Service Ticket before every API call
    def get_umls_service_ticket_(self):
        try: 
            url = self.SERVICE_TICKET_URL+self.tgt
            r = self.session.post(url=url, data={'service': 'http://umlsks.nlm.nih.gov'}, timeout=self.timeout_strategy)
        except (ConnectionError, TypeError, ValueError, Exception) as e:
            print(e)
        else:
            if r.status_code == 200:
                return str(r._content, 'UTF-8')
     
    
    ### Get a UMLS Service Ticket before every API call       
    def query_UMLS_(self, url, payload=None):
        try:
            r = self.session.get(url=url, params=payload, timeout=self.timeout_strategy)
        except (ConnectionError, TypeError, ValueError, Exception) as e:
            print(e)
        else:
            if r.status_code == 200:
                return r.json().get('result')

In [None]:
a = UMLS_QUERY(key=UMLS_API_KEY)

In [None]:
a.cui_query('C0001645', detail_type='pref')

In [None]:
a.cui_query('C0001645', detail_type='relat')

In [None]:
a.cui_query('C0001645', detail_type=None)

In [None]:
a.string_query('adrenergic receptor antagonist', search_type='normalizedWords')