In [None]:
import os
from pathlib import Path

import pytorch_lightning as pl
import ray
import torch
import torch.nn as nn
from pytorch_lightning.loggers import WandbLogger
from ray import tune
from ray.tune.integration.pytorch_lightning import TuneReportCallback
from ray.tune.schedulers import ASHAScheduler
from torch.optim import lr_scheduler, SGD
from torch.utils.data import DataLoader
from torchvision import models, transforms
from torchvision.datasets import ImageFolder

In [None]:
GPUS = 0
PROJECT_PATH = '/home/maria/Documents/DL_Project'

In [None]:
def prepare_data(path: Path):
    # Data transformation need for ResNet18. It applies only basic cropping
    # and normalization.
    data_transforms = {
        "train":
            transforms.Compose([
                transforms.Resize(128),
                transforms.RandomResizedCrop(112),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406],
                                     [0.229, 0.224, 0.225])
            ]),
        "val":
            transforms.Compose([
                transforms.Resize(128),
                transforms.CenterCrop(112),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406],
                                     [0.229, 0.224, 0.225])
            ]),
        "test":
            transforms.Compose([
                transforms.Resize(128),
                transforms.CenterCrop(112),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406],
                                     [0.229, 0.224, 0.225])
            ])
    }


    # Creates dataset based on a given path.
    image_datasets = {
        mode: ImageFolder(path / mode, data_transforms[mode])
        for mode in ["train", "val", "test"]
    }
    #Creates dataloaders from ImageFolders.
    dataloaders = {
        mode: DataLoader(image_datasets[mode],
                         batch_size=4,
                         shuffle=True,
                         num_workers=2) for mode in ["train", "val", "test"]
    }

    dataset_sizes = {
        mode: len(image_datasets[mode]) for mode in ["train", "val", "test"]
    }
    class_names = image_datasets["train"].classes
    return dataloaders, dataset_sizes, class_names

In [None]:
class CNNLit(pl.LightningModule):

    def __init__(self, config):
        super().__init__()

        self.lr = config["lr"]
        self.momentum = config["momentum"]
        self.step_size = config["step_size"]
        self.gamma = config["gamma"]
        self.save_hyperparameters()

        self.cnn_layers = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(32, 64, kernel_size=5),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        
        self.linear_layers = nn.Sequential(
            nn.Linear(40000, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 14),
        )
        
    def forward(self, x):
        x = self.cnn_layers(x)
        x = x.view(x.size(0), -1)
        outputs = self.linear_layers(x)
        return outputs

    def loss_fn(self, out, target):
        return nn.CrossEntropyLoss()(out, target)

    def configure_optimizers(self):
        optimizer = SGD(self.parameters(), lr=self.lr, momentum=self.momentum)
        scheduler = lr_scheduler.StepLR(optimizer, step_size=self.step_size, gamma=self.gamma)
        return [optimizer], [scheduler]

    def training_step(self, batch, batch_idx):
        x, y = batch
        outputs = self(x)
        loss = self.loss_fn(outputs, y)
        _, preds = torch.max(outputs, 1)
        correct_preds = torch.sum(preds == y.data)
        return {"correct": correct_preds, "loss": loss, "total": len(y.data)}

    def validation_step(self, batch, batch_idx):
        x, y = batch
        outputs = self(x)
        loss = self.loss_fn(outputs, y)
        _, preds = torch.max(outputs, 1)
        correct_preds = torch.sum(preds == y.data)
        return {"correct": correct_preds, "loss": loss, "total": len(y.data)}
    
    def test_step(self, batch, batch_idx):
        x, y = batch
        outputs = self(x)
        loss = self.loss_fn(outputs, y)
        _, preds = torch.max(outputs, 1)
        correct_preds = torch.sum(preds == y.data)
        return {"correct": correct_preds, "loss": loss, "total": len(y.data)}

    def training_epoch_end(self, outputs):
        avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
        correct=sum([x["correct"] for x in outputs])
        total=sum([x["total"] for x in outputs])
        self.log('train_accuracy', correct/total)
        self.log('train_loss',  avg_loss)

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
        correct=sum([x["correct"] for x in outputs])
        total=sum([x["total"] for x in outputs])

        self.log('val_accuracy', correct/total)
        self.log('val_loss',  avg_loss)

    def test_epoch_end(self, outputs):
        avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
        correct=sum([x["correct"] for x in outputs])
        total=sum([x["total"] for x in outputs])
        self.log('test_accuracy', correct/total)
        self.log('test_loss',  avg_loss)

In [None]:
def train_model(config, num_epochs=10):

    dataloaders, dataset_sizes, class_names = prepare_data(
        Path(f"{PROJECT_PATH}/data/02_intermediate/sharks"))

    model = CNNLit(config)

    checkpoint_callback = pl.callbacks.ModelCheckpoint(
        monitor='val_accuracy',
        dirpath=f'{PROJECT_PATH}/data/06_models/',
        filename='cnn_model-{epoch:02d}-{val_accuracy:.2f}',
        save_top_k=1,
        mode='max')

    metrics = {"loss": "val_loss", "acc": "val_accuracy"}
    tune_callback = TuneReportCallback(metrics, on="validation_end")

    trainer = pl.Trainer(logger=WandbLogger(save_dir=f"{PROJECT_PATH}/logs/", project="cnn_hyperparams_search"), gpus=GPUS, max_epochs=num_epochs, callbacks=[checkpoint_callback, tune_callback])
    trainer.fit(model, dataloaders["train"], dataloaders["val"])

In [None]:
torch.cuda.empty_cache()

ray.shutdown()
ray.init(log_to_driver=False, object_store_memory=10**9)

config = {
    "lr": tune.uniform(0.0001, 0.01),
    "momentum": tune.uniform(0.05, 0.5),
    "step_size": tune.choice([1, 2, 4, 6, 8, 10]),
    "gamma": tune.uniform(0.05, 0.5)
}

trainable = tune.with_parameters(
    train_model,
    num_epochs=30
)

analysis = tune.run(
    trainable,
    resources_per_trial={
        "cpu": 6,
        "gpu": GPUS
    },
    config=config,
    num_samples = 20,
    scheduler=ASHAScheduler(metric="acc", mode="max"),
    name = "tune_cnn")

In [None]:
print(analysis.get_best_config("acc", "max"))

In [None]:
dfs = analysis.trial_dataframes

In [None]:
# Plot by epoch
import matplotlib.pyplot as plt
ax = None  # This plots everything on the same plot
for d in dfs.values():
    print(d)
    ax = plt.plot(d.acc)