In [None]:
!pip install pytorch-lightning gdown wandb --upgrade
!gdown --id 1_bAXzdCRBjoPSkO_MrQ_FRoM_-npeJll

In [None]:
!unzip -q data.zip

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim  
import torchvision.transforms as transforms
import torchvision
import os
import pandas as pd
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets,models
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
from torchvision.io import read_image
from pytorch_lightning.loggers import WandbLogger
import wandb

In [None]:
wandb.login()

In [None]:
PATH = './'
NUM_CLASSES = 10
BATCH_SIZE = 16
lr = 5e-5
epochs = 100

config = {
    'base_model': 'ResNet50',
    'num_classes': NUM_CLASSES,
    'batch_size': BATCH_SIZE,
    #'learning_rate': lr,
    'frozen_layers': 0,
    'frozen_blocks': 0,
    'epochs': 100,
    #'image_size': (480, 640),
    'Augmentation': "Color Jitter"
}

In [None]:
#search space
import math

sweep_config = {
    'method': 'bayes',
    'metric':{
      'name': 'val_acc',
      'goal': 'maximize'   
      },
    'parameters':{
        'learning_rate':{
            'distribution': 'uniform',
            'min': 1e-6,
            'max': 1e-4
            },
        'dropout':{
            'distribution': 'uniform',
            'min': 0,
            'max': 0.3
            },
        'decay':{
            'distribution': 'uniform',
            'min': 0.1,
            'max': 10.0
        },
        'input_shape':{
            'values':[
                (128,128),
                (480,640),
                (224,224),
                (300,300),
                (350,400),
                (400,400),
                (450,500),
                (500,500),
            ]
        },
        'brightness':{
            'distribution': 'uniform',
            'min': 0,
            'max': 1.0
        },
        'contrast':{
            'distribution': 'uniform',
            'min': 0,
            'max': 1.0
        },
        'saturation':{
            'distribution': 'uniform',
            'min': 0,
            'max': 1.0
        }
    },
    'early_terminate':{
        'type': 'hyperband',
        'max_iter': 10,
        's': 2
    }
}

In [None]:
class CreateDataset(Dataset):
    def __init__(self, df, transform=False):
        self.df = df
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):   
        img_path = self.df.iloc[index, 0]
        image = read_image(PATH+img_path) / 255.0
        label = self.df.iloc[index, 1]

        if self.transform:
            image = self.transform(image)
        
        if self.df.iloc[index, 3] == "Camera 2":
            image = transforms.RandomHorizontalFlip(p=1.0)(image)
            if label == 4 or label == 3:
                label -= 2
            elif label == 1 or label == 2:
                label += 2

        return image, label

In [None]:
train_df = pd.read_csv(PATH+"data/train.csv")
val_df = pd.read_csv(PATH+"data/val.csv")
test_df = pd.read_csv(PATH+"data/test.csv")

In [None]:
import pytorch_lightning as pl
import torchmetrics
from torch import nn


class Model(pl.LightningModule):
    def __init__(self, output_units, learning_rate, dropout, weight_decay=0.1):
        super().__init__()
        self.base_model = torchvision.models.resnet50(pretrained=True)
        self.base_model.fc = torch.nn.Linear(in_features=self.base_model.fc.in_features, out_features=output_units)
        
        self.criterion = nn.CrossEntropyLoss()
        self.train_acc = torchmetrics.Accuracy()
        self.val_acc = torchmetrics.Accuracy()

        self.learning_rate = learning_rate
        self.dropout = torch.nn.Dropout(p=dropout)
        self.weight_decay = weight_decay
        self.save_hyperparameters()
        
    def forward(self, input_data):
        return self.base_model(input_data)

    def training_step(self, batch, batch_nb):
        input_data, targets = batch
        preds = self(input_data)
        loss = self.criterion(preds, targets)
        self.log('train_loss', loss)
        self.train_acc(preds, targets)
        self.log('train_acc', self.train_acc, on_step=True, on_epoch=False, prog_bar=True)
        return loss
    
    def validation_step(self, batch, batch_nb):
        self._evaluate(batch, 'val')
        
    def test_step(self, batch, batch_nb):
        self._evaluate(batch, 'test')
        
    def _evaluate(self, batch, name):
        input_data, targets = batch
        preds = self(input_data)
        loss = self.criterion(preds, targets)
        self.log(f'{name}_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.val_acc(preds, targets)
        self.log(f'{name}_acc', self.val_acc, on_step=False, on_epoch=True, prog_bar=True)
        
    def predict_step(self, batch, batch_nb):
        input_data, targets = batch
        preds = self(input_data)
        return torch.argmax(preds, dim=1)
    
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
        scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=self.learning_rate, max_lr=1e-4, cycle_momentum=False)
        return [optimizer],[scheduler]

In [None]:
import torch
torch.cuda.empty_cache()

In [1]:
def train(config={
                    "learning_rate": 1e-5,
                    "dropout": 0.2,
                    "decay": 0.1,
                    "input_shape": (480,640),
                    "brightness": 0.75,
                    "contrast": 0.75,
                    "saturation": 0.75,}):

    torch.cuda.empty_cache()
    # Initialize a new wandb run
    with wandb.init(job_type="train",config=config) as run:
        config = run.config

        wandb_logger = WandbLogger(project="Driver-Distraction", entity='graduation-project', config=config, experiment=run, log_model=True)

        #prepare data
        transformers_test = transforms.Compose([
        transforms.Resize(config.input_shape),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

        transformers_train = transforms.Compose([
            transforms.Resize(config.input_shape),
            transforms.ColorJitter(brightness=config.brightness, contrast=config.contrast, saturation=config.saturation),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

        train_dataset=CreateDataset(train_df, transformers_train)
        test_dataset=CreateDataset(test_df, transformers_test)
        val_dataset=CreateDataset(val_df, transformers_test)
        
        train_dataloader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
        test_dataloader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
        val_dataloader = DataLoader(dataset=val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

        # setup model
        model = Model(NUM_CLASSES, config.learning_rate, config.dropout, weight_decay=config.decay)

        callbacks = [
          pl.callbacks.ModelCheckpoint(monitor='val_acc', dirpath=PATH, verbose=True, mode='max', filename='resnet50-t3-{val_acc:.4f}'),
          pl.callbacks.EarlyStopping(monitor='val_acc', patience=20, verbose=True, mode='max')
        ]

        # setup Trainer
        trainer = pl.Trainer(
            logger=wandb_logger,    
            gpus=1,
            max_epochs=10,            
            callbacks=callbacks
            )

        # train
        trainer.fit(model, train_dataloader, val_dataloader)

In [None]:
sweep_id = wandb.sweep(sweep_config, project="Driver-Distraction", entity='graduation-project')

In [None]:
wandb.agent(sweep_id, function=train)