In [1]:
%%shell
pip install pytorch-lightning==2.1 --quiet
pip install wandb --quiet



In [2]:
import os
import os.path as osp
import numpy as np
import gdown
import zipfile

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import CIFAR10
from torchvision import transforms

import pytorch_lightning as pl
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint
import torchmetrics

import wandb

In [3]:
!wandb login --relogin

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit: 
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


# Download and split the data

In [4]:
! gdown --id 19dSNIsEGoScG4AIxG0eOxjHxJu8AP5HZ

Downloading...
From: https://drive.google.com/uc?id=19dSNIsEGoScG4AIxG0eOxjHxJu8AP5HZ
To: /content/alien_predator.zip
100% 14.8M/14.8M [00:00<00:00, 30.7MB/s]


In [5]:
%%shell
mkdir alien_predator
unzip alien_predator.zip -d alien_predator

Archive:  alien_predator.zip
  inflating: alien_predator/alien_vs_predator_thumbnails/data/train/alien/0.jpg  
  inflating: alien_predator/alien_vs_predator_thumbnails/data/train/alien/1.jpg  
  inflating: alien_predator/alien_vs_predator_thumbnails/data/train/alien/10.jpg  
  inflating: alien_predator/alien_vs_predator_thumbnails/data/train/alien/100.jpg  
  inflating: alien_predator/alien_vs_predator_thumbnails/data/train/alien/101.jpg  
  inflating: alien_predator/alien_vs_predator_thumbnails/data/train/alien/102.jpg  
  inflating: alien_predator/alien_vs_predator_thumbnails/data/train/alien/103.jpg  
  inflating: alien_predator/alien_vs_predator_thumbnails/data/train/alien/104.jpg  
  inflating: alien_predator/alien_vs_predator_thumbnails/data/train/alien/105.jpg  
  inflating: alien_predator/alien_vs_predator_thumbnails/data/train/alien/106.jpg  
  inflating: alien_predator/alien_vs_predator_thumbnails/data/train/alien/107.jpg  
  inflating: alien_predator/alien_vs_predator_thumbn



In [6]:
from torchvision.datasets import ImageFolder
from torch.utils.data import random_split
from torchvision.transforms import Resize, Compose, ToTensor, Normalize

image_size = (256, 256)
imagenet_transform = Compose([Resize(image_size), ToTensor(), Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

In [7]:
train_dataset_path = '/content/alien_predator/data/train'
test_dataset_path = '/content/alien_predator/data/validation'

dataset = ImageFolder(train_dataset_path, transform=imagenet_transform)
test_dataset = ImageFolder(test_dataset_path, transform=imagenet_transform)

In [8]:
train_dataset_size = int(len(dataset) * 0.9)
train_dataset, val_dataset = random_split(dataset, [train_dataset_size, len(dataset) - train_dataset_size])

In [9]:
from torch.utils.data import DataLoader

batch_size = 16

train_gen = DataLoader(dataset = train_dataset, batch_size = batch_size, shuffle = True)
val_gen = DataLoader(dataset = val_dataset, batch_size = batch_size,  shuffle = False)
test_gen = DataLoader(dataset = test_dataset, batch_size = batch_size, shuffle = False)

# LightningDataModule

In [10]:
class AlienPredatorDataModule(pl.LightningDataModule):
    def __init__(self, batch_size, data_dir: str = './alien_predator/', train_dataset_path='/content/alien_predator/data/train', test_dataset_path='/content/alien_predator/data/validation'):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.train_dataset_path = train_dataset_path
        self.test_dataset_path = test_dataset_path
        self.image_size = (256, 256)
        self.imagenet_transform = Compose([Resize(image_size), ToTensor(), Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
        self.num_classes = 2
        self.zip_name = 'alien_predator.zip'


    def prepare_data(self):
        if not osp.isfile(self.zip_name):
            gdown.download('https://drive.google.com/uc?id=19dSNIsEGoScG4AIxG0eOxjHxJu8AP5HZ', output=self.zip_name, quiet=False)

        if not osp.isdir(self.data_dir):
            with zipfile.ZipFile(self.zip_name, 'r') as zip_ref:
                zip_ref.extractall(self.data_dir)

    def setup(self, stage=None):
        # train/val
        if stage == 'fit' or stage is None:
            dataset = ImageFolder(self.train_dataset_path, transform=self.imagenet_transform)
            train_dataset_size = int(len(dataset) * 0.9)
            self.train_dataset, self.val_dataset = random_split(dataset, [train_dataset_size, len(dataset) - train_dataset_size])
        # test
        if stage == 'test' or stage is None:
            self.test_dataset = ImageFolder(self.test_dataset_path, transform=self.imagenet_transform)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size)

# ImagePredictionLogger

In [11]:
class ImagePredictionLogger(Callback):
    def __init__(self, val_samples, num_samples=32):
        super().__init__()
        self.num_samples = num_samples
        self.val_imgs, self.val_labels = val_samples

    def on_validation_epoch_end(self, trainer, pl_module):
        val_imgs = self.val_imgs.to(device=pl_module.device)
        val_labels = self.val_labels.to(device=pl_module.device)
        logits = pl_module(val_imgs)
        preds = torch.argmax(logits, -1)
        trainer.logger.experiment.log({
            "examples":[wandb.Image(x, caption=f"Pred:{pred}, Label:{y}")
                           for x, pred, y in zip(val_imgs[:self.num_samples],
                                                 preds[:self.num_samples],
                                                 val_labels[:self.num_samples])]
            })

# Earlystopping

In [12]:
early_stop_callback = EarlyStopping(
   monitor='val_loss',
   patience=3,
   verbose=False,
   mode='min'
)

# Model Checkpoint

In [13]:
MODEL_CKPT_PATH = 'model/'
MODEL_CKPT = 'model-{epoch:02d}-{val_loss:.2f}'

checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    dirpath=MODEL_CKPT_PATH,
    filename=MODEL_CKPT,
    save_top_k=3,
    mode='min')

LightningModule with ResNet18

In [17]:
from torchvision import models

class TransferResNet18LitModel(pl.LightningModule):
    def __init__(self, num_classes, learning_rate=1e-3):
        super().__init__()
        self.save_hyperparameters()
        self.learning_rate = learning_rate
        self.num_classes = num_classes

        self.model = models.resnet18(pretrained=True)

        for param in self.model.parameters():
            param.requires_grad = False

        self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)

    def forward(self, x):
        return self.model(x)

    def compute_loss(self, x, y):
        return F.cross_entropy(x, y)

    def common_step(self, batch, batch_idx):
        x, y = batch
        outputs = self(x)
        loss = self.compute_loss(outputs,y)
        return loss, outputs, y

    def common_test_valid_step(self, batch, batch_idx):
        loss, outputs, y = self.common_step(batch, batch_idx)
        preds = torch.argmax(outputs, dim=1)
        acc = torchmetrics.functional.accuracy(preds, y, num_classes = self.num_classes, task="multiclass")
        return loss, acc

    def training_step(self, batch, batch_idx):
        loss, acc = self.common_test_valid_step(batch, batch_idx)
        self.log('train_loss', loss, on_step=True, on_epoch=True, logger=True)
        self.log('train_acc', acc, on_step=True, on_epoch=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        loss, acc = self.common_test_valid_step(batch, batch_idx)
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)
        return loss

    def test_step(self, batch, batch_idx):
        loss, acc = self.common_test_valid_step(batch, batch_idx)
        self.log('test_loss', loss, prog_bar=True)
        self.log('test_acc', acc, prog_bar=True)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

# Training and evaluation

In [18]:
dm =  AlienPredatorDataModule(batch_size=32)
dm.prepare_data()
dm.setup()

val_samples = next(iter(dm.val_dataloader()))
val_imgs, val_labels = val_samples[0], val_samples[1]
val_imgs.shape, val_labels.shape

(torch.Size([32, 3, 256, 256]), torch.Size([32]))

In [19]:
model = TransferResNet18LitModel(dm.num_classes)

wandb_logger = WandbLogger(project='wandb-lightning', job_type='train')

trainer = pl.Trainer(accelerator="gpu", devices=[0], max_epochs=50, logger=wandb_logger, callbacks=[checkpoint_callback, early_stop_callback, ImagePredictionLogger(val_samples)])

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 156MB/s]
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


In [20]:
trainer.fit(model, dm)

[34m[1mwandb[0m: Currently logged in as: [33madrian-barczuk[0m. Use [1m`wandb login --relogin`[0m to force relogin


INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name  | Type   | Params
---------------------------------
0 | model | ResNet | 11.2 M
---------------------------------
1.0 K     Trainable params
11.2 M    Non-trainable params
11.2 M    Total params
44.710    Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=1` in the `DataLoader` to improve performance.
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=1` in the `DataLoader` to improve performance.
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/fit_loop.py:293: The number of training batches (20) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

In [21]:
trainer.test(model=model, datamodule=dm)

wandb.finish()

INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=1` in the `DataLoader` to improve performance.


Testing: |          | 0/? [00:00<?, ?it/s]

VBox(children=(Label(value='32.070 MB of 33.797 MB uploaded\r'), FloatProgress(value=0.9489090332365746, max=1…

0,1
epoch,▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇█
test_acc,▁
test_loss,▁
train_acc_epoch,▁▆▇▇▇▇▇█▇▇█▇▇
train_acc_step,▃▆▆▁█
train_loss_epoch,█▅▃▃▂▂▂▁▂▂▁▁▂
train_loss_step,█▃▅▅▁
trainer/global_step,▁▁▁▂▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇████
val_acc,▂▆▅▅▇▇▇▇██▅▅▁
val_loss,█▅▄▄▂▂▂▂▁▁▁▂▃

0,1
epoch,13.0
test_acc,0.9
test_loss,0.26688
train_acc_epoch,0.91827
train_acc_step,1.0
train_loss_epoch,0.20677
train_loss_step,0.08769
trainer/global_step,260.0
val_acc,0.84286
val_loss,0.26451


# Saving checkpoints as W&B artifacts

In [22]:
run = wandb.init(project='wandb-lightning', job_type='producer')

artifact = wandb.Artifact('model', type='model')
artifact.add_dir(MODEL_CKPT_PATH)

run.log_artifact(artifact)
run.join()

[34m[1mwandb[0m: Adding directory to artifact (./model)... Done. 0.4s


VBox(children=(Label(value='128.186 MB of 128.196 MB uploaded\r'), FloatProgress(value=0.9999217695629636, max…