In [6]:
import os
import logging
import torch
import torch.nn as nn
from torch.utils.data import random_split
from torch_geometric.loader import DataLoader
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

# Import your existing model and dataset classes
from Train2_12 import ExperimentalGNN, SpinSystemDataset, PhysicalScaleAwareLoss

# Fine-tuning configuration
FINETUNE_CONFIG = {
    'pretrained_model_path': 'best_model_rung1_6_pre.pth',
    'processed_dir_larger': './processed_experimentalrung7-8_10k_r6',
    'processed_file_name': 'data.pt',
    'batch_size': 128,
    'learning_rate': 0.5e-4,
    'weight_decay': 1.5e-4,
    'num_epochs': 200,
    'patience': 50,
    'finetuned_model_path': 'finetuned_model.pth',
    'dropout_p': 0.3,
    'grad_clip': 0.5,
    'random_seed': 42
}

def setup_logging():
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s [%(levelname)s] %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S'
    )

def fine_tune_model():
    setup_logging()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Load the pretrained model
    model = ExperimentalGNN(
        hidden_channels=512,
        dropout_p=FINETUNE_CONFIG['dropout_p']
    ).to(device)
    
    # Load pretrained weights
    pretrained_state_dict = torch.load(FINETUNE_CONFIG['pretrained_model_path'], map_location=device)
    model.load_state_dict(pretrained_state_dict)
    logging.info("Loaded pretrained model successfully")

    # Load the new dataset with larger system sizes
    dataset = SpinSystemDataset(root=FINETUNE_CONFIG['processed_dir_larger'])
    
    # Split dataset
    train_size = int(0.8 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = random_split(
        dataset, 
        [train_size, val_size],
        generator=torch.Generator().manual_seed(FINETUNE_CONFIG['random_seed'])
    )

    train_loader = DataLoader(train_dataset, batch_size=FINETUNE_CONFIG['batch_size'], shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=FINETUNE_CONFIG['batch_size'], shuffle=False)

    # Initialize loss and optimizer
    criterion = PhysicalScaleAwareLoss(physics_weight=0.5)
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=FINETUNE_CONFIG['learning_rate'],
        weight_decay=FINETUNE_CONFIG['weight_decay']
    )

    scheduler = CosineAnnealingWarmRestarts(
        optimizer,
        T_0=20,
        T_mult=2,
        eta_min=1e-7
    )

    # Training loop
    best_val_loss = float('inf')
    patience_counter = 0
    
    for epoch in range(FINETUNE_CONFIG['num_epochs']):
        # Training phase
        model.train()
        total_train_loss = 0
        train_mae = 0
        total_train_samples = 0
        
        for data in train_loader:
            data = data.to(device)
            optimizer.zero_grad()
        
            pred_s = model(data)
            targets = data.y.squeeze().to(device)
            system_size = data.system_size.squeeze(-1).to(device)
            subsystem_size = data.nA.squeeze(-1).to(device)
            
            loss = criterion(pred_s, targets, system_size, subsystem_size)
            loss.backward()
            
            if FINETUNE_CONFIG['grad_clip'] is not None:
                nn.utils.clip_grad_norm_(model.parameters(), FINETUNE_CONFIG['grad_clip'])
            
            optimizer.step()
            
            # Calculate MAE for this batch
            mae = torch.abs(pred_s - targets).sum().item()
            train_mae += mae
            total_train_samples += data.num_graphs
            total_train_loss += loss.item() * data.num_graphs

        avg_train_loss = total_train_loss / len(train_dataset)
        avg_train_mae = train_mae / total_train_samples
        
        # Validation phase
        model.eval()
        total_val_loss = 0
        val_mae = 0
        total_val_samples = 0
        all_val_preds = []
        all_val_targets = []
        
        with torch.no_grad():
            for data in val_loader:
                data = data.to(device)
                pred_s = model(data)
                targets = data.y.squeeze().to(device)  # Added .to(device)
                system_size = data.system_size.squeeze(-1).to(device)  # Added .to(device)
                subsystem_size = data.nA.squeeze(-1).to(device)  # Added .to(device)
                
                loss = criterion(pred_s, targets, system_size, subsystem_size)
                total_val_loss += loss.item() * data.num_graphs
                
                # Calculate MAE for this batch
                mae = torch.abs(pred_s - targets).sum().item()
                val_mae += mae
                total_val_samples += data.num_graphs
                
                # Store CPU tensors for numpy conversion
                all_val_preds.extend(pred_s.cpu().numpy())
                all_val_targets.extend(targets.cpu().numpy())

        avg_val_loss = total_val_loss / len(val_dataset)
        avg_val_mae = val_mae / total_val_samples
        scheduler.step()

        logging.info(f'Epoch {epoch+1}/{FINETUNE_CONFIG["num_epochs"]}:')
        logging.info(f'  Training Loss: {avg_train_loss:.6f}')
        logging.info(f'  Training MAE: {avg_train_mae:.6f}')
        logging.info(f'  Validation Loss: {avg_val_loss:.6f}')
        logging.info(f'  Validation MAE: {avg_val_mae:.6f}')

        # Save best model and early stopping
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), FINETUNE_CONFIG['finetuned_model_path'])
            logging.info(f'  Saved new best model (val_loss={best_val_loss:.6f})')
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= FINETUNE_CONFIG['patience']:
                logging.info('Early stopping triggered')
                break

    logging.info('Fine-tuning completed')
    logging.info(f'Best validation loss: {best_val_loss:.6f}')

if __name__ == "__main__":
    fine_tune_model()

2025-02-25 09:08:36 [INFO] Loaded pretrained model successfully
2025-02-25 09:08:49 [INFO] Epoch 1/200:
2025-02-25 09:08:49 [INFO]   Training Loss: 0.002567
2025-02-25 09:08:49 [INFO]   Training MAE: 0.041783
2025-02-25 09:08:49 [INFO]   Validation Loss: 0.000933
2025-02-25 09:08:49 [INFO]   Validation MAE: 0.024295
2025-02-25 09:08:50 [INFO]   Saved new best model (val_loss=0.000933)
2025-02-25 09:08:59 [INFO] Epoch 2/200:
2025-02-25 09:08:59 [INFO]   Training Loss: 0.001146
2025-02-25 09:08:59 [INFO]   Training MAE: 0.028328
2025-02-25 09:08:59 [INFO]   Validation Loss: 0.000707
2025-02-25 09:08:59 [INFO]   Validation MAE: 0.020978
2025-02-25 09:09:00 [INFO]   Saved new best model (val_loss=0.000707)
2025-02-25 09:09:09 [INFO] Epoch 3/200:
2025-02-25 09:09:09 [INFO]   Training Loss: 0.000948
2025-02-25 09:09:09 [INFO]   Training MAE: 0.026270
2025-02-25 09:09:09 [INFO]   Validation Loss: 0.000655
2025-02-25 09:09:09 [INFO]   Validation MAE: 0.020442
2025-02-25 09:09:09 [INFO]   Saved

In [7]:
import os
import logging
import glob
import torch
import torch.nn as nn
from torch.utils.data import random_split, ConcatDataset
from torch_geometric.loader import DataLoader
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

# Import your existing model and dataset classes
from Train2_12 import ExperimentalGNN, SpinSystemDataset, PhysicalScaleAwareLoss

# Fine-tuning configuration
FINETUNE_CONFIG = {
    'pretrained_model_path': 'best_model_rung1_6_pre.pth',
    'data_dirs': [
        './processed_experimentalrung7-8_10k_r6',
        './processed_experimentalrung7-8_10k_r6_2'
    ],  # List of directories containing .pt files
    'processed_file_pattern': 'data*.pt',  # Pattern to match multiple files
    'alternative_file_paths': [
        # Add direct paths to specific .pt files if needed
        # './processed_data/data_v1.pt',
        # './processed_data/data_v2.pt'
    ],
    'batch_size': 128,
    'learning_rate': 0.5e-4,
    'weight_decay': 1.5e-4,
    'num_epochs': 200,
    'patience': 50,
    'finetuned_model_path': 'finetuned_model.pth',
    'dropout_p': 0.3,
    'grad_clip': 0.5,
    'random_seed': 42,
    'verbose_logging': True  # Set to True for detailed debug information
}

def setup_logging():
    level = logging.DEBUG if FINETUNE_CONFIG.get('verbose_logging', False) else logging.INFO
    logging.basicConfig(
        level=level,
        format='%(asctime)s [%(levelname)s] %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S'
    )
    
    # Create file handler
    file_handler = logging.FileHandler('finetuning.log')
    file_handler.setLevel(level)
    file_handler.setFormatter(logging.Formatter('%(asctime)s [%(levelname)s] %(message)s'))
    
    # Add file handler to root logger
    logging.getLogger('').addHandler(file_handler)
    
    logging.info("Logging initialized")

class DirectPTFileDataset(torch.utils.data.Dataset):
    """Dataset that loads directly from specified .pt files"""
    def __init__(self, file_paths):
        self.file_paths = file_paths
        logging.info(f"Attempting to load {len(file_paths)} PT files directly")
        
        # Load all data from these files
        self.data_list = []
        
        for file_path in file_paths:
            if not os.path.exists(file_path):
                logging.error(f"File does not exist: {file_path}")
                continue
                
            try:
                logging.info(f"Loading file: {file_path}")
                data_obj = torch.load(file_path)
                
                if isinstance(data_obj, list):
                    logging.info(f"Loaded list of {len(data_obj)} objects from {file_path}")
                    self.data_list.extend(data_obj)
                elif hasattr(data_obj, 'x'):
                    logging.info(f"Loaded single data object from {file_path}")
                    self.data_list.append(data_obj)
                else:
                    logging.warning(f"Unrecognized data format in {file_path}")
            except Exception as e:
                logging.error(f"Error loading {file_path}: {str(e)}", exc_info=True)
    
    def __len__(self):
        return len(self.data_list)
    
    def __getitem__(self, idx):
        return self.data_list[idx]

class CustomSpinSystemDataset(SpinSystemDataset):
    """Extended version of SpinSystemDataset with better error handling"""
    def __init__(self, root, transform=None, pre_transform=None):
        try:
            # Check if the path exists
            if not os.path.exists(root):
                logging.error(f"Directory does not exist: {root}")
                raise FileNotFoundError(f"Directory does not exist: {root}")
                
            # Try to initialize with original SpinSystemDataset
            super(CustomSpinSystemDataset, self).__init__(root=root, transform=transform, pre_transform=pre_transform)
            
            # Look for processed directory and check its contents
            processed_dir = os.path.join(root, 'processed')
            if os.path.exists(processed_dir):
                logging.info(f"Processed directory exists: {processed_dir}")
                processed_files = os.listdir(processed_dir)
                logging.info(f"Files in processed directory: {processed_files}")
            else:
                logging.warning(f"No 'processed' directory found in {root}")
                
        except Exception as e:
            logging.error(f"Error initializing CustomSpinSystemDataset: {str(e)}", exc_info=True)
            raise

    def download(self):
        # Override to avoid download attempts
        pass
        
    def process(self):
        # Override to avoid processing attempts if files already exist
        processed_file_path = os.path.join(self.processed_dir, 'data.pt')
        if os.path.exists(processed_file_path):
            logging.info(f"Processed file already exists: {processed_file_path}")
        else:
            logging.warning(f"Processed file does not exist: {processed_file_path}")
            # We would need to implement the processing logic here if needed
            raise FileNotFoundError(f"Required processed file not found: {processed_file_path}")
            
    def _download(self):
        # Override internal method
        pass
        
    def _process(self):
        # Override internal method if files exist
        if not os.path.exists(os.path.join(self.processed_dir, 'data.pt')):
            logging.error(f"Processed file not found at {os.path.join(self.processed_dir, 'data.pt')}")
            raise FileNotFoundError(f"Required processed file not found")

    
    def _merge_data(self, data1, slices1, data2, slices2):
        """Merge two datasets together"""
        # Create new data object with combined attributes
        merged_data = data1.__class__()
        
        # Combine all attributes from both data objects
        for key in data1.keys:
            # Get the attribute from both datasets
            item1, item2 = getattr(data1, key), getattr(data2, key)
            
            # Concatenate the attributes
            if torch.is_tensor(item1) and torch.is_tensor(item2):
                merged_attr = torch.cat([item1, item2], dim=data1.__cat_dim__(key, item1))
            else:
                merged_attr = item1 + item2  # For non-tensor attributes like edge_index
                
            setattr(merged_data, key, merged_attr)
            
        # Update the slices for the merged data
        merged_slices = {}
        for key in slices1.keys():
            if key in slices2:
                # Get the current maximum index from the first slice
                offset = slices1[key][-1]
                
                # Add this offset to all indices in the second slice (except the first one)
                second_slice_shifted = slices2[key][1:] + offset
                
                # Combine the slices, keeping only one copy of the overlapping index
                merged_slice = torch.cat([slices1[key], second_slice_shifted])
                merged_slices[key] = merged_slice
        
        return merged_data, merged_slices

def load_multi_datasets():
    """Load multiple datasets from different directories"""
    datasets = []
    
    # First check all directories for processed/data.pt files (PyG default)
    for data_dir in FINETUNE_CONFIG['data_dirs']:
        # Check for processed/data.pt (standard PyG dataset structure)
        pyg_file_path = os.path.join(data_dir, 'processed', 'data.pt')
        if os.path.exists(pyg_file_path):
            logging.info(f"Found PyG dataset file: {pyg_file_path}")
    
    # Check for pattern-matched files within the directories
    for data_dir in FINETUNE_CONFIG['data_dirs']:
        if os.path.exists(data_dir):
            # First check in the main directory
            pattern = os.path.join(data_dir, FINETUNE_CONFIG['processed_file_pattern'])
            main_files = glob.glob(pattern)
            
            # Then check in the processed subdirectory
            processed_pattern = os.path.join(data_dir, 'processed', FINETUNE_CONFIG['processed_file_pattern'])
            processed_files = glob.glob(processed_pattern)
            
            all_files = main_files + processed_files
            logging.info(f"Found {len(all_files)} files in {data_dir} matching the pattern")
            for file in all_files:
                logging.info(f"  - {file}")
    
    # Attempt each loading method
    # 1. Try loading as standard PyG SpinSystemDataset from each directory
    for data_dir in FINETUNE_CONFIG['data_dirs']:
        if not os.path.exists(data_dir):
            logging.error(f"Directory does not exist: {data_dir}")
            continue
            
        try:
            logging.info(f"Attempting to load PyG dataset from {data_dir}")
            dataset = CustomSpinSystemDataset(root=data_dir)
            datasets.append(dataset)
            logging.info(f"Successfully loaded PyG dataset from {data_dir} with {len(dataset)} samples")
        except Exception as e:
            logging.warning(f"Could not load PyG dataset from {data_dir}: {str(e)}")
            
            # Check for data.pt in the processed directory
            processed_file = os.path.join(data_dir, 'processed', 'data.pt')
            if os.path.exists(processed_file):
                try:
                    logging.info(f"Attempting to load direct data from {processed_file}")
                    # Try to load this specific file directly
                    direct_dataset = DirectPTFileDataset([processed_file])
                    if len(direct_dataset) > 0:
                        datasets.append(direct_dataset)
                        logging.info(f"Loaded {len(direct_dataset)} samples directly from {processed_file}")
                except Exception as direct_e:
                    logging.error(f"Failed to load direct file {processed_file}: {str(direct_e)}")
    
    # 2. Try loading from additional specified PT files
    if hasattr(FINETUNE_CONFIG, 'alternative_file_paths') and FINETUNE_CONFIG['alternative_file_paths']:
        direct_dataset = DirectPTFileDataset(FINETUNE_CONFIG['alternative_file_paths'])
        if len(direct_dataset) > 0:
            datasets.append(direct_dataset)
            logging.info(f"Loaded {len(direct_dataset)} samples from specified PT files")
    
    # 3. If all else fails, request the proper file path
    if not datasets:
        logging.error("""
        No datasets could be loaded from the specified directories.
        
        Please check the following:
        1. Verify that your data directories exist
        2. Check that PyG dataset files are in a 'processed/data.pt' path
        3. Try specifying direct paths to PT files in 'alternative_file_paths'
        """)
        
        # Get user input for data file path
        print("\nNo datasets could be loaded from the specified directories.")
        print("Please enter the path to a PyTorch Geometric dataset file (data.pt):")
        file_path = input("Path: ").strip()
        
        if os.path.exists(file_path):
            try:
                # Try to load it as a direct PT file
                direct_dataset = DirectPTFileDataset([file_path])
                if len(direct_dataset) > 0:
                    datasets.append(direct_dataset)
                    logging.info(f"Loaded {len(direct_dataset)} samples from user-specified {file_path}")
            except Exception as e:
                logging.error(f"Failed to load user-specified file: {str(e)}")
                
                # Last resort - try to find parent directory and load as PyG dataset
                parent_dir = os.path.dirname(os.path.dirname(file_path))
                try:
                    dataset = CustomSpinSystemDataset(root=parent_dir)
                    datasets.append(dataset)
                    logging.info(f"Loaded PyG dataset from {parent_dir} with {len(dataset)} samples")
                except Exception as pe:
                    logging.error(f"Failed to load from parent directory {parent_dir}: {str(pe)}")
                    raise ValueError("No datasets could be loaded. Please check your data files.")
        else:
            raise ValueError(f"Specified file does not exist: {file_path}")
    
    # Combine all datasets
    combined_dataset = ConcatDataset(datasets) if len(datasets) > 0 else None
    if combined_dataset:
        logging.info(f"Combined dataset contains {len(combined_dataset)} samples total")
    else:
        raise ValueError("No datasets could be loaded")
    
    return combined_dataset

def fine_tune_model():
    setup_logging()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logging.info(f"Using device: {device}")
    
    # Print out current working directory and available files for debugging
    logging.info(f"Current working directory: {os.getcwd()}")
    for data_dir in FINETUNE_CONFIG['data_dirs']:
        if os.path.exists(data_dir):
            logging.info(f"Directory {data_dir} exists")
            files = os.listdir(data_dir)
            logging.info(f"Files in {data_dir}: {files}")
            # Check if processed_dir exists
            processed_dir = os.path.join(data_dir, 'processed')
            if os.path.exists(processed_dir):
                logging.info(f"Processed directory exists: {processed_dir}")
                processed_files = os.listdir(processed_dir)
                logging.info(f"Files in processed dir: {processed_files}")
        else:
            logging.error(f"Directory {data_dir} does not exist!")
    
    # Load the pretrained model
    try:
        model = ExperimentalGNN(
            hidden_channels=512,
            dropout_p=FINETUNE_CONFIG['dropout_p']
        ).to(device)
        
        # Check if pretrained model file exists
        if not os.path.exists(FINETUNE_CONFIG['pretrained_model_path']):
            logging.error(f"Pretrained model file not found: {FINETUNE_CONFIG['pretrained_model_path']}")
            raise FileNotFoundError(f"Pretrained model file not found: {FINETUNE_CONFIG['pretrained_model_path']}")
            
        # Load pretrained weights
        pretrained_state_dict = torch.load(FINETUNE_CONFIG['pretrained_model_path'], map_location=device)
        model.load_state_dict(pretrained_state_dict)
        logging.info("Loaded pretrained model successfully")
    except Exception as e:
        logging.error(f"Error loading model: {str(e)}", exc_info=True)
        raise

    # Load datasets from multiple directories
    combined_dataset = load_multi_datasets()
    
    # Split dataset
    train_size = int(0.8 * len(combined_dataset))
    val_size = len(combined_dataset) - train_size
    train_dataset, val_dataset = random_split(
        combined_dataset, 
        [train_size, val_size],
        generator=torch.Generator().manual_seed(FINETUNE_CONFIG['random_seed'])
    )

    train_loader = DataLoader(train_dataset, batch_size=FINETUNE_CONFIG['batch_size'], shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=FINETUNE_CONFIG['batch_size'], shuffle=False)

    # Initialize loss and optimizer
    criterion = PhysicalScaleAwareLoss(physics_weight=0.5)
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=FINETUNE_CONFIG['learning_rate'],
        weight_decay=FINETUNE_CONFIG['weight_decay']
    )

    scheduler = CosineAnnealingWarmRestarts(
        optimizer,
        T_0=20,
        T_mult=2,
        eta_min=1e-7
    )

    # Training loop
    best_val_loss = float('inf')
    patience_counter = 0
    
    for epoch in range(FINETUNE_CONFIG['num_epochs']):
        # Training phase
        model.train()
        total_train_loss = 0
        train_mae = 0
        total_train_samples = 0
        
        for data in train_loader:
            data = data.to(device)
            optimizer.zero_grad()
        
            pred_s = model(data)
            targets = data.y.squeeze().to(device)
            system_size = data.system_size.squeeze(-1).to(device)
            subsystem_size = data.nA.squeeze(-1).to(device)
            
            loss = criterion(pred_s, targets, system_size, subsystem_size)
            loss.backward()
            
            if FINETUNE_CONFIG['grad_clip'] is not None:
                nn.utils.clip_grad_norm_(model.parameters(), FINETUNE_CONFIG['grad_clip'])
            
            optimizer.step()
            
            # Calculate MAE for this batch
            mae = torch.abs(pred_s - targets).sum().item()
            train_mae += mae
            total_train_samples += data.num_graphs
            total_train_loss += loss.item() * data.num_graphs

        avg_train_loss = total_train_loss / len(train_dataset)
        avg_train_mae = train_mae / total_train_samples
        
        # Validation phase
        model.eval()
        total_val_loss = 0
        val_mae = 0
        total_val_samples = 0
        all_val_preds = []
        all_val_targets = []
        
        with torch.no_grad():
            for data in val_loader:
                data = data.to(device)
                pred_s = model(data)
                targets = data.y.squeeze().to(device)
                system_size = data.system_size.squeeze(-1).to(device)
                subsystem_size = data.nA.squeeze(-1).to(device)
                
                loss = criterion(pred_s, targets, system_size, subsystem_size)
                total_val_loss += loss.item() * data.num_graphs
                
                # Calculate MAE for this batch
                mae = torch.abs(pred_s - targets).sum().item()
                val_mae += mae
                total_val_samples += data.num_graphs
                
                # Store CPU tensors for numpy conversion
                all_val_preds.extend(pred_s.cpu().numpy())
                all_val_targets.extend(targets.cpu().numpy())

        avg_val_loss = total_val_loss / len(val_dataset)
        avg_val_mae = val_mae / total_val_samples
        scheduler.step()

        logging.info(f'Epoch {epoch+1}/{FINETUNE_CONFIG["num_epochs"]}:')
        logging.info(f'  Training Loss: {avg_train_loss:.6f}')
        logging.info(f'  Training MAE: {avg_train_mae:.6f}')
        logging.info(f'  Validation Loss: {avg_val_loss:.6f}')
        logging.info(f'  Validation MAE: {avg_val_mae:.6f}')

        # Save best model and early stopping
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), FINETUNE_CONFIG['finetuned_model_path'])
            logging.info(f'  Saved new best model (val_loss={best_val_loss:.6f})')
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= FINETUNE_CONFIG['patience']:
                logging.info('Early stopping triggered')
                break

    logging.info('Fine-tuning completed')
    logging.info(f'Best validation loss: {best_val_loss:.6f}')

if __name__ == "__main__":
    fine_tune_model()

2025-02-25 09:41:14 [INFO] Logging initialized
2025-02-25 09:41:14 [INFO] Using device: cuda
2025-02-25 09:41:14 [INFO] Current working directory: C:\Users\amssa\Documents\Codes\New\Von-Neumann-Entropy-GNN\Size14
2025-02-25 09:41:14 [INFO] Directory ./processed_experimentalrung7-8_10k_r6 exists
2025-02-25 09:41:14 [INFO] Files in ./processed_experimentalrung7-8_10k_r6: ['processed', 'raw']
2025-02-25 09:41:14 [INFO] Processed directory exists: ./processed_experimentalrung7-8_10k_r6\processed
2025-02-25 09:41:14 [INFO] Files in processed dir: ['data.pt', 'pre_filter.pt', 'pre_transform.pt', 'processed', 'raw']
2025-02-25 09:41:14 [INFO] Directory ./processed_experimentalrung7-8_10k_r6_2 exists
2025-02-25 09:41:14 [INFO] Files in ./processed_experimentalrung7-8_10k_r6_2: ['processed', 'raw']
2025-02-25 09:41:14 [INFO] Processed directory exists: ./processed_experimentalrung7-8_10k_r6_2\processed
2025-02-25 09:41:14 [INFO] Files in processed dir: ['data.pt', 'pre_filter.pt', 'pre_transfor

In [1]:
import os
import logging
import glob
import torch
import torch.nn as nn
from torch.utils.data import random_split, ConcatDataset
from torch_geometric.loader import DataLoader
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

# Import your existing model and dataset classes
from Train2_12 import ExperimentalGNN, SpinSystemDataset, PhysicalScaleAwareLoss

# Fine-tuning configuration
FINETUNE_CONFIG = {
    'pretrained_model_path': 'best_model_rung1_6_pre.pth',
    'data_dirs': [
        './processed_experimentalrung7-8_10k_r6',
        './processed_experimentalrung7-8_10k_r6_2'
    ],  # List of directories containing .pt files
    'processed_file_pattern': 'data*.pt',  # Pattern to match multiple files
    'alternative_file_paths': [
        # Add direct paths to specific .pt files if needed
        # './processed_data/data_v1.pt',
        # './processed_data/data_v2.pt'
    ],
    'batch_size': 128,
    'learning_rate': 0.5e-4,
    'weight_decay': 1.5e-4,
    'num_epochs': 200,
    'patience': 50,
    'finetuned_model_path': 'finetuned_model.pth',
    'dropout_p': 0.3,
    'grad_clip': 0.5,
    'random_seed': 42,
    'verbose_logging': True  # Set to True for detailed debug information
}

def setup_logging():
    level = logging.DEBUG if FINETUNE_CONFIG.get('verbose_logging', False) else logging.INFO
    logging.basicConfig(
        level=level,
        format='%(asctime)s [%(levelname)s] %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S'
    )
    
    # Create file handler
    file_handler = logging.FileHandler('finetuning.log')
    file_handler.setLevel(level)
    file_handler.setFormatter(logging.Formatter('%(asctime)s [%(levelname)s] %(message)s'))
    
    # Add file handler to root logger
    logging.getLogger('').addHandler(file_handler)
    
    logging.info("Logging initialized")

class DirectPTFileDataset(torch.utils.data.Dataset):
    """Dataset that loads directly from specified .pt files"""
    def __init__(self, file_paths):
        self.file_paths = file_paths
        logging.info(f"Attempting to load {len(file_paths)} PT files directly")
        
        # Load all data from these files
        self.data_list = []
        
        for file_path in file_paths:
            if not os.path.exists(file_path):
                logging.error(f"File does not exist: {file_path}")
                continue
                
            try:
                logging.info(f"Loading file: {file_path}")
                data_obj = torch.load(file_path)
                
                if isinstance(data_obj, list):
                    logging.info(f"Loaded list of {len(data_obj)} objects from {file_path}")
                    self.data_list.extend(data_obj)
                elif hasattr(data_obj, 'x'):
                    logging.info(f"Loaded single data object from {file_path}")
                    self.data_list.append(data_obj)
                else:
                    logging.warning(f"Unrecognized data format in {file_path}")
            except Exception as e:
                logging.error(f"Error loading {file_path}: {str(e)}", exc_info=True)
    
    def __len__(self):
        return len(self.data_list)
    
    def __getitem__(self, idx):
        return self.data_list[idx]

class CustomSpinSystemDataset(SpinSystemDataset):
    """Extended version of SpinSystemDataset with better error handling"""
    def __init__(self, root, transform=None, pre_transform=None):
        try:
            # Check if the path exists
            if not os.path.exists(root):
                logging.error(f"Directory does not exist: {root}")
                raise FileNotFoundError(f"Directory does not exist: {root}")
                
            # Try to initialize with original SpinSystemDataset
            super(CustomSpinSystemDataset, self).__init__(root=root, transform=transform, pre_transform=pre_transform)
            
            # Look for processed directory and check its contents
            processed_dir = os.path.join(root, 'processed')
            if os.path.exists(processed_dir):
                logging.info(f"Processed directory exists: {processed_dir}")
                processed_files = os.listdir(processed_dir)
                logging.info(f"Files in processed directory: {processed_files}")
            else:
                logging.warning(f"No 'processed' directory found in {root}")
                
        except Exception as e:
            logging.error(f"Error initializing CustomSpinSystemDataset: {str(e)}", exc_info=True)
            raise

    def download(self):
        # Override to avoid download attempts
        pass
        
    def process(self):
        # Override to avoid processing attempts if files already exist
        processed_file_path = os.path.join(self.processed_dir, 'data.pt')
        if os.path.exists(processed_file_path):
            logging.info(f"Processed file already exists: {processed_file_path}")
        else:
            logging.warning(f"Processed file does not exist: {processed_file_path}")
            # We would need to implement the processing logic here if needed
            raise FileNotFoundError(f"Required processed file not found: {processed_file_path}")
            
    def _download(self):
        # Override internal method
        pass
        
    def _process(self):
        # Override internal method if files exist
        if not os.path.exists(os.path.join(self.processed_dir, 'data.pt')):
            logging.error(f"Processed file not found at {os.path.join(self.processed_dir, 'data.pt')}")
            raise FileNotFoundError(f"Required processed file not found")

    
    def _merge_data(self, data1, slices1, data2, slices2):
        """Merge two datasets together"""
        # Create new data object with combined attributes
        merged_data = data1.__class__()
        
        # Combine all attributes from both data objects
        for key in data1.keys:
            # Get the attribute from both datasets
            item1, item2 = getattr(data1, key), getattr(data2, key)
            
            # Concatenate the attributes
            if torch.is_tensor(item1) and torch.is_tensor(item2):
                merged_attr = torch.cat([item1, item2], dim=data1.__cat_dim__(key, item1))
            else:
                merged_attr = item1 + item2  # For non-tensor attributes like edge_index
                
            setattr(merged_data, key, merged_attr)
            
        # Update the slices for the merged data
        merged_slices = {}
        for key in slices1.keys():
            if key in slices2:
                # Get the current maximum index from the first slice
                offset = slices1[key][-1]
                
                # Add this offset to all indices in the second slice (except the first one)
                second_slice_shifted = slices2[key][1:] + offset
                
                # Combine the slices, keeping only one copy of the overlapping index
                merged_slice = torch.cat([slices1[key], second_slice_shifted])
                merged_slices[key] = merged_slice
        
        return merged_data, merged_slices

def load_multi_datasets():
    """Load multiple datasets from different directories"""
    datasets = []
    
    # First check all directories for processed/data.pt files (PyG default)
    for data_dir in FINETUNE_CONFIG['data_dirs']:
        # Check for processed/data.pt (standard PyG dataset structure)
        pyg_file_path = os.path.join(data_dir, 'processed', 'data.pt')
        if os.path.exists(pyg_file_path):
            logging.info(f"Found PyG dataset file: {pyg_file_path}")
    
    # Check for pattern-matched files within the directories
    for data_dir in FINETUNE_CONFIG['data_dirs']:
        if os.path.exists(data_dir):
            # First check in the main directory
            pattern = os.path.join(data_dir, FINETUNE_CONFIG['processed_file_pattern'])
            main_files = glob.glob(pattern)
            
            # Then check in the processed subdirectory
            processed_pattern = os.path.join(data_dir, 'processed', FINETUNE_CONFIG['processed_file_pattern'])
            processed_files = glob.glob(processed_pattern)
            
            all_files = main_files + processed_files
            logging.info(f"Found {len(all_files)} files in {data_dir} matching the pattern")
            for file in all_files:
                logging.info(f"  - {file}")
    
    # Attempt each loading method
    # 1. Try loading as standard PyG SpinSystemDataset from each directory
    for data_dir in FINETUNE_CONFIG['data_dirs']:
        if not os.path.exists(data_dir):
            logging.error(f"Directory does not exist: {data_dir}")
            continue
            
        try:
            logging.info(f"Attempting to load PyG dataset from {data_dir}")
            dataset = CustomSpinSystemDataset(root=data_dir)
            datasets.append(dataset)
            logging.info(f"Successfully loaded PyG dataset from {data_dir} with {len(dataset)} samples")
        except Exception as e:
            logging.warning(f"Could not load PyG dataset from {data_dir}: {str(e)}")
            
            # Check for data.pt in the processed directory
            processed_file = os.path.join(data_dir, 'processed', 'data.pt')
            if os.path.exists(processed_file):
                try:
                    logging.info(f"Attempting to load direct data from {processed_file}")
                    # Try to load this specific file directly
                    direct_dataset = DirectPTFileDataset([processed_file])
                    if len(direct_dataset) > 0:
                        datasets.append(direct_dataset)
                        logging.info(f"Loaded {len(direct_dataset)} samples directly from {processed_file}")
                except Exception as direct_e:
                    logging.error(f"Failed to load direct file {processed_file}: {str(direct_e)}")
    
    # 2. Try loading from additional specified PT files
    if hasattr(FINETUNE_CONFIG, 'alternative_file_paths') and FINETUNE_CONFIG['alternative_file_paths']:
        direct_dataset = DirectPTFileDataset(FINETUNE_CONFIG['alternative_file_paths'])
        if len(direct_dataset) > 0:
            datasets.append(direct_dataset)
            logging.info(f"Loaded {len(direct_dataset)} samples from specified PT files")
    
    # 3. If all else fails, request the proper file path
    if not datasets:
        logging.error("""
        No datasets could be loaded from the specified directories.
        
        Please check the following:
        1. Verify that your data directories exist
        2. Check that PyG dataset files are in a 'processed/data.pt' path
        3. Try specifying direct paths to PT files in 'alternative_file_paths'
        """)
        
        # Get user input for data file path
        print("\nNo datasets could be loaded from the specified directories.")
        print("Please enter the path to a PyTorch Geometric dataset file (data.pt):")
        file_path = input("Path: ").strip()
        
        if os.path.exists(file_path):
            try:
                # Try to load it as a direct PT file
                direct_dataset = DirectPTFileDataset([file_path])
                if len(direct_dataset) > 0:
                    datasets.append(direct_dataset)
                    logging.info(f"Loaded {len(direct_dataset)} samples from user-specified {file_path}")
            except Exception as e:
                logging.error(f"Failed to load user-specified file: {str(e)}")
                
                # Last resort - try to find parent directory and load as PyG dataset
                parent_dir = os.path.dirname(os.path.dirname(file_path))
                try:
                    dataset = CustomSpinSystemDataset(root=parent_dir)
                    datasets.append(dataset)
                    logging.info(f"Loaded PyG dataset from {parent_dir} with {len(dataset)} samples")
                except Exception as pe:
                    logging.error(f"Failed to load from parent directory {parent_dir}: {str(pe)}")
                    raise ValueError("No datasets could be loaded. Please check your data files.")
        else:
            raise ValueError(f"Specified file does not exist: {file_path}")
    
    # Combine all datasets
    combined_dataset = ConcatDataset(datasets) if len(datasets) > 0 else None
    if combined_dataset:
        logging.info(f"Combined dataset contains {len(combined_dataset)} samples total")
    else:
        raise ValueError("No datasets could be loaded")
    
    return combined_dataset

def fine_tune_model():
    setup_logging()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logging.info(f"Using device: {device}")
    
    # Print out current working directory and available files for debugging
    logging.info(f"Current working directory: {os.getcwd()}")
    for data_dir in FINETUNE_CONFIG['data_dirs']:
        if os.path.exists(data_dir):
            logging.info(f"Directory {data_dir} exists")
            files = os.listdir(data_dir)
            logging.info(f"Files in {data_dir}: {files}")
            # Check if processed_dir exists
            processed_dir = os.path.join(data_dir, 'processed')
            if os.path.exists(processed_dir):
                logging.info(f"Processed directory exists: {processed_dir}")
                processed_files = os.listdir(processed_dir)
                logging.info(f"Files in processed dir: {processed_files}")
        else:
            logging.error(f"Directory {data_dir} does not exist!")
    
    # Load the pretrained model
    try:
        model = ExperimentalGNN(
            hidden_channels=512,
            dropout_p=FINETUNE_CONFIG['dropout_p']
        ).to(device)
        
        # Check if pretrained model file exists
        if not os.path.exists(FINETUNE_CONFIG['pretrained_model_path']):
            logging.error(f"Pretrained model file not found: {FINETUNE_CONFIG['pretrained_model_path']}")
            raise FileNotFoundError(f"Pretrained model file not found: {FINETUNE_CONFIG['pretrained_model_path']}")
            
        # Load pretrained weights
        pretrained_state_dict = torch.load(FINETUNE_CONFIG['pretrained_model_path'], map_location=device)
        model.load_state_dict(pretrained_state_dict)
        logging.info("Loaded pretrained model successfully")
    except Exception as e:
        logging.error(f"Error loading model: {str(e)}", exc_info=True)
        raise

    # Load datasets from multiple directories
    combined_dataset = load_multi_datasets()
    
    # Split dataset
    train_size = int(0.8 * len(combined_dataset))
    val_size = len(combined_dataset) - train_size
    train_dataset, val_dataset = random_split(
        combined_dataset, 
        [train_size, val_size],
        generator=torch.Generator().manual_seed(FINETUNE_CONFIG['random_seed'])
    )

    train_loader = DataLoader(train_dataset, batch_size=FINETUNE_CONFIG['batch_size'], shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=FINETUNE_CONFIG['batch_size'], shuffle=False)

    # Initialize loss and optimizer
    criterion = PhysicalScaleAwareLoss(physics_weight=0.5)
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=FINETUNE_CONFIG['learning_rate'],
        weight_decay=FINETUNE_CONFIG['weight_decay']
    )

    scheduler = CosineAnnealingWarmRestarts(
        optimizer,
        T_0=20,
        T_mult=2,
        eta_min=1e-7
    )

    # Training loop
    best_val_loss = float('inf')
    patience_counter = 0
    
    for epoch in range(FINETUNE_CONFIG['num_epochs']):
        # Training phase
        model.train()
        total_train_loss = 0
        train_mae = 0
        total_train_samples = 0
        
        for data in train_loader:
            data = data.to(device)
            optimizer.zero_grad()
        
            pred_s = model(data)
            targets = data.y.squeeze().to(device)
            system_size = data.system_size.squeeze(-1).to(device)
            subsystem_size = data.nA.squeeze(-1).to(device)
            
            loss = criterion(pred_s, targets, system_size, subsystem_size)
            loss.backward()
            
            if FINETUNE_CONFIG['grad_clip'] is not None:
                nn.utils.clip_grad_norm_(model.parameters(), FINETUNE_CONFIG['grad_clip'])
            
            optimizer.step()
            
            # Calculate MAE for this batch
            mae = torch.abs(pred_s - targets).sum().item()
            train_mae += mae
            total_train_samples += data.num_graphs
            total_train_loss += loss.item() * data.num_graphs

        avg_train_loss = total_train_loss / len(train_dataset)
        avg_train_mae = train_mae / total_train_samples
        
        # Validation phase
        model.eval()
        total_val_loss = 0
        val_mae = 0
        total_val_samples = 0
        all_val_preds = []
        all_val_targets = []
        
        with torch.no_grad():
            for data in val_loader:
                data = data.to(device)
                pred_s = model(data)
                targets = data.y.squeeze().to(device)
                system_size = data.system_size.squeeze(-1).to(device)
                subsystem_size = data.nA.squeeze(-1).to(device)
                
                loss = criterion(pred_s, targets, system_size, subsystem_size)
                total_val_loss += loss.item() * data.num_graphs
                
                # Calculate MAE for this batch
                mae = torch.abs(pred_s - targets).sum().item()
                val_mae += mae
                total_val_samples += data.num_graphs
                
                # Store CPU tensors for numpy conversion
                all_val_preds.extend(pred_s.cpu().numpy())
                all_val_targets.extend(targets.cpu().numpy())

        avg_val_loss = total_val_loss / len(val_dataset)
        avg_val_mae = val_mae / total_val_samples
        scheduler.step()

        logging.info(f'Epoch {epoch+1}/{FINETUNE_CONFIG["num_epochs"]}:')
        logging.info(f'  Training Loss: {avg_train_loss:.6f}')
        logging.info(f'  Training MAE: {avg_train_mae:.6f}')
        logging.info(f'  Validation Loss: {avg_val_loss:.6f}')
        logging.info(f'  Validation MAE: {avg_val_mae:.6f}')

        # Save best model and early stopping
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), FINETUNE_CONFIG['finetuned_model_path'])
            logging.info(f'  Saved new best model (val_loss={best_val_loss:.6f})')
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= FINETUNE_CONFIG['patience']:
                logging.info('Early stopping triggered')
                break

    logging.info('Fine-tuning completed')
    logging.info(f'Best validation loss: {best_val_loss:.6f}')

if __name__ == "__main__":
    fine_tune_model()



ERROR! Intel® Extension for PyTorch* needs to work with PyTorch 2.3.*, but PyTorch 2.5.1 is found. Please switch to the matching version and run again.


  pretrained_state_dict = torch.load(FINETUNE_CONFIG['pretrained_model_path'], map_location=device)
2025-03-08 11:48:31 [INFO] Loaded pretrained model successfully
2025-03-08 11:48:32 [INFO] Pretrained model performance:
2025-03-08 11:48:32 [INFO]   Validation Loss: 0.006081
2025-03-08 11:48:32 [INFO]   Validation MAE: 0.066555
2025-03-08 11:48:33 [INFO] Validation plot saved to ./validation_plots9-10_from7-8\val_plot_pretrained.png
2025-03-08 11:48:47 [INFO] Epoch 1/200:
2025-03-08 11:48:47 [INFO]   Training Loss: 0.004260
2025-03-08 11:48:47 [INFO]   Training MAE: 0.058869
2025-03-08 11:48:47 [INFO]   Validation Loss: 0.002060
2025-03-08 11:48:47 [INFO]   Validation MAE: 0.037451
2025-03-08 11:48:48 [INFO] Validation plot saved to ./validation_plots9-10_from7-8\val_plot_epoch_1.png
2025-03-08 11:48:48 [INFO]   Saved new best model (val_loss=0.002060)
2025-03-08 11:48:48 [INFO] Validation plot saved to ./validation_plots9-10_from7-8\val_plot_best_model.png
2025-03-08 11:49:02 [INFO] E

In [2]:
import os
import logging
import glob
import torch
import torch.nn as nn
from torch.utils.data import random_split, ConcatDataset
from torch_geometric.loader import DataLoader
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

# Import your existing model and dataset classes
from Train2_12 import ExperimentalGNN, SpinSystemDataset, PhysicalScaleAwareLoss

# Fine-tuning configuration
FINETUNE_CONFIG = {
    'pretrained_model_path': 'finetuned_model_0.5lr_1.5wd.pth',
    'data_dirs': [
        './processed_9-10_times5_r6'
    ],  # List of directories containing .pt files
    'processed_file_pattern': 'data*.pt',  # Pattern to match multiple files
    'alternative_file_paths': [
        # Add direct paths to specific .pt files if needed
        # './processed_data/data_v1.pt',
        # './processed_data/data_v2.pt'
    ],
    'batch_size': 128,
    'learning_rate': 0.5e-4,
    'weight_decay': 1.5e-4,
    'num_epochs': 200,
    'patience': 50,
    'finetuned_model_path': 'finetuned_model_9-10_1.5k.pth',
    'dropout_p': 0.3,
    'grad_clip': 0.5,
    'random_seed': 42,
    'verbose_logging': True  # Set to True for detailed debug information
}

def setup_logging():
    level = logging.DEBUG if FINETUNE_CONFIG.get('verbose_logging', False) else logging.INFO
    logging.basicConfig(
        level=level,
        format='%(asctime)s [%(levelname)s] %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S'
    )
    
    # Create file handler
    file_handler = logging.FileHandler('finetuning.log')
    file_handler.setLevel(level)
    file_handler.setFormatter(logging.Formatter('%(asctime)s [%(levelname)s] %(message)s'))
    
    # Add file handler to root logger
    logging.getLogger('').addHandler(file_handler)
    
    logging.info("Logging initialized")

class DirectPTFileDataset(torch.utils.data.Dataset):
    """Dataset that loads directly from specified .pt files"""
    def __init__(self, file_paths):
        self.file_paths = file_paths
        logging.info(f"Attempting to load {len(file_paths)} PT files directly")
        
        # Load all data from these files
        self.data_list = []
        
        for file_path in file_paths:
            if not os.path.exists(file_path):
                logging.error(f"File does not exist: {file_path}")
                continue
                
            try:
                logging.info(f"Loading file: {file_path}")
                data_obj = torch.load(file_path)
                
                if isinstance(data_obj, list):
                    logging.info(f"Loaded list of {len(data_obj)} objects from {file_path}")
                    self.data_list.extend(data_obj)
                elif hasattr(data_obj, 'x'):
                    logging.info(f"Loaded single data object from {file_path}")
                    self.data_list.append(data_obj)
                else:
                    logging.warning(f"Unrecognized data format in {file_path}")
            except Exception as e:
                logging.error(f"Error loading {file_path}: {str(e)}", exc_info=True)
    
    def __len__(self):
        return len(self.data_list)
    
    def __getitem__(self, idx):
        return self.data_list[idx]

class CustomSpinSystemDataset(SpinSystemDataset):
    """Extended version of SpinSystemDataset with better error handling"""
    def __init__(self, root, transform=None, pre_transform=None):
        try:
            # Check if the path exists
            if not os.path.exists(root):
                logging.error(f"Directory does not exist: {root}")
                raise FileNotFoundError(f"Directory does not exist: {root}")
                
            # Try to initialize with original SpinSystemDataset
            super(CustomSpinSystemDataset, self).__init__(root=root, transform=transform, pre_transform=pre_transform)
            
            # Look for processed directory and check its contents
            processed_dir = os.path.join(root, 'processed')
            if os.path.exists(processed_dir):
                logging.info(f"Processed directory exists: {processed_dir}")
                processed_files = os.listdir(processed_dir)
                logging.info(f"Files in processed directory: {processed_files}")
            else:
                logging.warning(f"No 'processed' directory found in {root}")
                
        except Exception as e:
            logging.error(f"Error initializing CustomSpinSystemDataset: {str(e)}", exc_info=True)
            raise

    def download(self):
        # Override to avoid download attempts
        pass
        
    def process(self):
        # Override to avoid processing attempts if files already exist
        processed_file_path = os.path.join(self.processed_dir, 'data.pt')
        if os.path.exists(processed_file_path):
            logging.info(f"Processed file already exists: {processed_file_path}")
        else:
            logging.warning(f"Processed file does not exist: {processed_file_path}")
            # We would need to implement the processing logic here if needed
            raise FileNotFoundError(f"Required processed file not found: {processed_file_path}")
            
    def _download(self):
        # Override internal method
        pass
        
    def _process(self):
        # Override internal method if files exist
        if not os.path.exists(os.path.join(self.processed_dir, 'data.pt')):
            logging.error(f"Processed file not found at {os.path.join(self.processed_dir, 'data.pt')}")
            raise FileNotFoundError(f"Required processed file not found")

    
    def _merge_data(self, data1, slices1, data2, slices2):
        """Merge two datasets together"""
        # Create new data object with combined attributes
        merged_data = data1.__class__()
        
        # Combine all attributes from both data objects
        for key in data1.keys:
            # Get the attribute from both datasets
            item1, item2 = getattr(data1, key), getattr(data2, key)
            
            # Concatenate the attributes
            if torch.is_tensor(item1) and torch.is_tensor(item2):
                merged_attr = torch.cat([item1, item2], dim=data1.__cat_dim__(key, item1))
            else:
                merged_attr = item1 + item2  # For non-tensor attributes like edge_index
                
            setattr(merged_data, key, merged_attr)
            
        # Update the slices for the merged data
        merged_slices = {}
        for key in slices1.keys():
            if key in slices2:
                # Get the current maximum index from the first slice
                offset = slices1[key][-1]
                
                # Add this offset to all indices in the second slice (except the first one)
                second_slice_shifted = slices2[key][1:] + offset
                
                # Combine the slices, keeping only one copy of the overlapping index
                merged_slice = torch.cat([slices1[key], second_slice_shifted])
                merged_slices[key] = merged_slice
        
        return merged_data, merged_slices

def load_multi_datasets():
    """Load multiple datasets from different directories"""
    datasets = []
    
    # First check all directories for processed/data.pt files (PyG default)
    for data_dir in FINETUNE_CONFIG['data_dirs']:
        # Check for processed/data.pt (standard PyG dataset structure)
        pyg_file_path = os.path.join(data_dir, 'processed', 'data.pt')
        if os.path.exists(pyg_file_path):
            logging.info(f"Found PyG dataset file: {pyg_file_path}")
    
    # Check for pattern-matched files within the directories
    for data_dir in FINETUNE_CONFIG['data_dirs']:
        if os.path.exists(data_dir):
            # First check in the main directory
            pattern = os.path.join(data_dir, FINETUNE_CONFIG['processed_file_pattern'])
            main_files = glob.glob(pattern)
            
            # Then check in the processed subdirectory
            processed_pattern = os.path.join(data_dir, 'processed', FINETUNE_CONFIG['processed_file_pattern'])
            processed_files = glob.glob(processed_pattern)
            
            all_files = main_files + processed_files
            logging.info(f"Found {len(all_files)} files in {data_dir} matching the pattern")
            for file in all_files:
                logging.info(f"  - {file}")
    
    # Attempt each loading method
    # 1. Try loading as standard PyG SpinSystemDataset from each directory
    for data_dir in FINETUNE_CONFIG['data_dirs']:
        if not os.path.exists(data_dir):
            logging.error(f"Directory does not exist: {data_dir}")
            continue
            
        try:
            logging.info(f"Attempting to load PyG dataset from {data_dir}")
            dataset = CustomSpinSystemDataset(root=data_dir)
            datasets.append(dataset)
            logging.info(f"Successfully loaded PyG dataset from {data_dir} with {len(dataset)} samples")
        except Exception as e:
            logging.warning(f"Could not load PyG dataset from {data_dir}: {str(e)}")
            
            # Check for data.pt in the processed directory
            processed_file = os.path.join(data_dir, 'processed', 'data.pt')
            if os.path.exists(processed_file):
                try:
                    logging.info(f"Attempting to load direct data from {processed_file}")
                    # Try to load this specific file directly
                    direct_dataset = DirectPTFileDataset([processed_file])
                    if len(direct_dataset) > 0:
                        datasets.append(direct_dataset)
                        logging.info(f"Loaded {len(direct_dataset)} samples directly from {processed_file}")
                except Exception as direct_e:
                    logging.error(f"Failed to load direct file {processed_file}: {str(direct_e)}")
    
    # 2. Try loading from additional specified PT files
    if hasattr(FINETUNE_CONFIG, 'alternative_file_paths') and FINETUNE_CONFIG['alternative_file_paths']:
        direct_dataset = DirectPTFileDataset(FINETUNE_CONFIG['alternative_file_paths'])
        if len(direct_dataset) > 0:
            datasets.append(direct_dataset)
            logging.info(f"Loaded {len(direct_dataset)} samples from specified PT files")
    
    # 3. If all else fails, request the proper file path
    if not datasets:
        logging.error("""
        No datasets could be loaded from the specified directories.
        
        Please check the following:
        1. Verify that your data directories exist
        2. Check that PyG dataset files are in a 'processed/data.pt' path
        3. Try specifying direct paths to PT files in 'alternative_file_paths'
        """)
        
        # Get user input for data file path
        print("\nNo datasets could be loaded from the specified directories.")
        print("Please enter the path to a PyTorch Geometric dataset file (data.pt):")
        file_path = input("Path: ").strip()
        
        if os.path.exists(file_path):
            try:
                # Try to load it as a direct PT file
                direct_dataset = DirectPTFileDataset([file_path])
                if len(direct_dataset) > 0:
                    datasets.append(direct_dataset)
                    logging.info(f"Loaded {len(direct_dataset)} samples from user-specified {file_path}")
            except Exception as e:
                logging.error(f"Failed to load user-specified file: {str(e)}")
                
                # Last resort - try to find parent directory and load as PyG dataset
                parent_dir = os.path.dirname(os.path.dirname(file_path))
                try:
                    dataset = CustomSpinSystemDataset(root=parent_dir)
                    datasets.append(dataset)
                    logging.info(f"Loaded PyG dataset from {parent_dir} with {len(dataset)} samples")
                except Exception as pe:
                    logging.error(f"Failed to load from parent directory {parent_dir}: {str(pe)}")
                    raise ValueError("No datasets could be loaded. Please check your data files.")
        else:
            raise ValueError(f"Specified file does not exist: {file_path}")
    
    # Combine all datasets
    combined_dataset = ConcatDataset(datasets) if len(datasets) > 0 else None
    if combined_dataset:
        logging.info(f"Combined dataset contains {len(combined_dataset)} samples total")
    else:
        raise ValueError("No datasets could be loaded")
    
    return combined_dataset

def fine_tune_model():
    setup_logging()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logging.info(f"Using device: {device}")
    
    # Print out current working directory and available files for debugging
    logging.info(f"Current working directory: {os.getcwd()}")
    for data_dir in FINETUNE_CONFIG['data_dirs']:
        if os.path.exists(data_dir):
            logging.info(f"Directory {data_dir} exists")
            files = os.listdir(data_dir)
            logging.info(f"Files in {data_dir}: {files}")
            # Check if processed_dir exists
            processed_dir = os.path.join(data_dir, 'processed')
            if os.path.exists(processed_dir):
                logging.info(f"Processed directory exists: {processed_dir}")
                processed_files = os.listdir(processed_dir)
                logging.info(f"Files in processed dir: {processed_files}")
        else:
            logging.error(f"Directory {data_dir} does not exist!")
    
    # Load the pretrained model
    try:
        model = ExperimentalGNN(
            hidden_channels=512,
            dropout_p=FINETUNE_CONFIG['dropout_p']
        ).to(device)
        
        # Check if pretrained model file exists
        if not os.path.exists(FINETUNE_CONFIG['pretrained_model_path']):
            logging.error(f"Pretrained model file not found: {FINETUNE_CONFIG['pretrained_model_path']}")
            raise FileNotFoundError(f"Pretrained model file not found: {FINETUNE_CONFIG['pretrained_model_path']}")
            
        # Load pretrained weights
        pretrained_state_dict = torch.load(FINETUNE_CONFIG['pretrained_model_path'], map_location=device)
        model.load_state_dict(pretrained_state_dict)
        logging.info("Loaded pretrained model successfully")
    except Exception as e:
        logging.error(f"Error loading model: {str(e)}", exc_info=True)
        raise

    # Load datasets from multiple directories
    combined_dataset = load_multi_datasets()
    
    # Split dataset
    train_size = int(0.8 * len(combined_dataset))
    val_size = len(combined_dataset) - train_size
    train_dataset, val_dataset = random_split(
        combined_dataset, 
        [train_size, val_size],
        generator=torch.Generator().manual_seed(FINETUNE_CONFIG['random_seed'])
    )

    train_loader = DataLoader(train_dataset, batch_size=FINETUNE_CONFIG['batch_size'], shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=FINETUNE_CONFIG['batch_size'], shuffle=False)

    # Initialize loss and optimizer
    criterion = PhysicalScaleAwareLoss(physics_weight=0.5)
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=FINETUNE_CONFIG['learning_rate'],
        weight_decay=FINETUNE_CONFIG['weight_decay']
    )

    scheduler = CosineAnnealingWarmRestarts(
        optimizer,
        T_0=20,
        T_mult=2,
        eta_min=1e-7
    )

    # Training loop
    best_val_loss = float('inf')
    patience_counter = 0
    
    for epoch in range(FINETUNE_CONFIG['num_epochs']):
        # Training phase
        model.train()
        total_train_loss = 0
        train_mae = 0
        total_train_samples = 0
        
        for data in train_loader:
            data = data.to(device)
            optimizer.zero_grad()
        
            pred_s = model(data)
            targets = data.y.squeeze().to(device)
            system_size = data.system_size.squeeze(-1).to(device)
            subsystem_size = data.nA.squeeze(-1).to(device)
            
            loss = criterion(pred_s, targets, system_size, subsystem_size)
            loss.backward()
            
            if FINETUNE_CONFIG['grad_clip'] is not None:
                nn.utils.clip_grad_norm_(model.parameters(), FINETUNE_CONFIG['grad_clip'])
            
            optimizer.step()
            
            # Calculate MAE for this batch
            mae = torch.abs(pred_s - targets).sum().item()
            train_mae += mae
            total_train_samples += data.num_graphs
            total_train_loss += loss.item() * data.num_graphs

        avg_train_loss = total_train_loss / len(train_dataset)
        avg_train_mae = train_mae / total_train_samples
        
        # Validation phase
        model.eval()
        total_val_loss = 0
        val_mae = 0
        total_val_samples = 0
        all_val_preds = []
        all_val_targets = []
        
        with torch.no_grad():
            for data in val_loader:
                data = data.to(device)
                pred_s = model(data)
                targets = data.y.squeeze().to(device)
                system_size = data.system_size.squeeze(-1).to(device)
                subsystem_size = data.nA.squeeze(-1).to(device)
                
                loss = criterion(pred_s, targets, system_size, subsystem_size)
                total_val_loss += loss.item() * data.num_graphs
                
                # Calculate MAE for this batch
                mae = torch.abs(pred_s - targets).sum().item()
                val_mae += mae
                total_val_samples += data.num_graphs
                
                # Store CPU tensors for numpy conversion
                all_val_preds.extend(pred_s.cpu().numpy())
                all_val_targets.extend(targets.cpu().numpy())

        avg_val_loss = total_val_loss / len(val_dataset)
        avg_val_mae = val_mae / total_val_samples
        scheduler.step()

        logging.info(f'Epoch {epoch+1}/{FINETUNE_CONFIG["num_epochs"]}:')
        logging.info(f'  Training Loss: {avg_train_loss:.6f}')
        logging.info(f'  Training MAE: {avg_train_mae:.6f}')
        logging.info(f'  Validation Loss: {avg_val_loss:.6f}')
        logging.info(f'  Validation MAE: {avg_val_mae:.6f}')

        # Save best model and early stopping
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), FINETUNE_CONFIG['finetuned_model_path'])
            logging.info(f'  Saved new best model (val_loss={best_val_loss:.6f})')
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= FINETUNE_CONFIG['patience']:
                logging.info('Early stopping triggered')
                break

    logging.info('Fine-tuning completed')
    logging.info(f'Best validation loss: {best_val_loss:.6f}')

if __name__ == "__main__":
    fine_tune_model()

2025-03-10 22:38:54 [INFO] Logging initialized
2025-03-10 22:38:54 [INFO] Using device: cuda
2025-03-10 22:38:54 [INFO] Current working directory: C:\Users\amssa\Documents\Codes\New\Von-Neumann-Entropy-GNN\Size14
2025-03-10 22:38:54 [INFO] Directory ./processed_9-10_times5_r6 exists
2025-03-10 22:38:54 [INFO] Files in ./processed_9-10_times5_r6: ['processed', 'raw']
2025-03-10 22:38:54 [INFO] Processed directory exists: ./processed_9-10_times5_r6\processed
2025-03-10 22:38:54 [INFO] Files in processed dir: ['data.pt', 'pre_filter.pt', 'pre_transform.pt']
2025-03-10 22:38:55 [INFO] Loaded pretrained model successfully
2025-03-10 22:38:55 [INFO] Found PyG dataset file: ./processed_9-10_times5_r6\processed\data.pt
2025-03-10 22:38:55 [INFO] Found 1 files in ./processed_9-10_times5_r6 matching the pattern
2025-03-10 22:38:55 [INFO]   - ./processed_9-10_times5_r6\processed\data.pt
2025-03-10 22:38:55 [INFO] Attempting to load PyG dataset from ./processed_9-10_times5_r6
2025-03-10 22:38:55 [