# Load libraries

In [None]:
import os
import json
import numpy as np
import pandas as pd
import csv
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import seaborn as sns

import torch
import torch.nn.functional as F
from torch.nn import Linear, Conv2d, Dropout, AdaptiveAvgPool2d
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv, global_mean_pool
from torch.optim import Adam
from datetime import datetime
from torch.utils.tensorboard import SummaryWriter
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
%matplotlib inline

# Set the seed for reproducibility

In [None]:
# Set the seed for reproducibility
SEED = 2
np.random.seed(SEED)
torch.manual_seed(SEED)

# Check if GPU is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Making info table

In [None]:
def load_npy_files(base_folder):
    """
    Load all .npy file paths from the given base folder.
    """
    npy_files = []
    materials = os.listdir(base_folder)

    for material in materials:
        material_path = os.path.join(base_folder, material)
        if not os.path.isdir(material_path):
            continue

        for condition in ['CIP', 'CM']:
            condition_path = os.path.join(material_path, condition)
            if not os.path.exists(condition_path):
                continue

            for file in os.listdir(condition_path):
                if file.endswith('.gjf.npy'):
                    npy_files.append({
                        'material': material,
                        'condition': condition,
                        'file_path': os.path.join(condition_path, file),
                        'filename': file.split('.')[0]  # Extract filename without extension
                    })

    return pd.DataFrame(npy_files)


def load_csv_targets(base_folder):
    """
    Load target values from the transfer_integrals.csv files and create a dataframe.
    """
    targets = []
    materials = os.listdir(base_folder)

    for material in materials:
        material_path = os.path.join(base_folder, material)
        if not os.path.isdir(material_path):
            continue

        csv_file_path = os.path.join(material_path, 'transfer_integrals.csv')
        if not os.path.exists(csv_file_path):
            continue

        df = pd.read_csv(csv_file_path)
        df['material'] = material
        df['filename'] = df['Filename'].apply(lambda x: x.split('.')[0])  # Extract filename without extension
        targets.append(df)

    return pd.concat(targets, ignore_index=True)


def find_missing_targets(base_folder):
    """
    Find .npy files that are missing their target values in the CSV files.
    """
    # Load all npy file paths
    npy_files = load_npy_files(base_folder)

    # Load target values from CSV files
    target_data = load_csv_targets(base_folder)

    # Extract filenames (without extensions) from the CSV targets for comparison
    target_filenames = target_data['filename'].unique()

    # Find npy files without corresponding targets
    missing_targets = npy_files[~npy_files['filename'].isin(target_filenames)]

    return missing_targets


def prepare_dataset(base_folder, epsilon):
    """
    Prepare the dataset for use in machine learning models, excluding missing targets.
    """
    # Load npy data paths
    npy_data = load_npy_files(base_folder)

    # Load target values from CSV files
    target_data = load_csv_targets(base_folder)

    # Find missing targets
    missing_targets = find_missing_targets(base_folder)
    if not missing_targets.empty:
        print("Found .npy files with missing target values:")
        print(missing_targets.loc[missing_targets['condition'] == 'CIP', 'file_path'])
    else:
        print("All .npy files have corresponding target values.")

    # Filter out npy files that are missing their targets
    npy_data_filtered = npy_data[~npy_data['filename'].isin(missing_targets['filename'])]

    # Merge npy data with target values
    dataset = npy_data_filtered.merge(target_data, on=['material', 'filename'], how='left')

    # Drop rows with missing target values (this is just a safety check)
    dataset = dataset.dropna(subset=['HOMO'])

    # Compute the two target variables
    dataset['target_HOMO'] = dataset['HOMO']
    dataset['target_log_abs_HOMO'] = np.log(epsilon + np.abs(dataset['HOMO']))

    # Remove unnecessary columns
    dataset = dataset.drop(columns=['Filename', 'NLUMO', 'LUMO', 'HOMO', 'NHOMO'])

    # Separate CIP and CM data
    dataset_CIP = dataset[dataset['condition'] == 'CIP'].rename(columns={'file_path': 'file_path_CIP'})
    dataset_CM = dataset[dataset['condition'] == 'CM'].rename(columns={'file_path': 'file_path_CM'})

    # Merge CIP and CM datasets based on material and filename
    merged_dataset = pd.merge(dataset_CIP[['material', 'filename', 'file_path_CIP', 'target_HOMO', 'target_log_abs_HOMO']],
                              dataset_CM[['material', 'filename', 'file_path_CM']],
                              on=['material', 'filename'])

    return merged_dataset

def get_data(base_folder, target, input_type, test_material, SEED, epsilon):
    """
    Prepare the dataset ready for training based on the selected target and input type.

    Parameters:
    - base_folder: str, the base directory containing data files.
    - target: str, the target column to use ('target_HOMO' or 'target_log_abs_HOMO').
    - input_type: str, the input type ('CIP', 'CM', or 'Multi').
    - test_material: str, optional material to use as a test set.
    """
    merged_dataset = prepare_dataset(base_folder, epsilon)

    # Choose the target column
    if target not in ['target_HOMO', 'target_log_abs_HOMO']:
        raise ValueError("Invalid target. Choose 'target_HOMO' or 'target_log_abs_HOMO'.")

    # Select the input type
    if input_type == 'CIP':
        merged_dataset = merged_dataset[['material', 'filename', 'file_path_CIP', target]]
    elif input_type == 'CM':
        merged_dataset = merged_dataset[['material', 'filename', 'file_path_CM', target]]
    elif input_type == 'Multi':
        merged_dataset = merged_dataset[['material', 'filename', 'file_path_CIP', 'file_path_CM', target]]
    else:
        raise ValueError("Invalid input type. Choose 'CIP', 'CM', or 'Multi'.")

    # Split data into train/test sets
    if test_material:
        train_data = merged_dataset[merged_dataset['material'] != test_material]
        test_data = merged_dataset[merged_dataset['material'] == test_material]
    else:
        train_data, test_data = train_test_split(
            merged_dataset,
            test_size=0.2,
            stratify=merged_dataset['material'],
            random_state=SEED
        )
    return train_data, test_data

# Plots and anlaysis

In [None]:
def plot_distribution(data, column, title):
    plt.figure(figsize=(10, 6))
    sns.histplot(data[column], kde=True, color='blue')
    plt.title(title)
    plt.xlabel(column)
    plt.ylabel('Frequency')
    plt.grid(True)
    plt.show()

def analyze_data(data, dataset):
    print(data.describe())  # Basic descriptive statistics
    plt.figure()
    plt.hist(data['material'], color='blue')
    plt.title(f'Material Distribution on {dataset}')
    plt.show()

# Assuming 'train_data' and 'test_data' are your datasets
def analyze_target_distribution(train_data, test_data, target='target_log_abs_HOMO'):
    analyze_data(train_data, 'Train Set')
    analyze_data(test_data, 'Test Set')
    plot_distribution(train_data, target, f'Distribution of {target} in Training Data')
    plot_distribution(test_data, target, f'Distribution of {target} in Testing Data')


def plot_predictions(data_loader, model_path, epsilon):
    actuals, predictions = [], []
    extracted_features = extract_features_from_path(model_path)
    model_type= extracted_features['model_type']
    input_type= extracted_features['input_type']
    target_type = extracted_features['target']
    model = load_model(model_path)
    for batch_idx, data in enumerate(data_loader):
        output, actual  = predict_model(data, model, model_type, input_type)
        actuals.extend(actual.cpu().numpy())
        predictions.extend(output.cpu().detach().numpy())


    # Check target type and compute the counterpart if necessary
    if target_type == "target_log_abs_HOMO":
        actuals_transformed = np.exp(actuals) - epsilon
        predictions_transformed = np.exp(predictions) - epsilon
        transformed_title = 'HOMO'
        title = "log_abs_HOMO"
    elif target_type == "target_HOMO":
        actuals_transformed = np.log(epsilon + np.abs(actuals))
        predictions_transformed = np.log(epsilon + np.abs(predictions))
        transformed_title = 'log_abs_HOMO'
        title = 'HOMO'

    # Compute metrics for original target
    mse = mean_squared_error(actuals, predictions)
    rmse = np.sqrt(mse)
    mae = mean_absolute_error(actuals, predictions)
    r2 = r2_score(actuals, predictions)

    # Compute metrics for transformed target
    mse_transformed = mean_squared_error(actuals_transformed, predictions_transformed)
    rmse_transformed = np.sqrt(mse_transformed)
    mae_transformed = mean_absolute_error(actuals_transformed, predictions_transformed)
    r2_transformed = r2_score(actuals_transformed, predictions_transformed)


    # Plotting for original target
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)  # Subplot for original target
    plt.scatter(actuals, predictions, facecolors='none', edgecolors='b', alpha=0.5)
    plt.xlabel(f'Actual {title}')
    plt.ylabel(f'Estimated {title}')
    plt.title(f"{model_type} Model Trained on {title}")
    plt.plot([min(actuals), max(actuals)], [min(actuals), max(actuals)], 'k')
    plt.grid(True)
    plt.text(min(actuals), max(predictions), f'MSE: {mse:.2f}\nRMSE: {rmse:.2f}\nMAE: {mae:.2f}\nR²: {r2:.2f}',
             verticalalignment='top', horizontalalignment='left', color='red')

    # Plotting for transformed target
    plt.subplot(1, 2, 2)  # Subplot for transformed target
    plt.scatter(actuals_transformed, predictions_transformed, facecolors='none', edgecolors='b', alpha=0.5)
    plt.xlabel(f'Actual {transformed_title}')
    plt.ylabel(f'Estimated {transformed_title}')
    plt.title(f"{model_type} Model Trained on {title}")
    plt.plot([min(actuals_transformed), max(actuals_transformed)], [min(actuals_transformed), max(actuals_transformed)], 'k')
    plt.grid(True)
    plt.text(min(actuals_transformed), max(predictions_transformed),
             f'MSE: {mse_transformed:.2f}\nRMSE: {rmse_transformed:.2f}\nMAE: {mae_transformed:.2f}\nR²: {r2_transformed:.2f}',
             verticalalignment='top', horizontalalignment='left', color='red')

    plt.tight_layout()
    plt.show()

    # Print metrics for both targets
    print(f"Original ({title}) - MSE: {mse:.2f} | RMSE: {rmse:.2f} | MAE: {mae:.2f} | R²: {r2:.2f}")
    print(f"Transformed ({transformed_title}) - MSE: {mse_transformed:.2f} | RMSE: {rmse_transformed:.2f} | MAE: {mae_transformed:.2f} | R²: {r2_transformed:.2f}")


def plot_matrix(matrix, title="Matrix"):
    plt.figure(figsize=(6, 6))
    plt.imshow(matrix, cmap='viridis')
    plt.colorbar()
    plt.title(title)
    plt.show()

def find_actual_homo(base_folder, sample_file_path, epsilon):
    """
    Find the actual HOMO value for a given sample file path.

    Args:
    base_folder (str): Base folder where data folders are located.
    sample_file_path (str): Full path to the sample file.

    Returns:
    float: The actual HOMO value, or None if not found.
    """
    # Load all targets into a DataFrame
    df_targets = load_csv_targets(base_folder)

    # Extract the material and filename from the sample file path
    parts = sample_file_path.split('/')
    material = parts[-3]
    filename = parts[-1].split('.')[0]

    # Find the row in the DataFrame
    row = df_targets[(df_targets['material'] == material) & (df_targets['filename'] == filename)]
    if not row.empty:
        return row['HOMO'].values[0], np.log(epsilon+ np.abs(row['HOMO'].values[0]))  # Assuming the column with the HOMO value is named 'HOMO'
    else:
        return None

# Assuming 'model' is your pre-trained model instance
def predict_and_visualize(model_path, sample_path_cip, sample_path_cm, epsilon):
    parts = sample_path_cip.split('/')
    base_folder = os.path.join(parts[0], parts[1])
    extracted_features = extract_features_from_path(model_path)
    input_type= extracted_features['input_type']

    # Load the sample data
    sample_cip = np.load(sample_path_cip)
    sample_cm = np.load(sample_path_cm)

    print('Sample Addresses: \n', sample_path_cip,'\n', sample_path_cm)
    # Visualize the matrices
    plot_matrix(sample_cip, title="CIP Data Sample")
    plot_matrix(sample_cm, title="CM Data Sample")

    # Convert numpy array to torch tensor
    sample_cip_tensor = torch.tensor(sample_cip, dtype=torch.float32)
    sample_cm_tensor = torch.tensor(sample_cm, dtype=torch.float32)

    # Add a batch dimension (model expects batches)
    sample_cip_tensor = sample_cip_tensor.unsqueeze(0)
    sample_cm_tensor = sample_cm_tensor.unsqueeze(0)

    model = load_model(model_path)
    print('Model Address: \n', model_path)
    # Ensure the model is in eval mode
    model.eval()

    # If the model is on CUDA, transfer tensor to CUDA
    if next(model.parameters()).is_cuda:
        sample_cip_tensor = sample_cip_tensor.cuda()
        sample_cm_tensor = sample_cm_tensor.cuda()

    # Make a prediction
    with torch.no_grad():
        predict_log_abs_homo = model(sample_cip_tensor, sample_cm_tensor, input_type) #model(sample_cip_tensor)
        predict_homo = np.exp(predict_log_abs_homo.item()-epsilon)

    # Process the prediction if necessary, e.g., applying softmax or argmax, depending on the model output
    print("Predicted HOMO Value:", predict_homo.item())
    print("Predicted Log Abs HOMO Value:", predict_log_abs_homo.item())
    actual_homo,  actual_log_abs_homo= find_actual_homo(base_folder, sample_path_cip, epsilon)
    print("Actual HOMO Value:", actual_homo)
    print("Actual Log Abs HOMO Value:", actual_log_abs_homo)

    return None

# Data loader, Models' architecture, Main loop

In [None]:
def create_unique_folder(base_folder, model_type, input_type, target, test_material, epochs, lr, batch_size, seed):
    """
    Create a unique folder based on the provided parameters and save results.
    """
    # Create a unique folder name based on the parameters
    folder_name = f"{model_type}_{input_type}_{target}_{test_material}_epochs{epochs}_lr{lr}_batch{batch_size}_seed{seed}"
    folder_path = os.path.join(base_folder, folder_name)
    # Create the folder if it does not exist
    if not os.path.exists(folder_path):
        os.makedirs(folder_path)
    return folder_name, folder_path

def extract_features_from_path(folder_path):
    """
    Extract model parameters from a folder name created by create_unique_folder, handling complex target names.

    Args:
    folder_path (str): The path to the folder.

    Returns:
    dict: Dictionary containing the extracted features.
    """
    
    # Remove the file component from the path if it exists
    folder_path = os.path.dirname(folder_path)
    
    # Get the folder name from the full path
    folder_name = os.path.basename(folder_path)

    # Initialize dictionary to store the components
    features = {
        'model_type': None,
        'input_type': None,
        'target': None,
        'test_material': None,
        'epochs': None,
        'lr': None,
        'batch_size': None,
        'SEED': None
    }

    # Extract 'seed' at the end as it's the easiest and work backwards
    parts = folder_name.split('_seed')
    features['SEED'] = int(parts[1])

    # Split the remaining by 'batch'
    parts = parts[0].split('_batch')
    features['batch_size'] = int(parts[1])

    # Split by 'lr'
    parts = parts[0].split('_lr')
    features['lr'] = float(parts[1])

    # Split by 'epochs'
    parts = parts[0].split('_epochs')
    features['epochs'] = int(parts[1])

    # Remaining part will have 'model_type', 'input_type', 'target', and 'test_material'
    parts = parts[0].split('_')

    # As the first four elements should now be model_type, input_type, target, and test_material
    if len(parts) >= 4:
        features['model_type'] = parts[0]
        features['input_type'] = parts[1]
        features['target'] = '_'.join(parts[2:-1])  # Join all parts that make up the target as it may contain underscores
        features['test_material'] = parts[-1]
    return features


def save_results_to_csv(folder_name, folder_path, epoch, train_loss, mse_train_loss, mae_train_loss, mse_test_loss, mae_test_loss):
    """
    Save training and test results to a CSV file inside the unique folder.
    """
    csv_file = os.path.join(folder_path, folder_name+".csv")

    # If the file doesn't exist, write the header first
    file_exists = os.path.isfile(csv_file)

    with open(csv_file, mode='a', newline='') as file:
        writer = csv.writer(file)
        if not file_exists:
            writer.writerow(["Epoch", "Train Loss", "MSE Train Loss", "MAE Train Loss", "MSE Test Loss", "MAE Test Loss"])
        writer.writerow([epoch, train_loss, mse_train_loss, mae_train_loss, mse_test_loss, mae_test_loss])


def setup_tensorboard_log(folder_name, folder_path):
    """
    Setup TensorBoard SummaryWriter to log results inside the unique folder.
    """
    #log_dir = os.path.join(folder_path, folder_name)
    writer = SummaryWriter(log_dir=folder_path)
    return writer


def load_graph_or_image_data(row, target, input_type, model_type):
    if model_type == 'GNN':
        return load_graph_data(row, target, input_type)
    elif model_type == 'CNN':
        return load_image_data(row, target, input_type)

def load_image_data(row, target, input_type, ):
    """
    Load image (matrix) data for CNN based on the selected input type, along with the target.

    Parameters:
    - row: pd.Series, the row of the dataset containing paths to CIP, CM, and target.
    - input_type: str, the type of input ('CIP', 'CM', or 'multi').
    - target_column: str, the name of the column containing the target value.

    Returns:
    - The input data and the target.
    """
    target = torch.tensor(row[target], dtype=torch.float)

    if input_type == 'CIP':
        cip_data = np.load(row['file_path_CIP'])
        cip_data = torch.tensor(cip_data, dtype=torch.float)  # Shape (158, 158)
        return cip_data, target  # Return both input and target

    elif input_type == 'CM':
        cm_data = np.load(row['file_path_CM'])
        cm_data = torch.tensor(cm_data, dtype=torch.float)  # Shape (316, 316)
        return cm_data, target  # Return both input and target

    elif input_type == 'Multi':
        cip_data = np.load(row['file_path_CIP'])
        cm_data = np.load(row['file_path_CM'])
        cip_data = torch.tensor(cip_data, dtype=torch.float)  # Shape (158, 158)
        cm_data = torch.tensor(cm_data, dtype=torch.float)  # Shape (316, 316)
        return cip_data, cm_data, target  # Return both inputs and the target


def load_graph_data(row, target, input_type):
    """
    Load graph data based on the selected input type.

    Parameters:
    - row: pd.Series, a row from the dataset.
    - input_type: str, the input type ('CIP', 'CM', or 'Multi').
    """
    if input_type == 'CIP':
        # Load only CIP data
        cip_data = np.load(row['file_path_CIP'])
        edge_index_cip = torch.tensor([[i, j] for i in range(158) for j in range(158)], dtype=torch.long).t().contiguous()
        x_cip = torch.tensor(cip_data, dtype=torch.float)
        data_cip = Data(x=x_cip, edge_index=edge_index_cip, y=torch.tensor([row[target]], dtype=torch.float))
        return data_cip

    elif input_type == 'CM':
        # Load only CM data
        cm_data = np.load(row['file_path_CM'])
        edge_index_cm = torch.tensor([[i, j] for i in range(316) for j in range(316)], dtype=torch.long).t().contiguous()
        x_cm = torch.tensor(cm_data, dtype=torch.float)
        data_cm = Data(x=x_cm, edge_index=edge_index_cm, y=torch.tensor([row[target]], dtype=torch.float))
        return data_cm

    elif input_type == 'Multi':
        # Load both CIP and CM data
        cip_data = np.load(row['file_path_CIP'])
        cm_data = np.load(row['file_path_CM'])
        edge_index_cip = torch.tensor([[i, j] for i in range(158) for j in range(158)], dtype=torch.long).t().contiguous()
        edge_index_cm = torch.tensor([[i, j] for i in range(316) for j in range(316)], dtype=torch.long).t().contiguous()
        x_cip = torch.tensor(cip_data, dtype=torch.float)
        x_cm = torch.tensor(cm_data, dtype=torch.float)
        data_cip = Data(x=x_cip, edge_index=edge_index_cip, y=torch.tensor([row[target]], dtype=torch.float))
        data_cm = Data(x=x_cm, edge_index=edge_index_cm, y=torch.tensor([row[target]], dtype=torch.float))
        return data_cip, data_cm

def evaluate_model(data_loader, model, criterion, model_type, input_type):
    model.eval()  # Set model to evaluation mode
    total_loss = 0
    total_mae = 0

    with torch.no_grad():  # Disable gradient calculation for evaluation
        for batch_idx, data in enumerate(data_loader):
            if model_type == 'GNN':
                if input_type == 'Multi':
                    data_cip, data_cm = data

                    data_cip = data_cip.to(device)
                    data_cm = data_cm.to(device)

                    output = model(data_cip, data_cm, input_type)
                    loss = criterion(output, data_cip.y)
                    mae = torch.nn.L1Loss()(output, data_cip.y)
                else:
                    data = data.to(device)

                    if input_type == 'CIP':
                        output = model(data, None, input_type)
                    elif input_type == 'CM':
                        output = model(None, data, input_type)

                    loss = criterion(output, data.y)
                    mae = torch.nn.L1Loss()(output, data.y)

            elif model_type == 'CNN':
                if input_type == 'Multi':
                    data_cip, data_cm, target = data

                    data_cip = data_cip.to(device)
                    data_cm = data_cm.to(device)
                    target = target.to(device)

                    output = model(data_cip, data_cm, input_type)
                else:
                    data, target = data

                    data = data.to(device)
                    target = target.to(device)

                    if input_type == 'CIP':
                        output = model(data, None, input_type)
                    elif input_type == 'CM':
                        output = model(None, data, input_type)

                # Ensure the target is passed as a tensor
                loss = criterion(output, target.unsqueeze(1))  # Ensure target is the correct shape
                mae = torch.nn.L1Loss()(output, target.unsqueeze(1))

            total_loss += loss.item()
            total_mae += mae.item()

    avg_mse_loss = total_loss / len(data_loader)
    avg_mae_loss = total_mae / len(data_loader)
    return avg_mse_loss, avg_mae_loss

def predict_model(data, model, model_type, input_type):
    model.eval()  # Set model to evaluation mode
    if model_type == 'GNN':
        if input_type == 'Multi':
            data_cip, data_cm = data
            data_cip = data_cip.to(device)
            data_cm = data_cm.to(device)
            output = model(data_cip, data_cm, input_type)
            actual = data_cip.y
        else:
            data = data.to(device)
            if input_type == 'CIP':
                output = model(data, None, input_type)
            elif input_type == 'CM':
                output = model(None, data, input_type)
            actual = data.y
    elif model_type == 'CNN':
        if input_type == 'Multi':
            data_cip, data_cm, actual = data
            data_cip = data_cip.to(device)
            data_cm = data_cm.to(device)
            #actual = actual.to(device)
            output = model(data_cip, data_cm, input_type)
        else:
            data, actual = data
            data = data.to(device)
            #actual = actual.to(device)
            if input_type == 'CIP':
                output = model(data, None, input_type)
            elif input_type == 'CM':
                output = model(None, data, input_type)
    return output, actual


def load_model(folder_path):
    #model = torch.load(f"{folder_path}/model.pt")
    #model = torch.load(f"{folder_path}/best_model.pt")
    model = torch.load(folder_path)
    model.eval()
    return model



class MultiInputModel(torch.nn.Module):
    def __init__(self, model_type, input_dim_cip, input_dim_cm, input_type,
                 layer_configs_cip=None, layer_configs_cm=None):
        super(MultiInputModel, self).__init__()
        self.model_type = model_type
        self.input_type = input_type

        # Default layer configurations if none are provided
        if layer_configs_cip is None:
            layer_configs_cip = [(32, 3, 0.5)] if model_type == 'CNN' else [(128, 0.5)]  # (channels, kernel_size, dropout) for CNN or (channels, dropout) for GNN
        if layer_configs_cm is None:
            layer_configs_cm = [(32, 3, 0.5)] if model_type == 'CNN' else [(128, 0.5)]

        # Define model layers based on the type (CNN or GNN)
        if model_type == 'CNN':
            self.setup_cnn_layers(input_dim_cip, input_dim_cm, layer_configs_cip, layer_configs_cm)
        elif model_type == 'GNN':
            self.setup_gnn_layers(input_dim_cip, input_dim_cm, layer_configs_cip, layer_configs_cm)

    def setup_cnn_layers(self, input_dim_cip, input_dim_cm, layer_configs_cip, layer_configs_cm):
        self.layers_cip = torch.nn.ModuleList()
        self.layers_cm = torch.nn.ModuleList()

        # Set up layers for CIP
        prev_channels_cip = 1
        for out_channels, kernel_size, dropout in layer_configs_cip:
            self.layers_cip.append(Conv2d(prev_channels_cip, out_channels, kernel_size=kernel_size, stride=1, padding=kernel_size//2))
            self.layers_cip.append(Dropout(dropout))
            prev_channels_cip = out_channels

        # Set up layers for CM
        prev_channels_cm = 1
        for out_channels, kernel_size, dropout in layer_configs_cm:
            self.layers_cm.append(Conv2d(prev_channels_cm, out_channels, kernel_size=kernel_size, stride=1, padding=kernel_size//2))
            self.layers_cm.append(Dropout(dropout))
            prev_channels_cm = out_channels

        self.fc_cip = Linear(prev_channels_cip, 1)  # Assuming output from global pool
        self.fc_cm = Linear(prev_channels_cm, 1)
        if self.input_type == 'Multi':
            self.fc = Linear(prev_channels_cip+prev_channels_cm, 1)

    def setup_gnn_layers(self, input_dim_cip, input_dim_cm, layer_configs_cip, layer_configs_cm):
        self.layers_cip = torch.nn.ModuleList()
        self.layers_cm = torch.nn.ModuleList()

        # Set up layers for CIP
        prev_channels_cip = input_dim_cip
        for out_channels, dropout in layer_configs_cip:
            self.layers_cip.append(GCNConv(prev_channels_cip, out_channels))
            self.layers_cip.append(Dropout(dropout))
            prev_channels_cip = out_channels

        # Set up layers for CM
        prev_channels_cm = input_dim_cm
        for out_channels, dropout in layer_configs_cm:
            self.layers_cm.append(GCNConv(prev_channels_cm, out_channels))
            self.layers_cm.append(Dropout(dropout))
            prev_channels_cm = out_channels

        self.fc_cip = Linear(prev_channels_cip, 1)
        self.fc_cm = Linear(prev_channels_cm, 1)
        if self.input_type == 'Multi':
            self.fc = Linear(prev_channels_cip+prev_channels_cm, 1)

    def forward(self, data_cip, data_cm, input_type):
        if self.model_type == 'CNN':
            return self.forward_cnn(data_cip, data_cm, input_type)
        elif self.model_type == 'GNN':
            return self.forward_gnn(data_cip, data_cm, input_type)

    def forward_cnn(self, data_cip, data_cm, input_type):
        if self.input_type in ['CIP', 'Multi']:
            x = data_cip.unsqueeze(1)  # Adding channel dimension
            for layer in self.layers_cip:
                x = layer(x) if isinstance(layer, Conv2d) else F.relu(x)
            x_cip = AdaptiveAvgPool2d((1, 1))(x).view(x.size(0), -1)

        if self.input_type in ['CM', 'Multi']:
            x = data_cm.unsqueeze(1)
            for layer in self.layers_cm:
                x = layer(x) if isinstance(layer, Conv2d) else F.relu(x)
            x_cm = AdaptiveAvgPool2d((1, 1))(x).view(x.size(0), -1)

        if self.input_type == 'CIP':
            return self.fc_cip(x_cip)
        elif self.input_type == 'CM':
            return self.fc_cm(x_cm)
        elif self.input_type == 'Multi':
            x = torch.cat([x_cip, x_cm], dim=1)
            return self.fc(x)

    def forward_gnn(self, data_cip, data_cm, input_type):
        if self.input_type in ['CIP', 'Multi']:
            x = data_cip.x
            edge_index = data_cip.edge_index
            for layer in self.layers_cip:
                x = layer(x, edge_index) if isinstance(layer, GCNConv) else F.relu(layer(x))
            x_cip = global_mean_pool(x, data_cip.batch)

        if self.input_type in ['CM', 'Multi']:
            x = data_cm.x
            edge_index = data_cm.edge_index
            for layer in self.layers_cm:
                x = layer(x, edge_index) if isinstance(layer, GCNConv) else F.relu(layer(x))
            x_cm = global_mean_pool(x, data_cm.batch)

        if self.input_type == 'CIP':
            return self.fc_cip(x_cip)
        elif self.input_type == 'CM':
            return self.fc_cm(x_cm)
        elif self.input_type == 'Multi':
            x = torch.cat([x_cip, x_cm], dim=1)
            return self.fc(x)


def train_model(train_loader, test_loader, model, optimizer, criterion, model_type, input_type,
                epochs, folder_name, folder_path):
    """
    Train the GNN model.

    Parameters:
    - train_loader: DataLoader, the training data loader.
    - model: torch.nn.Module, the CNN or GNN model.
    - optimizer: torch.optim.Optimizer, the optimizer.
    - criterion: torch.nn.Module, the loss function.
    - model_type: str, the model type ('CNN', or 'GNN').
    - input_type: str, the input type ('CIP', 'CM', or 'Multi').
    - epochs: int, the number of training epochs.
    """
    model.train()
    writer = setup_tensorboard_log(folder_name, folder_path)
    best_mae = np.inf

    for epoch in range(epochs):
        total_train_loss = 0
        total_test_loss = 0
        model.train()
        for data in train_loader:
            optimizer.zero_grad()
            if model_type == 'GNN':
                if input_type == 'Multi':
                    data_cip, data_cm = data

                    data_cip = data_cip.to(device)
                    data_cm = data_cm.to(device)

                    output = model(data_cip, data_cm, input_type)
                    loss = criterion(output, data_cip.y)
                else:
                    data = data.to(device)

                    if input_type == 'CIP':
                        output = model(data, None, input_type)
                    elif input_type == 'CM':
                        output = model(None, data, input_type)

                    loss = criterion(output, data.y)

            elif model_type == 'CNN':
                if input_type == 'Multi':
                    data_cip, data_cm, target = data

                    data_cip = data_cip.to(device)
                    data_cm = data_cm.to(device)
                    target = target.to(device)

                    output = model(data_cip, data_cm, input_type)
                else:
                    data, target = data

                    data = data.to(device)
                    target = target.to(device)

                    if input_type == 'CIP':
                        output = model(data, None, input_type)
                    elif input_type == 'CM':
                        output = model(None, data, input_type)
                # Ensure the target is passed as a tensor
                loss = criterion(output, target.unsqueeze(1))  # Ensure target is the correct shape

            loss.backward()
            optimizer.step()
            total_train_loss += loss.item()
        avg_train_loss = total_train_loss / len(train_loader)

        # Train evaluation
        mse_train_loss, mae_train_loss = evaluate_model(train_loader, model, criterion, model_type, input_type)
        # Test evaluation
        mse_test_loss, mae_test_loss = evaluate_model(test_loader, model, criterion, model_type, input_type)

        # Log both training and test loss to TensorBoard
        writer.add_scalar('Loss/Train/AVG_MSE/', avg_train_loss, epoch)
        writer.add_scalar('Loss/Train/MSE/', mse_train_loss, epoch)
        writer.add_scalar('Loss/Train/MAE/', mae_train_loss, epoch)
        writer.add_scalar('Loss/Test/MSE/', mse_test_loss, epoch)
        writer.add_scalar('Loss/Test/MAE/', mae_test_loss, epoch)

        # Save results to CSV
        save_results_to_csv(folder_name, folder_path, epoch, avg_train_loss, mse_train_loss, mae_train_loss, mse_test_loss, mae_test_loss)

        # Print results for the current epoch
        print(f"Epoch [{epoch+1}/{epochs}], Train Loss: {avg_train_loss}, MSE Train Loss: {mse_train_loss}, MAE Train Loss: {mae_train_loss}, MSE Test Loss: {mse_test_loss}, MAE Test Loss: {mae_test_loss}")
        # Save the model
        torch.save(model, f"{folder_path}/model.pt")
        if mae_test_loss <= best_mae:
            best_mae = mae_test_loss
            # Replace the best model
            torch.save(model, f"{folder_path}/best_model.pt")
            print(f"Save the best model at Epoch [{epoch+1}/{epochs}]")

    writer.close()
    # Save the model
    torch.save(model, f"{folder_path}/model.pt")


def initialize_model(model_type, input_dim_cip, input_dim_cm, input_type, save_path):
    """
    Initialize a model based on the specified type and input dimensions.

    Args:
    model_type (str): 'CNN' or 'GNN' indicating the type of model.
    input_dim_cip (int): Dimension size for CIP input.
    input_dim_cm (int): Dimension size for CM input.
    input_type (str): 'CIP', 'CM', or 'Multi' indicating the type of input handling.

    Returns:
    torch.nn.Module: Configured model instance.
    """
    # Define default configurations for CNN and GNN for both CIP and CM inputs
    if model_type == 'CNN':
        # Define the configurations for CNN layers for both CIP and CM inputs
        layer_configs_cip = [
            (32, 3, 0.5),  # 32 filters, kernel size 3, dropout 0.1
            (64, 3, 0.4),  # 64 filters, kernel size 3, dropout 0.1
            (128, 3, 0.3),
            (256, 3, 0.2),
            (128, 3, 0.1)  # 128 filters, kernel size 3, dropout 0.1
        ]
        layer_configs_cm = [
            (32, 3, 0.2),  # 32 filters, kernel size 3, dropout 0.2
            (64, 3, 0.2)   # 64 filters, kernel size 3, dropout 0.2

        ]

    elif model_type == 'GNN':
        # Define the configurations for GNN layers for both CIP and CM inputs
        layer_configs_cip = [
            (32, 0.5),  # 64 channels, dropout 0.5
            (64, 0.4),  # 128 channels, dropout 0.4
            (128, 0.3),
            (256, 0.2),
            (128, 0.1)  # 128 channels, dropout 0.1
        ]
        layer_configs_cm = [
            (64, 0.2),  # 64 channels, dropout 0.2
            (128, 0.2)  # 128 channels, dropout 0.2
        ]


    # Initialize the model with the specified configurations
    model = MultiInputModel(
            model_type=model_type,
            input_dim_cip=input_dim_cip,
            input_dim_cm=input_dim_cm,
            input_type=input_type,
            layer_configs_cip=layer_configs_cip,
            layer_configs_cm=layer_configs_cm)

    # Save model specifications to a JSON file
    spec_path = os.path.join(save_path, 'model_spec.json')
    specs = {
        "model_type": model_type,
        "input_dim_cip": input_dim_cip,
        "input_dim_cm": input_dim_cm,
        "input_type": input_type,
        "layer_configs_cip": layer_configs_cip,
        "layer_configs_cm": layer_configs_cm
    }
    with open(spec_path, 'w') as f:
        json.dump(specs, f, indent=4)

    return model

# Hyperparameters, info loading

In [None]:
# Example usage
base_folder = './Data1/'
target = 'target_HOMO' #'target_log_abs_HOMO' 'target_HOMO'
input_type='CM' #'CIP' CM' 'Multi'
test_material=None
SEED=SEED
epsilon=1e-5
model_type='CNN'
input_dim_cip=158
input_dim_cm=316
epochs=1000
lr=0.001
batch_size=64 #256
shuffle=True
log_path='./logs'

# Switch between input types ('CIP', 'CM', 'multi')
train_data, test_data = get_data(base_folder, target, input_type,
                                 test_material, SEED, epsilon) # test_material=None 'DMSO1'

train_data.head(5)

# You can call this function after you load your train and test data
analyze_target_distribution(train_data, test_data , target)

# Model Training

In [None]:
# Create the unique folder for logging
folder_name, folder_path = create_unique_folder(log_path, model_type, input_type, target, test_material,
                                   epochs, lr, batch_size, SEED)

# Create the model
model = initialize_model(model_type, input_dim_cip, input_dim_cm, input_type, folder_path)
print(model)

optimizer = Adam(model.parameters(), lr)
criterion = torch.nn.MSELoss()


# Move model to GPU
model = model.to(device)
# Prepare data loader
train_graphs = [load_graph_or_image_data(row, target, input_type, model_type) for _, row in train_data.iterrows()]
test_graphs = [load_graph_or_image_data(row, target, input_type, model_type) for _, row in test_data.iterrows()]
train_loader = DataLoader(train_graphs, batch_size, shuffle, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_graphs, batch_size, shuffle, num_workers=4, pin_memory=True)

# Train the model
train_model(train_loader, test_loader, model, optimizer, criterion, model_type, input_type,
            epochs, folder_name, folder_path)

# Loading Model and Post-training Analysis

In [None]:
model_path= './dir/model.pt'

test_material=None
epsilon=1e-5
features = extract_features_from_path(model_path)
train_data, test_data = get_data(base_folder, features['target'], features['input_type'], 
                                 test_material, features['SEED'], epsilon)
train_graphs = [load_graph_or_image_data(row, features['target'], features['input_type'], features['model_type']) for _, row in train_data.iterrows()]
test_graphs = [load_graph_or_image_data(row, features['target'], features['input_type'], features['model_type']) for _, row in test_data.iterrows()]
train_loader = DataLoader(train_graphs, batch_size, shuffle, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_graphs, batch_size, shuffle, num_workers=4, pin_memory=True)

print('Results on Train Set:')
plot_predictions(train_loader, model_path, epsilon)
print('\n Results on Test Set:')
plot_predictions(test_loader, model_path, epsilon)