In [None]:
import pandas as pd
import torch
from torch import nn
from torch.utils.data import DataLoader, random_split
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
import torchmetrics
import wandb
import gc

from dataset import CoughDataset

In [None]:
class CoughDataModule(pl.LightningDataModule):
    def __init__(self, 
                 df, 
                 data_path, 
                 batch_size=32, 
                 num_workers=4, 
                 train_size=0.8, 
                 val_size=0.1, 
                 test_size=0.1,
                 duration=10.0,
                 sample_rate=48000,
                 channels=1,
                 n_mels=64,
                 n_fft=1024, 
                 top_db=80):
        super().__init__()
        
        self.df = df
        self.data_path = data_path
        self.batch_size = batch_size
        self.num_workers = num_workers
        
        self.train_size = train_size
        self.val_size = val_size
        self.test_size = test_size
        
        self.duration = duration
        self.sample_rate = sample_rate
        self.channels = channels
        
        self.n_mels = n_mels
        self.n_fft = n_fft
        self.top_db = top_db
        
        if self.train_size + self.val_size + self.test_size != 1.0:
            raise Exception('train_size + val_size + test_size must be equal to 1.0')
        
    def setup(self, stage=None):
        dataset = CoughDataset(df=self.df, 
                               data_path=self.data_path,
                               duration=self.duration,
                               sample_rate=self.sample_rate,
                               channels=self.channels,
                               n_mels=self.n_mels,
                               n_fft=self.n_fft,
                               top_db=self.top_db)

        self.train_dataset, self.val_dataset, self.test_dataset = random_split(dataset, [self.train_size, self.val_size, self.test_size])
            
    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)
    
    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)
    
    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)        

In [None]:
class CoughModel(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        
        # First Convolution Block with Relu and Batch Norm. Use Kaiming Initialization
        self.conv1 = nn.Conv2d(1, 8, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
        self.relu1 = nn.ReLU()
        self.bn1 = nn.BatchNorm2d(8)
        nn.init.kaiming_normal_(self.conv1.weight, a=0.1)
        self.conv1.bias.data.zero_()

        # Second Convolution Block
        self.conv2 = nn.Conv2d(8, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        self.relu2 = nn.ReLU()
        self.bn2 = nn.BatchNorm2d(16)
        nn.init.kaiming_normal_(self.conv2.weight, a=0.1)
        self.conv2.bias.data.zero_()

        # Third Convolution Block
        self.conv3 = nn.Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        self.relu3 = nn.ReLU()
        self.bn3 = nn.BatchNorm2d(32)
        nn.init.kaiming_normal_(self.conv3.weight, a=0.1)
        self.conv3.bias.data.zero_()

        # Forth Convolution Block
        self.conv4 = nn.Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        self.relu4 = nn.ReLU()
        self.bn4 = nn.BatchNorm2d(64)
        nn.init.kaiming_normal_(self.conv4.weight, a=0.1)
        self.conv4.bias.data.zero_()

        # Linear Classifier
        self.ap = nn.AdaptiveAvgPool2d(output_size=1)
        self.lin = nn.Linear(in_features=64, out_features=3)
        self.act = nn.Softmax(dim=1)
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.bn1(x)
        
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.bn2(x)
        
        x = self.conv3(x)
        x = self.relu3(x)
        x = self.bn3(x)
        
        x = self.conv4(x)
        x = self.relu4(x)
        x = self.bn4(x)
        
        x = self.ap(x)
        x = x.view(x.shape[0], -1)
        
        x = self.lin(x)
        x = self.act(x)
        
        return x

In [None]:
class LitCoughClassifier(pl.LightningModule):
    def __init__(self, model, learning_rate=1e-3):
        super().__init__()
        
        # Model
        self.model = model
        self.criterion = nn.CrossEntropyLoss()
        
        # Hyperparameters
        self.learning_rate = learning_rate
        
        # Metrics
        self.accuracy = torchmetrics.Accuracy(task='multiclass', num_classes=3)
        self.f1 = torchmetrics.F1Score(task='multiclass', num_classes=3)
        # self.precision_recall_curve = torchmetrics.PrecisionRecallCurve(task='multiclass', num_classes=3)
        # self.confusion_matrix = torchmetrics.ConfusionMatrix(task='multiclass', num_classes=3)
        
    def training_step(self, batch, batch_idx):
        x, y = batch
        
        y_hat = self.model(x)
        
        loss = self.criterion(y_hat, y)
        self.log('train_loss', loss, on_epoch=True, logger=True)
        
        # Metrics
        accuracy = self.accuracy(y_hat, y)
        self.log('train_accuracy', accuracy, on_epoch=True, prog_bar=True, logger=True)
        f1 = self.f1(y_hat, y)
        self.log('train_f1', f1, on_epoch=True, prog_bar=True, logger=True)

        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        x, y = x.to(self.device), y.to(self.device)
        
        # TODO: Move to dataset
        # # Normalize the input
        # x = (x - x.mean()) / x.std()
        
        y_hat = self.model(x)
        
        loss = self.criterion(y_hat, y)
        self.log('val_loss', loss, on_epoch=True, logger=True)
        
        # Metrics
        accuracy = self.accuracy(y_hat, y)
        self.log('val_accuracy', accuracy, on_epoch=True, logger=True)
        f1 = self.f1(y_hat, y)
        self.log('val_f1', f1, on_epoch=True, logger=True)
        
        return loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
        return optimizer

In [None]:
METADATA_FILE = 'data/metadata_compiled.csv'
DATA_PATH = 'data/'

metadata_df = pd.read_csv(METADATA_FILE)

In [None]:
not_nan_df = metadata_df[metadata_df['status'].isna() == False]
filtered_df = not_nan_df[not_nan_df['cough_detected'] > 0.9]    # TODO: Set as a hyperparameter
filtered_df[['uuid', 'cough_detected', 'SNR', 'age', 'gender', 'status']]

In [None]:
def train(config=None):
    with wandb.init(config=config):
        config = wandb.config
        
        gc.collect()
        torch.cuda.empty_cache()
        
        try:    
            torch.manual_seed(69) # noice
            
            wandb_logger = WandbLogger(log_model=True)
            
            data_module = CoughDataModule(df=filtered_df, 
                              data_path=DATA_PATH, 
                              batch_size=config.batch_size, 
                              )

            model = CoughModel()
            classifier = LitCoughClassifier(model=model, 
                                            learning_rate=config.learning_rate)
            
            trainer = pl.Trainer(
                max_epochs=config.max_epochs,
                logger=wandb_logger,
            )
            
            trainer.fit(classifier, data_module)
            
        except Exception as e:
            print(e)
            wandb.finish()
            raise e
        
        del wandb_logger
        del data_module
        del model
        del classifier
        del trainer
        
        gc.collect()
        torch.cuda.empty_cache()

In [None]:
sweep_config = {
  "method": "bayes",
  "metric": {
        "name": "1_val/sharpe",
        "goal": "maximize"
  },
  "parameters": {
    "batch_size": {
        "values": [32, 64, 128]
    },
    "max_epochs": {
        "values": [30, 45, 60]
    },
    "learning_rate": {
        "min": 0.001,
        "max": 0.01
    },
    "sample_rate": {
      "values": [16000, 22050, 44100, 48000]
    },
    "n_fft": {
      "values": [512, 1024, 2048]
    },
    "n_mels": {
      "values": [32, 64, 128, 256]
    },
  }
}

sweep_id = wandb.sweep(sweep_config, project='cough-classifier', entity='dl-miniproject')

wandb.agent(sweep_id, function=train, count=1)