In [None]:
from google.colab import drive
drive.mount("/content/gdrive")

import time
import os
import glob
import random
import json
import subprocess
import sys
import gc

import numpy as np
import matplotlib.pyplot as plt
from collections import OrderedDict
import xml.etree.ElementTree

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader

def set_deterministic():
    if torch.cuda.is_available():
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True
    torch.set_deterministic(True)
    
    
def set_all_seeds(seed):
    os.environ["PL_GLOBAL_SEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    def compute_accuracy(model_diff, model_conv, data_loader_diff, data_loader_conv, data_loader_base, device):
    model_diff.eval()
    model_conv.eval()
    numers = []
    denoms = []
    with torch.no_grad():
  
        for (features_diff, _), (features_conv, _), (_, targets_base) in zip(data_loader_diff, data_loader_conv, data_loader_base):

            targets_base = targets_base.to(device)

            features_diff = features_diff.to(device)
            features_conv = features_conv.to(device)

            recon = model_diff(features_diff) + model_conv(features_conv)

            numer = np.sum(np.square(np.array(targets_base.cpu()) - np.array(recon.cpu())))
            numers = numers + [numer]

            denom = np.sum(np.square(np.array(targets_base.cpu())))
            denoms = denoms + [denom]

    acc = (1 - (np.array(np.ravel(numers))/np.array(np.ravel(denoms))))

    return np.round(acc[0], 5)
    

def compute_epoch_loss_autoencoder(model_diff, model_conv, data_loader_diff, data_loader_conv, data_loader_base, loss_fn, device):
    model.eval()
    curr_loss, num_examples = 0., 0
    with torch.no_grad():
        for (features_diff, _), (features_conv, _), (_, targets_base) in zip(data_loader_diff, data_loader_conv, data_loader_base):

            features_diff = features_diff.to(device)
            features_conv = features_conv.to(device)
            targets_base = targets_base.to(device)

            predictions = model_diff(features_diff) + model_conv(features_conv)
            loss = loss_fn(predictions, targets_base, reduction='sum')
            num_examples += targets_base.size(0)
            curr_loss += loss

            features_diff = features_diff.to('cpu')
            features_conv = features_conv.to('cpu')
            targets_base = targets_base.to('cpu')

            del features_diff
            del features_conv
            del targets_base

        curr_loss = curr_loss / num_examples
        return curr_loss


def train_autoencoder_v1(num_epochs, model_diff, model_conv, optimizer_diff, optimizer_conv, 
                         train_loader_diff, train_loader_conv, train_loader_base, device, loss_fn=None, 
                         skip_epoch_stats=False,
                         save_model=None):
    
    log_dict = {'train_loss_per_batch': [],
                'train_loss_per_epoch': []}
    
    if loss_fn is None:
        loss_fn = F.mse_loss

    start_time = time.time()
    for epoch in range(num_epochs):

        model_diff.train()
        model_conv.train()

        batch_idx = 0

        for (features_diff, _), (features_conv, _), (_, targets_base) in zip(train_loader_diff, train_loader_conv, train_loader_base):

            batch_idx += 1

            features_diff = features_diff.to(device)
            features_conv = features_conv.to(device)
            targets_base = targets_base.to(device)

            # FORWARD AND BACK PROP
            predictions = model_diff(features_diff) + model_conv(features_conv)
            loss = loss_fn(predictions, targets_base)
            optimizer_diff.zero_grad()
            optimizer_conv.zero_grad()


            loss.backward()

            # UPDATE MODEL PARAMETERS
            optimizer_diff.step()
            optimizer_conv.step()

            #additional

            features_diff = features_diff.to('cpu')
            features_conv = features_conv.to('cpu')
            targets_base = targets_base.to('cpu')

            del features_diff
            del features_conv
            del targets_base

            # LOGGING
            log_dict['train_loss_per_batch'].append(loss.item())
            
            print('Epoch: %03d/%03d | Batch %04d/%04d | Loss: %.4f'
                  % (epoch+1, num_epochs, batch_idx,
                       len(train_loader_diff), loss))

        if not skip_epoch_stats:
            model_diff.eval()
            model_conv.eval()
            
            with torch.set_grad_enabled(False):  # save memory during inference
                
                train_loss = compute_epoch_loss_autoencoder(
                    model_diff, model_conv, train_loader_diff, train_loader_conv, train_loader_base, loss_fn, device)
                print('***Epoch: %03d/%03d | Loss: %.3f' % (
                      epoch+1, num_epochs, train_loss))
                log_dict['train_loss_per_epoch'].append(train_loss.item())

        print('Time elapsed: %.2f min' % ((time.time() - start_time)/60))

    print('Total Training Time: %.2f min' % ((time.time() - start_time)/60))
    if save_model is not None:
        torch.save(model.state_dict(), save_model)
    
    return log_dict



def plot_training_loss(minibatch_losses, num_epochs, averaging_iterations=100, custom_label=''):

    iter_per_epoch = len(minibatch_losses) // num_epochs

    plt.figure()
    ax1 = plt.subplot(1, 1, 1)
    ax1.plot(range(len(minibatch_losses)),
             (minibatch_losses), label=f'Minibatch Loss{custom_label}')
    ax1.set_xlabel('Iterations')
    ax1.set_ylabel('Loss')

    if len(minibatch_losses) < 1000:
        num_losses = len(minibatch_losses) // 2
    else:
      num_losses = 1000

    #ax1.set_ylim([
    #    0, np.max(minibatch_losses[num_losses:])*1.5
    #    ])

    ax1.plot(np.convolve(minibatch_losses,
                         np.ones(averaging_iterations,)/averaging_iterations,
                         mode='valid'),
             label=f'Running Average{custom_label}')
    ax1.legend()

def plot_accuracy(train_acc, valid_acc):

    num_epochs = len(train_acc)

    plt.plot(np.arange(1, num_epochs+1), 
             train_acc, label='Training')
    plt.plot(np.arange(1, num_epochs+1),
             valid_acc, label='Validation')

    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()



def scientific_to_string(Ra):
  Ra_s = format(Ra, '.3e')
  Ra_s = Ra_s.replace('.', 'p')
  Ra_s = Ra_s.replace('+', '')

  return Ra_s

def create_result_dirs(Rai, Raf, ti, tf):

  Rai_s = scientific_to_string(Rai)
  Raf_s = scientific_to_string(Raf)

  main_directory = Rai_s + '_' + Raf_s + '_' + str(ti) + '_' + str(tf)

  root_path = '/content/gdrive/My Drive/Project/Results/'

  os.mkdir(os.path.join(root_path, main_directory))

  snapshots_folder = ['snapshots', 'snapshots/' + str(ti) + '_0', 'snapshots/' + str(ti) + '_0' + '/arrays', 'snapshots/' + str(ti) + '_0' + '/images', 'snapshots/' + str(tf) + '_1', 'snapshots/' + str(tf) + '_1' + '/arrays', 'snapshots/' + str(tf) + '_1' + '/images']

  segmentations_folder = ['segmentations', 'segmentations/' + str(ti) + '_0', 'segmentations/' + str(ti) + '_0' + '/arrays', 'segmentations/' + str(ti) + '_0' + '/images', 'segmentations/' + str(tf) + '_1', 'segmentations/' + str(tf) + '_1' + '/arrays', 'segmentations/' + str(tf) + '_1' + '/images']

  masks_folder = ['masks', 'masks/diffusion', 'masks/diffusion/' + str(ti) + '_0', 'masks/diffusion/' + str(ti) + '_0' + '/arrays', 'masks/diffusion/' + str(ti) + '_0' + '/images', 'masks/diffusion/' + str(tf) + '_1', 'masks/diffusion/' + str(tf) + '_1' + '/arrays', 'masks/diffusion/' + str(tf) + '_1' + '/images', 'masks/convection' ,'masks/convection/' + str(ti) + '_0', 'masks/convection/' + str(ti) + '_0' + '/arrays', 'masks/convection/' + str(ti) + '_0' + '/images', 'masks/convection/' + str(tf) + '_1', 'masks/convection/' + str(tf) + '_1' + '/arrays', 'masks/convection/' + str(tf) + '_1' + '/images']

  models_folder = ['models', 'models/diffusion', 'models/convection', 'models/base', 'models/diff_conv_casc']

  folders = snapshots_folder + segmentations_folder + masks_folder + models_folder

  for folder in folders:
    os.mkdir(os.path.join(os.path.join(root_path, main_directory), folder))

  return Rai_s, Raf_s


def load_data(Rai, Raf, ti, tf):

  file_count = 0

  _, _, files = next(os.walk('/content/gdrive/My Drive/Project/Results/' + Rai + '_' + Raf + '_13_1_19' + '/snapshots/' + str(ti) + '/arrays'))
  file_count = len(files)

  data_x_diff = []
  file_list_x_diff = []

  data_x_conv = []
  file_list_x_conv = []

  data_y = []
  file_list_y = []

  for i in range(file_count):

    if i > 2140: 
      continue  

    file_e_diff = glob.glob('/content/gdrive/My Drive/Project/Results/' + Rai + '_' + Raf + '_13_1_19' + '/masks/diffusion/' + str(ti) + '/arrays/' + str(i) + '_' + '*.npy')
    file_list_x_diff = file_list_x_diff + [file_e_diff]

    file_e_conv = glob.glob('/content/gdrive/My Drive/Project/Results/' + Rai + '_' + Raf + '_13_1_19' + '/masks/convection/' + str(ti) + '/arrays/' + str(i) + '_' + '*.npy')
    file_list_x_conv = file_list_x_conv + [file_e_conv]

    file_e_y = glob.glob('/content/gdrive/My Drive/Project/Results/' + Rai + '_' + Raf + '_13_1_19'  + '/snapshots/' + str(tf) + '/arrays/' + str(i) + '_' + '*.npy')
    file_list_y = file_list_y + [file_e_y]


  file_list_x_diff = np.ravel(file_list_x_diff)
  file_list_x_conv = np.ravel(file_list_x_conv)
  file_list_y = np.ravel(file_list_y)

  for file_path in file_list_x_diff:
    
    a = np.load(file_path, allow_pickle=True)

    data_x_diff = data_x_diff + [a]


  for file_path in file_list_x_conv:
    
    a = np.load(file_path, allow_pickle=True)

    data_x_conv = data_x_conv + [a]  


  for file_path in file_list_y:
    
    a = np.load(file_path, allow_pickle=True)

    data_y = data_y + [a]
  
  return data_x_diff, data_x_conv, data_y


def print_files(Rai, Raf, ti, tf):

  file_count = 0

  _, _, files = next(os.walk('/content/gdrive/My Drive/Project/Results/' + Rai + '_' + Raf + '_13_1_19' + '/snapshots/' + str(ti) + '/arrays'))
  file_count = len(files)

  data_x_diff = []
  file_list_x_diff = []

  data_x_conv = []
  file_list_x_conv = []

  data_y = []
  file_list_y = []

  for i in range(file_count):

    if i > 2140: 
      continue  

    file_e_diff = glob.glob('/content/gdrive/My Drive/Project/Results/' + Rai + '_' + Raf + '_13_1_19' + '/masks/diffusion/' + str(ti) + '/arrays/' + str(i) + '_' + '*.npy')
    file_list_x_diff = file_list_x_diff + [file_e_diff]

    file_e_conv = glob.glob('/content/gdrive/My Drive/Project/Results/' + Rai + '_' + Raf + '_13_1_19' + '/masks/convection/' + str(ti) + '/arrays/' + str(i) + '_' + '*.npy')
    file_list_x_conv = file_list_x_conv + [file_e_conv]

    file_e_y = glob.glob('/content/gdrive/My Drive/Project/Results/' + Rai + '_' + Raf + '_13_1_19'  + '/snapshots/' + str(tf) + '/arrays/' + str(i) + '_' + '*.npy')
    file_list_y = file_list_y + [file_e_y]


  file_list_x_diff = np.ravel(file_list_x_diff)
  file_list_x_conv = np.ravel(file_list_x_conv)
  file_list_y = np.ravel(file_list_y)

  return file_list_x_diff, file_list_x_conv, file_list_y


def split_files(file_list_x_diff, file_list_x_conv, file_list_y, train_size, seed = 0):

  file_list_x_diff = random.Random(seed).sample(list(file_list_x_diff), len(file_list_x_diff))
  file_list_x_conv = random.Random(seed).sample(list(file_list_x_conv), len(file_list_x_conv))
  file_list_y = random.Random(seed).sample(list(file_list_y), len(file_list_y))


  middle_index = int(np.round(len(file_list_x_diff) * train_size))

  #file_list_x_train_diff = file_list_x_diff[:middle_index]
  file_list_x_test_diff = file_list_x_diff[middle_index:]

  #file_list_x_train_conv = file_list_x_conv[:middle_index]
  file_list_x_test_conv = file_list_x_conv[middle_index:]

  #file_list_y_train = file_list_y[:middle_index]
  file_list_y_test = file_list_y[middle_index:]

  return file_list_x_test_diff, file_list_x_test_conv, file_list_y_test


def swap(i,j):
  a = i.swapaxes(0,1)
  b = j.swapaxes(0,1)
  return a,b

def unsqueeze(i,j):
  a = torch.unsqueeze(i,0)
  b = torch.unsqueeze(j,0)
  return a,b


def split_x_y_list(data_x_diff, data_x_conv, data_y, train_size, seed):


  data_x_diff = random.Random(seed).sample(data_x_diff, len(data_x_diff))
  data_x_conv = random.Random(seed).sample(data_x_conv, len(data_x_conv))
  data_y = random.Random(seed).sample(data_y, len(data_y))


  middle_index = int(np.round(len(data_x_diff) * train_size))

  data_x_train_diff = data_x_diff[:middle_index]
  data_x_test_diff = data_x_diff[middle_index:]

  data_x_train_conv = data_x_conv[:middle_index]
  data_x_test_conv = data_x_conv[middle_index:]

  data_y_train = data_y[:middle_index]
  data_y_test = data_y[middle_index:]

  return data_x_train_diff, data_x_train_conv, data_y_train, data_x_test_diff, data_x_test_conv, data_y_test


def generate_loader(data_x_diff, data_x_conv, data_y, TRAIN_SIZE, BATCH_SIZE, SEED):

  data_x_train_diff, data_x_train_conv, data_y_train, data_x_test_diff, data_x_test_conv, data_y_test = split_x_y_list(data_x_diff, data_x_conv, data_y, TRAIN_SIZE, SEED)

  ##########################
  ### Dataset
  ##########################

  import torch
  import numpy as np
  from torch.utils.data import TensorDataset, DataLoader

  import random

  my_x_train_diff = np.array(data_x_train_diff)
  my_x_test_diff = np.array(data_x_test_diff)

  my_x_train_conv = np.array(data_x_train_conv)
  my_x_test_conv = np.array(data_x_test_conv)

  my_y_train = np.array(data_y_train) 
  my_y_test = np.array(data_y_test)

  tensor_x_train_diff = torch.Tensor(my_x_train_diff) # transform to torch tensor
  tensor_x_test_diff = torch.Tensor(my_x_test_diff)
  tensor_x_train_conv = torch.Tensor(my_x_train_conv) # transform to torch tensor
  tensor_x_test_conv = torch.Tensor(my_x_test_conv)
  tensor_y_train = torch.Tensor(my_y_train) 
  tensor_y_test = torch.Tensor(my_y_test)

  #tensor_x_train, tensor_x_test, tensor_y_train, tensor_y_test = tensor_x_train.to(DEVICE), tensor_x_test.to(DEVICE), tensor_y_train.to(DEVICE), tensor_y_test.to(DEVICE)

  train_dataset_diff = TensorDataset(tensor_x_train_diff, tensor_y_train)
  test_dataset_diff = TensorDataset(tensor_x_test_diff, tensor_y_test)

  train_dataset_conv = TensorDataset(tensor_x_train_conv, tensor_y_train)
  test_dataset_conv = TensorDataset(tensor_x_test_conv, tensor_y_test)  

  train_loader_diff = DataLoader(train_dataset_diff, batch_size = BATCH_SIZE) 
  test_loader_diff = DataLoader(test_dataset_diff, batch_size = BATCH_SIZE)

  train_loader_conv = DataLoader(train_dataset_conv, batch_size = BATCH_SIZE) 
  test_loader_conv = DataLoader(test_dataset_conv, batch_size = BATCH_SIZE)  


  train_loader_diff = [unsqueeze(i,j) for (i,j) in train_loader_diff]
  train_loader_diff = [swap(i,j) for (i,j) in train_loader_diff]

  test_loader_diff = [unsqueeze(i,j) for (i,j) in test_loader_diff]
  test_loader_diff = [swap(i,j) for (i,j) in test_loader_diff]

  train_loader_conv = [unsqueeze(i,j) for (i,j) in train_loader_conv]
  train_loader_conv = [swap(i,j) for (i,j) in train_loader_conv]

  test_loader_conv = [unsqueeze(i,j) for (i,j) in test_loader_conv]
  test_loader_conv = [swap(i,j) for (i,j) in test_loader_conv]

  # Checking the dataset
  print('Training Set:\n')
  for image_x, image_y in train_loader_diff: 
      print('Image x batch dimensions:', image_x.size())
      print('Image y batch dimensions:', image_y.size())
      break

  # Checking the dataset
  print('\nTesting Set:')
  for image_x, image_y in test_loader_diff:
      print('Image x batch dimensions:', image_x.size())
      print('Image y batch dimensions:', image_y.size())
      break

  print('Training Set:\n')
  for image_x, image_y in train_loader_conv: 
      print('Image x batch dimensions:', image_x.size())
      print('Image y batch dimensions:', image_y.size())
      break

  # Checking the dataset
  print('\nTesting Set:')
  for image_x, image_y in test_loader_conv:
      print('Image x batch dimensions:', image_x.size())
      print('Image y batch dimensions:', image_y.size())
      break

  return train_loader_diff, test_loader_diff, train_loader_conv, test_loader_conv



##########################
### MODEL
##########################


class Reshape(nn.Module):
    def __init__(self, *args):
        super().__init__()
        self.shape = args

    def forward(self, x):
        return x.view(self.shape)


class Trim(nn.Module):
    def __init__(self, *args):
        super().__init__()

    def forward(self, x):
        return x[:, :, :256, :256]


class AutoEncoder(nn.Module):

    def __init__(self, m):
        super().__init__()
        
        self.encoder = nn.Sequential( #784
                nn.Conv2d(1, 4, stride=(1, 1), kernel_size=(4, 4), padding=0),
                nn.LeakyReLU(0.01),
                nn.Conv2d(4, 16, stride=(1, 1), kernel_size=(4, 4), padding=0),
                nn.LeakyReLU(0.01),
                nn.Conv2d(16, 16, stride=(1, 1), kernel_size=(4, 4), padding=0),
                nn.LeakyReLU(0.01),
                nn.Conv2d(16, 16, stride=(1, 1), kernel_size=(4, 4), padding=0),
                nn.Flatten(),
                nn.Linear(952576, m)
                )

        self.decoder = nn.Sequential(
                torch.nn.Linear(m, 952576),
                Reshape(-1, 16, 244, 244),
                nn.ConvTranspose2d(16, 16, stride=(1, 1), kernel_size=(4, 4), padding=0),
                nn.LeakyReLU(0.01),
                nn.ConvTranspose2d(16, 16, stride=(1, 1), kernel_size=(4, 4), padding=0),                
                nn.LeakyReLU(0.01),
                nn.ConvTranspose2d(16, 4, stride=(1, 1), kernel_size=(4, 4), padding=0),                
                nn.LeakyReLU(0.01),
                nn.ConvTranspose2d(4, 1, stride=(1, 1), kernel_size=(4, 4), padding=0), 
                #Trim(),  # 1x29x29 -> 1x28x28
                #nn.Sigmoid()
                )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x



def run_cae(Rai, Raf, ti, tf, TRAIN_SIZE, BATCH_SIZE, SEED, RANDOM_SEED, LEARNING_RATE, NUM_EPOCHS, DEVICE, MIN_M, MAX_M, NUM_M):

  data_x_diff, data_x_conv, data_y = load_data(Rai, Raf, ti, tf)

  train_loader_diff, test_loader_diff, train_loader_conv, test_loader_conv = generate_loader(data_x_diff, data_x_conv, data_y, TRAIN_SIZE, BATCH_SIZE, SEED)
  
  #np.random.seed(0)
  #m_values_diff = np.random.randint(MIN_M, MAX_M, NUM_M)

  m_values_diff = np.array([35, 37, 40, 42, 44, 45, 47, 55, 56, 57, 60, 62, 65, 67])
  m_values_conv = np.array([65, 63, 60, 58, 56, 55, 53, 45, 44, 43, 40, 38, 35, 33])


  print(m_values_diff)
  print(m_values_conv)

  word = 'word'

  i = 0

  for md, mc in zip(m_values_diff, m_values_conv):

    md = int(md)
    mc = int(mc)

    print(md, mc)
    print()

    if i != 0:

      time.sleep(30)

      model_diff = model_diff.cpu()
      model_conv = model_conv.cpu()
      del model_diff
      del model_conv
      del optimizer_diff
      del optimizer_conv
      del test_acc
      gc.collect()
      torch.cuda.empty_cache()

      time.sleep(30)


    set_all_seeds(RANDOM_SEED)

    model_diff = AutoEncoder(md)
    model_diff = model_diff.to(DEVICE)

    model_conv = AutoEncoder(mc)
    model_conv = model_conv.to(DEVICE)

    optimizer_diff = torch.optim.Adam(model_diff.parameters(), lr=LEARNING_RATE)
    optimizer_conv = torch.optim.Adam(model_conv.parameters(), lr=LEARNING_RATE)
  
    model_diff = model_diff.cpu()
    model_conv = model_conv.cpu()
    del model_diff
    del model_conv
    del optimizer_diff
    del optimizer_conv
    gc.collect()
    torch.cuda.empty_cache()

    model_diff = AutoEncoder(md)
    model_diff = model_diff.to(DEVICE)

    model_conv = AutoEncoder(mc)
    model_conv = model_conv.to(DEVICE)

    optimizer_diff = torch.optim.Adam(model_diff.parameters(), lr=LEARNING_RATE)
    optimizer_conv = torch.optim.Adam(model_conv.parameters(), lr=LEARNING_RATE)


    log_dict = train_autoencoder_v1(num_epochs=NUM_EPOCHS, model_diff=model_diff, model_conv=model_conv, 
                                    optimizer_diff=optimizer_diff, optimizer_conv = optimizer_conv,
                                    train_loader_diff=train_loader_diff, train_loader_conv = train_loader_conv, train_loader_base = train_loader_conv, device = DEVICE,
                                    skip_epoch_stats=True)
  
    test_acc = compute_accuracy(model_diff, model_conv, test_loader_diff, test_loader_conv, test_loader_conv, DEVICE)


    word = 'diff_conv_casc'

    os.mkdir('/content/gdrive/My Drive/Project/Results/' + Rai + '_' + Raf + '_' + str(ti) + '_' + str(tf)  + '/models/' + word + '/' + str(md) + '_' + str(mc) + '_RS' + str(RANDOM_SEED))

    torch.save(model_diff.state_dict(), '/content/gdrive/My Drive/Project/Results/' + Rai + '_' + Raf + '_' + str(ti) + '_' + str(tf)  + '/models/' + word + '/' + str(md) + '_' + str(mc) + '_RS' + str(RANDOM_SEED) + '/model_diff.pt')
    torch.save(model_conv.state_dict(), '/content/gdrive/My Drive/Project/Results/' + Rai + '_' + Raf + '_' + str(ti) + '_' + str(tf)  + '/models/' + word + '/' + str(md) + '_' + str(mc) + '_RS' + str(RANDOM_SEED) + '/model_conv.pt')

    np.save('/content/gdrive/My Drive/Project/Results/' + Rai + '_' + Raf + '_' + str(ti) + '_' + str(tf)  + '/models/' + word + '/' + str(md) + '_' + str(mc) + '_RS' + str(RANDOM_SEED) + '/test_acc.npy', test_acc, allow_pickle=True)
    np.save('/content/gdrive/My Drive/Project/Results/' + Rai + '_' + Raf + '_' + str(ti) + '_' + str(tf)  + '/models/' + word + '/' + str(md) + '_' + str(mc) + '_RS' + str(RANDOM_SEED) + '/log_dict.npy', log_dict, allow_pickle=True)

    i += 1

    print()
  