In [None]:
from __future__ import print_function, division
import os
import torch
import numpy as np
from barbar import Bar
from torch.utils.data import Dataset, DataLoader

import torch
from torch.autograd import Variable

In [None]:
# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

In [None]:
# Define Dataset Class

class TrivialDataset(Dataset):
	"""The S_trivial set: only one bird vocalizing, no radio noise."""

	def __init__(self, all_facc, all_mic):
		"""
		Args:
			root_dir (string): Directory with all the extracted samples.
		"""
		self.all_facc = all_facc
		self.all_mic = all_mic

	def __len__(self):
		return len(self.all_facc) # or use length of mic recordings, they are the same

	def __getitem__(self, idx):
		if torch.is_tensor(idx):
			idx = idx.tolist()

		facc = self.all_facc[idx]
		mic = self.all_mic[idx]

		sample = {'facc': facc, 'mic': mic}

		return sample

In [None]:
# Define linear regression model, take in input/output size at initialization
class LinearRegression(torch.nn.Module):
    def __init__(self, inputSize, outputSize):
        super(LinearRegression, self).__init__()
        self.linear = torch.nn.Linear(inputSize, outputSize)

    def forward(self, x):
        out = self.linear(x)
        return out

In [None]:
# Some parameters
DATA_SIZE = 333
TRAIN_SIZE = 300
max_epochs = 50
LR = 1e-4 # learning rate, can be changed
BS = 16 # batch size, can be changed

DIMF = 513
DIMT = 372

n_forward = 1 # timesteps forward from current one taken as input to fit one step of output
n_backward = 1 # timesteps backward from current one taken as input to fit one step of output
input_dim = DIMF*(n_forward + n_backward + 1) # number of magnitudes taken as input (in acc_female)
output_dim = DIMF # number of magnitudes to fit as output (microphone)

In [None]:
model = LinearRegression(input_dim, output_dim)
# For GPU only
# if torch.cuda.is_available():
#     model = LinearRegression(input_dim, output_dim).cuda(0)
criterion = torch.nn.MSELoss()
#optimizer = torch.optim.SGD(model.parameters(), lr=LR)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=5e-4)

In [None]:
# Check location of S_trivial storage
fileDir = os.path.dirname(os.path.realpath('__file__'))
print(fileDir)

In [None]:
# Load separate files and combine alltogether
# The files are in the form of compressed numpy arrays, each when loaded first comes as a dictionary
# Keys in the dictionary: 'mic' & 'female_acc' are the names of data channel

X = []
Y = []

for i in range(DATA_SIZE):
    filename = os.path.join(fileDir, 'datat/trivial_sample_' + str(i) +'.npz')
    specs = np.load(filename)
    micr = specs['mic']
    facc = specs['female_acc']
    X.append(facc)
    Y.append(micr)
    
X = np.array(X)
Y = np.array(Y)

In [None]:
indices = np.random.permutation(np.arange(0, DATA_SIZE))
train_indices = indices[:TRAIN_SIZE]
valid_indices = indices[TRAIN_SIZE:]
trainset = TrivialDataset(X[train_indices], Y[train_indices])
trainloader = DataLoader(trainset, batch_size=BS, shuffle=True)
validset = TrivialDataset(X[valid_indices], Y[valid_indices])
validloader = DataLoader(validset, batch_size=1, shuffle=False)

datasets, dataloaders = {'train': trainset, 'valid':validset}, {'train': trainloader, 'valid':validloader}

In [None]:
# import early stopping
from trainutils import EarlyStopping

In [None]:
train_loss = []
valid_loss = []
phases = ['train', 'valid']

# initialize early stopping
early_stopping = EarlyStopping(patience=5, verbose=True)

for e in range(max_epochs):
    epoch_loss_train = []
    epoch_loss_valid = []
    
    for phase in phases:
        if phase == 'train':
            model.train()
        else:
            model.eval()
        
        for i_batch, sample_batched in enumerate(Bar(dataloaders[phase])):
            # data is in shape (batchsize, freq_magnitude_size, timesteps)
            batched_acc = sample_batched['facc']
            batched_mic = sample_batched['mic']
                
            for timestep in range(DIMT):
                segment_acc = batched_acc[:, :, timestep]
                segment_mic = batched_mic[:, :, timestep]
                
                if n_backward == 1:
                    curr_pos = (timestep - 1) % DIMT
                    segment_acc = np.concatenate((batched_acc[:, :, curr_pos], segment_acc), axis=1)
                if n_forward == 1:
                    curr_pos = (timestep + 1) % DIMT
                    segment_acc = np.concatenate((segment_acc, batched_acc[:, :, curr_pos]), axis=1)
            
                series = Variable(torch.from_numpy(segment_acc).float())
                target = Variable(segment_mic.float())
            
                outputs = model(series)
                loss = criterion(outputs, target)
                
                if phase == 'train':
                    loss.backward()
                    optimizer.step()
                    optimizer.zero_grad()
                    epoch_loss_train.append(loss.data)
                else:
                    epoch_loss_valid.append(loss.data)
    
    
    # print progress metric
    eploss_train = torch.mean(torch.stack(epoch_loss_train))
    train_loss.append(eploss_train)
    eploss_valid = torch.mean(torch.stack(epoch_loss_valid))
    valid_loss.append(eploss_valid)
    
    progress_str = '[epoch {}/{}] - Train Loss: {:.4f} Valid Loss: {:.4f}'.format(e + 1, max_epochs, eploss_train, eploss_valid)
    print(progress_str)
    
    # early_stopping needs the validation loss to check if it has decresed, 
    # and if it has, it will make a checkpoint of the current model
    early_stopping(eploss_valid, model)
    if early_stopping.early_stop:
        print("Early stopping")
        break
    
# load the last checkpoint with the best model
model.load_state_dict(torch.load('checkpoint.pt'))