In [1]:
# Lighnintmodule definiere
import pytorch_lightning as pl
from pytorch_lightning import Trainer 
from pytorch_lightning.loggers import WandbLogger

import torchmetrics as tm

## PyTorch Libraries
import torch
import torchvision
import torchvision.transforms as transforms

import torch.nn as nn 
# import torch.nn.functional as F for activation functions
import torch.nn.functional as F

# import pytorch optimizer SGD 
import torch.optim as optim

import wandb

In [2]:
# daten einlesen
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

cifar10_train = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
cifar10_test = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data\cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting ./data\cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [3]:
# modell definieren
class MLP_test(pl.LightningModule):
    def __init__(self, hparams):
        super().__init__()
        self.fc1 = nn.Linear(3072, 512)
        self.fc2 = nn.Linear(512, 128)
        self.fc3 = nn.Linear(128, 10)

        self.hparams.update(hparams)

        self.accuracy = tm.Accuracy(task = 'multiclass', num_classes = 10)
        self.precision = tm.Precision(task = 'multiclass', num_classes = 10)
        self.recall = tm.Recall(task = 'multiclass', num_classes = 10)
        self.f1score = tm.F1Score(task = 'multiclass', num_classes = 10)


    def forward(self, x):
        x = x.view(-1, 3072)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        x = self.fc3(x)
        return x
    
    def loss_fn(self, output, target):
        return nn.CrossEntropyLoss()(output, target)
    
    def configure_optimizers(self):
        return optim.SGD(self.parameters(), lr = self.hparams.lr, momentum = self.hparams.momentum, weight_decay = self.hparams.weight_decay)
    
    def training_step(self, batch, batch_idx):
        data, target = batch
        output = self(data)
        loss = self.loss_fn(output, target)

        acc = self.accuracy(output, target)
        prec = self.precision(output, target)
        rec = self.recall(output, target)
        f1 = self.f1score(output, target)

        self.log('train_loss', loss, on_step=True, on_epoch=True)
        self.log('train_acc', acc, on_step=True, on_epoch=True)
        self.log('train_prec', prec, on_step=True, on_epoch=True)
        self.log('train_rec', rec, on_step=True, on_epoch=True)
        self.log('train_f1', f1, on_step=True, on_epoch=True)
        return loss
    

    def validation_step(self, batch, batch_idx):
        data, target = batch
        output = self(data)
        loss = self.loss_fn(output, target)

        acc = self.accuracy(output, target)
        prec = self.precision(output, target)
        rec = self.recall(output, target)
        f1 = self.f1score(output, target)

        self.log('val_loss', loss, on_step=True, on_epoch=True)
        self.log('val_acc', acc, on_step=True, on_epoch=True)
        self.log('val_prec', prec, on_step=True, on_epoch=True)
        self.log('val_rec', rec, on_step=True, on_epoch=True)
        self.log('val_f1', f1, on_step=True, on_epoch=True)
        return loss



In [4]:
def train():
    run = wandb.init()

    # wandb.config is a variable that holds and saves hyperparameters and inputs
    config = run.config

    # define the model
    model = MLP_test(config)

    # trainer
    trainer = Trainer(accelerator='auto',
                        max_epochs=config.epochs,
                        logger=WandbLogger(),
                        log_every_n_steps=1,
                        enable_progress_bar=True)
    # transformer
    transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    
    # traindataloader
    train_loader = torch.utils.data.DataLoader(cifar10_train, batch_size=config.batch_size, num_workers=2)
    # testdataloader
    test_loader = torch.utils.data.DataLoader(cifar10_test, batch_size=config.batch_size, num_workers=2)

    # train the model
    trainer.fit(model, train_loader, test_loader)

    # finish wandb run
    run.finish()

    


In [5]:
sweep_config = {
    'method': 'bayes', # grid, random
    'metric': {
        'name': 'accuracy',
        'goal': 'maximize'
    },
    'parameters': {
        'lr': {
            'values': [0.001, 0.01, 0.1]
        },
        'momentum': {
            'values': [0.1, 0.5, 0.9]
        },
        'batch_size': {
            'values': [4, 8, 16]
        },
        'weight_decay': {
            'values': [0.0001, 0.001, 0.01]
        },
        'epochs': {
            'values': [10, 20, 30]
        }
    }
}

sweep_id = wandb.sweep(sweep_config, project="cifar10_test")

# run agent
wandb.agent(sweep_id, function=train, count=10)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


Create sweep with ID: p2ugv0xg
Sweep URL: https://wandb.ai/7ben18/cifar10_test/sweeps/p2ugv0xg


[34m[1mwandb[0m: Agent Starting Run: ldrrg1hi with config:
[34m[1mwandb[0m: 	batch_size: 4
[34m[1mwandb[0m: 	epochs: 20
[34m[1mwandb[0m: 	lr: 0.01
[34m[1mwandb[0m: 	momentum: 0.9
[34m[1mwandb[0m: 	weight_decay: 0.0001
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33m7ben18[0m. Use [1m`wandb login --relogin`[0m to force relogin


  rank_zero_warn(
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name      | Type                | Params
--------------------------------------------------
0 | fc1       | Linear              | 1.6 M 
1 | fc2       | Linear              | 65.7 K
2 | fc3       | Linear              | 1.3 K 
3 | accuracy  | MulticlassAccuracy  | 0     
4 | precision | MulticlassPrecision | 0     
5 | recall    | MulticlassRecall    | 0     
6 | f1score   | MulticlassF1Score   | 0     
--------------------------------------------------
1.6 M     Trainable params
0         Non-trainable params
1.6 M     Total params
6.561     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]