In [1]:
from utils import *
from good_bad_teacher import *

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
def test_run():
    import torch
    import logging
    logging.basicConfig(level=logging.INFO)
    logger = logging.getLogger(__name__)
    
    try:
        # Initialize configuration
        logger.info("Initializing configuration...")
        config_manager = ConfigManager()
        config = config_manager.config
        
        # Initialize model manager
        logger.info("Initializing model manager...")
        model_manager = ModelManager(config)
        
        # Try loading just one model first as a test
        logger.info("Testing model loading...")
        test_model = model_manager._load_model()
        logger.info("Successfully loaded test model")
        
        # If the test model loaded successfully, proceed with the rest
        logger.info("Initializing data manager...")
        data_manager = DataManager(config)
        
        # Load and prepare data
        logger.info("Loading data...")
        retain_train_loader, forget_train_loader, retain_val_loader, forget_val_loader = data_manager.create_dataloaders(batch_size=8)
        
        logger.info("Data loading completed successfully!")
        
        # Initialize all models
        logger.info("Initializing all models...")
        good_teacher, bad_teacher = model_manager.initialize_teachers()
        student = model_manager.initialize_student()
        
        logger.info("Models initialized successfully!")
        
        logger.info("Test completed successfully!")
        
        
    except Exception as e:
        logger.error(f"Error occurred: {str(e)}")
        logger.error(f"Error type: {type(e)}")
        import traceback
        logger.error(f"Traceback: {traceback.format_exc()}")
        raise e

In [4]:
test_run()

INFO:__main__:Initializing configuration...
INFO:__main__:Initializing model manager...
INFO:__main__:Testing model loading...
Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  4.71it/s]
INFO:__main__:Successfully loaded test model
INFO:__main__:Initializing data manager...
INFO:__main__:Loading data...
INFO:__main__:Data loading completed successfully!
INFO:__main__:Initializing all models...


Initializing good teacher...


Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  4.73it/s]


Initializing bad teacher...


INFO:accelerate.utils.modeling:We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).


Initializing student...


Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  4.74it/s]
INFO:__main__:Models initialized successfully!
INFO:__main__:Test completed successfully!


In [5]:
import torch
import logging
import traceback
from good_bad_teacher import (
    ConfigManager, 
    DataManager, 
    ModelManager,
    TeacherStudentUnlearning
)

def test_model_steps():
    """
    Test single forward pass and training step for the Good-Bad Teacher implementation.
    """
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s'
    )
    logger = logging.getLogger(__name__)

    try:
        # Initialize components
        logger.info("Initializing test components...")
        config_manager = ConfigManager()
        config = config_manager.config
        data_manager = DataManager(config)
        model_manager = ModelManager(config)

        # Get dataloaders
        retain_loader, forget_loader, _, _ = data_manager.create_dataloaders(batch_size=2)
        
        # Initialize models
        good_teacher, bad_teacher = model_manager.initialize_teachers()
        student = model_manager.initialize_student()
        
        # Set up device
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        logger.info(f"Using device: {device}")
        
        #######################
        # 1. Test Forward Pass
        #######################
        logger.info("\nSTEP 1: Testing forward pass...")
        try:
            # Move models to device
            good_teacher = good_teacher.to(device)
            bad_teacher = bad_teacher.to(device)
            student = student.to(device)
            
            # Set models to eval mode
            good_teacher.eval()
            bad_teacher.eval()
            student.eval()
            
            # Get a batch
            batch = next(iter(retain_loader))
            batch = {k: v.to(device) for k, v in batch.items()}
            
            # Test forward pass for each model
            with torch.no_grad():
                # Good teacher forward pass
                good_teacher_output = good_teacher(**batch)
                logger.info("\nGood Teacher forward pass results:")
                logger.info(f"Output keys: {good_teacher_output.keys()}")
                logger.info(f"Logits shape: {good_teacher_output.logits.shape}")
                logger.info(f"Loss value: {good_teacher_output.loss.item():.4f}")
                
                # Bad teacher forward pass
                bad_teacher_output = bad_teacher(**batch)
                logger.info("\nBad Teacher forward pass results:")
                logger.info(f"Output keys: {bad_teacher_output.keys()}")
                logger.info(f"Logits shape: {bad_teacher_output.logits.shape}")
                logger.info(f"Loss value: {bad_teacher_output.loss.item():.4f}")
                
                # Student forward pass
                student_output = student(**batch)
                logger.info("\nStudent forward pass results:")
                logger.info(f"Output keys: {student_output.keys()}")
                logger.info(f"Logits shape: {student_output.logits.shape}")
                logger.info(f"Loss value: {student_output.loss.item():.4f}")
            
            logger.info("✓ Forward passes completed successfully")
        
        except Exception as e:
            logger.error("✗ Error in forward pass testing:")
            logger.error(traceback.format_exc())
            raise

        #######################
        # 2. Test Training Step
        #######################
        logger.info("\nSTEP 2: Testing training step...")
        try:
            # Initialize unlearning system
            unlearning_system = TeacherStudentUnlearning(
                good_teacher=good_teacher,
                bad_teacher=bad_teacher,
                student=student,
                config=config
            )
            
            # Set student to training mode
            student.train()
            
            # Keep teachers in eval mode
            good_teacher.eval()
            bad_teacher.eval()
            
            # Initialize optimizer
            optimizer = torch.optim.Adam(
                student.parameters(),
                lr=config['training']['learning_rate']
            )
            
            # Test retain step
            logger.info("\nTesting retain step:")
            batch = next(iter(retain_loader))
            batch = {k: v.to(device) for k, v in batch.items()}
            
            # Record initial loss
            with torch.no_grad():
                initial_output = student(**batch)
                initial_loss = initial_output.loss.item()
            
            # Perform optimization step
            optimizer.zero_grad()
            retain_loss = unlearning_system.calculate_retain_loss(
                student(**batch).logits,
                good_teacher(**batch).logits,
                batch
            )
            retain_loss.backward()
            optimizer.step()
            
            # Record final loss
            with torch.no_grad():
                final_output = student(**batch)
                final_loss = final_output.loss.item()
            
            logger.info(f"Initial loss: {initial_loss:.4f}")
            logger.info(f"Final loss: {final_loss:.4f}")
            logger.info(f"Loss change: {initial_loss - final_loss:.4f}")
            
            # Test forget step
            logger.info("\nTesting forget step:")
            batch = next(iter(forget_loader))
            batch = {k: v.to(device) for k, v in batch.items()}
            
            # Record initial loss
            with torch.no_grad():
                initial_output = student(**batch)
                initial_loss = initial_output.loss.item()
            
            # Perform optimization step
            optimizer.zero_grad()
            forget_loss = unlearning_system.calculate_forget_loss(
                student(**batch).logits,
                bad_teacher(**batch).logits,
                batch
            )
            forget_loss.backward()
            optimizer.step()
            
            # Record final loss
            with torch.no_grad():
                final_output = student(**batch)
                final_loss = final_output.loss.item()
            
            logger.info(f"Initial loss: {initial_loss:.4f}")
            logger.info(f"Final loss: {final_loss:.4f}")
            logger.info(f"Loss change: {initial_loss - final_loss:.4f}")
            
            logger.info("✓ Training steps completed successfully")
            
        except Exception as e:
            logger.error("✗ Error in training step testing:")
            logger.error(traceback.format_exc())
            raise

        logger.info("\n🎉 All tests completed successfully!")
        return True

    except Exception as e:
        logger.error("\n❌ Testing failed with error:")
        logger.error(traceback.format_exc())
        return False

if __name__ == "__main__":
    success = test_model_steps()
    print("\n✨ Final Result:")
    print("🎉 All tests passed successfully!" if success else "❌ Tests failed. Check the logs above for details.")

INFO:__main__:Initializing test components...


Initializing good teacher...


Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  4.66it/s]
INFO:accelerate.utils.modeling:We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).


Initializing bad teacher...
Initializing student...


Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  4.71it/s]
INFO:__main__:Using device: cuda
INFO:__main__:
STEP 1: Testing forward pass...
INFO:__main__:
Good Teacher forward pass results:
INFO:__main__:Output keys: odict_keys(['loss', 'logits', 'past_key_values'])
INFO:__main__:Logits shape: torch.Size([2, 512, 50304])
INFO:__main__:Loss value: 21.6610
INFO:__main__:
Bad Teacher forward pass results:
INFO:__main__:Output keys: odict_keys(['loss', 'logits', 'past_key_values'])
INFO:__main__:Logits shape: torch.Size([2, 512, 50304])
INFO:__main__:Loss value: 24.7344
INFO:__main__:
Student forward pass results:
INFO:__main__:Output keys: odict_keys(['loss', 'logits', 'past_key_values'])
INFO:__main__:Logits shape: torch.Size([2, 512, 50304])
INFO:__main__:Loss value: 21.6610
INFO:__main__:✓ Forward passes completed successfully
INFO:__main__:
STEP 2: Testing training step...
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-cor

INFO:__main__:
Testing retain step:
INFO:__main__:Initial loss: 21.8941
INFO:__main__:Final loss: 21.5921
INFO:__main__:Loss change: 0.3020
INFO:__main__:
Testing forget step:
INFO:__main__:Initial loss: 20.1088
INFO:__main__:Final loss: 20.4955
INFO:__main__:Loss change: -0.3867
INFO:__main__:✓ Training steps completed successfully
INFO:__main__:
🎉 All tests completed successfully!



✨ Final Result:
🎉 All tests passed successfully!


In [15]:
config = ConfigManager()
m = config.config['model']['good_teacher']['path']

model = AutoModelForCausalLM.from_pretrained(m)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 12.75it/s]


In [19]:
# At the start of your script, test this:
from transformers import AutoTokenizer

test_path = "/data1/malto/unlearning_llm/models/semeval25-unlearning-model-1B-model"
print("Testing direct tokenizer load from:", test_path)
test_tokenizer = AutoTokenizer.from_pretrained(test_path, trust_remote_code=True)

Testing direct tokenizer load from: /data1/malto/unlearning_llm/models/semeval25-unlearning-model-1B-model


OSError: Can't load tokenizer for '/data1/malto/unlearning_llm/models/semeval25-unlearning-model-1B-model'. If you were trying to load it from 'https://huggingface.co/models', make sure you don't have a local directory with the same name. Otherwise, make sure '/data1/malto/unlearning_llm/models/semeval25-unlearning-model-1B-model' is the correct path to a directory containing all relevant files for a GPTNeoXTokenizerFast tokenizer.

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
path = "/data1/malto/unlearning_llm/"

## Fetch and load model:
model_path = path + 'models/semeval25-unlearning-model-1B-model'
print(f"Loading model from {model_path}")
#snapshot_download(repo_id='llmunlearningsemeval2025organization/olmo-finetuned-semeval25-unlearning', token=hf_token, local_dir=model_path)
model = AutoModelForCausalLM.from_pretrained(model_path)

Loading model from /data1/malto/unlearning_llm/models/semeval25-unlearning-model-1B-model


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  8.88it/s]


In [16]:
config_manager = ConfigManager()
config = config_manager.config
data_manager = DataManager(config)
retain_loader, forget_loader, _, _ = data_manager.create_dataloaders(batch_size=2)
len(retain_loader)
first_batch = next(iter(retain_loader))
print("Shape of input_ids:", first_batch['input_ids'].shape)
print("Shape of attention_mask:", first_batch['attention_mask'].shape)
print("Shape of labels:", first_batch['labels'].shape)

# Assuming 'first_batch' is the batch from your DataLoader
input_ids = first_batch['input_ids']  # Tensor of input_ids

# Print out sequence lengths for each item in the batch
print("Sequence lengths in 'input_ids':")
for idx, seq in enumerate(input_ids):
    seq_length = (seq != 1).sum().item()  # Count non-padding tokens (assuming padding token=1)
    print(f"Sequence {idx + 1}: Length = {seq_length}")


Shape of input_ids: torch.Size([2, 512])
Shape of attention_mask: torch.Size([2, 512])
Shape of labels: torch.Size([2, 512])
Sequence lengths in 'input_ids':
Sequence 1: Length = 15
Sequence 2: Length = 73


In [None]:
from good_bad_teacher import (
    ConfigManager, 
    DataManager, 
    ModelManager,
    TeacherStudentUnlearning
)
import torch
import logging
import os
from datetime import datetime

def setup_logging(output_dir):
    """Set up logging configuration"""
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
        
    log_file = os.path.join(output_dir, f'training_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log')
    
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s',
        handlers=[
            logging.FileHandler(log_file),
            logging.StreamHandler()
        ]
    )
    return logging.getLogger(__name__)

def train_and_evaluate():
    # Initialize configuration
    config_manager = ConfigManager()
    config = config_manager.config
    
    # Set up output directory and logging
    output_dir = os.path.join(config['checkpoints_dir'], datetime.now().strftime("%Y%m%d_%H%M%S"))
    logger = setup_logging(output_dir)
    
    try:
        # Initialize components
        logger.info("Initializing components...")
        data_manager = DataManager(config)
        model_manager = ModelManager(config)
        
        # Load data
        logger.info("Loading data...")
        retain_train_loader, forget_train_loader, retain_val_loader, forget_val_loader = \
            data_manager.create_dataloaders(batch_size=config['training'].get('batch_size', 8))
        
        # Initialize models
        logger.info("Initializing models...")
        good_teacher, bad_teacher = model_manager.initialize_teachers()
        student = model_manager.initialize_student()
        
        # Freeze teachers
        logger.info("Freezing teacher models...")
        model_manager.freeze_teachers(good_teacher, bad_teacher)
        
        # Initialize unlearning system
        logger.info("Setting up unlearning system...")
        unlearning_system = TeacherStudentUnlearning(
            good_teacher=good_teacher,
            bad_teacher=bad_teacher,
            student=student,
            config=config
        )
        
        # Training
        logger.info("Starting training...")
        training_history = unlearning_system.train_student(
            retain_loader=retain_train_loader,
            forget_loader=forget_train_loader,
            validation_loader=retain_val_loader,
            num_epochs=config['training']['num_epochs']
        )
        
        # Final evaluation
        logger.info("Performing final evaluation...")
        final_results = unlearning_system.evaluate(
            retain_loader=retain_val_loader,
            forget_loader=forget_val_loader
        )
        
        # Save final model and results
        final_checkpoint_path = os.path.join(output_dir, 'final_model.pt')
        logger.info(f"Saving final model to {final_checkpoint_path}")
        unlearning_system.save_checkpoint(
            path=final_checkpoint_path,
            epoch=config['training']['num_epochs'],
            val_loss=training_history['val_losses'][-1],
            metrics=final_results
        )
        
        return True, training_history, final_results
        
    except Exception as e:
        logger.error(f"Training failed with error: {str(e)}")
        logger.error(f"Error type: {type(e)}")
        import traceback
        logger.error(f"Traceback: {traceback.format_exc()}")
        return False, None, None

if __name__ == "__main__":
    success, history, results = train_and_evaluate()
    
    if success:
        print("\n✨ Training completed successfully!")
        print("\nTraining History:")
        for key, value in history.items():
            print(f"{key}: {value}")
        print("\nFinal Results:")
        for key, value in results.items():
            print(f"{key}: {value}")
    else:
        print("\n❌ Training failed. Check the logs for details.")

In [8]:
config_manager = ConfigManager()
config = config_manager.config
config["loss"]["gamma"] = 0
config["retain"]["gamma"] = 0
config["forget"]["gamma"] = 0

config

{'model': {'good_teacher': {'path': '/data1/malto/unlearning_llm/models/semeval25-unlearning-model-1B-model',
   'type': 'base_model'},
  'bad_teacher': {'model_id': 'EleutherAI/pythia-70m',
   'torch_dtype': 'float16'}},
 'training': {'num_epochs': 10,
  'learning_rate': 1e-05,
  'checkpoint_frequency': 5,
  'validation_frequency': 1},
 'checkpoints_dir': '/path/to/checkpoints',
 'validation': {'batch_size': 32,
  'metrics': ['perplexity', 'agreement', 'divergence']},
 'retain': {'alpha': 1.0, 'gamma': 0},
 'forget': {'beta': 0.5, 'gamma': 0},
 'loss': {'alpha': 1.0, 'beta': 0.5, 'gamma': 0},
 'data': {'base_path': '/data1/malto/unlearning_llm/datasets/semeval25-unlearning-data',
  'max_length': 512}}

In [2]:
def main():
    # Initialize configuration
    config_manager = ConfigManager()
    config = config_manager.config
    config["loss"]["gamma"] = 0
    # Initialize wandb
    wandb.init(
        project="llm-unlearning",
        config=config
    )
    
    model_manager = ModelManager(config)
    data_manager = DataManager(config)
    
    # Load and prepare data
    retain_train_loader, forget_train_loader, retain_val_loader, forget_val_loader = data_manager.create_dataloaders(batch_size=8)
    
    # Initialize models
    good_teacher, bad_teacher = model_manager.initialize_teachers()
    student = model_manager.initialize_student()
    
    # Freeze teachers
    model_manager.freeze_teachers(good_teacher, bad_teacher)
    
    # Initialize unlearning system
    unlearning = TeacherStudentUnlearning(good_teacher, bad_teacher, student, config)
    
    # Train student
    training_history = unlearning.train_student(retain_train_loader, forget_train_loader, retain_val_loader, config['training']['num_epochs'])
    
    # Final evaluation
    final_results = unlearning.evaluate(retain_val_loader, forget_val_loader)
    
    # Log final results summary
    wandb.run.summary.update({
        "final_retain_regurgitation": np.mean(final_results['retain']['regurgitation-score']),
        "final_retain_knowledge": np.mean(final_results['retain']['knowledge-score']),
        "final_forget_regurgitation": np.mean(final_results['forget']['regurgitation-score']),
        "final_forget_knowledge": np.mean(final_results['forget']['knowledge-score']),
        "training_epochs": config['training']['num_epochs'],
        "best_validation_loss": min(filter(None, training_history['val_losses']))
    })

    # Save final results to a file
    with open(os.path.join(config['checkpoints_dir'], 'final_results.json'), 'w') as f:
        json.dump(final_results, f, indent=4)
        
    # Log final results file as artifact
    final_results_artifact = wandb.Artifact(
        "final_results", 
        type="results",
        description="Final evaluation results"
    )
    final_results_artifact.add_file(os.path.join(config['checkpoints_dir'], 'final_results.json'))
    wandb.log_artifact(final_results_artifact)
    
    # Save config
    config_manager.save_config("config.json")
    
    # Clean up wandb
    unlearning.finish()

main()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33merfanbyt13[0m ([33merfanbyt13-politecnico-di-torino[0m). Use [1m`wandb login --relogin`[0m to force relogin


Initializing good teacher...


Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  4.60it/s]


Initializing bad teacher...
Initializing student...


Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  4.70it/s]


Freezing teacher models...

Epoch 1/2


Training Retain: 100%|██████████| 142/142 [01:59<00:00,  1.19it/s]
Training Forget: 100%|██████████| 139/139 [01:33<00:00,  1.49it/s]


Running validation at epoch 1
Validation Loss: 8.0211
Validation Metrics: {'perplexity': 3068.4190220424107, 'good_teacher_agreement': 0.11278599330357143, 'bad_teacher_divergence': 0.42303757497242517}
Retain Loss: 2038.6643
Forget Loss: 418.2347
Validation Loss: 8.0211
Validation Metrics: {'perplexity': 3068.4190220424107, 'good_teacher_agreement': 0.11278599330357143, 'bad_teacher_divergence': 0.42303757497242517}

Epoch 2/2


Training Retain: 100%|██████████| 142/142 [02:04<00:00,  1.14it/s]
Training Forget: 100%|██████████| 139/139 [01:35<00:00,  1.46it/s]


Running validation at epoch 2
Validation Loss: 7.9593
Validation Metrics: {'perplexity': 2886.243540736607, 'good_teacher_agreement': 0.11278599330357143, 'bad_teacher_divergence': 0.42303757497242517}
Retain Loss: 1821.0704
Forget Loss: 397.6603
Validation Loss: 7.9593
Validation Metrics: {'perplexity': 2886.243540736607, 'good_teacher_agreement': 0.11278599330357143, 'bad_teacher_divergence': 0.42303757497242517}

Evaluating on retain data...


Retain Evaluation:   0%|          | 0/35 [00:00<?, ?it/s]


KeyError: 'max_new_tokens'