##Test

In [None]:
!pip install pytorch_lightning
!pip install segmentation-models-pytorch
!pip install torchmetrics

In [2]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [3]:
import zipfile
audio = zipfile.ZipFile('/content/gdrive/MyDrive/data.zip', 'r')
audio.extractall('/content')

In [4]:
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import segmentation_models_pytorch as smp
import albumentations as albu
from albumentations.pytorch import ToTensorV2
import pandas as pd
from pathlib import Path
import os
from tqdm.notebook import tqdm
import numpy as np
from torchmetrics import MeanSquaredError

In [5]:
class SoundDataset(Dataset):

    def __init__(self, meta, source_folder, transforms):

        self.source_folder = source_folder
        self.transforms = transforms
        self.meta = meta
        self.files = self.meta.path.unique()

    def __len__(self):
        return len(self.files)

    def __getitem__(self, i):

        sound = self.files[i]
        sound_df = self.meta.loc[self.meta.path == sound]
        
        noisy_df = sound_df.loc[sound_df.type == 'noisy'].iloc[0]
        clean_df = sound_df.loc[sound_df.type == 'clean'].iloc[0]

        noisy_path = os.path.join(self.source_folder, noisy_df.get('folder'), noisy_df.get('type'), str(noisy_df.get('id')), sound)
        clean_path = os.path.join(self.source_folder, clean_df.get('folder'), clean_df.get('type'), str(clean_df.get('id')), sound)
        
        noisy = np.expand_dims(np.load(noisy_path).astype(float), 2)
        clean = np.expand_dims(np.load(clean_path).astype(float), 2)

        augmented = self.transforms(image=noisy, mask=clean)

        return {'noisy': augmented['image'][0].unsqueeze(0).float(), 'clean': augmented['mask'][:, :, 0].unsqueeze(0).float()}
        

In [6]:
class LightningPredictor(nn.Module):

    def __init__(self, test_df, source_folder):
        super().__init__()
  
        self.unet = smp.Unet(encoder_name='resnet18', in_channels=1)
        self.loss = {'MSE': MeanSquaredError()}
        self.transforms = albu.Compose([albu.PadIfNeeded(480, 80),
                                        albu.RandomCrop(480, 80),
                                        albu.Resize(576, 96),
                                        ToTensorV2()])
        self.testset = SoundDataset(test_df, source_folder, self.transforms)
        self.testloader = DataLoader(self.testset, batch_size=1, shuffle=False)

    def calculate_metrics(self):

        MSE = list()

        with torch.no_grad():
            for batch in tqdm(self.testloader):
                noisy, gt_clean = batch['noisy'], batch['clean']
                pr_clean = self.unet(noisy)
                MSE.append(self.loss['MSE'](pr_clean, gt_clean).detach().numpy())

        return np.mean(MSE)

In [7]:
test_df = pd.read_csv('/content/meta_val.csv') # Specify test meta
model = LightningPredictor(test_df, '/content/data')
model.load_state_dict(torch.load('/content/gdrive/MyDrive/unet_11.pth')) # Specify weights
model.eval()
metrics = model.calculate_metrics()

Downloading: "https://download.pytorch.org/models/resnet18-5c106cde.pth" to /root/.cache/torch/hub/checkpoints/resnet18-5c106cde.pth


  0%|          | 0.00/44.7M [00:00<?, ?B/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

In [8]:
print('MSE =', metrics)

MSE = 0.027986113
