In [1]:
from pathlib import Path

import pytorch_lightning as pl
import torch
import torch.nn as nn
from pytorch_lightning.loggers import WandbLogger
from torch.optim import lr_scheduler, SGD
from torch.utils.data import DataLoader
from torchvision import models, transforms
from torchvision.datasets import ImageFolder

In [2]:
GPUS = 1

In [3]:
def prepare_data(path: Path):
    # Data transformation need for ResNet18. It applies only basic cropping
    # and normalization.
    data_transforms = {
        "train":
            transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406],
                                     [0.229, 0.224, 0.225])
            ]),
        "val":
            transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406],
                                     [0.229, 0.224, 0.225])
            ])
    }

    # Creates dataset based on a given path.
    image_datasets = {
        mode: ImageFolder(path / mode, data_transforms[mode])
        for mode in ["train", "val"]
    }
    #Creates dataloaders from ImageFolders.
    dataloaders = {
        mode: DataLoader(image_datasets[mode],
                         batch_size=4,
                         shuffle=True,
                         num_workers=4) for mode in ["train", "val"]
    }

    dataset_sizes = {
        mode: len(image_datasets[mode]) for mode in ["train", "val"]
    }
    class_names = image_datasets["train"].classes
    return dataloaders, dataset_sizes, class_names

In [9]:
class ResNet18Lit(pl.LightningModule):

    def __init__(self):
        super().__init__()
        self.model = models.resnet18(pretrained=True)
        for param in self.model.parameters():
            param.requires_grad = False
        num_features = self.model.fc.in_features
        self.model.fc = nn.Linear(num_features, 14)

    def forward(self, x):
        outputs = self.model(x)
        return outputs

    def loss_fn(self, out, target):
        return nn.CrossEntropyLoss()(out, target)

    def configure_optimizers(self):
        optimizer = SGD(self.model.fc.parameters(), lr=0.001, momentum=0.9)
        scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
        return [optimizer], [scheduler]

    def training_step(self, batch, batch_idx):
        x, y = batch
        outputs = self(x)
        loss = self.loss_fn(outputs, y)
        _, preds = torch.max(outputs, 1)
        correct_preds = torch.sum(preds == y.data)
        return {"correct": correct_preds, "loss": loss, "total": len(y.data)}

    def validation_step(self, batch, batch_idx):
        x, y = batch
        outputs = self(x)
        loss = self.loss_fn(outputs, y)
        _, preds = torch.max(outputs, 1)
        correct_preds = torch.sum(preds == y.data)
        return {"correct": correct_preds, "loss": loss, "total": len(y.data)}

    def training_epoch_end(self, outputs):
        avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
        correct=sum([x["correct"] for  x in outputs])
        total=sum([x["total"] for  x in outputs])
        self.log('train_accuracy', correct/total)
        self.log('train_loss',  avg_loss)

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
        correct=sum([x["correct"] for  x in outputs])
        total=sum([x["total"] for  x in outputs])
        self.log('val_accuracy', correct/total)
        self.log('val_loss',  avg_loss)

In [10]:
dataloaders, dataset_sizes, class_names = prepare_data(
    Path("../data/02_intermediate/sharks"))

In [11]:
model = ResNet18Lit()

checkpoint_callback = pl.callbacks.ModelCheckpoint(
    monitor='val_accuracy',
    dirpath='../data/06_models/',
    filename='model-{epoch:02d}-{val_accuracy:.2f}',
    save_top_k=1,
    mode='max')

trainer = pl.Trainer(logger=WandbLogger(save_dir='../logs/'), gpus=GPUS, max_epochs=10, callbacks=[checkpoint_callback])

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [12]:
trainer.fit(model, dataloaders["train"], dataloaders["val"])

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type   | Params
---------------------------------
0 | model | ResNet | 11.2 M
---------------------------------
7.2 K     Trainable params
11.2 M    Non-trainable params
11.2 M    Total params
44.735    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]

[34m[1mwandb[0m: Currently logged in as: [33mmaria_wyrzykowska[0m (use `wandb login --relogin` to force relogin)




Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]

