In [1]:
# main.py

from config import *
from utils.seed import set_seed
from data.splits import create_splits
from data.transforms import get_transforms
from data.datamodule import ImageDataModule
from models.efficientnet import EfficientNetLit

import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import CSVLogger

import torch
from sklearn.utils.class_weight import compute_class_weight
import numpy as np


def main():
    set_seed(SEED)

    train_df, val_df, test_df, classes = create_splits(
        metadata_csv= METADATA_DIR,
        images_dir= METADATA_DIR,
        seed=SEED,
        test_size=TEST_SIZE,
        val_size=VAL_SIZE,
    )

    print(
        len(train_df),
        len(val_df),
        len(test_df),
        train_df.lesion_id.nunique(),
        val_df.lesion_id.nunique(),
        test_df.lesion_id.nunique(),
    )

    print("TRAIN class distribution:")
    print(train_df["label"].value_counts().sort_index())

    counts = train_df["label"].value_counts().sort_index()

    labels = train_df["label"].values
    class_weights = compute_class_weight(
        class_weight="balanced",
        classes=np.unique(labels),
        y=labels
    )
    class_weights = torch.tensor(class_weights, dtype=torch.float)

    train_tfms, val_tfms, weights = get_transforms()

    datamodule = ImageDataModule(
        train_df,
        val_df,
        test_df,
        train_tfms,
        val_tfms,
        BATCH_SIZE,
        NUM_WORKERS,
    )

    model = EfficientNetLit(
        num_classes=len(classes),
        lr=LR,
        min_lr=MIN_LR,
        weights=weights,
        class_weights=class_weights,
    )


    logger = CSVLogger("logs", name="efficientnet")

    early_stop = pl.callbacks.EarlyStopping(
        monitor="val_loss",
        patience=20,
        mode="min"
    )

    trainer = pl.Trainer(
        max_epochs=EPOCHS,
        accelerator="auto",
        devices="auto",
        logger=logger,
        callbacks=[
            ModelCheckpoint(monitor="val_loss", mode="min"),
            early_stop,
        ],
    )

    trainer.fit(model, datamodule=datamodule)


if __name__ == "__main__":
    main()


Seed set to 42
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores


11507 2657 3475 3218 805 1006
TRAIN class distribution:
label
0    2406
1     103
2     405
3    2639
4    3644
5     206
6     848
7     171
8     708
9     377
Name: count, dtype: int64


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Output()