In [1]:
from datetime import datetime

import polars as pl

from src.preprocessing import PolarsDataPreprocessor

In [2]:
min_dt = datetime(2021, 1, 1)

In [3]:
train_trx = pl.read_parquet("F:/chunks/trx_train.pq")
train_trx = train_trx.with_columns(
    month_ind=12 * (pl.col("event_time").dt.year() - min_dt.year) + (pl.col("event_time").dt.month() - min_dt.month),
    day_ind=(pl.col("event_time") - min_dt).dt.days(),
)
train_trx = train_trx.with_columns(
    (pl.col("event_time") - min_dt).dt.seconds()
)
train_trx.head()

event_time,amount,client_id,event_type,event_subtype,currency,src_type11,src_type12,dst_type11,dst_type12,src_type21,src_type22,src_type31,src_type32,month_ind,day_ind
i64,f32,str,i32,i32,f64,f64,f64,f64,f64,f64,f64,f64,f64,i64,i64
35499473,39204.261719,"""52f416800d4d8c…",54,55,11.0,19.0,344.0,364.0,22652.0,22823.0,48.0,942.0,4.0,13,410
38062618,77238.382812,"""52f416800d4d8c…",54,55,11.0,19.0,344.0,869.0,31488.0,22823.0,48.0,942.0,4.0,14,440
35018297,14293.958008,"""52f416800d4d8c…",54,55,11.0,19.0,344.0,364.0,22652.0,22823.0,48.0,942.0,4.0,13,405
32446147,2569.062988,"""52f416800d4d8c…",54,55,11.0,19.0,344.0,364.0,22652.0,22823.0,48.0,942.0,4.0,12,375
34413708,62966.214844,"""52f416800d4d8c…",54,55,11.0,19.0,344.0,364.0,22652.0,22823.0,48.0,942.0,4.0,13,398


In [4]:
train_geo = pl.read_parquet("F:/chunks/geo_train.pq")
train_geo = train_geo.with_columns(
    month_ind=12 * (pl.col("event_time").dt.year() - min_dt.year) + (pl.col("event_time").dt.month() - min_dt.month),
    day_ind=(pl.col("event_time") - min_dt).dt.days(),
)
train_geo = train_geo.with_columns(
    (pl.col("event_time") - min_dt).dt.seconds()
)
train_geo.head()

client_id,event_time,geohash_4,geohash_5,geohash_6,month_ind,day_ind
str,i64,i32,i32,i32,i64,i64
"""009c52bd099cbb…",57046408,32892,35465,1028609,21,660
"""009c52bd099cbb…",56294845,32892,35465,461846,21,651
"""009c52bd099cbb…",56912877,32892,35465,1028609,21,658
"""009c52bd099cbb…",55770461,32892,35465,461846,21,645
"""009c52bd099cbb…",54736642,32892,35465,1028609,20,633


In [5]:
train_dial = pl.read_parquet("F:/chunks/dial_train.pq")
train_dial = train_dial.with_columns(
    month_ind=12 * (pl.col("event_time").dt.year() - min_dt.year) + (pl.col("event_time").dt.month() - min_dt.month),
    day_ind=(pl.col("event_time") - min_dt).dt.days(),
)
train_dial = train_dial.with_columns(
    (pl.col("event_time") - min_dt).dt.seconds()
)
train_dial.head()

client_id,event_time,embedding,month_ind,day_ind
str,i64,list[f32],i64,i64
"""a08c690dd972d2…",42023542,"[-0.009235, -0.069714, … -0.062696]",16,486
"""a341aebbdf1301…",49797689,"[0.28782, -0.391491, … 0.344232]",18,576
"""a37e4caaa1c741…",59122873,"[0.114006, -0.112189, … 0.180738]",22,684
"""a37e4caaa1c741…",40232523,"[0.23755, -0.309908, … 0.215162]",15,465
"""a37e4caaa1c741…",51355955,"[0.370291, -0.336272, … 0.254396]",19,594


In [6]:
train_target = pl.read_parquet("F:/chunks/train_target.pq")
train_target = train_target.with_columns(
    target=pl.concat_list([f"target_{i}" for i in range(1, 5)])
)
train_target.head()

mon,target_1,target_2,target_3,target_4,client_id,target
str,i32,i32,i32,i32,str,list[i32]
"""2022-12-31""",0,0,0,0,"""06a2ce26f19242…","[0, 0, … 0]"
"""2022-12-31""",0,0,0,0,"""06d250bda1fe78…","[0, 0, … 0]"
"""2022-12-31""",0,0,0,0,"""1d0fd54040602e…","[0, 0, … 0]"
"""2022-12-31""",0,0,0,0,"""85c233cac30252…","[0, 0, … 0]"
"""2022-12-31""",0,0,0,0,"""16d2d3fbdef66b…","[0, 0, … 0]"


In [7]:
%%time

trx_preprocessor = PolarsDataPreprocessor(
    col_id="client_id",
    col_event_time="event_time",
    cols_category=["month_ind", "day_ind", "src_type21", "src_type22", "src_type31", "src_type32"],
    cols_numerical=None,
    prefix="trx",
)

train_trx = trx_preprocessor.fit_transform(train_trx)
trx_preprocessor.get_category_dictionary_sizes()

CPU times: total: 2min 14s
Wall time: 23.5 s


{'month_ind': 14,
 'day_ind': 356,
 'src_type21': 8220,
 'src_type22': 87,
 'src_type31': 1617,
 'src_type32': 88}

In [8]:
%%time

geo_preprocessor = PolarsDataPreprocessor(
    col_id="client_id",
    col_event_time="event_time",
    cols_category=["month_ind", "day_ind"],
    cols_numerical=None,
    prefix="geo",
)

train_geo = geo_preprocessor.fit_transform(train_geo)
geo_preprocessor.get_category_dictionary_sizes()

CPU times: total: 3min 48s
Wall time: 32.3 s


{'month_ind': 14, 'day_ind': 356}

In [9]:
%%time

dial_preprocessor = PolarsDataPreprocessor(
    col_id="client_id",
    col_event_time="event_time",
    cols_category=["month_ind", "day_ind"],
    cols_numerical=None,
    prefix="dial",
)

train_dial = dial_preprocessor.fit_transform(train_dial)
dial_preprocessor.get_category_dictionary_sizes()

CPU times: total: 2.47 s
Wall time: 2.06 s


{'month_ind': 13, 'day_ind': 365}

In [10]:
val_trx = pl.read_parquet("F:/chunks/trx_val.pq")
val_trx = val_trx.with_columns(
    month_ind=12 * (pl.col("event_time").dt.year() - min_dt.year) + (pl.col("event_time").dt.month() - min_dt.month),
    day_ind=(pl.col("event_time") - min_dt).dt.days(),
)
val_trx = val_trx.with_columns(
    (pl.col("event_time") - min_dt).dt.seconds()
)

val_geo = pl.read_parquet("F:/chunks/geo_val.pq")
val_geo = val_geo.with_columns(
    month_ind=12 * (pl.col("event_time").dt.year() - min_dt.year) + (pl.col("event_time").dt.month() - min_dt.month),
    day_ind=(pl.col("event_time") - min_dt).dt.days(),
)
val_geo = val_geo.with_columns(
    (pl.col("event_time") - min_dt).dt.seconds()
)

val_dial = pl.read_parquet("F:/chunks/dial_val.pq")
val_dial = val_dial.with_columns(
    month_ind=12 * (pl.col("event_time").dt.year() - min_dt.year) + (pl.col("event_time").dt.month() - min_dt.month),
    day_ind=(pl.col("event_time") - min_dt).dt.days(),
)
val_dial = val_dial.with_columns(
    (pl.col("event_time") - min_dt).dt.seconds()
)

In [11]:
val_target = pl.read_parquet("F:/chunks/val_target.pq")
val_target = val_target.with_columns(
    target=pl.concat_list([f"target_{i}" for i in range(1, 5)])
)

In [12]:
%%time

val_trx = trx_preprocessor.transform(val_trx)

CPU times: total: 1min 4s
Wall time: 5.34 s


In [13]:
%%time

val_geo = geo_preprocessor.transform(val_geo)

CPU times: total: 1min 28s
Wall time: 8.51 s


In [14]:
%%time

val_dial = geo_preprocessor.transform(val_dial)

CPU times: total: 1.77 s
Wall time: 1.36 s


In [15]:
train_joined = train_trx.join(train_geo, on="client_id", how="left")
train_joined = train_joined.join(train_dial, on="client_id", how="left")
train_joined = train_joined.join(train_target.select(("client_id", "target")), on="client_id", how="left")

In [16]:
train_joined

client_id,trx_event_time,trx_month_ind,trx_day_ind,trx_src_type21,trx_src_type22,trx_src_type31,trx_src_type32,geo_event_time,geo_month_ind,geo_day_ind,dial_event_time,dial_month_ind,dial_day_ind,target
str,list[i64],list[i32],list[i32],list[i32],list[i32],list[i32],list[i32],list[i64],list[i32],list[i32],list[i64],list[i32],list[i32],list[i32]
"""947d6a26b3009a…","[31525212, 31616831, … 60864316]","[13, 3, … 12]","[355, 343, … 206]","[430, 430, … 430]","[7, 7, … 7]","[95, 95, … 95]","[45, 45, … 45]","[33201116, 33745421, … 48718034]","[11, 11, … 6]","[256, 292, … 99]","[43842022, 46344739, … 57738890]","[1, 6, … 7]","[255, 211, … 143]","[0, 0, … 0]"
"""c92c144da66bf2…","[47471022, 47503182, … 62052592]","[5, 5, … 12]","[74, 74, … 348]","[126, 126, … 126]","[7, 7, … 7]","[265, 265, … 265]","[59, 59, … 59]",,,,[54630589],[9],[310],"[0, 1, … 0]"
"""3632d37cfd1e14…","[31525215, 31572144, … 61359432]","[13, 3, … 12]","[355, 343, … 190]","[407, 407, … 407]","[25, 25, … 25]","[650, 650, … 650]","[2, 2, … 2]","[31583389, 31930640, … 61251365]","[11, 11, … 12]","[159, 143, … 353]","[51713719, 62514384]","[5, 11]","[152, 40]","[0, 0, … 0]"
"""9983a7acd58065…","[31525217, 31604544, … 61098765]","[13, 3, … 12]","[355, 343, … 205]","[2213, 2213, … 2213]","[25, 25, … 25]","[791, 791, … 791]","[15, 15, … 15]",[47575137],[6],[295],,,,"[0, 0, … 0]"
"""4d5e2ffa5bc0e8…","[31525218, 31813485, … 60730330]","[13, 3, … 12]","[355, 312, … 230]","[286, 286, … 286]","[22, 22, … 22]","[7, 7, … 7]","[7, 7, … 7]",,,,"[43398757, 43494215, 46090119]","[1, 1, 6]","[162, 44, 84]","[0, 0, … 0]"
"""b1594d7f4a4c73…","[34582808, 34584617, … 61955899]","[1, 1, … 12]","[325, 325, … 120]","[577, 577, … 577]","[37, 37, … 37]","[454, 454, … 454]","[2, 2, … 2]",,,,,,,"[0, 0, … 0]"
"""68994d58660a3b…","[31525223, 31576174, … 61863422]","[13, 3, … 12]","[355, 343, … 189]","[3632, 3632, … 3632]","[40, 40, … 40]","[262, 262, … 262]","[50, 50, … 50]",,,,"[36235994, 60075919]","[8, 3]","[277, 19]","[0, 0, … 0]"
"""972569de488cb5…","[31525228, 31583123, … 61837852]","[13, 3, … 12]","[355, 343, … 354]","[967, 967, … 967]","[27, 27, … 27]","[162, 162, … 162]","[12, 12, … 12]","[31865020, 31869792, … 61742402]","[11, 11, … 12]","[145, 145, … 352]",[32536832],[12],[218],"[0, 0, … 0]"
"""30699c8dbd18e1…","[48527275, 48649066, … 60474274]","[5, 5, … 12]","[212, 314, … 134]","[491, 491, … 491]","[17, 17, … 17]","[141, 141, … 141]","[2, 2, … 2]",,,,,,,"[0, 0, … 0]"
"""03860004127bbe…","[31525231, 31540852, … 60171464]","[13, 3, … 9]","[355, 343, … 204]","[299, 299, … 299]","[24, 24, … 24]","[929, 929, … 929]","[4, 4, … 4]","[37461416, 37478499, … 61944054]","[3, 3, … 12]","[286, 286, … 346]",,,,"[0, 0, … 0]"


In [17]:
val_joined = val_trx.join(val_geo, on="client_id", how="left")
val_joined = val_joined.join(val_dial, on="client_id", how="left")
val_joined = val_joined.join(val_target.select(("client_id", "target")), on="client_id", how="left")

In [18]:
from tqdm.auto import tqdm

In [19]:
import torch

In [20]:
train_dict = []
for _ in range(len(train_joined)):
    train_dict.append({})
    
for col, dtype in zip(train_joined.columns, train_joined.dtypes):
    if col == "client_id":
        continue
    if col == "target":
        for i, value in enumerate(tqdm(train_joined[col])):
            train_dict[i][col] = value.to_list()
    elif dtype == pl.List:
        for i, value in enumerate(tqdm(train_joined[col].fill_null([]))):
            train_dict[i][col] = torch.tensor(value.to_numpy())
    else:
        assert False

  0%|          | 0/70917 [00:00<?, ?it/s]

  0%|          | 0/70917 [00:00<?, ?it/s]

  0%|          | 0/70917 [00:00<?, ?it/s]

  0%|          | 0/70917 [00:00<?, ?it/s]

  0%|          | 0/70917 [00:00<?, ?it/s]

  0%|          | 0/70917 [00:00<?, ?it/s]

  0%|          | 0/70917 [00:00<?, ?it/s]

  0%|          | 0/70917 [00:00<?, ?it/s]

  0%|          | 0/70917 [00:00<?, ?it/s]

  0%|          | 0/70917 [00:00<?, ?it/s]

  0%|          | 0/70917 [00:00<?, ?it/s]

  0%|          | 0/70917 [00:00<?, ?it/s]

  0%|          | 0/70917 [00:00<?, ?it/s]

  0%|          | 0/70917 [00:00<?, ?it/s]

In [21]:
val_dict = []
for _ in range(len(val_joined)):
    val_dict.append({})
    
for col, dtype in zip(val_joined.columns, val_joined.dtypes):
    if col == "client_id":
        continue
    if col == "target":
        for i, value in enumerate(tqdm(val_joined[col])):
            val_dict[i][col] = value.to_list()
    elif dtype == pl.List:
        for i, value in enumerate(tqdm(val_joined[col].fill_null([]))):
            val_dict[i][col] = torch.tensor(value.to_numpy())
    else:
        assert False

  0%|          | 0/34511 [00:00<?, ?it/s]

  0%|          | 0/34511 [00:00<?, ?it/s]

  0%|          | 0/34511 [00:00<?, ?it/s]

  0%|          | 0/34511 [00:00<?, ?it/s]

  0%|          | 0/34511 [00:00<?, ?it/s]

  0%|          | 0/34511 [00:00<?, ?it/s]

  0%|          | 0/34511 [00:00<?, ?it/s]

  0%|          | 0/34511 [00:00<?, ?it/s]

  0%|          | 0/34511 [00:00<?, ?it/s]

  0%|          | 0/34511 [00:00<?, ?it/s]

  0%|          | 0/34511 [00:00<?, ?it/s]

  0%|          | 0/34511 [00:00<?, ?it/s]

  0%|          | 0/34511 [00:00<?, ?it/s]

  0%|          | 0/34511 [00:00<?, ?it/s]

In [22]:
source_features = {
    "trx": ["event_time", "month_ind", "day_ind", "src_type21", "src_type22", "src_type31", "src_type32"],
    "geo": ["event_time", "month_ind", "day_ind"],
    "dial": ["event_time", "month_ind", "day_ind"],
}

In [23]:
from ptls.frames.coles.multimodal_supervised_dataset import MultiModalSupervisedIterableDataset

In [24]:
train_multimodal_data = MultiModalSupervisedIterableDataset(
    data=train_dict,
    source_features=source_features,
    source_names=source_features.keys(),
    col_id="client_id",
    col_time="event_time",
    target_name="target",
)

In [25]:
train_multimodal_data[0]

{'trx': [{'event_time': tensor([31525212, 31616831, 31959543, 32167040, 32442901, 32572970, 32862376,
           32897657, 32898247, 32961263, 32998517, 33117113, 33163777, 33288092,
           33372746, 33378610, 33418187, 33541386, 33633857, 33691210, 33895955,
           33994765, 34204995, 34256049, 34351251, 34383191, 34491077, 34635565,
           34671312, 34804328, 34854685, 34889357, 34959836, 34982551, 35109893,
           35122146, 35151133, 35237479, 35409282, 35454314, 35617820, 35667957,
           35704539, 35836009, 35884795, 35964083, 35973571, 36027979, 36065229,
           36090267, 36141634, 36413699, 36502577, 36671946, 37276397, 37317446,
           37495343, 37611697, 37612973, 37749373, 38001833, 38014404, 38040378,
           38045092, 38050794, 38134785, 38139463, 38182918, 38338898, 38339712,
           38345338, 38483974, 38498753, 38502932, 38523914, 38582127, 38654667,
           38697099, 38737733, 38752137, 38876770, 38912306, 38945600, 38997727,
       

In [26]:
val_multimodal_data = MultiModalSupervisedIterableDataset(
    data=val_dict,
    source_features=source_features,
    source_names=source_features.keys(),
    col_id="client_id",
    col_time="event_time",
    target_name="target",
)

In [27]:
from ptls.frames import PtlsDataModule

In [28]:
train_loader = PtlsDataModule(
    train_data=train_multimodal_data,
    train_num_workers=0,
    train_batch_size=64,
    
    valid_data=val_multimodal_data,
    valid_num_workers=0,
    valid_batch_size=64,
)

In [29]:
for batch in train_loader.train_dataloader():
    break

In [30]:
trx_preprocessor.get_category_dictionary_sizes()

{'month_ind': 14,
 'day_ind': 356,
 'src_type21': 8220,
 'src_type22': 87,
 'src_type31': 1617,
 'src_type32': 88}

In [31]:
trx_encoder_params = dict(
    embeddings_noise=0.003,
    linear_projection_size=128,
    embeddings = {
        "month_ind": {"in": 14, "out": 7},
        "day_ind": {"in": 356, "out": 64},
        "src_type21": {"in": 8220, "out": 32},
        "src_type22": {"in": 87, "out": 32},
        "src_type31": {"in": 1617, "out": 32},
        "src_type32": {"in": 88, "out": 32},
    },
)

In [32]:
geo_preprocessor.get_category_dictionary_sizes()

{'month_ind': 14, 'day_ind': 356}

In [33]:
geo_encoder_params = dict(
    embeddings_noise=0.003,
    linear_projection_size=128,
    embeddings = {
        "month_ind": {"in": 14, "out": 7},
        "day_ind": {"in": 356, "out": 64},
    },
)

In [34]:
dial_preprocessor.get_category_dictionary_sizes()

{'month_ind': 13, 'day_ind': 365}

In [35]:
dial_encoder_params = dict(
    embeddings_noise=0.003,
    linear_projection_size=128,
    embeddings = {
        "month_ind": {"in": 13, "out": 7},
        "day_ind": {"in": 365, "out": 64},
    },
)

In [36]:
from ptls.nn import TrxEncoder

In [37]:
trx_encoder = TrxEncoder(**trx_encoder_params)
geo_encoder = TrxEncoder(**geo_encoder_params)
dial_encoder = TrxEncoder(**dial_encoder_params)

In [73]:
import torch
from ptls.data_load.padded_batch import PaddedBatch

class MultiModalSortTimeSeqEncoderContainer(torch.nn.Module):
    def __init__(self,
                 trx_encoders,
                 seq_encoder_cls, 
                 input_size,
                 is_reduce_sequence=True,
                 col_time='event_time',
                 **seq_encoder_params
                ):
        super().__init__()
        
        self.trx_encoders = torch.nn.ModuleDict(trx_encoders)
        self.seq_encoder = seq_encoder_cls(
            input_size=input_size,
            is_reduce_sequence=is_reduce_sequence,
            **seq_encoder_params,
        )
        
        self.col_time = col_time
    
    @property
    def is_reduce_sequence(self):
        return self.seq_encoder.is_reduce_sequence

    @is_reduce_sequence.setter
    def is_reduce_sequence(self, value):
        self.seq_encoder.is_reduce_sequence = value

    @property
    def embedding_size(self):
        return self.seq_encoder.embedding_size
    
    def get_tensor_by_indices(self, tensor, indices):
        batch_size = tensor.shape[0]
        return tensor[:, indices, :][torch.arange(batch_size), torch.arange(batch_size), :, :]
        
    def merge_by_time(self, x):
        device = "cuda" # list(x.values())[1][0].device
        batch, batch_time = torch.tensor([], device=device), torch.tensor([], device=device)
        for source_batch in x.values():
            if source_batch[0] != 'None':
                batch = torch.cat((batch, source_batch[1].payload), dim=1)
                batch_time = torch.cat((batch_time, source_batch[0]), dim=1)
        
        batch_time[batch_time == 0] = float('inf')
        # indices_time = torch.argsort(batch_time, dim=1)
        # batch = self.get_tensor_by_indices(batch, indices_time)
        return batch
            
    def trx_encoder_wrapper(self, x_source, trx_encoder, col_time):
        if torch.nonzero(x_source.seq_lens).size()[0] == 0:
            return x_source.seq_lens, 'None', 'None'
        return x_source.seq_lens, x_source.payload[col_time], trx_encoder(x_source)
        
    def multimodal_trx_encoder(self, x):
        res = {}
        tmp_el = list(x.values())[0]
        
        batch_size = tmp_el.payload[self.col_time].shape[0]
        length = torch.zeros(batch_size, device=tmp_el.device).int()
        
        for source, trx_encoder in self.trx_encoders.items():
            enc_res = self.trx_encoder_wrapper(x[source], trx_encoder, self.col_time)
            source_length, res[source] = enc_res[0], (enc_res[1], enc_res[2])
            length = length + source_length
        return res, length
            
    def forward(self, x):
        x, length = self.multimodal_trx_encoder(x)
        x = self.merge_by_time(x)
        padded_x = PaddedBatch(payload=x, length=length)
        x = self.seq_encoder(padded_x)
        return x

In [74]:
from ptls.nn.seq_encoder.rnn_encoder import RnnEncoder

In [75]:
seq_encoder = MultiModalSortTimeSeqEncoderContainer(
    trx_encoders = {
        "trx": trx_encoder,
        "geo": geo_encoder,
        "dial": dial_encoder,
    },
    input_size = 128,
    hidden_size = 256,
    seq_encoder_cls = RnnEncoder,
    type = "gru"
)

In [76]:
from torch import nn
classifier = nn.Sequential(
    nn.Linear(256, 256),
    nn.ReLU(),
    nn.Linear(256, 4),
)

In [77]:
from functools import partial
from ptls.frames.supervised import SeqToTargetDataset, SequenceToTarget
from ptls.frames import PtlsDataModule
import torch.nn as nn
import torchmetrics

In [78]:
class AUROC(nn.Module):
    def __init__(self):
        super().__init__()
        self.metric = torchmetrics.AUROC(task='multilabel', num_labels=4)
    def forward(self, preds, target):
        return self.metric(preds, target.int())
    def compute(self):
        return self.metric.compute()
    def reset(self):
        return self.metric.reset()

In [79]:
sup_module = SequenceToTarget(
    seq_encoder=seq_encoder,
    head=classifier,
    loss=nn.BCEWithLogitsLoss(),
    metric_list=AUROC(),
    optimizer_partial=partial(torch.optim.AdamW, lr=1e-4),
    lr_scheduler_partial=partial(torch.optim.lr_scheduler.ConstantLR, factor=1.0),
)

In [80]:
sup_module

SequenceToTarget(
  (seq_encoder): MultiModalSortTimeSeqEncoderContainer(
    (trx_encoders): ModuleDict(
      (trx): TrxEncoder(
        (embeddings): ModuleDict(
          (month_ind): NoisyEmbedding(
            14, 7, padding_idx=0
            (dropout): Dropout(p=0, inplace=False)
          )
          (day_ind): NoisyEmbedding(
            356, 64, padding_idx=0
            (dropout): Dropout(p=0, inplace=False)
          )
          (src_type21): NoisyEmbedding(
            8220, 32, padding_idx=0
            (dropout): Dropout(p=0, inplace=False)
          )
          (src_type22): NoisyEmbedding(
            87, 32, padding_idx=0
            (dropout): Dropout(p=0, inplace=False)
          )
          (src_type31): NoisyEmbedding(
            1617, 32, padding_idx=0
            (dropout): Dropout(p=0, inplace=False)
          )
          (src_type32): NoisyEmbedding(
            88, 32, padding_idx=0
            (dropout): Dropout(p=0, inplace=False)
          )
        )
   

In [81]:
import pytorch_lightning

In [82]:
pl_trainer = pytorch_lightning.Trainer(
    logger = False,
    max_epochs = 8,
    accelerator = "gpu",
    devices = 1,
    enable_progress_bar = True
)

In [83]:
pl_trainer.fit(sup_module, train_loader)

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]