In [171]:
from typing import Optional
import pandas as pd
import numpy as np
import gc
import time
import json
from datetime import datetime
import matplotlib.pyplot as plt
import seaborn as sns
import os, glob
import joblib
import random
import math
from tqdm import tqdm 
from collections import OrderedDict

from scipy.interpolate import interp1d
from scipy import signal
from scipy.signal import argrelmax

from sklearn.metrics import mean_squared_error

from math import pi, sqrt, exp
import sklearn,sklearn.model_selection
import torch
from torch import nn,Tensor
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, SubsetRandomSampler
from torch.optim import AdamW
# import pytorch_lightning as pl
# from pytorch_lightning import seed_everything
import lightning as L
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import (
    LearningRateMonitor,
    ModelCheckpoint,
    RichModelSummary,
    RichProgressBar,
)

from sklearn.metrics import average_precision_score
import timm
from timm.scheduler import CosineLRScheduler

from transformers import get_cosine_schedule_with_warmup
from torchvision.transforms.functional import resize
plt.style.use("ggplot")

from pyarrow.parquet import ParquetFile
import pyarrow as pa 
import ctypes
import polars as pl

from pathlib import Path

In [2]:
import sys
sys.path.append("/project")
from nn_utils import loss

In [129]:
class CFG:
    # Fundamental config
    NOTDEBUG = True # False -> DEBUG, True -> normally train
    WORKERS = os.cpu_count() // 2
    N_FOLDS = 5
    TRAIN_FOLD = 0
    #一時間720step
    MAX_LEN = 12*720
    USE_AMP = False
    SEED = 86

    # Model config
    HIDDEN = 256 if NOTDEBUG else 16
    EMB_DIM = 16
    KS = 31 if NOTDEBUG else 7
    N_BLKS = 5 if NOTDEBUG else 2
    DROPOUT = 0.2

    # Optimizer config
    LR = 5e-4
    WD = 1e-2
    WARMUP_PROP = 0.1
    # LR_INIT = 1e-4
    # LR_MIN = 1e-5
    
    # Train config
    EPOCHS = 10
    batch_size = 32
    MAX_GRAD_NORM = 2.
    GRAD_ACC = 32 // batch_size
    check_val_every_n_epoch=1
    monitor='val_loss'
    monitor_mode='min'
    num_warmup_steps=0

L.seed_everything(CFG.SEED)

INFO: Seed set to 86


86

In [8]:
from sklearn.model_selection import StratifiedKFold, StratifiedGroupKFold

class MakeKFold():
    def __init__(self, n_split=CFG.N_FOLDS, random_state=CFG.SEED, shuffle=True):
        self.n_split = n_split
        self.random_state = random_state
        self.shuffle = shuffle
        self.make_fold()

    def make_fold(self):
        print("making fold")
        skf = StratifiedKFold(n_splits=self.n_split, random_state=self.random_state, shuffle=True)
        metadata = pd.read_csv('/project/input/child-mind-institute-detect-sleep-states/train_events.csv')
        unique_ids = metadata['series_id'].unique()
        meta_cts = pd.DataFrame(unique_ids, columns=['series_id'])
        
        fold_train_ids = []
        fold_valid_ids = []
        
        for i, (train_index, valid_index) in enumerate(skf.split(X=meta_cts['series_id'], y=[1]*len(meta_cts))):
            # if i != TRAIN_FOLD:
            #     continue
            print(f"Fold = {i}")
            train_ids = meta_cts.loc[train_index, 'series_id']
            valid_ids = meta_cts.loc[valid_index, 'series_id']
            print(f"Length of Train = {len(train_ids)}, Length of Valid = {len(valid_ids)}")
            
            fold_train_ids.append(train_ids)
            fold_valid_ids.append(valid_ids)
        self.fold_train_ids = fold_train_ids
        self.fold_valid_ids = fold_valid_ids

    def train_fpaths(self, fold):
        train_fpaths = [f"/project/input/detect-sleep-states-dataprepare/train_csvs/{_id}.csv" for _id in self.fold_train_ids[fold]]
        return train_fpaths

    def valid_fpaths(self, fold):
        valid_fpaths = [f"/project/input/detect-sleep-states-dataprepare/train_csvs/{_id}.csv" for _id in self.fold_valid_ids[fold]]
        return valid_fpaths


makefold = MakeKFold()
makefold.train_fpaths(0)[:5]

making fold
Fold = 0
Length of Train = 221, Length of Valid = 56
Fold = 1
Length of Train = 221, Length of Valid = 56
Fold = 2
Length of Train = 222, Length of Valid = 55
Fold = 3
Length of Train = 222, Length of Valid = 55
Fold = 4
Length of Train = 222, Length of Valid = 55


['/kaggle/input/detect-sleep-states-dataprepare/train_csvs/038441c925bb.csv',
 '/kaggle/input/detect-sleep-states-dataprepare/train_csvs/03d92c9f6f8a.csv',
 '/kaggle/input/detect-sleep-states-dataprepare/train_csvs/0402a003dae9.csv',
 '/kaggle/input/detect-sleep-states-dataprepare/train_csvs/04f547b8017d.csv',
 '/kaggle/input/detect-sleep-states-dataprepare/train_csvs/062dbd4c95e6.csv']

In [9]:
# datasetを考える
#15時区切りと0時区切りを作る。
train_df = pl.read_csv('/project/input/child-mind-institute-detect-sleep-states/train_events.csv')
#train_df
train_df = train_df.with_columns(pl.col('timestamp').str.to_datetime('%Y-%m-%dT%H:%M:%S%z')
                                 .dt.hour()
                                 .alias("datetime_hour"))

In [10]:
fold=0
print(makefold.train_fpaths(fold)[:5])
train_fpathes = makefold.train_fpaths(fold)[:5]

['/kaggle/input/detect-sleep-states-dataprepare/train_csvs/038441c925bb.csv', '/kaggle/input/detect-sleep-states-dataprepare/train_csvs/03d92c9f6f8a.csv', '/kaggle/input/detect-sleep-states-dataprepare/train_csvs/0402a003dae9.csv', '/kaggle/input/detect-sleep-states-dataprepare/train_csvs/04f547b8017d.csv', '/kaggle/input/detect-sleep-states-dataprepare/train_csvs/062dbd4c95e6.csv']


In [213]:
df.columns

['step', 'anglez', 'enmo', 'hour', 'onset', 'wakeup']

In [217]:
df[:3]

step,anglez,enmo,hour,onset,wakeup
i64,f64,f64,i64,f64,f64
0,2.6367,0.0217,15,0.0,0.0
1,2.6368,0.0215,15,0.0,0.0
2,2.637,0.0216,15,0.0,0.0


In [209]:
df = pl.read_csv(train_fpathes[0])
df[:5]

step,anglez,enmo,hour,onset,wakeup
i64,f64,f64,i64,f64,f64
0,2.6367,0.0217,15,0.0,0.0
1,2.6368,0.0215,15,0.0,0.0
2,2.637,0.0216,15,0.0,0.0
3,2.6368,0.0213,15,0.0,0.0
4,2.6368,0.0215,15,0.0,0.0


In [233]:
d = {1:2, 3:4, 5:6}
list(d.values())

[2, 4, 6]

In [239]:
list(d.keys())

[1, 3, 5]

In [None]:
SIGMA = 720 #average length of day is 24*60*12 = 17280 for comparison
SAMPLE_FREQ = 12 # 1 obs per minute
class SleepDataset(Dataset):
    def __init__(
        self,
        cfg,
        file_pathes,
        max_len=2**12,
        is_train=False,
        sample_per_epoch=10000
    ):
        self.cfg = cfg
        self.enmo_mean = np.load('/project/input/detect-sleep-states-dataprepare/enmo_mean.npy').item()
        self.enmo_std = np.load('/project/input/detect-sleep-states-dataprepare/enmo_std.npy').item()
        
        self.max_len = max_len
        
        self.is_train = is_train
        
        self.max_df_size = 0
        self.min_df_size = 1e9
        
        self.sample_per_epoch = sample_per_epoch
        
        self.feat_list = ['anglez','enmo']
        
        self.Xys = self.read_csvs(file_pathes)
        self.file_pathes = file_pathes
        
        self.label_list = ['onset', 'wakeup']
        
        self.hour_feat= ['hour']
        
    def read_csvs(self, folder):
        res = []
        features = {}
        
        if type(folder) is str:
            files = glob.glob(f'{folder}/*.csv')
        else:
            files = folder
        for i, f in tqdm(enumerate(files), total=len(files), leave=False):
            series_id = os.path.splitext(os.path.basename(self.file_pathes[f]))[0]
            df = pl.read_csv(f)
            df = self.norm_feat_eng(df, init=True if i==0 else False)

            features[f"{series_id}"] = df
            self.max_df_size = max(self.max_df_size, len(df))
            self.min_df_size = min(self.min_df_size, len(df))
            #chunk
            if not is_train:
                num_chunks = (len(df) // cfg.duration) + 1
                for i in range(num_chunks):
                    chunk_feature = df[i * cfg.duration : (i + 1) * cfg.duration]
                    chunk_feature = pad_if_needed(chunk_feature, duration, pad_value=0)
                    features[f"{series_id}_{i:07}"] = chunk_feature
        if is_train:
            return res
        else:
            return features

    def pad_if_needed(x: pl.DataFrame, max_len: int, pad_value: float = 0.0) -> np.ndarray:
        if len(x) == max_len:
            return x
        columns = x.columns
        num_pad = max_len - len(x)
        n_dim = len(x.shape)
        pad_widths = [(0, num_pad)] + [(0, 0) for _ in range(n_dim - 1)]
        print(pad_widths)
        padded = np.pad(x, pad_width=pad_widths, mode="constant", constant_values=pad_value)
        output = pl.DataFrame(padded)
        output.columns = columns
        return output

    def norm_feat_eng(self, X, init=False):
        X = X.with_columns(pl.col('anglez') / 90)
        X = X.with_columns((pl.col('enmo') - self.enmo_mean) / (self.enmo_std + 1e-12))

        X = X.fill_nan(0)
        
        return X.cast(pl.Float32).to_pandas()

    def gauss(self,n=SIGMA,sigma=SIGMA*0.15):
        # guassian distribution function
        r = range(-int(n/2),int(n/2)+1)
        return [1 / (sigma * sqrt(2*pi)) * exp(-float(x)**2/(2*sigma**2)) for x in r]
    
    def __len__(self):
        if is_train:
            return self.sample_per_epoch if self.is_train else len(self.Xys)
        else:
            return len(self.features)

    def __getitem__(self, index):
        #filenameの数==datasetの数
        if self.is_train:
            #augmentation ランダムに時間軸を変える
            ind = np.random.randint(0, len(self.Xys))
            Xy = self.Xys[ind]
            Xy[self.hour_feat] = Xy[self.hour_feat]/24
            X = Xy[self.feat_list+self.hour_feat].values.astype(np.float32)
            
            y = Xy[self.label_list].values.astype(np.float32)

            if len(Xy)+1<self.max_len:
                res = self.max_len - len(Xy) + 1
                X = np.pad(X, ((0, res), (0, 0)))
                y = np.pad(y, ((0, res), (0, 0)))

            start = np.random.randint(0, len(X)-self.max_len)

            X = X[start:start+self.max_len].T
            y = y[start:start+self.max_len].T

            series_id = os.path.splitext(os.path.basename(self.file_pathes[index]))[0]
            return {'series_id':series_id, 'feature':X.T,'label':y.T}

        else:
            key = list(self.Xys.keys())[index]
            Xy = list(self.Xys.values())[index]
            Xy[self.hour_feat] = Xy[self.hour_feat]/24
            X = Xy[self.feat_list+self.hour_feat].values.astype(np.float32)
            y = Xy[self.label_list].values.astype(np.float32)
            #t = Xy[self.hour_feat].values.astype(np.int32)
        #train shape (3, 8640) label shape (2, 8640)
        #channel 
        #(波形データ、時間データ、ラベル)

NOTDEBUG = False

fold = 0
train_fpaths = makefold.train_fpaths(fold)
valid_fpaths = makefold.valid_fpaths(fold)

train_fpaths = train_fpaths if NOTDEBUG else train_fpaths[:50]
valid_fpaths = valid_fpaths if NOTDEBUG else valid_fpaths[:10]
sample_per_epoch = 20_000 if NOTDEBUG else 1_000

train_ds = SleepDataset(train_fpaths, max_len=CFG.MAX_LEN, is_train=True, sample_per_epoch=sample_per_epoch)
val_ds = SleepDataset(valid_fpaths, is_train=False)

In [200]:
print(train_ds[0]['series_id'],'train shape', train_ds[0]['feature'].shape, 'label shape', train_ds[0]['label'].shape)

038441c925bb train shape (8640, 3) label shape (8640, 2)


In [139]:
# データモジュールを定義
from pytorch_lightning import LightningDataModule
class LitDataModule(LightningDataModule):
    def __init__(self,
                 cfg,
                 train_fpaths,
                 valid_fpaths,
                 sample_per_epoch):
        super(LitDataModule, self).__init__()
        self.cfg = cfg
        self.train_fpaths = train_fpaths
        self.valid_fpaths = valid_fpaths
        self.sample_per_epoch = sample_per_epoch
        self.batch_size = cfg.batch_size

    def train_dataloader(self):
        train_dataset = SleepDataset(self.train_fpaths, max_len=CFG.MAX_LEN, is_train=True, sample_per_epoch=self.sample_per_epoch)
        train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=12)
        return train_loader
    
    def val_dataloader(self):
        valid_dataset = SleepDataset(self.valid_fpaths, max_len=CFG.MAX_LEN, is_train=True, sample_per_epoch=self.sample_per_epoch)
        valid_dataloader = DataLoader(valid_dataset, batch_size=self.batch_size, shuffle=False, num_workers=12)
        return valid_dataloader

In [190]:
class FocalLoss(nn.Module):
    def __init__(self, weight=None, size_average=True, alpha=1., gamma=2.):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, inputs, targets):
        
        #comment out if your model contains a sigmoid or equivalent activation layer
        #inputs = F.sigmoid(inputs)       
        
        #flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        
        #first compute binary cross-entropy 
        BCE = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        BCE_EXP = torch.exp(-BCE)
        focal_loss = self.alpha * (1-BCE_EXP)**self.gamma * BCE
                       
        return focal_loss.mean()

def add_weight_decay(model, weight_decay=1e-5, skip_list=()):
    decay = []
    no_decay = []
    for name, param in model.named_parameters():
        if not param.requires_grad:
            continue
        if len(param.shape) == 1 or np.any([v in name.lower()  for v in skip_list]):
            # print(name, 'no decay')
            no_decay.append(param)
        else:
            # print(name, 'decay')
            decay.append(param)
    return [
        {'params': no_decay, 'weight_decay': 0.},
        {'params': decay, 'weight_decay': weight_decay}]

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, padding=1):
        super(ConvBlock, self).__init__()
        self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, padding=padding)
        self.batchnorm = nn.BatchNorm1d(out_channels)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.conv(x)
        x = self.batchnorm(x)
        return self.relu(x)

class ConvTransposeBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1):
        super(ConvTransposeBlock, self).__init__()
        self.convtranspose = nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, output_padding=output_padding)
        self.batchnorm = nn.BatchNorm1d(out_channels)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.convtranspose(x)
        x = self.batchnorm(x)
        return self.relu(x)

class OreOreUNet1D(nn.Module):
    def __init__(self):
        super(OreOreUNet1D, self).__init__()

        # Encoder
        self.enc_conv0 = ConvBlock(3, 64)
        self.pool0 = nn.MaxPool1d(2)  # Add pooling layers to downsample
        self.enc_conv1 = ConvBlock(64, 128)
        self.pool1 = nn.MaxPool1d(2)
        self.enc_conv2 = ConvBlock(128, 256)
        self.pool2 = nn.MaxPool1d(2)
        self.enc_conv3 = ConvBlock(256, 512)
        
        # Decoder
        self.dec_conv3 = ConvTransposeBlock(512, 256, stride=2, padding=1, output_padding=1)
        self.dec_conv2 = ConvTransposeBlock(256 * 2, 128, stride=2, padding=1, output_padding=1)
        self.dec_conv1 = ConvTransposeBlock(128 * 2, 64, stride=2, padding=1, output_padding=1)
        self.dec_conv0 = ConvBlock(64 * 2, 2, kernel_size=1, padding=0)  # Adjust the final layer
        self.loss_fn = nn.BCEWithLogitsLoss()

    def forward(self, x, labels):
        # Encoder pathway
        enc0 = self.enc_conv0(x)
        enc0p = self.pool0(enc0)
        enc1 = self.enc_conv1(enc0p)
        enc1p = self.pool1(enc1)
        enc2 = self.enc_conv2(enc1p)
        enc2p = self.pool2(enc2)
        enc3 = self.enc_conv3(enc2p)

        # Decoder pathway
        dec3 = self.dec_conv3(enc3)
        dec2_input = self.crop_and_concat(dec3, enc2)  # Crop enc2 to match dec3 if necessary
        dec2 = self.dec_conv2(dec2_input)
        dec1_input = self.crop_and_concat(dec2, enc1)  # Crop enc1 to match dec2 if necessary
        dec1 = self.dec_conv1(dec1_input)
        dec0_input = self.crop_and_concat(dec1, enc0)  # Crop enc0 to match dec1 if necessary
        logits = self.dec_conv0(dec0_input)

        #output
        output = {"logits": logits}
        if labels is not None:
            loss = self.loss_fn(logits, labels)
            output["loss"] = loss

        return output

    def crop_and_concat(self, upsampled, bypass):
        """ Crop the bypass to the same size as upsampled and concatenate """
        diff = bypass.size()[2] - upsampled.size()[2]
        bypass = F.pad(bypass, (-diff // 2, -diff - (-diff // 2)))
        return torch.cat((upsampled, bypass), 1)      

from pytorch_lightning import LightningModule
# PyTorch Lightningモジュールを定義
class LitModule(LightningModule):
    def __init__(self, cfg):
        super(LitModule, self).__init__()
        #cfg.table_feature_num = len(train_ds.feat_list)
        self.model = OreOreUNet1D()
        self.criterion = FocalLoss(alpha=1., gamma=2.)
        self.cfg = cfg
        self.duration = cfg.MAX_LEN
        self.validation_step_outputs: list = []
        self.__best_loss = np.inf

    def forward(
        self, x: torch.Tensor, labels: Optional[torch.Tensor] = None
    ) -> dict[str, Optional[torch.Tensor]]:
        return self.model(x, labels)

    def training_step(self, batch, batch_idx):
        return self.__share_step(batch, "train")

    def validation_step(self, batch, batch_idx):
        return self.__share_step(batch, "val")

    def __share_step(self, batch, mode: str) -> torch.Tensor:
        output = self.model(batch["feature"], batch["label"])
        loss: torch.Tensor = output["loss"]
        logits = output["logits"]  # (batch_size, n_timesteps, n_classes)

        if mode == "train":
            self.log(
                f"{mode}_loss",
                loss.detach().item(),
                on_step=False,
                on_epoch=True,
                logger=True,
                prog_bar=True,
            )
        elif mode == "val":
            resized_logits = resize(
                logits.sigmoid().detach().cpu(),
                size=[self.duration, logits.shape[2]],
                antialias=False,
            )
            resized_labels = resize(
                batch["label"].detach().cpu(),
                size=[self.duration, logits.shape[2]],
                antialias=False,
            )
            self.validation_step_outputs.append(
                (
                    batch["key"],
                    resized_labels.numpy(),
                    resized_logits.numpy(),
                    loss.detach().item(),
                )
            )
            self.log(
                f"{mode}_loss",
                loss.detach().item(),
                on_step=False,
                on_epoch=True,
                logger=True,
                prog_bar=True,
            )

        return loss

    def on_validation_epoch_end(self):
        keys = []
        for x in self.validation_step_outputs:
            keys.extend(x[0])
        labels = np.concatenate([x[1] for x in self.validation_step_outputs])
        preds = np.concatenate([x[2] for x in self.validation_step_outputs])
        losses = np.array([x[3] for x in self.validation_step_outputs])
        loss = losses.mean()

        val_pred_df = post_process_for_seg(
            keys=keys,
            preds=preds[:, :, [1, 2]],
            score_th=self.cfg.post_process.score_th,
            distance=self.cfg.post_process.distance,
        )
        score = event_detection_ap(self.val_event_df.to_pandas(), val_pred_df.to_pandas())
        self.log("val_score", score, on_step=False, on_epoch=True, logger=True, prog_bar=True)

        if loss < self.__best_loss:
            np.save("keys.npy", np.array(keys))
            np.save("labels.npy", labels)
            np.save("preds.npy", preds)
            val_pred_df.write_csv("val_pred_df.csv")
            torch.save(self.model.state_dict(), "best_model.pth")
            print(f"Saved best model {self.__best_loss} -> {loss}")
            self.__best_loss = loss

        self.validation_step_outputs.clear()

    def configure_optimizers(self):
        #lr
        optimizer_parameters = add_weight_decay(model, weight_decay=self.cfg.WD, skip_list=['bias'])
        optimizer = AdamW(optimizer_parameters, lr=self.cfg.LR, eps=1e-6, betas=(0.9, 0.999))
        
        #shcedular
        scheduler = get_cosine_schedule_with_warmup(optimizer,
                                                    num_training_steps=self.trainer.max_steps,
                                                    num_warmup_steps=self.cfg.num_warmup_steps)

        return [optimizer], [{"scheduler": scheduler, "interval": "step"}]
 

In [188]:
model = LitModule(CFG)

# Example input tensor of shape (batch_size, channels, length)
input_tensor = torch.randn(1, 3, 8640)
output_tensor = torch.randn(1, 2, 8640)

batch = {'feature':input_tensor, 'label':output_tensor}
# Forward pass
output_tensor = model(input_tensor, output_tensor)
print(output_tensor['loss'])
print(output_tensor['logits'])

tensor(0.9499, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
tensor([[[0.0000, 1.1383, 0.0000,  ..., 0.0000, 0.0000, 1.1501],
         [0.0000, 0.0000, 0.0000,  ..., 0.0022, 0.0000, 0.0000]]],
       grad_fn=<ReluBackward0>)


In [191]:
NOTDEBUG = False

fold = 0
train_fpaths = makefold.train_fpaths(fold)
valid_fpaths = makefold.valid_fpaths(fold)

train_fpaths = train_fpaths if NOTDEBUG else train_fpaths[:50]
valid_fpaths = valid_fpaths if NOTDEBUG else valid_fpaths[:10]
sample_per_epoch = 20_000 if NOTDEBUG else 1_000
datamodule = LitDataModule(CFG, train_fpaths, valid_fpaths, sample_per_epoch)
model = LitModule(CFG)

checkpoint_cb = ModelCheckpoint(
    verbose=True,
    monitor=CFG.monitor,
    mode=CFG.monitor_mode,
    save_top_k=1,
    save_last=False,
)
lr_monitor = LearningRateMonitor("epoch")
progress_bar = RichProgressBar()
model_summary = RichModelSummary(max_depth=2)

trainer = Trainer(
    # env
    default_root_dir=Path.cwd(),
    # num_nodes=cfg.training.num_gpus,
    # accelerator=cfg.accelerator,
    precision=16 if CFG.USE_AMP else 32,
    # training
    # fast_dev_run=cfg.debug,  # run only 1 train batch and 1 val batch
    max_epochs=CFG.EPOCHS,
    max_steps=CFG.EPOCHS * len(datamodule.train_dataloader()),
    # gradient_clip_val=cfg.gradient_clip_val,
    # accumulate_grad_batches=cfg.accumulate_grad_batches,
    callbacks=[checkpoint_cb, lr_monitor, progress_bar, model_summary],
    # logger=pl_logger,
    # resume_from_checkpoint=resume_from,
    num_sanity_val_steps=0,
    log_every_n_steps=int(len(datamodule.train_dataloader()) * 0.1),
    sync_batchnorm=True,
    check_val_every_n_epoch=CFG.check_val_every_n_epoch,
)
trainer.fit(model, datamodule=datamodule)

                                   

                                   

Output()



KeyError: 'key'