# 1. Dependencies

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
%cd /content/drive/MyDrive/AICovidVN

In [None]:
!pip install torch torchvision torchaudio

# 2. Packages

In [None]:
import torch, torchvision
from torchvision import datasets, models, transforms
import torch.nn as nn
import torch.optim as optim
import time
import torchaudio.transforms as T
from torch.utils.data import DataLoader
import torch.utils.data.dataset as dataset
import pandas as pd
import os
import torchaudio
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# 3. Dataloader

In [None]:
class AICovidVNDataset(dataset.Dataset):
    def __init__(self, csv_file, root_dir, transform=None):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.aicovidvn_data = pd.read_csv(csv_file)
        self.file_path = self.aicovidvn_data['file_path'].values
        self.assessment_result = self.aicovidvn_data['assessment_result'].values
        self.root_dir = root_dir

        self.transform = transform

    def __len__(self):
        return len(self.aicovidvn_data)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        SAMPLE_WAV_PATH = os.path.join(self.root_dir, self.file_path[idx])
        waveform, sample_rate = torchaudio.load(SAMPLE_WAV_PATH)
        waveform = waveform.to(device)
        if self.transform:
            waveform = self.transform(waveform)
        target = torch.tensor(self.assessment_result[idx], dtype=torch.float32, device=device)
        sample = (waveform, target)
        return sample

# 4. Training

### 4.1. Applying MFCC transforms to the Data

In [None]:
mfcc_transform = T.MFCC(
    sample_rate=8000,
    n_mfcc=256,
    melkwargs={
        'n_fft': 2048,
        'n_mels': 256,
        'hop_length': 512,
        'mel_scale': 'htk',
    }
)

### 4.2. Load data

In [None]:
train_dataset = AICovidVNDataset(csv_file='./Data/aicv115m_public_train/metadata_train_challenge.csv',
                                 root_dir='./Data/aicv115m_public_train/train_audio_files_8k',
                                 transform=transforms.Compose([
                                     mfcc_transform.to(device),
                                     transforms.Resize(256).to(device),
                                     transforms.CenterCrop(224).to(device)
                                 ]))
lengths = [int(len(train_dataset) * 0.8), len(train_dataset) - int(len(train_dataset) * 0.8)]
train_data, test_data = torch.utils.data.random_split(dataset=train_dataset, lengths=lengths,
                                                      generator=torch.Generator().manual_seed(42))



batch_size = 64
train_data_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, drop_last=True)
test_data_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True, drop_last=False)

train_data_size = len(train_data)
test_data_size = len(test_data)

### 4.3. Model

In [None]:
# Load pretrained ResNet50 Model
resnet50 = models.resnet50(pretrained=False)
resnet50 = resnet50.to(device)
resnet50.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
# Change the final layer of ResNet50 Model for Transfer Learning
fc_inputs = resnet50.fc.in_features

resnet50.fc = nn.Sequential(
    nn.Linear(fc_inputs, 256),
    nn.ReLU(),
    nn.Dropout(0.6),
    nn.Linear(256, 1),
    nn.Sigmoid()
)

# Convert model to be used on GPU
resnet50 = resnet50.to(device)

# Define Optimizer and Loss Function
loss_func = nn.BCELoss()
num_epochs = 500
optimizer = optim.Adam(resnet50.parameters(), lr=0.01)
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10, 20], gamma=0.1)
trained_model, history, best_epoch = train_and_validate(resnet50, loss_func, optimizer, scheduler, num_epochs)
torch.save(history, 'history.pt')