# 1. Useful functions

This notebook contains useful functions that can be imported and reused within a Jupyter notebook.

In [None]:
import dns
from dns import resolver
import requests
from urllib.parse import urlparse
import geoip2.database
import pyspark.sql.types as t

# Get A-Record from domain
def getARecords(domain):
    try:
        result = dns.resolver.resolve(domain, 'A')
        return list(map(lambda ipval: ipval.to_text(), result))
    except dns.resolver.NXDOMAIN:
        return None
    except dns.resolver.NoAnswer:
        return None
    except dns.resolver.NoNameservers:
        return None
    except dns.resolver.Timeout:
        return None
    except Exception:
        return None

# Get A-Record error from domain
def getARecords_error(domain):
    try:
        result = dns.resolver.resolve(domain, 'A')
        return 0
    except dns.resolver.NXDOMAIN:
        return 1
    except dns.resolver.NoAnswer:
        return 2
    except dns.resolver.NoNameservers:
        return 3
    except dns.resolver.Timeout:
        return 4
    except Exception:
        return 8

# Get IPv6-Record from domain
def IPv6Record(domain):
    try:
        result = dns.resolver.resolve(domain, 'AAAA')
        return True
    except dns.resolver.NXDOMAIN:
        return None
    except dns.resolver.NoAnswer:
        return False
    except dns.resolver.NoNameservers:
        return None
    except dns.resolver.Timeout:
        return None
    except Exception:
        return None
    
def IPv6Record_error(domain):
    try:
        result = dns.resolver.resolve(domain, 'AAAA')
        return 0
    except dns.resolver.NXDOMAIN:
        return 1
    except dns.resolver.NoAnswer:
        return 2
    except dns.resolver.NoNameservers:
        return 3
    except dns.resolver.Timeout:
        return 4
    except Exception:
        return 8

# Get MX-Record from domain   
def getMXRecords(domain):
    try:
        result = dns.resolver.resolve(domain, 'MX')
        return list(map(lambda mail: mail.exchange.to_text(), result))
    except dns.resolver.NXDOMAIN:
        return None
    except dns.resolver.NoAnswer:
        return None
    except dns.resolver.NoNameservers:
        return None
    except dns.exception.Timeout:
        return None
    except Exception:
        return None

# Get MX-Record error from domain  
def getMXRecords_error(domain):
    try:
        result = dns.resolver.resolve(domain, 'MX')
        return 0
    except dns.resolver.NXDOMAIN:
        return 1
    except dns.resolver.NoAnswer:
        return 2
    except dns.resolver.NoNameservers:
        return 3
    except dns.exception.Timeout:
        return 4
    except Exception:
        return 8
    
# Get Redirect from domain
def getRedirectUrl(domain):
    try:
        response = requests.get("http://" + domain, timeout=5)
        url = urlparse(response.url).netloc
        url = ('.'.join(url.split('.')[-2:]))
        return url
    except requests.exceptions.ConnectionError:
        return None
    except requests.exceptions.ReadTimeout:
        return None
    except requests.exceptions.TooManyRedirects:
        return None
    except Exception:
        return None

# Get status code from domain
def getStatusCodeUrl(domain):
    try:
        response = requests.get("http://" + domain, timeout=5)
        return response.status_code
    except requests.exceptions.ConnectionError:
        return None
    except requests.exceptions.ReadTimeout:
        return None
    except requests.exceptions.TooManyRedirects:
        return None
    except Exception:
        return None

# Get SOA information from domain
def getSOAInformation(domain):
    try:
        result = dns.resolver.resolve(domain, 'SOA')
        return list(map(lambda soa: soa.to_text(), result))
    except dns.resolver.NXDOMAIN:
        return None
    except dns.resolver.NoAnswer:
        return None
    except dns.resolver.NoNameservers:
        return None
    except dns.exception.Timeout:
        return None
    except Exception:
        return None
    
    
# Get SOA information error from domain  
def getSOAInformation_error(domain):
    try:
        result = dns.resolver.resolve(domain, 'SOA')
        return 0
    except dns.resolver.NXDOMAIN:
        return 1
    except dns.resolver.NoAnswer:
        return 2
    except dns.resolver.NoNameservers:
        return 3
    except dns.exception.Timeout:
        return 4
    except Exception:
        return 8
    
    
# Get nameserver from domain    
def getNameServers(domain):
    try:
        result = dns.resolver.resolve(domain, 'NS')
        return list(map(lambda soa: soa.to_text(), result))
    except dns.resolver.NXDOMAIN:
        return None
    except dns.resolver.NoAnswer:
        return None
    except dns.resolver.NoNameservers:
        return None
    except dns.exception.Timeout:
        return None
    except Exception:
        return None
    
    
# Get nameserver error from domain  
def getNameServers_error(domain):
    try:
        result = dns.resolver.resolve(domain, 'NS')
        return 0
    except dns.resolver.NXDOMAIN:
        return 1
    except dns.resolver.NoAnswer:
        return 2
    except dns.resolver.NoNameservers:
        return 3
    except dns.exception.Timeout:
        return 4
    except Exception:
        return 8
    
# Get geolite2 location data
def getGeoLite2_Location(ip):
    try:
        with geoip2.database.Reader('../GeoLite2-City.mmdb') as reader:
            response = reader.city(ip)
            iso_code = response.country.iso_code
            city = response.city.name
            postal = response.postal.code
            latitude = response.location.latitude
            longitude = response.location.longitude
            return t.Row('iso_code', 'city', 'postal', 'latitude', 'longitude')(iso_code, city, postal, latitude, longitude)
    except:
        return None

# Get geolite2 ASN data
def getGeoLite2_ASN(ip):
    try:
        with geoip2.database.Reader('../GeoLite2-ASN.mmdb') as reader:
            response = reader.asn(ip)
            asno = response.autonomous_system_organization
            return t.Row('autonomous_system_organization')(asno)
    except:
        return None

lambda_dot_remove = lambda arr: [x[:-1] for x in arr]
def fn_remove_dot(arr): return None if arr == None else lambda_dot_remove(arr)

In [None]:
from concurrent.futures import ThreadPoolExecutor

def execute_threaded_fn(fn, args):
    futures = []
    dictionary = {}

    with ThreadPoolExecutor(max_workers=32) as executor:
        for arg in args:
            futures.append((arg, executor.submit(fn, arg)))
        for future in futures:
            try:
                result = future[1].result(timeout=60)
                dictionary[future[0]] = result
            except Exception:
                results.append(None)
    return dictionary 

In [None]:
import psycopg2
import math

from pyspark.sql.functions import monotonically_increasing_id

def mapped_row_value(value):
    if value is None: return 'NULL'
    elif type(value) is str: return f"'{value}'"
    elif (type(value)) is list: return f"ARRAY [{','.join(map(lambda x: mapped_row_value(x), value))}]"
    return value

def upsert(table, keys, updateable, df):
    columns = f"{*keys, *updateable}".replace("'", "")
    update_excluded = map(lambda x: f"{x} = EXCLUDED.{x}", updateable)
    insert_into = f"INSERT INTO {table} {columns}" 
    on_conflict = f"ON CONFLICT ({','.join(keys)}) DO UPDATE SET {','.join(update_excluded)}"
    
    conn = psycopg2.connect("dbname='domainanalysis' host='bda_gr4_database' user='postgres' password='postgres'")
    cur = conn.cursor()

    processed = 0
    insert_batch_size = 10000
    total_len = df.count()
    split_size = int(math.ceil(total_len / insert_batch_size))
    df = df.withColumn('data_insertion_index_for_splitting',monotonically_increasing_id())
    
    for x in range(split_size):
        values = []
        pandas_df = df.filter(df["data_insertion_index_for_splitting"] % split_size == x) \
            .drop('data_insertion_index_for_splitting') \
            .toPandas()

        for index, row in pandas_df.iterrows():
            row_values = []
            for col in [*keys, *updateable]: row_values.append(mapped_row_value(row[col]))
            values.append(f"({','.join(row_values)})")

        cur.execute(f"{insert_into} VALUES {','.join(values)} {on_conflict}")
        conn.commit()
        processed = len(pandas_df) + processed
        print(f"[Table '{table}' - {processed/total_len * 100}% - ({processed}/{total_len})] Successfully inserted or updated data!")

    cur.close()
    conn.close()