In [1]:
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import os
# os.environ['CUDA_LAUNCH_BLOCKING']='1'
import random
import torchaudio
from torchaudio import transforms

from torchsummary import summary
import gc
from sklearn.metrics import mean_squared_error
from sklearn.metrics import mean_absolute_error
from sklearn.metrics import r2_score


from torchvision import models
from tqdm import tqdm

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = 'cpu'

config = {
    'epochs': 200,
    'batch_size' : 32,
    'context' : 48,
    'learning_rate' : 0.001,
    'architecture' : 'very-low-cutoff'
}

## Dataloader

In [3]:
class AudioDataset(torch.utils.data.Dataset):

    def __init__(self, data_path, am_path, gender = "female", phoneme_idx = 4, am_idx = 1, MAX_LEN = 128, partition = "train"):
        """
        :param data_path: the root path of phonemes
        :param am_path: the path of am (.csv)
        :param gender: female or male
        :param phoneme_idx: the phoneme index
        :param am_idx: the index of target AM, should be int within [1, 96]
        :param MAX_LEN: max length of voice seq, if less, pad, if more, slice
        :param partition: train / val1 / val2 / test
        """

        self.MAX_LEN = MAX_LEN
        # get phoneme list
        self.target_phoneme_path = "/".join([data_path, gender, str(int(phoneme_idx))])
        phoneme_list = sorted(os.listdir(self.target_phoneme_path))
        length = len(phoneme_list)
        # if partition == "train":
        #     self.phoneme_list = phoneme_list[:int(0.7 * length)]
        # elif partition == "val1":
        #     self.phoneme_list = phoneme_list[int(0.7 * length):int(0.8 * length)]
        # elif partition == "val2":
        #     self.phoneme_list = phoneme_list[int(0.8 * length):int(0.9 * length)]
        # elif partition == "test":
        #     self.phoneme_list = phoneme_list[int(0.9 * length):]
            
        if partition == "train":
            self.phoneme_list = phoneme_list[:int(0.7 * length)]
        elif partition == "val1":
            self.phoneme_list = phoneme_list[int(0.7 * length):]


        self.length = len(self.phoneme_list)

        # get_am data
        am_data = pd.read_csv(am_path)
        self.am_data = am_data[["ID", str(am_idx)]]

    def __len__(self):
        return self.length

    def spectro_gram(self, sig, n_mels=64, n_fft=1024, hop_len=None):
        top_db = 80

        # spec has shape [channel, n_mels, time], where channel is mono, stereo etc
        spec = transforms.MelSpectrogram(44100, n_fft=n_fft, hop_length=hop_len, n_mels=n_mels)(sig)

        # Convert to decibels
        spec = transforms.AmplitudeToDB(top_db=top_db)(spec)
        return spec

    def __getitem__(self, ind):
        item_filename = self.phoneme_list[ind]
        item_full_path = "/".join([self.target_phoneme_path, item_filename])
        phoneme = np.load(item_full_path)

        person_id = int(item_filename.split("_")[0][1:7])
        try:
            target_am = self.am_data[self.am_data["ID"] == person_id].values[0][-1]
        except:
            print("person id =", person_id)
            target_am = 0.

        # padding
        phoneme = torch.tensor(phoneme, dtype=torch.float) #.reshape(1, -1)
        # apply mel transform
        phoneme = self.spectro_gram(phoneme)

        std, mean = torch.std_mean(phoneme, unbiased=False, dim=0)
        phoneme = (phoneme - mean) / (std + 1e-6)

        if len(phoneme[0]) < MAX_LEN:
            phoneme = np.pad(phoneme, ((0, 0), (0, MAX_LEN - len(phoneme[0]))), 'symmetric')
            phoneme = torch.from_numpy(phoneme)
        else:
            phoneme = phoneme[:, :MAX_LEN]
        # phoneme = torch.from_numpy(phoneme)
        ##################################################################
        phoneme.unsqueeze_(0)
        ##################################################################
        target_am = torch.tensor(target_am).to(torch.float32)
        
        return phoneme, target_am


In [4]:
# class AudioDataset(torch.utils.data.Dataset):

#     def __init__(self, data_path, am_path, gender = "female", phoneme_idx = 4, am_idx = 1, MAX_LEN = 44100 * 2, partition = "train"):
#         """
#         :param data_path: the root path of phonemes
#         :param am_path: the path of am (.csv)
#         :param gender: female or male
#         :param phoneme_idx: the phoneme index
#         :param am_idx: the index of target AM, should be int within [1, 96]
#         :param MAX_LEN: max length of voice seq, if less, pad, if more, slice
#         :param partition: train / val1 / val2 / test
#         """

#         self.MAX_LEN = MAX_LEN
#         # get phoneme list
#         self.target_phoneme_path = "/".join([data_path, gender, str(int(phoneme_idx))])
#         phoneme_list = sorted(os.listdir(self.target_phoneme_path))
#         length = len(phoneme_list)
#         if partition == "train":
#             self.phoneme_list = phoneme_list[:int(0.7 * length)]
#         elif partition == "val1":
#             self.phoneme_list = phoneme_list[int(0.7 * length):int(0.8 * length)]
#         elif partition == "val2":
#             self.phoneme_list = phoneme_list[int(0.8 * length):int(0.9 * length)]
#         elif partition == "test":
#             self.phoneme_list = phoneme_list[int(0.9 * length):]

#         self.length = len(self.phoneme_list)

#         # get_am data
#         am_data = pd.read_csv(am_path)
#         self.am_data = am_data[["ID", str(am_idx)]]

#     def __len__(self):
#         return self.length

#     def spectro_gram(self, sig, n_mels=64, n_fft=1024, hop_len=None):
#         top_db = 80

#         # spec has shape [channel, n_mels, time], where channel is mono, stereo etc
#         spec = transforms.MelSpectrogram(44100, n_fft=n_fft, hop_length=hop_len, n_mels=n_mels)(sig)

#         # Convert to decibels
#         spec = transforms.AmplitudeToDB(top_db=top_db)(spec)
#         return spec

#     def padding(self, phoneme):
#         if len(phoneme) < self.MAX_LEN:
#             pad_begin_len = random.randint(0, self.MAX_LEN - len(phoneme))
#             pad_end_len = self.MAX_LEN - len(phoneme) - pad_begin_len

#             # Pad with 0s
#             pad_begin = np.zeros(pad_begin_len)
#             pad_end = np.zeros(pad_end_len)

#             phoneme = np.concatenate((pad_begin, phoneme, pad_end), 0)
#         else:
#             phoneme = phoneme[:self.MAX_LEN]
#         return phoneme

#     def __getitem__(self, ind):
#         item_filename = self.phoneme_list[ind]
#         item_full_path = "/".join([self.target_phoneme_path, item_filename])
#         phoneme = np.load(item_full_path)

#         person_id = int(item_filename.split("_")[0][1:7])
#         try:
#             target_am = self.am_data[self.am_data["ID"] == person_id].values[0][-1]
#         except:
#             print("person id =", person_id)
#             target_am = 0.

#         # padding
#         phoneme = self.padding(phoneme)
#         phoneme = torch.tensor(phoneme, dtype=torch.float) #.reshape(1, -1)
#         # apply mel transform
#         phoneme = self.spectro_gram(phoneme)
        
#         ################################### Normalization ######################################
#         std, mean = torch.std_mean(phoneme, unbiased=False, dim=0)
#         phoneme = (phoneme - mean) / (std + 1e-6)
#         # print(phoneme)
#         # ####################### convert phoneme from float32 to float64 ##################
#         # phoneme = phoneme.to(torch.float64)
#         # ##################################################################################

#         target_am = torch.tensor(target_am)
        
        
#         ####################################################################################
#         target_am = target_am.to(torch.float32)
#         # print(target_am)
#         ####################################################################################
        
#         # jia yi ge gui yi hua (phoneme)
        
#         return phoneme, target_am

In [5]:
# default_root_path = "./penstate_data/extract_phoneme"
default_root_path = "./penstate_data/extract_phoneme_processed"
gender = "female"
phoneme_idx = 10
# am_path = "./penstate_data/AMs_unnormalized.csv"
am_path = "./penstate_data/AMs_final.csv"

am_idx = 13
MAX_LEN = 32 # TODO: may be too small
batch_size = 64
batch_size = config['batch_size']
train_data = AudioDataset(data_path=default_root_path,
                            am_path = am_path,
                            gender = gender, phoneme_idx = phoneme_idx, am_idx = am_idx, MAX_LEN = MAX_LEN, partition="train")

######################################################################################################################################
val_data = AudioDataset(data_path=default_root_path,
                            am_path = am_path,
                            gender = gender, phoneme_idx = phoneme_idx, am_idx = am_idx, MAX_LEN = MAX_LEN, partition="val1")
test_data = AudioDataset(data_path=default_root_path,
                            am_path = am_path,
                            gender = gender, phoneme_idx = phoneme_idx, am_idx = am_idx, MAX_LEN = MAX_LEN, partition="val1")
######################################################################################################################################

train_loader = torch.utils.data.DataLoader(train_data, num_workers=0,
                                               batch_size=batch_size, shuffle=True)

######################################################################################################################################
val_loader = torch.utils.data.DataLoader(val_data, num_workers=0,
                                               batch_size=batch_size)
test_loader = torch.utils.data.DataLoader(test_data, num_workers=0,
                                               batch_size=batch_size)
######################################################################################################################################

print("Batch size: ", config['batch_size'])

print("Train dataset samples = {}, batches = {}".format(train_data.__len__(), len(train_loader)))
print("Validation dataset samples = {}, batches = {}".format(val_data.__len__(), len(val_loader)))
print("Test dataset samples = {}, batches = {}".format(test_data.__len__(), len(test_loader)))

Batch size:  32
Train dataset samples = 3067, batches = 96
Validation dataset samples = 1315, batches = 42
Test dataset samples = 1315, batches = 42


In [6]:
print("Batch size: ", batch_size)
print("Train dataset samples = {}, batches = {}".format(train_data.__len__(), len(train_loader)))

# for i, data in enumerate(train_loader):
#     phoneme, target_am = data
#     print(phoneme.shape, target_am.shape)
#     ##########################################
#     # print(phoneme.dtype, target_am.dtype)
#     ##########################################
#     # break

Batch size:  32
Train dataset samples = 3067, batches = 96


## Model

## Model 1: CNN

In [7]:
# class CNNNetwork(nn.Module):

#     def __init__(self):
#         super().__init__()
#         self.conv1=nn.Sequential(
#             nn.Conv2d(in_channels=1,out_channels=16,kernel_size=3,stride=1,padding=2),
#             nn.BatchNorm2d(num_features=16),
#             nn.ReLU(),
#             nn.MaxPool2d(kernel_size=2)
#         )
#         self.conv2=nn.Sequential(
#             nn.Conv2d(in_channels=16,out_channels=32,kernel_size=3,stride=1,padding=2),
#             nn.BatchNorm2d(num_features=32),
#             nn.ReLU(),
#             nn.MaxPool2d(kernel_size=2)
#         )
#         self.conv3=nn.Sequential(
#             nn.Conv2d(in_channels=32,out_channels=64,kernel_size=3,stride=1,padding=2),
#             nn.BatchNorm2d(num_features=64),
#             nn.ReLU(),
#             nn.MaxPool2d(kernel_size=2)
#         )
#         self.conv4=nn.Sequential(
#             nn.Conv2d(in_channels=64,out_channels=128,kernel_size=3,stride=1,padding=2),
#             nn.BatchNorm2d(num_features=128),
#             nn.ReLU(),
#             nn.MaxPool2d(kernel_size=2)
#         )
#         self.flatten=nn.Flatten()
#         self.linear1=nn.Linear(in_features=128*15,out_features=512)
#         self.linear2=nn.Linear(in_features=512,out_features=128)
#         self.linear3=nn.Linear(in_features=128,out_features=1)
#         # self.linear4=nn.Linear(in_features=1024,out_features=256)
#         # self.linear5=nn.Linear(in_features=256,out_features=128)
#         # self.linear6=nn.Linear(in_features=128,out_features=1)
#         # self.output=nn.Sigmoid()
#         self.pooling = nn.AdaptiveAvgPool2d((1,1))
#         self.output = nn.Tanh()
    
#     def forward(self,input_data):
#         # add one dimension
#         # input_data.unsqueeze_(1)
#         x=self.conv1(input_data)
#         x=self.conv2(x)
#         x=self.conv3(x)
#         x=self.conv4(x)
        
#         # x = self.pooling(x)
#         # print("After conv: ", x.shape)
#         x=self.flatten(x)
#         # print("After flatten: ", x.shape)
#         x=self.linear1(x)
#         # print("After linear: ",x.shape)
#         x=self.linear2(x)
#         # x=self.linear3(x)
#         # x=self.linear4(x)
#         # x=self.linear5(x)
        
#         logits=self.linear3(x)
#         output=self.output(logits)
#         # print(output)
#         return output

In [8]:
# model = CNNNetwork().to(device)
# phoneme, AM = next(iter(train_loader))
# # # summary(model,(64, 259)) # After conv: torch.Size([2, 128, 5, 18])
# summary(model, phoneme.to(device))

## Model 2: Resnet50

In [9]:
# model = models.resnet50(weights=None).to(device) # may be too weak
model = models.resnet152(weights=None).to(device) # may be too weak

num_features = model.fc.in_features
model.fc = nn.Linear(num_features, 1).to(device)
# print(model.conv1)
model.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False).to(device)

In [10]:
model = model.to(device)
phoneme, AM = next(iter(train_loader))
# # summary(model,(64, 259)) # After conv: torch.Size([2, 128, 5, 18])
# summary(model, phoneme.to(device))

## Model 3: DenseNet

In [11]:
model = models.densenet121(weights=None).to(device) # may be too weak

# num_features = model.fc.in_features
# model.fc = nn.Linear(num_features, 1).to(device)
# print(model)
model.features.conv0 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
model.classifier = nn.Linear(in_features=1024, out_features=1, bias=True)

# print(model)

In [12]:
model = model.to(device)
phoneme, AM = next(iter(train_loader))
# # summary(model,(64, 259)) # After conv: torch.Size([2, 128, 5, 18])
# summary(model, phoneme.to(device))

## Model 4: EfficientNetV2

MAE: 0.69??

In [13]:
model = models.efficientnet_v2_s(weights=None).to(device) # may be too weak
# print(model)
model.features[0][0] = nn.Conv2d(1, 24, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
model.classifier = nn.Linear(in_features=1280, out_features=1, bias=True)
# print(model)

In [14]:
model = model.to(device)
phoneme, AM = next(iter(train_loader))
# # summary(model,(64, 259)) # After conv: torch.Size([2, 128, 5, 18])
summary(model, phoneme.to(device))

Layer (type:depth-idx)                        Output Shape              Param #
├─Sequential: 1-1                             [-1, 1280, 2, 1]          --
|    └─Conv2dNormActivation: 2-1              [-1, 24, 32, 16]          --
|    |    └─Conv2d: 3-1                       [-1, 24, 32, 16]          216
|    |    └─BatchNorm2d: 3-2                  [-1, 24, 32, 16]          48
|    |    └─SiLU: 3-3                         [-1, 24, 32, 16]          --
|    └─Sequential: 2-2                        [-1, 24, 32, 16]          --
|    |    └─FusedMBConv: 3-4                  [-1, 24, 32, 16]          5,232
|    |    └─FusedMBConv: 3-5                  [-1, 24, 32, 16]          5,232
|    └─Sequential: 2-3                        [-1, 48, 16, 8]           --
|    |    └─FusedMBConv: 3-6                  [-1, 48, 16, 8]           25,632
|    |    └─FusedMBConv: 3-7                  [-1, 48, 16, 8]           92,640
|    |    └─FusedMBConv: 3-8                  [-1, 48, 16, 8]           92,640
|

Layer (type:depth-idx)                        Output Shape              Param #
├─Sequential: 1-1                             [-1, 1280, 2, 1]          --
|    └─Conv2dNormActivation: 2-1              [-1, 24, 32, 16]          --
|    |    └─Conv2d: 3-1                       [-1, 24, 32, 16]          216
|    |    └─BatchNorm2d: 3-2                  [-1, 24, 32, 16]          48
|    |    └─SiLU: 3-3                         [-1, 24, 32, 16]          --
|    └─Sequential: 2-2                        [-1, 24, 32, 16]          --
|    |    └─FusedMBConv: 3-4                  [-1, 24, 32, 16]          5,232
|    |    └─FusedMBConv: 3-5                  [-1, 24, 32, 16]          5,232
|    └─Sequential: 2-3                        [-1, 48, 16, 8]           --
|    |    └─FusedMBConv: 3-6                  [-1, 48, 16, 8]           25,632
|    |    └─FusedMBConv: 3-7                  [-1, 48, 16, 8]           92,640
|    |    └─FusedMBConv: 3-8                  [-1, 48, 16, 8]           92,640
|

## Model 5: MobileNetV3

##### MAE: 0.68??

In [15]:
model = models.mobilenet_v3_large(weights=None).to(device)
# print(model)
model.features[0][0] = nn.Conv2d(1, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
model.classifier[3] = nn.Linear(in_features=1280, out_features=1, bias=True)
# print(model)

In [16]:
model = model.to(device)
phoneme, AM = next(iter(train_loader))
# # summary(model,(64, 259)) # After conv: torch.Size([2, 128, 5, 18])
# summary(model, phoneme.to(device))

## Model 6: ShuffleNetV2

In [17]:
model = models.shufflenet_v2_x1_0(weights=None).to(device)
# print(model)
model.conv1[0] = nn.Conv2d(1, 24, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
model.fc = nn.Linear(in_features=1024, out_features=1, bias=True)
# print(model)

In [18]:
model = model.to(device)
phoneme, AM = next(iter(train_loader))
# # summary(model,(64, 259)) # After conv: torch.Size([2, 128, 5, 18])
# summary(model, phoneme.to(device))

## Model 7: SqueezeNet

In [19]:
model = models.squeezenet1_1(weights=None).to(device)
# print(model)
model.features[0] = nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(2, 2))
model.classifier[1] = nn.Conv2d(512, 1, kernel_size=(1, 1), stride=(1, 1))
# print(model)

In [20]:
model = model.to(device)
phoneme, AM = next(iter(train_loader))
# # summary(model,(64, 259)) # After conv: torch.Size([2, 128, 5, 18])
# summary(model, phoneme.to(device))

## Model 8: MnasNet

##### MAE=0.68 ok???

In [21]:
model = models.mnasnet1_0(weights=None).to(device)
# print(model)
model.layers[0] = nn.Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
model.classifier[1] = nn.Linear(in_features=1280, out_features=1, bias=True)
# print(model)

In [22]:
model = model.to(device)
phoneme, AM = next(iter(train_loader))
# # summary(model,(64, 259)) # After conv: torch.Size([2, 128, 5, 18])
# summary(model, phoneme.to(device))

## Model 9: Wide ResNet

In [23]:
model = models.mnasnet1_0(weights=None).to(device)
# print(model)
model.layers[0] = nn.Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
model.classifier[1] = nn.Linear(in_features=1280, out_features=1, bias=True)
# print(model)


In [24]:
model = model.to(device)
phoneme, AM = next(iter(train_loader))
# # summary(model,(64, 259)) # After conv: torch.Size([2, 128, 5, 18])
# summary(model, phoneme.to(device))

# Train and eval

In [25]:
torch.cuda.empty_cache()
gc.collect()

352

In [26]:
criterion = torch.nn.MSELoss() #Defining Loss function 
optimizer = torch.optim.Adam(model.parameters(), lr=config['learning_rate']) #Defining Optimizer
# optimizer = torch.optim.SGD(model.parameters(), lr=config['learning_rate'], momentum=0.9)
# scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=0.0001, last_epoch=-1)
# scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[35,40,45,50,60,65,70,90,110,150,170,180], gamma=0.5) # add learning rate scheduler

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=(len(train_loader) * config['epochs']))

In [27]:
def train(model, optimizer, criterion, dataloader):

    model.train()
    train_loss = 0.0 #Monitoring Loss
    
    #########################################################
    # AM_true_list = []
    # AM_pred_list = []
    #########################################################
    
    for iter, (phoneme, AM) in enumerate(dataloader):
        scheduler.step()
        ### Move Data to Device (Ideally GPU)
        phoneme = phoneme.to(device)
        AM = AM.to(device)

        ### Forward Propagation
        preds_AM = model(phoneme)

        ### Loss Calculation
        # print(AM.shape)
        preds_AM = torch.squeeze(preds_AM)
        # print(preds_AM)
        # print(preds_AM.shape)model = models.shufflenet_v2_x1_0(weights=None).to(device)
        loss = criterion(preds_AM, AM)
        train_loss += loss.item()
        
        #########################################################
        ### Store Pred and True Labels
        # AM_pred_list.extend(preds_AM.cpu().tolist())
        # AM_true_list.extend(AM.cpu().tolist())
        #########################################################

        ### Initialize Gradients
        optimizer.zero_grad()

        ### Backward Propagation
        loss.backward()

        ### Gradient Descent
        optimizer.step()
        # if iter % 20 == 0:
        #     print("iter =", iter, "loss =",loss.item())
    train_loss /= len(dataloader)
    print("Learning rate = ", scheduler.get_last_lr()[0])
    print("Train loss = ", train_loss)
    
    #########################################################
    # print(AM_pred_list)
    # print(AM_true_list)
    # print(len(AM_pred_list))
    # print(len(AM_true_list))
    # accuracy = mean_squared_error(AM_pred_list, AM_true_list)
    # print("Train MSE accuracy: ", accuracy)
    #########################################################
    
    # scheduler.step() # add schedule learning rate
    return train_loss

In [28]:
def eval(model, dataloader):

    model.eval() # set model in evaluation mode

    AM_true_list = []
    AM_pred_list = []

    for i, data in enumerate(dataloader):

        phoneme, AM = data
        ### Move data to device (ideally GPU)
        phoneme, AM = phoneme.to(device), AM.to(device) 

        with torch.inference_mode(): # makes sure that there are no gradients computed as we are not training the model now
            ### Forward Propagation
            ### Get Predictions
            predicted_AM = model(phoneme)
            # print(predicted_AM)
        
        ### Store Pred and True Labels
        AM_pred_list.extend(predicted_AM.cpu().tolist())
        AM_true_list.extend(AM.cpu().tolist())
        
        # Do you think we need loss.backward() and optimizer.step() here?
    
        del phoneme, AM, predicted_AM
        torch.cuda.empty_cache()

    ###############################################################################################
    # print(AM_pred_list[1000:3100])
    # print(AM_true_list)
    # print(len(AM_pred_list))
    # print(len(AM_true_list))
    ###############################################################################################
    
    # print("Number of equals between two list: ", sum(a == b for a,b in zip(AM_pred_list, AM_true_list)))
    
    ### Calculate Accuracy
    MSE = mean_squared_error(AM_pred_list, AM_true_list)
    r2_score_acc = r2_score(AM_pred_list, AM_true_list)
    MAE = mean_absolute_error(AM_pred_list, AM_true_list)
    print("Validation r2_score: ", r2_score_acc)
    print("Validation MAE: ", MAE)
    
    return MSE

# Experiment

In [None]:
# Iterate over number of epochs to train and evaluate your model
torch.cuda.empty_cache()

best_mse = 1.0 ### Monitor best accuracy in your run

for epoch in range(config['epochs']):
    print("\nEpoch {}/{}".format(epoch+1, config['epochs']))

    train_loss = train(model, optimizer, criterion, train_loader)
    MSE = eval(model, val_loader)

    print("\tTrain Loss: ", train_loss)
    print("\tValidation MSE: ", MSE)

    ### Save checkpoint if accuracy is better than your current best
    if MSE < best_mse:
        best_mse = MSE
    ### Save checkpoint with information you want
        torch.save({'epoch': epoch,
              'model_state_dict': model.state_dict(),
              'optimizer_state_dict': optimizer.state_dict(),
              'loss': train_loss,
              'learning rate': scheduler.get_last_lr()[0],
              'mse': MSE}, 
        './model_checkpoint.pth')


Epoch 1/200




Learning rate =  0.0009999383162408297
Train loss =  0.9854118501146635
Validation r2_score:  -16552928188587.42
Validation MAE:  0.7245533061875986
	Train Loss:  0.9854118501146635
	Validation MSE:  0.8776947049979213

Epoch 2/200
Learning rate =  0.000999753280182864
Train loss =  0.837199755012989
Validation r2_score:  -12157393210489.13
Validation MAE:  0.724489326079592
	Train Loss:  0.837199755012989
	Validation MSE:  0.8775509992818765

Epoch 3/200
Learning rate =  0.0009994449374809815
Train loss =  0.795664175413549
Validation r2_score:  -5230404930609.74
Validation MAE:  0.7189018611318145
	Train Loss:  0.795664175413549
	Validation MSE:  0.8653446138540797

Epoch 4/200
Learning rate =  0.0009990133642141313
Train loss =  0.7072865863641103
Validation r2_score:  -8497947508424.187
Validation MAE:  0.7221355557172584
	Train Loss:  0.7072865863641103
	Validation MSE:  0.8723710063550637

Epoch 5/200
Learning rate =  0.0009984586668665594
Train loss =  0.625629223883152
Validati

# Test

In [None]:
def test(model, test_loader):
  ### What you call for model to perform inference?
    model.eval()

  ### List to store predicted phonemes of test data
    test_predictions = []
    ground_truth = []

  ### Which mode do you need to avoid gradients?
    with torch.inference_mode():

        for i, data in enumerate(tqdm(test_loader)):

            phoneme, groundtruth_AM = data
            ### Move data to device (ideally GPU)
            phoneme, groundtruth_AM = phoneme.to(device), groundtruth_AM.to(device)         
          
            predicted_AM = model(phoneme)
            predicted_AM.squeeze_()
            # print(predicted_AM.shape)
            # print(groundtruth_AM.shape)

          ### How do you store predicted_phonemes with test_predictions? Hint, look at eval 
            test_predictions.extend(predicted_AM.cpu().tolist())
            ground_truth.extend(groundtruth_AM.cpu())
    
    # print(len(test_predictions))
    return test_predictions, ground_truth

In [None]:
predictions, ground_truth = test(model, test_loader)

In [None]:
### Create CSV file with predictions
with open("./phoneme%s"%phoneme_idx +  "_AM%s.csv"%am_idx, "w+") as f:
    f.write("person, label, prediction\n")
    for i in range(len(predictions)):
        f.write("{},{},{}\n".format(i, ground_truth[i], predictions[i]))

## 