In [1]:
import duckdb
import torch
from torch.utils.data import Dataset
import pathlib

In [2]:

class DiabetesDataset(Dataset):
    def __init__(self, dataPath: str):
        self.conn = duckdb.connect(':memory:')
        path = pathlib.Path(dataPath)
        if not path.exists():
            raise FileNotFoundError(f"Path {path} does not exist")

        # load data from parquet files to memory
        self.conn.execute(f"CREATE TABLE data AS SELECT * FROM parquet_scan('{path / 'data.parquet'}')")
        self.conn.execute(f"CREATE TABLE target AS SELECT * FROM parquet_scan('{path / 'target.parquet'}')")
        self.conn.execute(f"CREATE TABLE static_data_vocab AS SELECT * FROM parquet_scan('{path / 'static_data_vocab.parquet'}')")
        self.conn.execute(f"CREATE TABLE dynamic_data_vocab AS SELECT * FROM parquet_scan('{path / 'dynamic_data_vocab.parquet'}')")

        # calculate the number of samples and features
        self.len_samples = self.conn.execute("SELECT COUNT(*) FROM data").df().iloc[0, 0]
        self.len_static_features = self.conn.execute("SELECT max(id) FROM static_data_vocab").df().iloc[0, 0]
        self.len_dynamic_features = self.conn.execute("SELECT max(id) FROM dynamic_data_vocab").df().iloc[0, 0]

    def __len__(self):
        return self.len_samples

    def __getitem__(self, idx):
        # get the sample
        data = self.conn.execute(f"SELECT static_features, periods, dynamic_features, COALESCE(duration, 0) + 1 FROM data WHERE idx = {idx+1}").fetchone()
        target = self.conn.execute(f"SELECT target FROM target WHERE idx = {idx+1}").fetchone()

        static_data = torch.zeros(self.len_static_features)
        dynamic_data = torch.zeros(data[3], self.len_dynamic_features)

        # fill the static data
        for i, feature in enumerate(data[0]):
            static_data[feature-1] = 1

        # fill the dynamic data
        for i, period in enumerate(data[1]):
            for j, feature in enumerate(data[2][i]):
                dynamic_data[period][feature-1] = 1

        return static_data, dynamic_data, target[0]

In [3]:
ds = DiabetesDataset("/data/vgribanov/data/readm/6_hours")

In [4]:
ds.conn.query("SELECT * FROM data where duration is null LIMIT 10").df()

Unnamed: 0,idx,static_features,dynamic_features,periods,duration


In [5]:
idx = 0
data = ds.conn.execute(f"SELECT static_features, periods, dynamic_features, COALESCE(duration, 0) + 1 FROM data WHERE idx = {idx+1}").fetchone()
target = ds.conn.execute(f"SELECT target FROM target WHERE idx = {idx+1}").fetchone()

static_data = torch.zeros(ds.len_static_features)
dynamic_data = torch.zeros(data[3], ds.len_dynamic_features)

for i, feature in enumerate(data[0]):
    static_data[feature-1] = 1


for i, period in enumerate(data[1]):
    for j, feature in enumerate(data[2][i]):
        print("period, feature:",period, feature)
        dynamic_data[period][feature-1] = 1

non_zero_indices = [
    (i, j)
    for i, row in enumerate(dynamic_data)
    for j, val in enumerate(row)
    if val != 0
]
print("data", data[2])
print("non zero", non_zero_indices)

print(data)
print(target)
print(static_data)
print(dynamic_data)

period, feature: 0 3431
period, feature: 0 3451
period, feature: 0 3462
period, feature: 0 3467
period, feature: 0 3481
period, feature: 0 4401
period, feature: 0 4404
period, feature: 0 4405
period, feature: 0 4450
period, feature: 1 36
period, feature: 1 47
period, feature: 1 112
period, feature: 1 175
period, feature: 1 195
period, feature: 1 204
period, feature: 1 333
period, feature: 1 437
period, feature: 1 504
period, feature: 1 524
period, feature: 1 597
period, feature: 1 610
period, feature: 1 660
period, feature: 1 668
period, feature: 1 840
period, feature: 1 870
period, feature: 1 937
period, feature: 1 967
period, feature: 1 1033
period, feature: 1 1062
period, feature: 1 1073
period, feature: 1 1081
period, feature: 1 1139
period, feature: 1 1148
period, feature: 1 1175
period, feature: 1 1264
period, feature: 1 1277
period, feature: 1 1288
period, feature: 1 1438
period, feature: 1 1463
period, feature: 1 1468
period, feature: 1 1478
period, feature: 1 1532
period, feat

In [6]:
for i in range(ds.len_samples):
    item = ds[i]
    if i % 1000 == 0:
        print(i, item[0].shape, item[1].shape, item[2])

0 torch.Size([109]) torch.Size([60, 4493]) 1
1000 torch.Size([109]) torch.Size([7, 4493]) 0
2000 torch.Size([109]) torch.Size([14, 4493]) 1
3000 torch.Size([109]) torch.Size([22, 4493]) 0
4000 torch.Size([109]) torch.Size([18, 4493]) 0
5000 torch.Size([109]) torch.Size([35, 4493]) 0
6000 torch.Size([109]) torch.Size([73, 4493]) 0
7000 torch.Size([109]) torch.Size([12, 4493]) 1
8000 torch.Size([109]) torch.Size([24, 4493]) 1
9000 torch.Size([109]) torch.Size([13, 4493]) 0
10000 torch.Size([109]) torch.Size([9, 4493]) 0
11000 torch.Size([109]) torch.Size([5, 4493]) 0
12000 torch.Size([109]) torch.Size([20, 4493]) 0
13000 torch.Size([109]) torch.Size([206, 4493]) 0
14000 torch.Size([109]) torch.Size([10, 4493]) 0
15000 torch.Size([109]) torch.Size([16, 4493]) 0
16000 torch.Size([109]) torch.Size([5, 4493]) 1
17000 torch.Size([109]) torch.Size([21, 4493]) 0
18000 torch.Size([109]) torch.Size([30, 4493]) 0
19000 torch.Size([109]) torch.Size([10, 4493]) 0
20000 torch.Size([109]) torch.Size([

In [64]:
ds.conn.query("FROM (SELECT *, lag(idx) over(ORDER BY idx) as prev_idx FROM data) where idx-prev_idx <> 1").df()

FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

Unnamed: 0,idx,static_features,dynamic_features,periods,duration,prev_idx


In [67]:
ds.conn.query("select target, count(*) FROM target group by all").df()

Unnamed: 0,target,count_star()
0,0,82750
1,1,20128
