## Import required packages

In [0]:
# Download required packages
!pip -q install gdown missingno torch

%matplotlib inline

import pyspark
from pyspark.sql import *
from pyspark.sql.types import *
from pyspark.sql.functions import *
from pyspark import SparkContext, SparkConf

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px
import missingno as msno
import torch
import torch.nn as nn
from typing import *
import datetime
import gdown

import tqdm as tq
def tqdm(*args, **kwargs):
  ''' Small trick to prevent tqdm printing newlines at each step. '''
  return tq.tqdm(*args, **kwargs, leave=True, position=0)

## Data aquisition
We retrieve our datasets and download them to a temporary directory in the driver node.

In [0]:
gdown.download('https://drive.google.com/uc?id=1ggmDp-AWFzbQReLG0pLpQE_3fO0C0RnM', '/tmp/data.zip', quiet=False)
!unzip -q /tmp/data.zip -d /tmp/
!rm /tmp/data.zip

Then we load the datasets to the DBFS.

In [0]:
dbutils.fs.mv("file:/tmp/data", "dbfs:/data", recurse=True)

In [0]:
%fs ls /data/

path,name,size
dbfs:/data/.DS_Store,.DS_Store,6148
dbfs:/data/key_stats_yahoo.csv,key_stats_yahoo.csv,2047081
dbfs:/data/prices/,prices/,0


In [0]:
%fs ls /data/

path,name,size
dbfs:/data/.DS_Store,.DS_Store,6148
dbfs:/data/key_stats_yahoo.csv,key_stats_yahoo.csv,2047081
dbfs:/data/prices/,prices/,0


## Dataset loading

In [0]:
key_stats_df = spark.read.load("dbfs:/data/key_stats_yahoo.csv", 
                           format="csv",
                           sep=",",
                           inferSchema="true",
                           header="true"
                          )

# Drop the first ID column
key_stats_df = key_stats_df.drop(key_stats_df.columns[0])

# Use legacy format to parse dates
spark.sql("set spark.sql.legacy.timeParserPolicy=LEGACY")
key_stats_df = key_stats_df.withColumn("Date", to_date(key_stats_df["Date"], 'MM/dd/yyyy HH:mm'))

# Cast numerical columns to double
for column in key_stats_df.columns[2:]:
  key_stats_df = key_stats_df.withColumn(column, key_stats_df[column].cast("double"))

# Prices dataframes for every stock
prices_files = [f.path for f in dbutils.fs.ls('/data/prices/') if f.path.endswith('.csv')]
dfs_names = [f.rsplit('/', 1)[1][:-len('.csv')] for f in prices_files]
prices_dfs = []
for f in tqdm(prices_files, desc='Reading stock price data', total=len(prices_files)):
  df = spark.read.load(f,
                       format="csv",
                       sep=",",
                       inferSchema="true",
                       header="true"
                      )
  df = df.withColumn("Date", to_date(df["Date"], 'dd-MM-yyyy'))
  prices_dfs.append(df)

## Dataset analysis

In [0]:
print("Key stats dataframe format:")
prices_dfs[0].printSchema()

In [0]:
print("Prices dataframe format:")
key_stats_df.printSchema()

### Utility functions

In [0]:
# TODO: add remaining utility functions

def missing_values_summary(df):
  ''' Returns a utility summary to view missing values in our dataframe. '''
  n = df.count()
  
  def to_percentage(x: pyspark.sql.column.Column, n: int) -> int:
    ''' Utility function to compute the amount of missing values as a percentage of the original dataframe. '''
    return round(100 * x / n, 3)
  
  # Aggregate using the count function over null values, and return a view over the obtained (single row) dataframe
  return df.agg(*[to_percentage(count(when(isnull(c), c)), n).alias(c) for c in df.columns]).first()

In [0]:
print("Overview of the missing values in the key_stats dataframe\n")
key_stats_summary = missing_values_summary(key_stats_df)
key_stats_summary

In [0]:
def prices_df_nan_summary(prices_dfs, names):
    ''' Utility function to summarize columns that have missing values. '''
    nan_dfs = []
    for i, (prices_df, name) in enumerate(zip(prices_dfs, names)):
        nan_cols = []
        nan_values = []
        for column in prices_df.columns:
            nan_absolute = prices_df.filter(prices_df[column].isNull()).count()
            if nan_absolute > 0:
                # Only consider columns that cointain NaN values
                nan_cols.append(column)
                nan_values.append(nan_absolute)
        if len(nan_cols):
            # Either we have all the data for a given day or we don't have any data for it
            assert len(set(nan_values)) == 1
            count = nan_values[0]
            nan_dfs.append((name, round(100*count/prices_df.count(), 3), count))
    return nan_dfs
    #return pd.DataFrame(nan_dfs, columns=['Stock name', 'Missing data (%)', 'Count'])

### Missing values imputation

In [0]:
''' TODO: fix prices_df_nan_summary '''

nan_dfs = []
for i, (prices_df, name) in enumerate(tqdm(zip(prices_dfs, dfs_names), total=len(prices_dfs))):
  nan_absolute = prices_df.agg(*[count(when(isnull(c), c)).alias(c) for c in prices_df.columns]).first()
  if any(nan_absolute):
    #count = nan_absolute.Close
    print(count)
    nan_dfs.append((name, round(100*count/prices_df.count(), 3), count))
  
  #[print(x) for x in nan_absolute]
  
  #if nan_absolute > 0:
    #print(nan_absolute)
    
  

In [0]:
print(prices_dfs[0].columns)
prices_dfs[0].filter(prices_dfs[0]['Date'].isNull()).count()