In [1]:
# install pytorch lithening
!pip install pytorch-lightning --quiet
!pip install wandb -Uq

In [2]:
import pytorch_lightning as pl
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader,random_split
from torchmetrics import Accuracy
from torchvision import transforms
from torchvision.datasets import CIFAR10
from pytorch_lightning.loggers import WandbLogger
import wandb


In [3]:
# create one class to deal with data
class CifarDataModule(pl.LightningDataModule):
  def __init__(self, batch_size, data_dir="./"):
    super().__init__()
    self.data_dir=data_dir
    self.batch_size=batch_size
    self.transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])
    self.num_classes=10

  def prepare_data(self):
    CIFAR10(self.data_dir,train=True,download=True)
    CIFAR10(self.data_dir,train=False,download=True)

  def setup(self, stage=None):
    if stage=='fit' or stage is None:
      cifar_full=CIFAR10(self.data_dir,train=True,transform=self.transform)
      self.cifar_train,self.cifar_val=random_split(cifar_full,[45000,5000])

    if stage=='test' or stage is None:
      self.cifar_test=CIFAR10(self.data_dir,train=False,transform=self.transform)

  def train_dataloader(self):
    return DataLoader(self.cifar_train,batch_size=self.batch_size,shuffle=True,num_workers=2)

  def val_dataloader(self):
    return DataLoader(self.cifar_val,batch_size=self.batch_size,shuffle=False,num_workers=2)

  def test_dataloader(self):
    return DataLoader(self.cifar_test,batch_size=self.batch_size,shuffle=False,num_workers=2)




In [9]:
class CIFAR10LitModel(pl.LightningModule):
    def __init__(self, input_shape,num_classes,config,learning_rate=3e-4):
      super().__init__()
      self.save_hyperparameters()
      self.input_shape=input_shape
      self.learning_rate=learning_rate
      self.activation=F.relu

      # model architecture
      self.conv1=nn.Conv2d(3,32,3,1)
      self.conv2=nn.Conv2d(32,32,3,1)
      self.conv3=nn.Conv2d(32,64,3,1)
      self.conv4=nn.Conv2d(64,64,3,1)
      self.pool1=nn.MaxPool2d(2)
      self.pool2=nn.MaxPool2d(2)

      n_sizes = self._get_output_shape(input_shape)
      # Edited by Indra Narayan Dutta: Number of neurons now added as hyperparameter using sweep
      fc1_neurons = config.fc1;
      fc2_neurons = config.fc2;
      fc3_neurons = config.fc3;
      fc4_neurons = config.fc4;
      self.fc1=nn.Linear(n_sizes,fc1_neurons)
      self.fc2=nn.Linear(fc1_neurons,fc2_neurons)
      self.fc3=nn.Linear(fc2_neurons,fc3_neurons)
      self.fc4=nn.Linear(fc3_neurons,num_classes)

      self.train_acc=Accuracy(task='multiclass',num_classes=10)
      self.val_acc=Accuracy(task='multiclass',num_classes=10)
      self.test_acc=Accuracy(task='multiclass',num_classes=10)


    def _get_output_shape(self, shape):
          '''returns the size of the output tensor from the conv layers'''
          batch_size = 1
          input = torch.autograd.Variable(torch.rand(batch_size, *shape))
          output_feat = self._feature_extractor(input)
          n_size = output_feat.data.view(batch_size, -1).size(1)
          return n_size


  # conv1,relu, conv2,relu, maxpool,conv3,relu,conv4,relu,maxpool
    def _feature_extractor(self,x):
      x=self.activation(self.conv1(x))
      x=self.pool1(F.relu(self.conv2(x)))
      x=self.activation(self.conv3(x))
      x=self.pool2(F.relu(self.conv4(x)))
      return x


    def forward(self,x):
      x=self._feature_extractor(x)
      x=x.view(x.size(0),-1)
      x=self.activation(self.fc1(x))
      x=self.activation(self.fc2(x))                                  #Indra Narayan Dutta: new fc2 layer of 256 added to forward pass
      x=self.activation(self.fc3(x))
      x=F.log_softmax(self.fc4(x),dim=1)
      return x

    def training_step(self, batch, batch_idx):
      x, y = batch
      logits = self(x)
      loss = F.nll_loss(logits, y)
      # metric
      preds = torch.argmax(logits, dim=1)
      acc = self.train_acc(preds, y)
      self.log('train_loss', loss, on_step=True, on_epoch=True, logger=True)
      self.log('train_acc', acc, on_step=True, on_epoch=True, logger=True)
      return loss

    # validation loop
    def validation_step(self, batch, batch_idx):
      x, y = batch
      logits = self(x)
      loss = F.nll_loss(logits, y)
      preds = torch.argmax(logits, dim=1)
      acc = self.val_acc(preds, y)
      self.log('val_loss', loss, prog_bar=True)
      self.log('val_acc', acc, prog_bar=True)
      return loss

    # test loop
    def test_step(self,batch,batch_idx):
      x,y=batch
      logits=self(x)
      loss=F.nll_loss(logits,y)

      pred=torch.argmax(logits,dim=1)
      acc=self.test_acc(pred,y)
      self.log('test_loss',loss,on_epoch=True)
      self.log('test_acc',acc,on_epoch=True)
      return loss

    def configure_optimizers(self):
      optimizer=torch.optim.Adam(self.parameters(),self.learning_rate)
      return optimizer



In [5]:
# class for visualizing one batch of validation images along with predicted and rall class label
class ImagePredictionLogger(pl.Callback):
    def __init__(self, val_samples, num_samples=32):
        super().__init__()
        self.val_imgs, self.val_labels = val_samples
        self.val_imgs = self.val_imgs[:num_samples]
        self.val_labels = self.val_labels[:num_samples]

    def on_validation_epoch_end(self, trainer, pl_module):
        val_imgs = self.val_imgs.to(device=pl_module.device)
        logits = pl_module(val_imgs)
        preds = torch.argmax(logits, 1)

        trainer.logger.experiment.log({
            "examples": [wandb.Image(x, caption=f"Pred:{pred}, Label:{y}")
                            for x, pred, y in zip(val_imgs, preds, self.val_labels)],
            "global_step": trainer.global_step
            })

In [6]:
cifar = CifarDataModule(batch_size=32, data_dir="./")
cifar.prepare_data()
cifar.setup()
# grab samples to log predictions on
samples = next(iter(cifar.val_dataloader()))

Files already downloaded and verified
Files already downloaded and verified


In [12]:
from wandb.env import CONFIG_DIR
### WandB, you have have an account(if you don't, create one)
wandb.login(key='')
sweep_config = {
    'method': 'random'
    }
metric = {
    'name': 'loss',
    'goal': 'minimize'
    }
sweep_config['metric'] = metric
parameters_dict = {
    'fc1': {
        'values': [128, 256, 512]
        },
    'fc2': {
        'values': [128, 256, 512]
        },
    'fc3': {
        'values': [128, 256, 512]
        },
    'fc4': {
        'values': [128, 256, 512]
        },
    'dropout': {
          'values': [0.3, 0.4, 0.5]
        },
    }

sweep_config['parameters'] = parameters_dict
parameters_dict.update({
    'epochs': {
        'value': 1}
    })
parameters_dict.update({
    'batch_size': {
        # integers between 32 and 256
        # with evenly-distributed logarithms
        'distribution': 'q_log_uniform_values',
        'q': 8,
        'min': 32,
        'max': 256,
      }
    })

wandb_logger = WandbLogger(project='lastt', job_type='train', log_model="all")
sweep_id = wandb.sweep(sweep=sweep_config, project="lastt")
def train_model(learning_rate=1e-3):

    config=wandb.config


    with wandb.init(config=config):
        # If called by wandb.agent, as below,
        # this config will be set by Sweep Controller
        config = wandb.config
        # instantiate classes
        dm = CifarDataModule(config.batch_size)
        dm.prepare_data()
        dm.setup()
        model = CIFAR10LitModel((3, 32, 32), dm.num_classes, config=config)
        wandb_logger.watch(model)
        # Initialize Callbacks
        checkpoint_callback = pl.callbacks.ModelCheckpoint()
        early_stop_callback = pl.callbacks.EarlyStopping(monitor="val_acc", patience=3, verbose=False, mode="max")
        ### WandB
        trainer = pl.Trainer(max_epochs=5,
                     logger=wandb_logger,
                     callbacks=[checkpoint_callback, early_stop_callback,ImagePredictionLogger(samples)]
                    )
      # Train the model



    # Evaluate the model
    trainer.test(dataloaders=cifar.test_dataloader())
    # tell the WandB you have finished
    wandb.finish()



Create sweep with ID: 1xjzw1l3
Sweep URL: https://wandb.ai/msc_bme/lastt/sweeps/1xjzw1l3


In [14]:

wandb.agent(sweep_id, function=train_model, count=10)

[34m[1mwandb[0m: Agent Starting Run: yn6ioku0 with config:
[34m[1mwandb[0m: 	batch_size: 64
[34m[1mwandb[0m: 	dropout: 0.5
[34m[1mwandb[0m: 	epochs: 1
[34m[1mwandb[0m: 	fc1: 128
[34m[1mwandb[0m: 	fc2: 128
[34m[1mwandb[0m: 	fc3: 128
[34m[1mwandb[0m: 	fc4: 256
[34m[1mwandb[0m: Ctrl + C detected. Stopping sweep.
