In [1]:
import duckdb
import torch
from torch.utils.data import Dataset
import pathlib

In [62]:

class DiabetesDataset(Dataset):
    def __init__(self, dataPath: str):
        self.conn = duckdb.connect(':memory:')
        path = pathlib.Path(dataPath)
        if not path.exists():
            raise FileNotFoundError(f"Path {path} does not exist")

        # load data from parquet files to memory
        self.conn.execute(f"CREATE TABLE data AS SELECT * FROM parquet_scan('{path / 'data.parquet'}')")
        self.conn.execute(f"CREATE TABLE target AS SELECT * FROM parquet_scan('{path / 'target.parquet'}')")
        self.conn.execute(f"CREATE TABLE static_data_vocab AS SELECT * FROM parquet_scan('{path / 'static_data_vocab.parquet'}')")
        self.conn.execute(f"CREATE TABLE dynamic_data_vocab AS SELECT * FROM parquet_scan('{path / 'dynamic_data_vocab.parquet'}')")

        # calculate the number of samples and features
        self.len_samples = self.conn.execute("SELECT COUNT(*) FROM data").df().iloc[0, 0]
        self.len_static_features = self.conn.execute("SELECT max(id) FROM static_data_vocab").df().iloc[0, 0]
        self.len_dynamic_features = self.conn.execute("SELECT max(id) FROM dynamic_data_vocab").df().iloc[0, 0]

    def __len__(self):
        return self.len_samples

    def __getitem__(self, idx):
        # get the sample
        data = self.conn.execute(f"SELECT static_features, periods, dynamic_features, COALESCE(duration, 0) + 1 FROM data WHERE idx = {idx+1}").fetchone()
        target = self.conn.execute(f"SELECT target FROM target WHERE idx = {idx+1}").fetchone()

        static_data = torch.zeros(self.len_static_features)
        dynamic_data = torch.zeros(data[3], self.len_dynamic_features)

        # fill the static data
        for i, feature in enumerate(data[0]):
            static_data[feature-1] = 1

        # fill the dynamic data
        for i, period in enumerate(data[1]):
            for j, feature in enumerate(data[2][i]):
                dynamic_data[period][feature-1] = 1

        return static_data, dynamic_data, target[0]

In [63]:
ds = DiabetesDataset("/data/vgribanov/data/readm/prepared_data")

In [38]:
ds.conn.query("SELECT * FROM data where duration is null LIMIT 10").df()

Unnamed: 0,idx,static_features,dynamic_features,periods,duration


In [65]:
idx = 0
data = ds.conn.execute(f"SELECT static_features, periods, dynamic_features, COALESCE(duration, 0) + 1 FROM data WHERE idx = {idx+1}").fetchone()
target = ds.conn.execute(f"SELECT target FROM target WHERE idx = {idx+1}").fetchone()

static_data = torch.zeros(ds.len_static_features)
dynamic_data = torch.zeros(data[3], ds.len_dynamic_features)

for i, feature in enumerate(data[0]):
    static_data[feature-1] = 1


for i, period in enumerate(data[1]):
    for j, feature in enumerate(data[2][i]):
        print("period, feature:",period, feature)
        dynamic_data[period][feature-1] = 1

non_zero_indices = [
    (i, j)
    for i, row in enumerate(dynamic_data)
    for j, val in enumerate(row)
    if val != 0
]
print("data", data[2])
print("non zero", non_zero_indices)

print(data)
print(target)
print(static_data)
print(dynamic_data)

period, feature: 0 3434
period, feature: 0 3443
period, feature: 0 3453
period, feature: 0 3460
period, feature: 0 3470
period, feature: 3 3438
period, feature: 3 3440
period, feature: 3 3453
period, feature: 3 3458
period, feature: 3 3470
period, feature: 4 3442
period, feature: 4 3453
period, feature: 4 3460
period, feature: 4 3469
period, feature: 6 3434
period, feature: 6 3440
period, feature: 6 3453
period, feature: 6 3458
period, feature: 6 3469
period, feature: 7 3440
period, feature: 7 3453
period, feature: 7 3458
period, feature: 7 3468
period, feature: 8 3433
period, feature: 8 3440
period, feature: 8 3453
period, feature: 8 3458
period, feature: 8 3470
period, feature: 9 3434
period, feature: 9 3440
period, feature: 9 3453
period, feature: 9 3458
period, feature: 9 3473
period, feature: 10 3434
period, feature: 10 3441
period, feature: 10 3453
period, feature: 10 3462
period, feature: 10 3471
period, feature: 11 61
period, feature: 11 88
period, feature: 11 98
period, featur

In [66]:
for i in range(ds.len_samples):
    item = ds[i]
    if i % 1000 == 0:
        print(i, item[0].shape, item[1].shape, item[2])

0 torch.Size([109]) torch.Size([556, 4493]) 0
1000 torch.Size([109]) torch.Size([335, 4493]) 0
2000 torch.Size([109]) torch.Size([194, 4493]) 1
3000 torch.Size([109]) torch.Size([234, 4493]) 0
4000 torch.Size([109]) torch.Size([77, 4493]) 0
5000 torch.Size([109]) torch.Size([128, 4493]) 0
6000 torch.Size([109]) torch.Size([268, 4493]) 0
7000 torch.Size([109]) torch.Size([261, 4493]) 0
8000 torch.Size([109]) torch.Size([49, 4493]) 0
9000 torch.Size([109]) torch.Size([94, 4493]) 0
10000 torch.Size([109]) torch.Size([97, 4493]) 0
11000 torch.Size([109]) torch.Size([68, 4493]) 0
12000 torch.Size([109]) torch.Size([127, 4493]) 0
13000 torch.Size([109]) torch.Size([292, 4493]) 0
14000 torch.Size([109]) torch.Size([106, 4493]) 0
15000 torch.Size([109]) torch.Size([257, 4493]) 0
16000 torch.Size([109]) torch.Size([310, 4493]) 0
17000 torch.Size([109]) torch.Size([65, 4493]) 0
18000 torch.Size([109]) torch.Size([220, 4493]) 0
19000 torch.Size([109]) torch.Size([161, 4493]) 0
20000 torch.Size([1

In [64]:
ds.conn.query("FROM (SELECT *, lag(idx) over(ORDER BY idx) as prev_idx FROM data) where idx-prev_idx <> 1").df()

FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

Unnamed: 0,idx,static_features,dynamic_features,periods,duration,prev_idx


In [67]:
ds.conn.query("select target, count(*) FROM target group by all").df()

Unnamed: 0,target,count_star()
0,0,82750
1,1,20128
