Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add temp table bulk loading #17

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 139 additions & 0 deletions src/DustyTables/Sql.fs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ open System.Threading.Tasks
open System.Data
open Microsoft.Data.SqlClient
open System.Threading
open System.Text.RegularExpressions

type Sql() =
static member dbnull = SqlParameter(Value=DBNull.Value)
Expand Down Expand Up @@ -400,3 +401,141 @@ module Sql =
with
| error -> return Error error
}

type private TempTableLoader(fieldCount, items: obj seq) =
let enumerator = items.GetEnumerator()

interface IDataReader with
member this.FieldCount: int = fieldCount
member this.Read(): bool = enumerator.MoveNext()
member this.GetValue(i: int): obj =
let row : obj[] = unbox enumerator.Current
row.[i]
member this.Dispose(): unit = ()

member __.Close(): unit = invalidOp "NotImplementedException"
member __.Depth: int = invalidOp "NotImplementedException"
member __.GetBoolean(_: int): bool = invalidOp "NotImplementedException"
member __.GetByte(_ : int): byte = invalidOp "NotImplementedException"
member __.GetBytes(_ : int, _ : int64, _ : byte [], _ : int, _ : int): int64 = invalidOp "NotImplementedException"
member __.GetChar(_ : int): char = invalidOp "NotImplementedException"
member __.GetChars(_ : int, _ : int64, _ : char [], _ : int, _ : int): int64 = invalidOp "NotImplementedException"
member __.GetData(_ : int): IDataReader = invalidOp "NotImplementedException"
member __.GetDataTypeName(_ : int): string = invalidOp "NotImplementedException"
member __.GetDateTime(_ : int): System.DateTime = invalidOp "NotImplementedException"
member __.GetDecimal(_ : int): decimal = invalidOp "NotImplementedException"
member __.GetDouble(_ : int): float = invalidOp "NotImplementedException"
member __.GetFieldType(_ : int): System.Type = invalidOp "NotImplementedException"
member __.GetFloat(_ : int): float32 = invalidOp "NotImplementedException"
member __.GetGuid(_ : int): System.Guid = invalidOp "NotImplementedException"
member __.GetInt16(_ : int): int16 = invalidOp "NotImplementedException"
member __.GetInt32(_ : int): int = invalidOp "NotImplementedException"
member __.GetInt64(_ : int): int64 = invalidOp "NotImplementedException"
member __.GetName(_ : int): string = invalidOp "NotImplementedException"
member __.GetOrdinal(_ : string): int = invalidOp "NotImplementedException"
member __.GetSchemaTable(): DataTable = invalidOp "NotImplementedException"
member __.GetString(_ : int): string = invalidOp "NotImplementedException"
member __.GetValues(_ : obj []): int = invalidOp "NotImplementedException"
member __.IsClosed: bool = invalidOp "NotImplementedException"
member __.IsDBNull(_ : int): bool = invalidOp "NotImplementedException"
member __.Item with get (_ : int): obj = invalidOp "NotImplementedException"
member __.Item with get (_ : string): obj = invalidOp "NotImplementedException"
member __.NextResult(): bool = invalidOp "NotImplementedException"
member __.RecordsAffected: int = invalidOp "NotImplementedException"

type TempTable =
{ Name : string
Columns : Map<string, int> }

let private tempTableNameRegex = Regex("(#[a-z0-9\\-_]+)", RegexOptions.IgnoreCase)

let private tempTableColumnRegex =
[ "bigint"
"binary"
"bit"
"char"
"datetimeoffset"
"datetime2"
"datetime"
"date"
"decimal"
"float"
"image"
"int"
"nchar"
"ntext"
"nvarchar"
"real"
"timestamp"
"varbinary" ]
|> String.concat "|"
|> fun x -> Regex(@"[\[]{0,1}([a-z0-9\-_]+)[\]]{0,1} (?:"+x+")", RegexOptions.IgnoreCase)

let createTempTable table (props : SqlProps) =
let connection = getConnection props
if not (connection.State.HasFlag ConnectionState.Open) then connection.Open()

use command = new SqlCommand(table, connection)
command.ExecuteNonQuery() |> ignore

let name = tempTableNameRegex.Match(table).Groups.[1].Value

let columns =
tempTableColumnRegex.Matches(table)
|> Seq.cast
|> Seq.mapi(fun i (m : Match) -> m.Groups.[1].Value, i )
|> Map.ofSeq

let info =
{ TempTable.Name = name
Columns = columns }

{ props with ExistingConnection = Some connection }, info

let tempTableData data (props, info : TempTable) =
props, info, data

let loadTempTable mapper (props : SqlProps, info : TempTable, data) =
let items =
data
|> Seq.map(fun item ->
let cols = mapper item

let arr = Array.zeroCreate info.Columns.Count
cols
|> List.iter(fun (name, p : SqlParameter) ->
let index = info.Columns |> Map.find name
arr.[index] <- p.Value
)
box arr
)

use reader = new TempTableLoader(info.Columns.Count, items)

use bulkCopy = new SqlBulkCopy(props.ExistingConnection.Value)
props.Timeout |> Option.iter (fun x -> bulkCopy.BulkCopyTimeout <- x)
bulkCopy.BatchSize <- 5000
bulkCopy.DestinationTableName <- info.Name
bulkCopy.WriteToServer(reader)

props

let executeStream (read: RowReader -> 't) (props : SqlProps) =
seq {
if props.SqlQuery.IsNone then failwith "No query provided to execute. Please use Sql.query"
let connection = getConnection props
try
if not (connection.State.HasFlag ConnectionState.Open)
then connection.Open()
use command = new SqlCommand(props.SqlQuery.Value, connection)
props.Timeout |> Option.iter (fun x -> command.CommandTimeout <- x)
do populateCmd command props
if props.NeedPrepare then command.Prepare()
use reader = command.ExecuteReader()
let rowReader = RowReader(reader)
while reader.Read() do
read rowReader
finally
if props.ExistingConnection.IsNone
then connection.Dispose()
}
16 changes: 16 additions & 0 deletions tests/DustyTables.Tests/Tests.fs
Original file line number Diff line number Diff line change
Expand Up @@ -259,4 +259,20 @@ let tests = testList "DustyTables" [
| Error ex -> raise ex
| otherwise ->
fail()

testDatabase "temp table loading" <| fun connectionString ->
let data = [ 1; 2; 3]

connectionString
|> Sql.connect
|> Sql.createTempTable "create table #Temp(Id int not null)"
|> Sql.tempTableData data
|> Sql.loadTempTable (fun row ->
[ "Id", Sql.int row ]
)
|> Sql.query "select Id from #Temp"
|> Sql.executeStream (fun read -> read.int "Id")
|> fun stream ->
let actual = stream |> Seq.toList
Expect.equal actual data "Result doesn't match"
]