In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [6]:
import sys, torch, os, tarfile, pickle

from collections import namedtuple

from sklearn.model_selection import KFold

import numpy as np

from torch.utils import Dataloader
import torch.nn as nn
import torch.nn.functional as F

In [7]:
try:
  import cupy as cp
except ImportError:
    cp = np  # If CuPy is not available, fallback to NumPy

In [None]:
sys.path.append('/content/drive/MyDrive')
import utils

# Functions

In [None]:
def adaptive_hyperparameter_search(train_set, num_epochs, lr_range, weight_decay_range, conv2_filters_range, loss_weight_range, use_cuda = True, trials = 10, k = 5):
  """
  Function that performs adaptive hyperparameter search for RNAUnet using CuPy (if cuda available)

  Args:
  - train_set: Pytorch training data set
  - val_set: Pytorch validation data set
  - num_epochs: Maximum number of epochs to train the model
  - lr_range: Range og learning rate to search
  - weight_decay_range: Range of weight decay values to search
  - conv2_filters_range: Range of numbers of filters for the first hidden layer to search
  - loss_weight_range: Range og values for weight of classification loss to search
  - use_cuda: Boolean value indicating whether to use cuda if available
  - trials: Number of trials to perform
  - k: Number of folds for cross-validation

  Returns:
  - best_params: Dictionary containing the est hyperparameters found.
  """

  if use_cuda and torch.cuda.is_available():
    device = torch.device("cuda")
  else:
    device = torch.device("cpu")
  
  def Kfold_cv(parameters: dict, k=5):
    val_losses = 0.0
    
    #Split data into k folds: 
    kf = KFold(n_splits=k, shuffle=True, random_state=42) 

    #Loop over folds
    for train_idx, val_idx in kf:
        #Split data intro training and validation sets
        fold_train_set = torch.utils.data.Subset(train_set, train_idx)
        fold_val_set = torch.utils.data.Subset(train_set, val_idx)

        #Define data loaders
        train_fold_loader = Dataloader(fold_train_set, batch_size=1, shuffle=True)
        val_fold_loader = Dataloader(fold_val_set, batch_size=1)

        #Define model
        model = utils.RNAUnet_multi(channels = parameters["conv2_filters"])
        model.to(device)
        optimizer = utils.adam_optimizer(model, parameters["lr"], parameters["weight_decay"])

        #Train model
        for epoch in num_epochs:
            for input, output, label in train_fold_loader:
                input, output, label = input.to(device), output.to(device), label.to(device)
                optimizer.zero_grad()
                predicted, family = model(input, output)
                loss = (1-parameters["loss_weight"])*utils.dice_loss(predicted, output) + parameters["loss_weight"]*F.cross_entropy(family, label)
                loss.backward()
                optimizer.step()
            
            #Evaluate model on validation set
            val_loss = 0.0

            with torch.no_grad():
                for input, output, label in val_fold_loader: 
                    input, output, label = input.to(device), output.to(device), label.to(device)
                    predicted, family = model(input)
                    val_loss += ((1-parameters["loss_weight"])*utils.dice_loss(predicted, output) + parameters["loss_weight"]*F.cross_entropy(family, label)).item()
            val_loss = val_loss/len(val_fold_loader)


        val_losses += val_loss

    return val_losses/k

  best_loss = float('inf')
  best_params ={}

  #Define search space
  params = {
        "lr": lr_range,
        "weight_decay": weight_decay_range,
        "conv2_filters": conv2_filters_range,
        "loss_weight": loss_weight_range,
    }
  

  for i in range(trials):
    #Get search space for this iteration
    parameters = {'lr': cp.random.choice(params["lr"]),
                  'weight_decay': cp.random.choice(params["weight_decay"]),
                  'conv2_filters': cp.random.choice(params["conv2_filters"]),
                  'loss_weight': cp.random.choice(params["loss_weight"])}
    
    val_loss = Kfold_cv(parameters, k)


    #Update best hyperparameters if applicable
    if val_loss < best_loss:
      best_loss = val_loss
      best_params = parameters
      print("New best hyperparameters found: ") 
      print(best_params)
      
    #Update search space based on performance
    if i > 0:
      if val_loss < prev_val_loss:
        if parameters["lr"] in params["lr"]:
          params["lr"].remove(parameters["lr"])
        if parameters["weight_decay"] in params["weight_decay"]:  
          params["weight_decay"].remove(parameters["weight_decay"])
        if parameters["conv2_filters"] in params["conv2_filters"]:
          params["conv2_filters"].remove(parameters["conv2_filters"])
        if parameters["loss_weight"] in params["loss_weight"]:
          params["loss_weight"].remove(parameters["loss_weight"])
          
        prev_val_loss = val_loss

    print(f"Trial {i+1} completed. Best hyperparameters found: {best_params} with loss {best_loss}")

  return best_params



# Data

In [None]:
RNA_data = namedtuple('RNA_data', 'input output length family name pairs')

In [None]:
# Define the path to the zipped folder in your Google Drive
tar_file_path = '/content/drive/MyDrive/data/experiment8.tar.gz'


# Extract the tar.gz archive
with tarfile.open(tar_file_path, 'r:gz') as tar:
    tar.extractall('/content')

file_list = [os.path.join('data', 'experiment8', file) for file in os.listdir('data/experiment8')]

In [None]:
train = pickle.load(open('/content/drive/MyDrive/data/experiment_train.pkl', 'rb'))
valid = pickle.load(open('/content/drive/MyDrive/data/experiment_valid.pkl', 'rb'))

family_map = pickle.load(open('/content/drive/MyDrive/data/experiment_familymap.pkl', 'rb'))

In [None]:
# Define your train_dataset and validation_dataset
train_dataset = utils.ImageToImageDataset(train, family_map)
validation_dataset = utils.ImageToImageDataset(valid, family_map)

# Test

In [None]:
params = {
        "lr": [0.01, 0.005, 0.001],
        "weight_decay": [0.01, 0.001, 0.0001, 0],
        "conv2_filters": [32, 64],
        "loss_weight": [0.25, 0.5, 0.75],
    }

In [None]:
adaptive_hyperparameter_search(train_dataset, 10, params["lr"], params["weight_decay"], params["conv2_filters"], params["loss_weight"], use_cuda = True, trials = 10, k = 5)