In [None]:
"""
Goal = build autoencoder for audio file
    1. Get nn to train
    2. Show that low error (and visually inspect) for the STFTs
    3. Do iSTFT and audially compare 

Lower input dimension of audio
Use low # of channels, and pooling/stride in the network 

Build out simple network, train on collab, slowly make more complex

"""

from sklearn.preprocessing import normalize
import torch
import torch.nn as nn
import pickle
import scipy.io.wavfile as wavfile
from scipy import signal
import torch.optim as optim
from torchvision.transforms import ToTensor
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
import time
from google.colab import drive
drive.mount('/content/drive')

TEST_LOSS_EPOCH_FREQ = 4
NUM_SPECTROGRAM_VISUALIZATION_PLOTS = 3
RUN = 2

NUM_EPOCHS = 5
BOTTLENECK_SIZE = 1000

MODELNAME = "SecondTry"
ROOT_FILENAME = 'Dataloader'
SEED = 1 
BATCH_SIZE = 20
DATA_PER_FILE = 50
DATA_FILE = pickle_file = '/content/drive/MyDrive/Dataloaders/' + ROOT_FILENAME + '_BS' + str(BATCH_SIZE) + '_DPF' + str(DATA_PER_FILE) + '_S' + str(SEED) + '.pkl'

class CustomDataset(Dataset):
    def __init__(self, X, Y):
        self.X = X
        self.Y = Y

    def __len__(self):
        return len(self.X)

    def __getitem__(self, index):
        x = self.X[index]
        y = self.Y[index]
        return x, y
    
with open(DATA_FILE, 'rb') as file:
    train_loader, test_loader = pickle.load(file)


class AutoEncoder(nn.Module):
    # input_shape probably 260 x 90, it is a 2d tuple
    # bottleneck size is just a number that is the size of 1d nodes between encoder and decoder
    def __init__(self, input_shape, bottleneck_size):
        super(AutoEncoder, self).__init__()
        self.input_shape = input_shape
        flattened_size = input_shape[0]*input_shape[1]*input_shape[2]
        max_channels = 32

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 16, stride=(1, 1), kernel_size=(3, 3), padding='same'),
            nn.LeakyReLU(0.01),
            # nn.MaxPool2d(kernel_size=(2, 2)),
            nn.Conv2d(16, max_channels, stride=(1, 1), kernel_size=(3, 3), padding='same'),
            nn.LeakyReLU(0.01),
            # nn.MaxPool2d(kernel_size=(2, 2),
            nn.Flatten(),
            nn.Linear(max_channels * flattened_size, 1000)
        )
        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(1000, max_channels * flattened_size),
            nn.Unflatten(1, (max_channels, 260, 90)),
            nn.LeakyReLU(0.01),
            # nn.MaxUnpool2d(),
            nn.ConvTranspose2d(32, 16, stride=(1, 1), kernel_size=(3, 3), padding=1),
            # nn.MaxUnpool2d(),
            nn.LeakyReLU(0.01),
            nn.ConvTranspose2d(16, 1, stride=(1, 1), kernel_size=(3, 3), padding=1),
            nn.LeakyReLU(0.01),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x


# Initialize the Autoencoder
autoencoder = AutoEncoder((260, 90, 1), BOTTLENECK_SIZE)

# Define the loss function
reconstruction_loss_fn = nn.MSELoss()

# Define the optimizer
optimizer = optim.Adam(autoencoder.parameters(), lr=0.01)


# Training loop
num_epochs = NUM_EPOCHS
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('device is ', device)
autoencoder.to(device)


print("hello")
start = time.time()

for epoch in range(num_epochs):
    total_loss = 0.0

    # Iterate over the training dataset
    for X_batch, _ in tqdm(train_loader):
        y_batch = X_batch # we are doing an autoencoder so want to get the same thing
        print(X_batch.shape)
        X_batch = X_batch.view(BATCH_SIZE, -1, 260, 90).to(device)
        print(X_batch.shape)
        y_batch = y_batch.to(device)

        # Forward pass
        X_pred = autoencoder(X_batch)

        # Compute the loss
        constr_loss = reconstruction_loss_fn(X_pred, X_batch)
        loss = constr_loss
        total_loss += loss.item()

        # Backward pass and optimization step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # Print average loss for the epoch
    average_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {average_loss}")

    # Calculate and print test loss every TEST_LOSS_EPOCH_FREQ epochs
    if (epoch+1) % TEST_LOSS_EPOCH_FREQ == 0:
        autoencoder.eval()
        with torch.no_grad():
            test_loss = 0.0
            for X_test_batch, _ in test_loader:
                X_test_batch = X_test_batch.view(BATCH_SIZE, -1, 260, 90).to(device)
                X_pred_test = autoencoder(X_test_batch)
                test_loss += reconstruction_loss_fn(X_pred_test, X_test_batch).item()
            print(f"Epoch {epoch+1}/{num_epochs}, Test Loss: {test_loss / len(test_loader)}")
        autoencoder.train()

    end = time.time()
    dur = end - start
    if dur > 3600:
        print(f"Elapsed {epoch} number of epochs.")
        num_epochs = epoch
        break

    weight_file_name = f'model_{epoch}_{len(train_loader)}'
    torch.save(autoencoder.state_dict(), weight_file_name)


# Calculate final test loss
autoencoder.eval()
with torch.no_grad():
    test_loss = 0.0
    for X_test_batch, _ in test_loader:
        X_test_batch = X_test_batch.view(BATCH_SIZE, -1, 260, 90).to(device)
        X_pred_test = autoencoder(X_test_batch)
        test_loss += reconstruction_loss_fn(X_pred_test, X_test_batch).item()
    print(f"Final Test Loss: {test_loss / len(test_loader)}")
autoencoder.train()


In [None]:
import matplotlib.pyplot as plt
import os
import numpy as np 

if not os.path.exists('autoencoder'):
    os.makedirs('autoencoder')
NUM_SPECTROGRAM_VISUALIZATION_PLOTS = 10
autoencoder.eval()
i = -1
with torch.no_grad():
    for batch_number, (X_test_batch, _) in enumerate(test_loader):
        for k in range(X_test_batch.shape[0]):
          i += 1
          if i >= NUM_SPECTROGRAM_VISUALIZATION_PLOTS:
              break
          X_test_batch = X_test_batch.view(BATCH_SIZE, -1, 260, 90).to(device)
          X_pred_test = autoencoder(X_test_batch)
          print('The values for i = ', i, '  ', X_pred_test)

          # Select the ith example from the batch
          original = X_test_batch[0].cpu().numpy()
          reconstructed = X_pred_test[0].cpu().numpy()

          # Squeeze the arrays to remove single-dimensional entries
          original = np.squeeze(original)
          reconstructed = np.squeeze(reconstructed)

          fig, axs = plt.subplots(1, 2, figsize=(10, 5))

          axs[0].imshow(original, aspect='auto', cmap='jet')
          axs[0].set_title('Original')

          axs[1].imshow(reconstructed, aspect='auto', cmap='jet')
          axs[1].set_title('Reconstructed')


          datafile = 'autoencoder/' 'SET_' + MODELNAME + '_BS' + str(BATCH_SIZE) + '_DPF' + str(DATA_PER_FILE) + '_S' + str(SEED) + f'example_{i}' + '.png'
          plt.savefig(datafile)
          plt.close(fig)

autoencoder.train()