In [1]:
from utils import *
from good_bad_teacher import *

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
def test_run():
    import torch
    import logging
    logging.basicConfig(level=logging.INFO)
    logger = logging.getLogger(__name__)
    
    try:
        # Initialize configuration
        logger.info("Initializing configuration...")
        config_manager = ConfigManager()
        config = config_manager.config
        
        # Initialize model manager
        logger.info("Initializing model manager...")
        model_manager = ModelManager(config)
        
        # Try loading just one model first as a test
        logger.info("Testing model loading...")
        test_model = model_manager._load_model()
        logger.info("Successfully loaded test model")
        
        # If the test model loaded successfully, proceed with the rest
        logger.info("Initializing data manager...")
        data_manager = DataManager(config)
        
        # Load and prepare data
        logger.info("Loading data...")
        retain_loader, forget_loader = data_manager.create_dataloaders()
        
        logger.info("Data loading completed successfully!")
        
        # Initialize all models
        logger.info("Initializing all models...")
        good_teacher, bad_teacher = model_manager.initialize_teachers()
        student = model_manager.initialize_student()
        
        logger.info("Models initialized successfully!")
        
        # Continue with the rest of the process...
        logger.info("Starting training process...")
        
        # Train teachers
        logger.info("Training teachers...")
        teacher_trainer = TeacherTrainer(config)
        teacher_trainer.train_good_teacher(good_teacher, retain_loader)
        teacher_trainer.train_bad_teacher(bad_teacher, forget_loader)
        
        # Freeze teachers
        logger.info("Freezing teachers...")
        model_manager.freeze_teachers(good_teacher, bad_teacher)
        
        # Initialize unlearning system
        logger.info("Initializing unlearning system...")
        unlearning = TeacherStudentUnlearning(good_teacher, bad_teacher, student, config)
        
        # Train student
        logger.info("Training student...")
        unlearning.train(retain_loader, config['training']['num_epochs'])
        
        logger.info("Test completed successfully!")
        
        
    except Exception as e:
        logger.error(f"Error occurred: {str(e)}")
        logger.error(f"Error type: {type(e)}")
        import traceback
        logger.error(f"Traceback: {traceback.format_exc()}")
        raise e

In [3]:
test_run()

INFO:__main__:Initializing configuration...
INFO:__main__:Initializing model manager...
INFO:__main__:Testing model loading...
Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  3.87it/s]
INFO:__main__:Successfully loaded test model
INFO:__main__:Initializing data manager...
Generating train split: 100%|██████████| 3600/3600 [00:00<00:00, 1003022.08 examples/s]
Generating train split: 100%|██████████| 40/40 [00:00<00:00, 40108.09 examples/s]
INFO:__main__:Loading data...
INFO:__main__:Data loading completed successfully!
INFO:__main__:Initializing all models...


Initializing good teacher...


Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  3.96it/s]


Initializing bad teacher...


Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.98it/s]


Initializing student...


Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  3.18it/s]
INFO:__main__:Models initialized successfully!
