In [1]:
# | default_exp benchmark.cnn_classification

In [2]:
# | export
import json

import pytorch_lightning as pl
import torch
import torchvision.transforms as T
import wandb
from pytorch_lightning.loggers import WandbLogger
from tqdm.notebook import tqdm

# | export
from ts.benchmark.tsdataset import TimeSeriesBenchmarkDataset
from ts.classification.cnnclassifer import TimeSeriesDataModule, TSNDTensorClassifier
from ts.tsfeatures.ts2image import transform_ts2img_tensor

torch.set_float32_matmul_precision("medium")
import gc

In [3]:
# | export
# Time Series-Safe Augmentations
train_transforms = T.Compose(
    [
        T.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),
        T.RandomApply([T.RandomErasing(p=1.0, scale=(0.02, 0.05))], p=0.5),
        # T.Normalize(mean=[0.5], std=[0.5]),
    ]
)

In [4]:
# | export
benchmark = TimeSeriesBenchmarkDataset()
# ds_list = benchmark.task_datasets["classification"]
ds_list = [
    "EthanolConcentration",
    "FaceDetection",
    "Handwriting",
    # 'JapaneseVowels',   # Shape missmatch
    # 'PEMS-SF',          # cuda out of memorry
    "SelfRegulationSCP1",
    "SelfRegulationSCP2",
    # 'SpokenArabicDigits',   # Shape missmatch
    "UWaveGestureLibrary",
]

# for dataset in ds_list:
#     print("Processing >>  ", dataset)
#     df = benchmark.load_dataset(dataset)
#     transform_ts2img_tensor(
#         df, data_dir=f"{dataset}_classification", categorical_label=True, label_col="label"
#     )
#     del df
#     gc.collect()  # Force garbage collection

In [None]:
# | export
for dataset in ds_list:
    print("Processing >>  ", dataset)
    # Load class labels
    with open(f"{dataset}_classification/classes.json", "r") as file:
        classes = json.load(file)

    # Load a single sample to determine input shape
    with torch.no_grad():
        x = torch.load(f"{dataset}_classification/0.pt")["image"]
        input_size = x.shape[-1]
        in_channels = x.shape[0]

    # Reduce batch size for large sequence lengths
    batch_size = 4 if input_size > 550 else 64

    ds = TimeSeriesDataModule(
        data_dir=f"{dataset}_classification",
        transform=train_transforms,
        batch_size=batch_size,
        num_workers=6 if batch_size < 8 else 16,  # Optimize workers for memory
        #  persistent_workers=False,  # Prevent memory leaks in DataLoader
    )
    del x
    gc.collect()

    model = TSNDTensorClassifier(
        model_name="efficientnet_b0",
        num_classes=len(classes),
        in_channels=in_channels,
        reduced_channels=3,
        input_size=input_size,
        output_size=min(255, input_size),
    )

    wandb_logger = WandbLogger(
        project="benchmark-ts-classification-finetune",
        name=f"cnn.model=efficientnet_b0.ds={dataset}",
    )
    wandb_logger.experiment.config.update({"model": "efficientnet_b0", "finetune": False})
    wandb_logger.watch(model, log="all")
    trainer = pl.Trainer(
        logger=wandb_logger,
        accelerator="auto",
        devices=[0],
        min_epochs=1,
        max_epochs=50,
        enable_checkpointing=True,
        precision="bf16-mixed",
        callbacks=[pl.callbacks.EarlyStopping("val_loss", patience=5, verbose=False)],
    )

    ckpt_path = f"model_checkpoints/{dataset}.ckpt"

    trainer.fit(model, ds)
    trainer.save_checkpoint(ckpt_path)
    trainer.validate(model, ds)
    trainer.test(model, ds)

    # Cleanup to free memory
    wandb_logger.experiment.unwatch(model)
    del model, ds
    gc.collect()
    # wandb_logger = WandbLogger(log_model=True)
    wandb.finish()

Processing >>   EthanolConcentration


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mpranav_jha[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`
Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name             | Type                        | Params | Mode 
-------------------------------------------------------------------------
0 | preprocessor     | ChannelReducerAndDownscaler | 1.8 K  | train
1 | pretrained_model | EfficientNet                | 4.3 M  | train
2 | criterion        | CrossEntropyLoss            | 0      | train
3 | accuracy         | MulticlassAccuracy          | 0      | train
4 | f1_score         | MulticlassF1Score           | 0      | train
5 | auc              | MulticlassAUROC             | 0      | train
-------------------------------------------------------------------------
331 K     Trainable params
4.0 M     Non-trainable params
4.3 M     Total params
17.355  

Sanity Checking: |                                        | 0/? [00:00<?, ?it/s]

Training: |                                               | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]