In [None]:
%pip install pytorch-lightning



In [None]:
%pip install wandb
!wandb login

[34m[1mwandb[0m: Currently logged in as: [33mdlochmelis[0m ([33mdlhf[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.datasets import CelebA
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
import pytorch_lightning as pl
import torchmetrics

import wandb
from pytorch_lightning.loggers import WandbLogger

In [None]:
NUM_CLASSES = 40
BATCH_SIZE = 256
NUM_EPOCHS = 5
LEARNING_RATE = 0.001
NUM_WORKERS = 0 # can be made higher

Load data

In [None]:
class DataModule(pl.LightningDataModule):
    def __init__(self, data_path='./'):
        super().__init__()
        self.data_path = data_path

    def prepare_data(self):
        CelebA(root=self.data_path, download=True)
        # input: (3, 218, 178)
        self.transform = transforms.Compose([
            transforms.RandomCrop((160, 160)), # 178 > 160 = 128 + 32
            transforms.Resize([128, 128]),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # why?
        ])

    def setup(self, stage=None):
        self.train = CelebA(root=self.data_path,
                                     split='train',
                                     target_type='attr',
                                     transform=self.transform)

        self.valid = CelebA(root=self.data_path,
                                     split='valid',
                                     target_type='attr',
                                     transform=self.transform)

        self.test = CelebA(root=self.data_path,
                                    split='test',
                                    target_type='attr',
                                    transform=self.transform)

    def train_dataloader(self):
        train_loader = DataLoader(dataset=self.train,
                                  batch_size=BATCH_SIZE,
                                  drop_last=True,
                                  shuffle=True,
                                  num_workers=NUM_WORKERS)
        return train_loader

    def val_dataloader(self):
        valid_loader = DataLoader(dataset=self.valid,
                                  batch_size=BATCH_SIZE,
                                  drop_last=False, # why not True, the same question about test
                                  shuffle=False,
                                  num_workers=NUM_WORKERS)
        return valid_loader

    def test_dataloader(self):
        test_loader = DataLoader(dataset=self.test,
                                 batch_size=BATCH_SIZE,
                                 drop_last=False,
                                 shuffle=False,
                                 num_workers=NUM_WORKERS)
        return test_loader

In [None]:
torch.manual_seed(1)
data_module = DataModule(data_path='./data')

Define model

In [None]:
class DlhfModel(pl.LightningModule):
    """Simple CNN according to ChatGPT"""

    def __init__(self, num_classes):
        super().__init__()
        self.num_classes = num_classes

        # input: (B, 3, 128, 128)
        self.block1 = nn.Sequential(
            nn.Conv2d(
                in_channels=3,
                out_channels=32,
                kernel_size=(3, 3),
                stride=(1, 1),
                padding=1
            ), # (B, 32, 128, 128)
            nn.ReLU(),
            nn.MaxPool2d(
                kernel_size=(2, 2),
                stride=(2, 2)
            ) # (B, 32, 64, 64)
        )
        self.block2 = nn.Sequential(
            nn.Conv2d(
                in_channels=32,
                out_channels=64,
                kernel_size=(3, 3),
                stride=(1, 1),
                padding=1
            ), # (B, 64, 64, 64)
            nn.ReLU(),
            nn.MaxPool2d(
                kernel_size=(2, 2),
                stride=(2, 2)
            ) # (B, 64, 32, 32)
        )
        self.block3 = nn.Sequential(
            nn.Conv2d(
                in_channels=64,
                out_channels=128,
                kernel_size=(3, 3),
                stride=(1, 1),
                padding=1
            ), # (B, 128, 32, 32)
            nn.ReLU(),
            nn.MaxPool2d(
                kernel_size=(2, 2),
                stride=(2, 2)
            )  # (B, 128, 16, 16)
        )
        self.classifier = nn.Sequential(
            self.block1,
            self.block2,
            self.block3,
            nn.Flatten(), # (B, 128 * 16 * 16)
            nn.Linear(128 * 16 * 16, 128), # (B, 128)
            nn.ReLU(),
            nn.Linear(128, self.num_classes), # (B, num_classes)
            nn.Sigmoid()
        )

    def forward(self, x):
        y_pred = self.classifier(x)
        return y_pred

In [None]:
class DlhfLightningModel(pl.LightningModule):

    def _create_accuracy(self):
        return torchmetrics.Accuracy(task='multilabel', num_labels=self.model.num_classes)

    def __init__(self, model, learning_rate):
        super().__init__()

        self.learning_rate = learning_rate
        self.model = model
        self.save_hyperparameters(ignore=['model'])

        self.train_acc = self._create_accuracy()
        self.valid_acc = self._create_accuracy()
        self.test_acc = self._create_accuracy()

    def forward(self, x):
        return self.model(x)

    def _shared_step(self, batch):
        x, y = batch
        y_pred = self(x)
        loss = F.binary_cross_entropy(y_pred.to(torch.float), y.to(torch.float))
        return loss, y, y_pred

    def training_step(self, batch, batch_idx):
        loss, y, y_pred = self._shared_step(batch)
        self.log("train_loss", loss)
        self.train_acc.update(y_pred, y)
        self.log("train_acc", self.train_acc, on_epoch=True, on_step=False)
        return loss

    def validation_step(self, batch, batch_idx):
        loss, y, y_pred = self._shared_step(batch)
        self.log("valid_loss", loss)
        self.valid_acc(y_pred, y)
        self.log("valid_acc", self.valid_acc,
                 on_epoch=True, on_step=False, prog_bar=True)

    def test_step(self, batch, batch_idx):
        loss, true_labels, y_pred = self._shared_step(batch)
        self.test_acc(y_pred, true_labels)
        self.log("test_acc", self.test_acc, on_epoch=True, on_step=False)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

Train the classifier

In [None]:
model = DlhfModel(num_classes=NUM_CLASSES)
lightning_model = DlhfLightningModel(model, learning_rate=LEARNING_RATE)

wandb_logger = WandbLogger(project='fashion-celeba-test', log_model='all')
wandb_logger.experiment.config["batch_size"] = BATCH_SIZE
wandb_logger.watch(lightning_model, log='all')

trainer = pl.Trainer(
    max_epochs=NUM_EPOCHS,
    accelerator="auto",
    devices="auto",
    logger=wandb_logger,
    log_every_n_steps=100
)

trainer.fit(model=lightning_model, datamodule=data_module)

/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loggers/wandb.py:389: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


Files already downloaded and verified


INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name      | Type               | Params
-------------------------------------------------
0 | model     | DlhfModel          | 4.3 M 
1 | train_acc | MultilabelAccuracy | 0     
2 | valid_acc | MultilabelAccuracy | 0     
3 | test_acc  | MultilabelAccuracy | 0     
-------------------------------------------------
4.3 M     Trainable params
0         Non-trainable params
4.3 M     Total params
17.171    Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=1` in the `DataLoader` to improve performance.
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=1` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=5` reached.


In [None]:
trainer.test(model=lightning_model, datamodule=data_module, ckpt_path='best')

Files already downloaded and verified


INFO:pytorch_lightning.utilities.rank_zero:Restoring states from the checkpoint path at ./fashion-celeba-test/ncbmcs1w/checkpoints/epoch=4-step=3175.ckpt
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.utilities.rank_zero:Loaded model weights from the checkpoint at ./fashion-celeba-test/ncbmcs1w/checkpoints/epoch=4-step=3175.ckpt
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=1` in the `DataLoader` to improve performance.


Testing: |          | 0/? [00:00<?, ?it/s]

[{'test_acc': 0.8972923755645752}]

In [None]:
wandb.finish()

VBox(children=(Label(value='245.716 MB of 245.716 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▁▁▁▁▁▁▁▂▂▂▂▂▂▂▂▄▄▄▄▄▄▄▄▅▅▅▅▅▅▅▅▇▇▇▇▇▇▇█
test_acc,▁
train_acc,▁▆▇██
train_loss,█▆▅▄▄▃▄▃▃▃▂▂▂▂▁▂▂▂▂▂▂▂▂▁▂▁▁▁▂▁▁
trainer/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▄▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇████
valid_acc,▁▅▇▇█
valid_loss,█▄▂▂▁

0,1
epoch,5.0
test_acc,0.89729
train_acc,0.90114
train_loss,0.22543
trainer/global_step,3175.0
valid_acc,0.90345
valid_loss,0.22105


In [None]:
raise NotImplementedError

Inference

In [None]:
path = trainer.checkpoint_callback.best_model_path
lightning_model = DlhfLightningModel.load_from_checkpoint(
    path, model=DlhfModel(num_classes=NUM_CLASSES)
)
lightning_model.cpu()
lightning_model.eval()

test_dataloader = data_module.test_dataloader()
acc = torchmetrics.Accuracy(task='multilabel', num_labels=NUM_CLASSES)

for batch in test_dataloader:
    x, y = batch

    with torch.no_grad():
        y_pred = lightning_model(x)

    print(f'acc: {acc(y_pred, y)}')
    break

y_pred[:5]

acc: 0.897656261920929


tensor([[1.5149e-04, 6.0467e-01, 3.8571e-01, 3.5255e-01, 3.9867e-05, 1.3407e-01,
         1.5566e-01, 1.5678e-01, 8.5519e-03, 3.7461e-01, 2.4550e-02, 1.0793e-01,
         2.4127e-02, 9.1303e-03, 1.7901e-02, 1.0041e-05, 2.1145e-05, 3.7986e-02,
         6.1476e-01, 9.6524e-01, 9.7450e-04, 9.9743e-01, 1.0058e-05, 7.4596e-02,
         9.9988e-01, 2.5575e-01, 4.6392e-02, 2.0099e-01, 1.4437e-02, 1.9307e-02,
         3.7359e-05, 9.8863e-01, 1.5696e-01, 1.9796e-01, 6.3906e-01, 6.8665e-02,
         8.3751e-01, 1.1406e-01, 7.5097e-04, 3.9152e-01],
        [8.2037e-02, 1.3757e-01, 2.3664e-01, 1.2555e-01, 7.6048e-04, 1.0345e-02,
         3.6915e-01, 2.6705e-01, 7.3642e-01, 4.3978e-04, 1.1650e-01, 3.9211e-02,
         2.8525e-01, 5.6613e-02, 1.2597e-02, 1.0406e-02, 4.4820e-01, 2.7393e-04,
         1.4393e-01, 2.1928e-01, 3.2172e-01, 4.3526e-01, 2.0688e-01, 3.3683e-01,
         2.6928e-01, 1.2100e-01, 3.5393e-03, 2.0035e-01, 5.3659e-02, 5.1187e-03,
         2.2600e-01, 1.2504e-01, 1.2243e-01, 4.2181