In [3]:
from src.data import DataModule

dm = DataModule(
    r"C:\Users\user\data\dl_radiomics\preprocessed_3d",
    r"C:\Users\user\data\tables\lesion_followup_curated_v4.csv",
    "lung",
    "mst",
)
dm.setup()

Loading dataset: 100%|██████████| 123/123 [00:11<00:00, 10.26it/s]
Loading dataset: 100%|██████████| 42/42 [00:04<00:00,  9.71it/s]
Loading dataset: 100%|██████████| 56/56 [00:05<00:00, 10.21it/s]


In [4]:
import torch
import torch.nn as nn
import wandb
from torchmetrics.classification import BinaryAUROC, Accuracy
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning import LightningModule, Trainer, seed_everything
from monai.networks.nets.densenet import DenseNet121


class Model(LightningModule):
    def __init__(self):
        super().__init__()
        self.model = DenseNet121(spatial_dims=3, in_channels=1, out_channels=1)
        self.train_auc = BinaryAUROC(pos_label=1)
        self.val_auc = BinaryAUROC(pos_label=1)

    def forward(self, x):
        return self.model(x)

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.model.parameters(), lr=1e-2)

        return optimizer

    def training_step(self, batch, batch_idx):
        x, y = batch["img"], batch["label"]
        y_hat = torch.sigmoid(self.model(x))

        loss = nn.BCELoss()(y_hat.squeeze(), y.float())
        self.train_auc.update(y_hat.squeeze(), y.int())

        self.log_dict(
            {"train_loss": loss, 'train_auc': self.train_auc.compute()},
            on_step=False,
            on_epoch=True,
            prog_bar=True,
            logger=True,
        )

        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch["img"], batch["label"]
        y_hat = torch.sigmoid(self.model(x))

        loss = nn.BCELoss()(y_hat.squeeze(), y.float())
        self.val_auc.update(y_hat.squeeze(), y.int())

        self.log_dict(
            {"valid_loss": loss, 'valid_auc': self.val_auc.compute()},
            on_step=False,
            on_epoch=True,
            prog_bar=True,
            logger=True,
        )

        return loss

model = Model()

In [6]:
import wandb
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning import Trainer, seed_everything


seed_everything(0)

logger = WandbLogger(
    name='hello1',
    project='project_skeleton_on_lung_lesions',
)

trainer = Trainer(
    max_epochs=50,
    gpus=1,
    deterministic=True,
    fast_dev_run=False,
    logger=logger
)
        
trainer.fit(model, dm)

wandb.finish()

Global seed set to 0


2022-10-27 17:17:14,813 - Global seed set to 0


GPU available: True, used: True


2022-10-27 17:17:14,815 - GPU available: True, used: True


TPU available: None, using: 0 TPU cores


2022-10-27 17:17:14,816 - TPU available: None, using: 0 TPU cores
2022-10-27 17:17:14,820 - LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]



  | Name      | Type        | Params
------------------------------------------
0 | model     | DenseNet121 | 11.2 M
1 | train_auc | BinaryAUROC | 0     
2 | val_auc   | BinaryAUROC | 0     
------------------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
44.975    Total estimated model params size (MB)


2022-10-27 17:17:15,627 - 
  | Name      | Type        | Params
------------------------------------------
0 | model     | DenseNet121 | 11.2 M
1 | train_auc | BinaryAUROC | 0     
2 | val_auc   | BinaryAUROC | 0     
------------------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
44.975    Total estimated model params size (MB)
Epoch 49: 100%|██████████| 12/12 [00:06<00:00,  1.83it/s, loss=0.0164, v_num=wocf, valid_loss=0.432, valid_auc=0.869, train_loss=0.0155, train_auc=0.995]


0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
train_auc,▁▂▃▄▅▅▆▆▆▇▇▇▇▇▇▇████████████████████████
train_loss,█▇▇▆▅▅▅▄▄▄▃▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
valid_auc,▁▂▃▃▅▅▅▆▆▆▆▇▇▇▇▇▇▇▇▇████████████████████
valid_loss,█▆▅▄▃▃▂▂▂▂▁▂▁▂▁▁▁▂▁▂▁▂▂▂▂▂▂▂▂▂▂▃▃▃▃▃▃▃▃▄

0,1
epoch,49.0
train_auc,0.99476
train_loss,0.01546
valid_auc,0.86924
valid_loss,0.43165


In [17]:
wandb.finish()