In [1]:
import sqlite3
import pandas as pd
import numpy as np
import csv
import random

import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence
from torch.utils.data import DataLoader, random_split
from torch.optim import Adam

import torchdata.datapipes as dp
from torchdata.datapipes import functional_datapipe
from torchdata.datapipes.iter import IterableWrapper, IterDataPipe, Mapper, MaxTokenBucketizer, ShardingFilter


In [2]:
csv_file = "/groups/icecube/petersen/GraphNetDatabaseRepository/Upgrade_Data/sqlite3/upgrade_pure_numu_selection_event_no.csv"
db_path = "/groups/icecube/petersen/GraphNetDatabaseRepository/Upgrade_Data/sqlite3/dev_step4_upgrade_028_with_noise_dynedge_pulsemap_v3_merger_aftercrash.db"


In [3]:
query_all ="SELECT * FROM truth"
with sqlite3.connect(db_path) as conn:
    db_tables = pd.read_sql_query("SELECT name FROM sqlite_master WHERE type = 'table'", conn)
    print('Tables within the SQL database is:')
    print(db_tables)
    print()
    mini_db = {name:  pd.read_sql_query(f"SELECT * FROM {name} LIMIT 3000", conn) for name in db_tables.name}

Tables within the SQL database is:
                                      name
0                                    truth
1                        pisa_dependencies
2                         SplitInIcePulses
3              SplitInIcePulses_TruthFlags
4                  SplitIceCubePulsesTWSRT
5       SplitIceCubePulsesTWSRT_TruthFlags
6        SplitInIcePulses_GraphSage_Pulses
7   SplitInIcePulses_GraphSage_Predictions
8  SplitInIcePulses_dynedge_v2_Predictions
9       SplitInIcePulses_dynedge_v2_Pulses



In [5]:
mini_db["SplitInIcePulses_dynedge_v2_Pulses"].describe()

Unnamed: 0,charge,dom_number,dom_time,dom_type,dom_x,dom_y,dom_z,event_no,event_time,is_bad_dom,...,is_errata_dom,is_saturated_dom,pmt_area,pmt_dir_x,pmt_dir_y,pmt_dir_z,pmt_number,rde,string,width
count,3000.0,3000.0,3000.0,3000.0,3000.0,3000.0,3000.0,3000.0,3000.0,3000.0,...,3000.0,3000.0,3000.0,3000.0,3000.0,3000.0,3000.0,3000.0,3000.0,3000.0
mean,0.996919,57.886,10693.175805,93.953333,42.207459,-54.537325,-340.829778,689930.899,59000.17,-1.0,...,-1.0,-1.0,0.02592,-0.005142,-0.013368,-0.331288,5.063667,1.06545,83.200333,2.175756
std,0.550068,26.17662,1626.229152,48.833384,56.274106,51.696705,88.189197,310.444116,9.90248e-12,0.0,...,0.0,0.0,0.015758,0.386108,0.379991,0.772572,7.267454,0.136492,16.927819,1.416558
min,0.118422,11.0,9797.400937,20.0,-368.93,-404.48,-652.074662,689357.0,59000.17,-1.0,...,-1.0,-1.0,0.008171,-0.878662,-0.878662,-1.0,0.0,1.0,8.0,0.750868
25%,0.706281,38.0,10148.451085,20.0,18.164966,-80.56375,-407.177,689735.0,59000.17,-1.0,...,-1.0,-1.0,0.008171,0.0,0.0,-1.0,0.0,1.0,85.0,2.0
50%,0.963712,54.0,10369.437995,120.0,47.164966,-58.873716,-339.22,689937.0,59000.17,-1.0,...,-1.0,-1.0,0.032429,0.0,0.0,-0.838671,1.0,1.0,89.0,2.0
75%,1.189269,77.0,10771.465614,130.0,62.623333,-35.16375,-277.868527,690226.0,59000.17,-1.0,...,-1.0,-1.0,0.0444,0.0,0.0,0.309017,9.0,1.0,91.0,2.0
max,10.291269,123.0,26447.381146,130.0,429.76,351.02,-62.78,690443.0,59000.17,-1.0,...,-1.0,-1.0,0.0444,0.878662,0.878662,1.0,23.0,1.35,93.0,6.25


In [6]:
mini_db["truth"].describe()

Unnamed: 0,CascadeFilter_13,DeepCoreFilter_13,EventID,L3_oscNext_bool,L4_oscNext_bool,L5_oscNext_bool,L6_oscNext_bool,L7_oscNext_bool,MuonFilter_13,OnlineL2Filter_17,...,event_time,inelasticity,interaction_type,pid,position_x,position_y,position_z,stopped_muon,track_length,zenith
count,3000.0,3000.0,3000.0,3000.0,3000.0,3000.0,3000.0,3000.0,3000.0,3000.0,...,3000.0,3000.0,3000.0,3000.0,3000.0,3000.0,3000.0,3000.0,3000.0,3000.0
mean,-1.0,0.260333,254.022,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,...,1.306125e+17,-0.2162,0.115,2.239333,26.162849,-23.869276,-231.741314,-0.816667,154.456544,0.597393
std,0.0,0.43889,304.680902,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,28823090.0,0.803394,1.112139,11.032042,175.903603,173.939024,282.852285,0.468112,440.534421,1.287551
min,-1.0,0.0,0.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,...,1.306125e+17,-1.0,-1.0,-14.0,-1712.766257,-1524.145615,-2475.092566,-1.0,-1.0,-1.0
25%,-1.0,0.0,4.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,...,1.306125e+17,-1.0,-1.0,-1.0,-1.0,-84.994875,-402.716683,-1.0,-1.0,-1.0
50%,-1.0,0.0,110.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,...,1.306125e+17,0.015703,1.0,-1.0,-1.0,-1.0,-235.331844,-1.0,-1.0,0.771274
75%,-1.0,1.0,455.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,...,1.306125e+17,0.480928,1.0,13.0,89.279698,4.831071,-1.0,-1.0,-1.0,1.659241
max,-1.0,1.0,1273.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,...,1.306125e+17,1.0,2.0,16.0,1010.625241,1452.757406,561.743984,1.0,3364.091649,3.125282


In [4]:

@functional_datapipe("read_csv")
class ReadCSV(IterDataPipe):
    def __init__(self, csv_file):
        self.csv_file = csv_file

    def __iter__(self):
        with open(self.csv_file, "r") as f:
            for line in f:
                yield int(line.strip())

@functional_datapipe("read_csv_dp")
class ReadCSVMultiple(IterDataPipe):
    def __init__(self, datapipe):
        self.datapipe = datapipe

    def __iter__(self):
        for csv_file_path in self.datapipe:
            with open(csv_file_path, "r") as f:
                for line in f:
                    yield int(line.strip())

@functional_datapipe("query_sql")
class QuerySQL(IterDataPipe):
    def __init__(self, datapipe, db_path, input_cols, pulsemap, target_cols, truth_table):

        self.datapipe = datapipe
        self.db_path = db_path
        self.input_cols_str = ", ".join(input_cols)
        self.target_cols_str = ", ".join(target_cols)
        self.pulsemap = pulsemap
        self.truth_table = truth_table

    def __iter__(self):
        with sqlite3.connect(self.db_path) as conn:
            for event_no in self.datapipe:
                features = torch.Tensor(conn.execute(f"SELECT {self.input_cols_str} FROM {self.pulsemap} WHERE event_no == {event_no}").fetchall())
                truth = torch.Tensor(conn.execute(f"SELECT {self.target_cols_str} FROM {self.truth_table} WHERE event_no == {event_no}").fetchall())
                yield (features, truth)

@functional_datapipe("transform_data")
class TransfromData(IterDataPipe):
    def __init__(self, datapipe, feature_transform, truth_transform = None):
        self.datapipe = datapipe 
        # self.input_cols = input_cols
        # self.target_cols = target_cols
        self.feature_transform = feature_transform

        if not truth_transform:
            self.truth_transform = lambda features : features
        else:
          self.truth_transform = truth_transform


    def __iter__(self):
        for features, truth in self.datapipe:
            features = self.feature_transform(features)
            truth = self.truth_transform(truth)

            yield (features, truth)

def upgrade_transform_func(x):
    features, truth = x
    features[:, 0] = torch.log10(features[:, 0]) / 2.0  # charge
    features[:, 1] /= 2e04  # dom_time
    features[:, 1] -= 1.0
    features[:, 2] /= 500.0  # dom_x
    features[:, 3] /= 500.0  # dom_y
    features[:, 4] /= 500.0  # dom_z
    features[:, 5] /= 0.05  # pmt_area
    # features[:,6] /= 1.  # pmt_dir_x
    # features[:,7] /= 1.  # pmt_dir_y
    # features[:,8] /= 1.  # pmt_dir_z
    truth = torch.log10(truth)
    return (features, truth)

def upgrade_feature_transform(features):
    features[:, 0] = torch.log10(features[:, 0]) / 2.0  # charge
    features[:, 1] /= 2e04  # dom_time
    features[:, 1] -= 1.0
    features[:, 2] /= 500.0  # dom_x
    features[:, 3] /= 500.0  # dom_y
    features[:, 4] /= 500.0  # dom_z
    features[:, 5] /= 0.05  # pmt_area
    # features[:,6] /= 1.  # pmt_dir_x
    # features[:,7] /= 1.  # pmt_dir_y
    # features[:,8] /= 1.  # pmt_dir_z
    return features
    

def Prometheus_feature_transform(features):
    features[:, 0] /= 100.0  # dom_x
    features[:, 1] /= 100.0  # dom_y
    features[:, 2] += 350.0  # dom_z
    features[:, 2] /= 100.0
    features[:, 3] /= 1.05e04  # dom_time
    features[:, 3] -= 1.0
    features[:, 3] *= 20.0
    return features

def log10_target_transform(target):
   return torch.log10(target)


@functional_datapipe("pad_batch")
class PadBatch(IterDataPipe):
    def __init__(self, batch):
        self.batch = batch
        
    def __iter__(self):
        for batch in self.batch:

          (xx, y) = zip(*batch)
          x_lens = [len(x) for x in xx]
          xx_pad = pad_sequence(xx, batch_first=True, padding_value=0)

          pad_mask = torch.zeros_like(xx_pad[:, :, 0]).type(torch.bool)

          for i, length in enumerate(x_lens):
              pad_mask[i, length:] = True

          yield (xx_pad, torch.tensor(y), pad_mask)


def len_fn(datapipe):
  features, _ = datapipe
  return features.shape[0]

In [9]:
def build_datapipe(csv_file, db_path, input_cols, pulsemap, target_cols, truth_table):
    datapipe = ReadCSV(
        csv_file
        )
    datapipe = QuerySQL(
        datapipe = datapipe, 
        db_path = db_path, 
        input_cols = input_cols, 
        pulsemap = pulsemap, 
        target_cols = target_cols,
        truth_table = truth_table,
        )
    return datapipe

def build_datapipe_transform(
        csv_file, 
        db_path, 
        input_cols, 
        pulsemap, 
        target_cols, 
        truth_table,
        feature_transform,
        truth_transform = None
    ):
    datapipe = ReadCSV(csv_file)
    datapipe = QuerySQL(
        datapipe = datapipe, 
        db_path = db_path, 
        input_cols = input_cols, 
        pulsemap = pulsemap, 
        target_cols = target_cols,
        truth_table = truth_table,
        )
    datapipe = TransfromData( datapipe, feature_transform, truth_transform)
    return datapipe

def build_datapipe_batch(
        csv_file, 
        db_path, 
        input_cols, 
        pulsemap, 
        target_cols, 
        truth_table, 
        max_token_count,
        feature_transform,
        truth_transform = None
    ):
    datapipe = ReadCSV(csv_file)
    datapipe = QuerySQL(
        datapipe = datapipe, 
        db_path = db_path, 
        input_cols = input_cols, 
        pulsemap = pulsemap, 
        target_cols = target_cols,
        truth_table = truth_table,
        )
    # datapipe = Mapper()
    datapipe = TransfromData( datapipe, feature_transform, truth_transform)
    datapipe = MaxTokenBucketizer(datapipe,max_token_count = max_token_count, len_fn=len_fn, include_padding=True)
    return datapipe

def build_datapipe_padded(
        csv_file, 
        db_path, 
        input_cols, 
        pulsemap, 
        target_cols, 
        truth_table, 
        max_token_count,
        feature_transform,
        truth_transform = None
    ):
    datapipe = ReadCSV( csv_file).sharding_filter()
    # datapipe = ShardingFilter( datapipe)
    datapipe = QuerySQL(
        datapipe = datapipe, 
        db_path = db_path, 
        input_cols = input_cols, 
        pulsemap = pulsemap, 
        target_cols = target_cols,
        truth_table = truth_table,
        )
    datapipe = TransfromData( datapipe, feature_transform, truth_transform)
    datapipe = MaxTokenBucketizer( datapipe, max_token_count = max_token_count, len_fn = len_fn, include_padding = True)
    datapipe = PadBatch(datapipe)
    return datapipe

def build_datapipe_padded_func(
        csv_file, 
        db_path, 
        input_cols, 
        pulsemap, 
        target_cols, 
        truth_table, 
        max_token_count,
        feature_transform,
        truth_transform = None
    ):
    datapipe = ReadCSV( csv_file).sharding_filter() \
        .query_sql(
        db_path = db_path, 
        input_cols = input_cols, 
        pulsemap = pulsemap, 
        target_cols = target_cols,
        truth_table = truth_table,
        ) \
        .transform_data(
        feature_transform, 
        truth_transform
        ) \
        .max_token_bucketize(
        max_token_count = max_token_count,
        len_fn = len_fn,
        include_padding = True
        ) \
        .pad_batch()
    
    # datapipe = ShardingFilter( datapipe)
    # datapipe = QuerySQL(
    #     datapipe = datapipe, 
    #     db_path = db_path, 
    #     input_cols = input_cols, 
    #     pulsemap = pulsemap, 
    #     target_cols = target_cols,
    #     truth_table = truth_table,
    #     )
    # datapipe = TransfromData( datapipe, feature_transform, truth_transform)
    # datapipe = MaxTokenBucketizer( datapipe, max_token_count = max_token_count, len_fn = len_fn, include_padding = True)
    # datapipe = PadBatch(datapipe)
    return datapipe

def build_datapipe_map_padded_func(
        csv_file, 
        db_path, 
        input_cols, 
        pulsemap, 
        target_cols, 
        truth_table, 
        max_token_count,
    ):
    datapipe = ReadCSV( csv_file).sharding_filter() \
        .query_sql(
        db_path = db_path, 
        input_cols = input_cols, 
        pulsemap = pulsemap, 
        target_cols = target_cols,
        truth_table = truth_table,
        ) \
        .map(upgrade_transform_func
        ) \
        .max_token_bucketize(
        max_token_count = max_token_count,
        len_fn = len_fn,
        include_padding = True
        ) \
        .pad_batch()
    
    # datapipe = ShardingFilter( datapipe)
    # datapipe = QuerySQL(
    #     datapipe = datapipe, 
    #     db_path = db_path, 
    #     input_cols = input_cols, 
    #     pulsemap = pulsemap, 
    #     target_cols = target_cols,
    #     truth_table = truth_table,
    #     )
    # datapipe = TransfromData( datapipe, feature_transform, truth_transform)
    # datapipe = MaxTokenBucketizer( datapipe, max_token_count = max_token_count, len_fn = len_fn, include_padding = True)
    # datapipe = PadBatch(datapipe)
    return datapipe

In [10]:


pulsemap = "SplitInIcePulses_dynedge_v2_Pulses"
truth_table = "truth"
input_cols = ["charge",	"dom_time", "dom_x", "dom_y", "dom_z", "pmt_area", "pmt_dir_x", "pmt_dir_y", "pmt_dir_z"]
target_cols = ["energy"]
max_token_count = 400

datapipe = build_datapipe(
    csv_file = csv_file, 
    db_path = db_path, 
    input_cols = input_cols, 
    pulsemap = pulsemap, 
    target_cols = target_cols, 
    truth_table = truth_table,
)

datapipe_transform = build_datapipe_transform(
    csv_file = csv_file, 
    db_path = db_path, 
    input_cols = input_cols, 
    pulsemap = pulsemap, 
    target_cols = target_cols, 
    truth_table = truth_table, 
    feature_transform = upgrade_feature_transform,
    truth_transform = log10_target_transform,
)

datapipe_batch = build_datapipe_batch(
    csv_file = csv_file, 
    db_path = db_path, 
    input_cols = input_cols, 
    pulsemap = pulsemap, 
    target_cols = target_cols, 
    truth_table = truth_table, 
    max_token_count = max_token_count, 
    feature_transform = upgrade_feature_transform,
    truth_transform = log10_target_transform,
)

datapipe_padded = build_datapipe_padded(
    csv_file = csv_file, 
    db_path = db_path, 
    input_cols = input_cols, 
    pulsemap = pulsemap, 
    target_cols = target_cols, 
    truth_table = truth_table, 
    max_token_count = max_token_count, 
    feature_transform = upgrade_feature_transform, 
    truth_transform = log10_target_transform, 
)

datapipe_padded_func = build_datapipe_padded_func(
    csv_file = csv_file, 
    db_path = db_path, 
    input_cols = input_cols, 
    pulsemap = pulsemap, 
    target_cols = target_cols, 
    truth_table = truth_table, 
    max_token_count = max_token_count, 
    feature_transform = upgrade_feature_transform, 
    truth_transform = log10_target_transform, 
)
datapipe_map = build_datapipe_map_padded_func(
    csv_file = csv_file, 
    db_path = db_path, 
    input_cols = input_cols, 
    pulsemap = pulsemap, 
    target_cols = target_cols, 
    truth_table = truth_table, 
    max_token_count = max_token_count,
)

In [11]:
for i, batch in enumerate(datapipe_map):

  xx, y, pad, = batch
  print(sum([x for x in xx]))
  print(y)

  print()
  if i == 2:
      break

400
40

396
33

396
33



In [13]:
for i, batch in enumerate(datapipe_map):

  xx, y, pad, = batch
  print(sum([x for x in xx]))
  print(y)

  print()
  if i == 2:
      break

tensor([[-2.8842e+00, -1.7621e+01,  7.9036e-01, -3.6699e+00, -2.5945e+01,
          3.4795e+01,  3.6395e-01, -8.7866e-01, -3.9309e+01],
        [-7.3822e-01, -1.7906e+01,  6.7714e-01, -1.4778e+00, -2.6468e+01,
          3.4556e+01,  3.6395e-01, -8.7866e-01, -3.9309e+01],
        [-5.7206e-01, -1.8345e+01,  6.6343e-01, -1.8990e+00, -2.6685e+01,
          3.4071e+01,  8.7866e-01,  1.8069e-01, -3.8530e+01],
        [-1.3951e+00, -1.8339e+01,  7.5559e-01, -1.2806e+00, -2.6674e+01,
          3.2382e+01, -1.4233e+00, -1.3934e+00, -3.7766e+01],
        [-6.6976e-01, -1.8371e+01,  1.0478e+00, -9.5608e-01, -2.7507e+01,
          3.2628e+01, -5.1471e-01,  2.9931e-02, -3.7839e+01],
        [-2.1951e+00, -1.8347e+01,  1.6641e+00, -1.1207e+00, -2.7895e+01,
          3.1903e+01,  1.4558e+00,  0.0000e+00, -3.6000e+01],
        [-1.0564e+00, -1.7971e+01,  2.3121e+00, -1.3465e+00, -2.8413e+01,
          3.0454e+01,  1.2127e+00, -1.2725e+00, -2.5087e+01],
        [-1.0466e+00, -1.7555e+01,  2.7820e+00, 

In [14]:
for i, batch in enumerate(datapipe_padded_func):

  xx, y, pad, = batch
  print(sum([x for x in xx]))
  print(y)

  print()
  if i == 2:
      break

tensor([[-2.8842e+00, -1.7621e+01,  7.9036e-01, -3.6699e+00, -2.5945e+01,
          3.4795e+01,  3.6395e-01, -8.7866e-01, -3.9309e+01],
        [-7.3822e-01, -1.7906e+01,  6.7714e-01, -1.4778e+00, -2.6468e+01,
          3.4556e+01,  3.6395e-01, -8.7866e-01, -3.9309e+01],
        [-5.7206e-01, -1.8345e+01,  6.6343e-01, -1.8990e+00, -2.6685e+01,
          3.4071e+01,  8.7866e-01,  1.8069e-01, -3.8530e+01],
        [-1.3951e+00, -1.8339e+01,  7.5559e-01, -1.2806e+00, -2.6674e+01,
          3.2382e+01, -1.4233e+00, -1.3934e+00, -3.7766e+01],
        [-6.6976e-01, -1.8371e+01,  1.0478e+00, -9.5608e-01, -2.7507e+01,
          3.2628e+01, -5.1471e-01,  2.9931e-02, -3.7839e+01],
        [-2.1951e+00, -1.8347e+01,  1.6641e+00, -1.1207e+00, -2.7895e+01,
          3.1903e+01,  1.4558e+00,  0.0000e+00, -3.6000e+01],
        [-1.0564e+00, -1.7971e+01,  2.3121e+00, -1.3465e+00, -2.8413e+01,
          3.0454e+01,  1.2127e+00, -1.2725e+00, -2.5087e+01],
        [-1.0466e+00, -1.7555e+01,  2.7820e+00, 

In [12]:
for i, batch in enumerate(datapipe_padded_func):

  xx, y, pad, = batch
  print(sum([len(x) for x in xx]))
  print(len(y))

  print()
  if i == 2:
      break

400
40

396
33

396
33



In [11]:
for i, ( features, truth) in enumerate(datapipe):
    print("Features:", features.shape)
    print("Truth:", truth.shape)
    print()
    if i == 2:
        break


Features: torch.Size([24, 9])
Truth: torch.Size([1, 1])

Features: torch.Size([35, 9])
Truth: torch.Size([1, 1])

Features: torch.Size([37, 9])
Truth: torch.Size([1, 1])



In [12]:
def blob():
    datapipe = IterableWrapper(range(10))
    datapipe = ShardingFilter(datapipe)
    # datapipe = datapipe.shuffle(buffer_size=2)
    # print(list(datapipe))
    train, test, valid = datapipe.random_split(total_length=len(datapipe), weights={"train": 0.4, "test": 0.3, "valid": 0.3}, seed=1)
    # return train.shuffle(), test.shuffle(), valid.shuffle()
    return train, test, valid


train, test, valid = blob()
print(list(train))
print(list(test))
print(list(valid))

[0, 3, 8, 9]
[4, 5, 6]
[1, 2, 7]


In [13]:
for i in range(200):

    train, test, valid = IterableWrapper(range(10)).shuffle().random_split(total_legnth = 10, weights={"train": 0.4, "test": 0.3, "valid": 0.3}, seed=1)
    if len(list(train))< len(list(test)):
        print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
        print(i)

    # print()
    # print(list(train))
    # print(list(test))
    # print(list(valid))

TypeError: __new__() got an unexpected keyword argument 'total_legnth'

In [13]:
train, test, valid = IterableWrapper(range(10)).shuffle(buffer_size=10).random_split(total_length = 10, weights={"train": 0.4, "test": 0.3, "valid": 0.3}, seed=0)
print(list(train))
print(list(test))
print(list(valid))

[2, 8, 1, 9]
[9, 7, 6]
[1, 9, 4]


In [14]:
train, test, valid = IterableWrapper(range(10)).shuffle().round_robin_demux(3)
print(list(train))
print(list(test))
print(list(valid))

[9, 7, 0, 3]
[4, 5, 1]
[2, 6, 8]


In [15]:
for data in datapipe:
    print("length:", len_fn(data))

length: 24
length: 35
length: 37
length: 58
length: 61
length: 70
length: 13
length: 14
length: 102
length: 140
length: 32
length: 14
length: 85
length: 64
length: 29
length: 24
length: 89
length: 53
length: 47
length: 11
length: 177
length: 158
length: 27
length: 15
length: 87
length: 20
length: 38
length: 53
length: 12
length: 115
length: 9
length: 24
length: 18
length: 34
length: 15
length: 46
length: 32
length: 31
length: 21
length: 355
length: 15
length: 121
length: 47
length: 13
length: 19
length: 123
length: 76
length: 12
length: 18
length: 21
length: 38
length: 40
length: 46
length: 23
length: 189
length: 15
length: 50
length: 277
length: 23
length: 42
length: 638
length: 25
length: 13
length: 83
length: 22
length: 32
length: 44
length: 20
length: 25
length: 13
length: 201
length: 55
length: 10
length: 48
length: 30
length: 349
length: 29
length: 55
length: 220
length: 9
length: 55
length: 35
length: 48
length: 26
length: 19
length: 66
length: 109
length: 23
length: 25
length: 

KeyboardInterrupt: 

In [16]:
for i, ( features, truth) in enumerate(datapipe_transform):
    print("Features:", features)
    print("Truth:", truth)
    print()
    if i == 2:
        break

Features: tensor([[ 0.0210, -0.4338, -0.5813, -0.6148, -0.3193,  0.8880,  0.0000,  0.0000,
         -1.0000],
        [-0.0272, -0.4408, -0.5813, -0.6148, -0.3874,  0.8880,  0.0000,  0.0000,
         -1.0000],
        [-0.2408, -0.4000, -0.5813, -0.6148, -0.4215,  0.8880,  0.0000,  0.0000,
         -1.0000],
        [ 0.1923, -0.4632, -0.4913, -0.3810, -0.4260,  0.8880,  0.0000,  0.0000,
         -1.0000],
        [ 0.0528, -0.4626, -0.4913, -0.3810, -0.4260,  0.8880,  0.0000,  0.0000,
         -1.0000],
        [-0.1724, -0.4608, -0.4913, -0.3810, -0.4260,  0.8880,  0.0000,  0.0000,
         -1.0000],
        [-0.1281, -0.4588, -0.4913, -0.3810, -0.4260,  0.8880,  0.0000,  0.0000,
         -1.0000],
        [-0.0760, -0.4635, -0.4913, -0.3810, -0.4601,  0.8880,  0.0000,  0.0000,
         -1.0000],
        [-0.3990, -0.4631, -0.4913, -0.3810, -0.4601,  0.8880,  0.0000,  0.0000,
         -1.0000],
        [ 0.0158, -0.3814, -0.6488, -0.1869, -0.4976,  0.8880,  0.0000,  0.0000,
         

In [17]:
for i, ( features, truth) in enumerate(datapipe):
    print("Features:", features)
    print("Truth:", truth)
    print()
    if i == 2:
        break

Features: tensor([[ 1.1016e+00,  1.1324e+04, -2.9066e+02, -3.0738e+02, -1.5966e+02,
          4.4400e-02,  0.0000e+00,  0.0000e+00, -1.0000e+00],
        [ 8.8229e-01,  1.1184e+04, -2.9066e+02, -3.0738e+02, -1.9371e+02,
          4.4400e-02,  0.0000e+00,  0.0000e+00, -1.0000e+00],
        [ 3.2997e-01,  1.2000e+04, -2.9066e+02, -3.0738e+02, -2.1073e+02,
          4.4400e-02,  0.0000e+00,  0.0000e+00, -1.0000e+00],
        [ 2.4245e+00,  1.0737e+04, -2.4565e+02, -1.9049e+02, -2.1301e+02,
          4.4400e-02,  0.0000e+00,  0.0000e+00, -1.0000e+00],
        [ 1.2752e+00,  1.0748e+04, -2.4565e+02, -1.9049e+02, -2.1301e+02,
          4.4400e-02,  0.0000e+00,  0.0000e+00, -1.0000e+00],
        [ 4.5200e-01,  1.0785e+04, -2.4565e+02, -1.9049e+02, -2.1301e+02,
          4.4400e-02,  0.0000e+00,  0.0000e+00, -1.0000e+00],
        [ 5.5437e-01,  1.0825e+04, -2.4565e+02, -1.9049e+02, -2.1301e+02,
          4.4400e-02,  0.0000e+00,  0.0000e+00, -1.0000e+00],
        [ 7.0470e-01,  1.0729e+04, -2.

In [18]:
for i, batch in enumerate(datapipe_batch):
  (xx, y) = zip(*batch)
  print([len(x) for x in xx])
  print(len(y))
  print()
  if i == 2:
      break

[9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10]
40

[10, 9, 10, 10, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 12]
33

[12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 11, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 9, 12, 12, 12, 12, 12, 9, 12]
33



In [19]:
for i, batch in enumerate(datapipe_padded):

  xx, y, pad, = batch
  print(sum([len(x) for x in xx]))
  print(len(y))

  print()
  if i == 2:
      break

400
40

396
33

396
33



In [11]:
for i, batch in enumerate (datapipe_batch):
  (xx, y) = zip(*batch)
  print([x.shape for x in xx])
  print(len(y))
  # print(xx[:, 0, :].shape)
  print(torch.mean(xx, dim=1))
  print()
  if i == 2:
      break

[torch.Size([9, 9]), torch.Size([9, 9]), torch.Size([9, 9]), torch.Size([9, 9]), torch.Size([9, 9]), torch.Size([9, 9]), torch.Size([9, 9]), torch.Size([9, 9]), torch.Size([9, 9]), torch.Size([9, 9]), torch.Size([9, 9]), torch.Size([9, 9]), torch.Size([9, 9]), torch.Size([9, 9]), torch.Size([9, 9]), torch.Size([9, 9]), torch.Size([9, 9]), torch.Size([10, 9]), torch.Size([10, 9]), torch.Size([10, 9]), torch.Size([10, 9]), torch.Size([10, 9]), torch.Size([10, 9]), torch.Size([10, 9]), torch.Size([10, 9]), torch.Size([10, 9]), torch.Size([10, 9]), torch.Size([10, 9]), torch.Size([10, 9]), torch.Size([10, 9]), torch.Size([10, 9]), torch.Size([10, 9]), torch.Size([10, 9]), torch.Size([10, 9]), torch.Size([10, 9]), torch.Size([10, 9]), torch.Size([10, 9]), torch.Size([10, 9]), torch.Size([10, 9]), torch.Size([10, 9])]
40


TypeError: mean() received an invalid combination of arguments - got (tuple, dim=int), but expected one of:
 * (Tensor input, *, torch.dtype dtype)
 * (Tensor input, tuple of ints dim, bool keepdim, *, torch.dtype dtype, Tensor out)
 * (Tensor input, tuple of names dim, bool keepdim, *, torch.dtype dtype, Tensor out)


In [15]:
class SimpleTransformerEncoderPooling(nn.Module):
    def __init__(
        self,
        input_size: int = 9,
        d_model: int = 64, 
        nhead: int = 8,
        num_layers: int = 6,
        output_size = 1
    ):
        super(SimpleTransformerEncoderPooling, self).__init__()

        self.input_size = input_size
        self.d_model = d_model
        self.nhead = nhead

        self.fc_in = nn.Linear(input_size, d_model)
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_layers)
        self.fc_out = nn.Linear(d_model *3, output_size)

    def forward(self, x, pad_mask = None):
        x = self.fc_in(x)
        x = self.transformer_encoder(x,src_key_padding_mask = pad_mask)
        # x = x[:, 0, :] # change this line to min max and mean pooling
        # Min, max, and mean pooling
        x_min, _ = torch.min(x, dim=1)
        x_max, _ = torch.max(x, dim=1)
        x_mean = torch.mean(x, dim=1)

        # Concatenate the pooled features
        x_pooled = torch.cat((x_min, x_max, x_mean), dim=1)

        x = self.fc_out(x_pooled)
        return x.squeeze()

In [16]:
from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService

model = SimpleTransformerEncoderPooling()

rs = MultiProcessingReadingService(num_workers=2)
dl = DataLoader2(datapipe_padded,)#reading_service=rs

for i, (features, truth,pad) in enumerate(dl):
    print(truth.shape)
    pred = model(features)
    print(pred.shape)
    if i == 1:
        break
dl.shutdown()

torch.Size([40])
torch.Size([40])
torch.Size([33])
torch.Size([33])


In [26]:
model = SimpleTransformerEncoderPooling()
dl = DataLoader2(datapipe_padded)
criterion = nn.MSELoss()

optimizer = Adam(model.parameters(), lr=0.001)

num_epochs = 20

for epoch in range(num_epochs):
  running_loss = 0.0

  for i, (inputs, targets, pad_mask) in enumerate(dl):
    optimizer.zero_grad()

    outputs = model(inputs)

    loss = criterion(outputs, targets)

    loss.backward()
    optimizer.step()

    running_loss += loss.item()
    if i > 50:
        break

  epoch_loss = running_loss
  print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, epoch_loss))

dl.shutdown()

Epoch [1/20], Loss: 24.0419
Epoch [2/20], Loss: 16.1142
Epoch [3/20], Loss: 16.1286
Epoch [4/20], Loss: 15.9761
Epoch [5/20], Loss: 15.9594
Epoch [6/20], Loss: 16.0721
Epoch [7/20], Loss: 16.1323
Epoch [8/20], Loss: 15.7754
Epoch [9/20], Loss: 15.7852
Epoch [10/20], Loss: 15.6314
Epoch [11/20], Loss: 12.5647
Epoch [12/20], Loss: 12.4371
Epoch [13/20], Loss: 16.3622
Epoch [14/20], Loss: 15.5688
Epoch [15/20], Loss: 15.5552
Epoch [16/20], Loss: 15.6911
Epoch [17/20], Loss: 15.5472
Epoch [18/20], Loss: 15.6070
Epoch [19/20], Loss: 15.5024
Epoch [20/20], Loss: 15.5562


In [24]:
dl.shutdown()