In [None]:
#created on databricks
import websocket
import json
from datetime import datetime

from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, StringType, FloatType, TimestampType, IntegerType


class websockettospark:
    def __init__(self):
        self.spark = SparkSession.builder \
            .appName("WebSocketToDataFrame") \
            .master('local[*]') \
            .getOrCreate()
        
        self.schema = StructType([
            StructField('type', StringType(), True),
            StructField('product_id', StringType(), True),
            StructField('price', FloatType(), True),
            StructField('time', StringType(), True),
            StructField('trade_id', IntegerType(), True)
        ])
        empty_rdd = spark.sparkContext.emptyRDD()

        self.df = spark.createDataFrame(empty_rdd, self.schema)
        self.messages_count = 0

    def on_message(self, ws, message):
        self.messages_count += 1
        msg = json.loads(message)
        if msg["type"] != "subscribe":
            df2 = self.spark.createDataFrame([(msg["type"], 
                                               msg["product_id"], 
                                               float(msg["price"]), 
                                               msg["time"],
                                               int(msg["trade_id"]))], 
                                            schema=self.schema)
            self.df = self.df.union(df2)
        if self.messages_count > 10:
            ws.close()

    def on_error(self, ws, error):
        print(error)

    def on_close(self, ws, _, __):
        print("------------Connection closed------------------")

    def on_open(self, ws):
        channel = "ticker"
        product_ids = 'BTC-USD'
        message = {
            'type': 'subscribe',
            'channels': [{'name': channel, 'product_ids': [product_ids]}]
        }
        message_json = json.dumps(message)

        print(f"Sending: {message_json}")
        ws.send(message_json)

    def run(self):
        if not hasattr(self, '_already_running'):
            websocket.enableTrace(False)
            ws = websocket.WebSocketApp("wss://ws-feed.exchange.coinbase.com",
                                        on_message=self.on_message,
                                        on_error=self.on_error,
                                        on_close=self.on_close)
            ws.on_open = self.on_open
            ws.run_forever()

            self.df.show(truncate=False)


if __name__ == "__main__":
    websockettospark().run()
