In [1]:
from utils import *
from good_bad_teacher import *

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def test_run():
    import torch
    import logging
    logging.basicConfig(level=logging.INFO)
    logger = logging.getLogger(__name__)
    
    try:
        # Initialize configuration
        logger.info("Initializing configuration...")
        config_manager = ConfigManager()
        config = config_manager.config
        
        # Initialize model manager
        logger.info("Initializing model manager...")
        model_manager = ModelManager(config)
        
        # Try loading just one model first as a test
        logger.info("Testing model loading...")
        test_model = model_manager._load_model()
        logger.info("Successfully loaded test model")
        
        # If the test model loaded successfully, proceed with the rest
        logger.info("Initializing data manager...")
        data_manager = DataManager(config)
        
        # Load and prepare data
        logger.info("Loading data...")
        retain_train_loader, forget_train_loader, retain_val_loader, forget_val_loader = data_manager.create_dataloaders(batch_size=8)
        
        logger.info("Data loading completed successfully!")
        
        # Initialize all models
        logger.info("Initializing all models...")
        good_teacher, bad_teacher = model_manager.initialize_teachers()
        student = model_manager.initialize_student()
        
        logger.info("Models initialized successfully!")
        
        logger.info("Test completed successfully!")
        
        
    except Exception as e:
        logger.error(f"Error occurred: {str(e)}")
        logger.error(f"Error type: {type(e)}")
        import traceback
        logger.error(f"Traceback: {traceback.format_exc()}")
        raise e

In [3]:
test_run()

INFO:__main__:Initializing configuration...
INFO:__main__:Initializing model manager...
INFO:__main__:Testing model loading...
Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  4.91it/s]
INFO:__main__:Successfully loaded test model
INFO:__main__:Initializing data manager...
INFO:__main__:Loading data...
INFO:__main__:Data loading completed successfully!
INFO:__main__:Initializing all models...


Initializing good teacher...


Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  4.92it/s]


Initializing bad teacher...


INFO:accelerate.utils.modeling:We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).


Initializing student...


Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  4.90it/s]
INFO:__main__:Models initialized successfully!
INFO:__main__:Test completed successfully!


In [4]:
import torch
import logging
import traceback
from good_bad_teacher import (
    ConfigManager, 
    DataManager, 
    ModelManager,
    TeacherStudentUnlearning
)

def test_model_steps():
    """
    Test single forward pass and training step for the Good-Bad Teacher implementation.
    """
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s'
    )
    logger = logging.getLogger(__name__)

    try:
        # Initialize components
        logger.info("Initializing test components...")
        config_manager = ConfigManager()
        config = config_manager.config
        data_manager = DataManager(config)
        model_manager = ModelManager(config)

        # Get dataloaders
        retain_loader, forget_loader, _, _ = data_manager.create_dataloaders(batch_size=2)
        
        # Initialize models
        good_teacher, bad_teacher = model_manager.initialize_teachers()
        student = model_manager.initialize_student()
        
        # Set up device
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        logger.info(f"Using device: {device}")
        
        #######################
        # 1. Test Forward Pass
        #######################
        logger.info("\nSTEP 1: Testing forward pass...")
        try:
            # Move models to device
            good_teacher = good_teacher.to(device)
            bad_teacher = bad_teacher.to(device)
            student = student.to(device)
            
            # Set models to eval mode
            good_teacher.eval()
            bad_teacher.eval()
            student.eval()
            
            # Get a batch
            batch = next(iter(retain_loader))
            batch = {k: v.to(device) for k, v in batch.items()}
            
            # Test forward pass for each model
            with torch.no_grad():
                # Good teacher forward pass
                good_teacher_output = good_teacher(**batch)
                logger.info("\nGood Teacher forward pass results:")
                logger.info(f"Output keys: {good_teacher_output.keys()}")
                logger.info(f"Logits shape: {good_teacher_output.logits.shape}")
                logger.info(f"Loss value: {good_teacher_output.loss.item():.4f}")
                
                # Bad teacher forward pass
                bad_teacher_output = bad_teacher(**batch)
                logger.info("\nBad Teacher forward pass results:")
                logger.info(f"Output keys: {bad_teacher_output.keys()}")
                logger.info(f"Logits shape: {bad_teacher_output.logits.shape}")
                logger.info(f"Loss value: {bad_teacher_output.loss.item():.4f}")
                
                # Student forward pass
                student_output = student(**batch)
                logger.info("\nStudent forward pass results:")
                logger.info(f"Output keys: {student_output.keys()}")
                logger.info(f"Logits shape: {student_output.logits.shape}")
                logger.info(f"Loss value: {student_output.loss.item():.4f}")
            
            logger.info("✓ Forward passes completed successfully")
        
        except Exception as e:
            logger.error("✗ Error in forward pass testing:")
            logger.error(traceback.format_exc())
            raise

        #######################
        # 2. Test Training Step
        #######################
        logger.info("\nSTEP 2: Testing training step...")
        try:
            # Initialize unlearning system
            unlearning_system = TeacherStudentUnlearning(
                good_teacher=good_teacher,
                bad_teacher=bad_teacher,
                student=student,
                config=config
            )
            
            # Set student to training mode
            student.train()
            
            # Keep teachers in eval mode
            good_teacher.eval()
            bad_teacher.eval()
            
            # Initialize optimizer
            optimizer = torch.optim.Adam(
                student.parameters(),
                lr=config['training']['learning_rate']
            )
            
            # Test retain step
            logger.info("\nTesting retain step:")
            batch = next(iter(retain_loader))
            batch = {k: v.to(device) for k, v in batch.items()}
            
            # Record initial loss
            with torch.no_grad():
                initial_output = student(**batch)
                initial_loss = initial_output.loss.item()
            
            # Perform optimization step
            optimizer.zero_grad()
            retain_loss = unlearning_system.calculate_retain_loss(
                student(**batch).logits,
                good_teacher(**batch).logits,
                batch
            )
            retain_loss.backward()
            optimizer.step()
            
            # Record final loss
            with torch.no_grad():
                final_output = student(**batch)
                final_loss = final_output.loss.item()
            
            logger.info(f"Initial loss: {initial_loss:.4f}")
            logger.info(f"Final loss: {final_loss:.4f}")
            logger.info(f"Loss change: {initial_loss - final_loss:.4f}")
            
            # Test forget step
            logger.info("\nTesting forget step:")
            batch = next(iter(forget_loader))
            batch = {k: v.to(device) for k, v in batch.items()}
            
            # Record initial loss
            with torch.no_grad():
                initial_output = student(**batch)
                initial_loss = initial_output.loss.item()
            
            # Perform optimization step
            optimizer.zero_grad()
            forget_loss = unlearning_system.calculate_forget_loss(
                student(**batch).logits,
                bad_teacher(**batch).logits,
                batch
            )
            forget_loss.backward()
            optimizer.step()
            
            # Record final loss
            with torch.no_grad():
                final_output = student(**batch)
                final_loss = final_output.loss.item()
            
            logger.info(f"Initial loss: {initial_loss:.4f}")
            logger.info(f"Final loss: {final_loss:.4f}")
            logger.info(f"Loss change: {initial_loss - final_loss:.4f}")
            
            logger.info("✓ Training steps completed successfully")
            
        except Exception as e:
            logger.error("✗ Error in training step testing:")
            logger.error(traceback.format_exc())
            raise

        logger.info("\n🎉 All tests completed successfully!")
        return True

    except Exception as e:
        logger.error("\n❌ Testing failed with error:")
        logger.error(traceback.format_exc())
        return False

if __name__ == "__main__":
    success = test_model_steps()
    print("\n✨ Final Result:")
    print("🎉 All tests passed successfully!" if success else "❌ Tests failed. Check the logs above for details.")

INFO:__main__:Initializing test components...


Initializing good teacher...


Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  4.83it/s]
INFO:accelerate.utils.modeling:We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).


Initializing bad teacher...
Initializing student...


Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  4.91it/s]
INFO:__main__:Using device: cuda
INFO:__main__:
STEP 1: Testing forward pass...
INFO:__main__:
Good Teacher forward pass results:
INFO:__main__:Output keys: odict_keys(['loss', 'logits', 'past_key_values'])
INFO:__main__:Logits shape: torch.Size([2, 512, 50304])
INFO:__main__:Loss value: 20.0908
INFO:__main__:
Bad Teacher forward pass results:
INFO:__main__:Output keys: odict_keys(['loss', 'logits', 'past_key_values'])
INFO:__main__:Logits shape: torch.Size([2, 512, 50304])
INFO:__main__:Loss value: 16.7031
INFO:__main__:
Student forward pass results:
INFO:__main__:Output keys: odict_keys(['loss', 'logits', 'past_key_values'])
INFO:__main__:Logits shape: torch.Size([2, 512, 50304])
INFO:__main__:Loss value: 20.0908
INFO:__main__:✓ Forward passes completed successfully
INFO:__main__:
STEP 2: Testing training step...
INFO:__main__:
Testing retain step:
INFO:__main__:Initial loss: 22.4500
INFO:__main__:Final loss: 21


✨ Final Result:
🎉 All tests passed successfully!


In [15]:
config = ConfigManager()
m = config.config['model']['good_teacher']['path']

model = AutoModelForCausalLM.from_pretrained(m)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 12.75it/s]


In [19]:
# At the start of your script, test this:
from transformers import AutoTokenizer

test_path = "/data1/malto/unlearning_llm/models/semeval25-unlearning-model-1B-model"
print("Testing direct tokenizer load from:", test_path)
test_tokenizer = AutoTokenizer.from_pretrained(test_path, trust_remote_code=True)

Testing direct tokenizer load from: /data1/malto/unlearning_llm/models/semeval25-unlearning-model-1B-model


OSError: Can't load tokenizer for '/data1/malto/unlearning_llm/models/semeval25-unlearning-model-1B-model'. If you were trying to load it from 'https://huggingface.co/models', make sure you don't have a local directory with the same name. Otherwise, make sure '/data1/malto/unlearning_llm/models/semeval25-unlearning-model-1B-model' is the correct path to a directory containing all relevant files for a GPTNeoXTokenizerFast tokenizer.

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
path = "/data1/malto/unlearning_llm/"

## Fetch and load model:
model_path = path + 'models/semeval25-unlearning-model-1B-model'
print(f"Loading model from {model_path}")
#snapshot_download(repo_id='llmunlearningsemeval2025organization/olmo-finetuned-semeval25-unlearning', token=hf_token, local_dir=model_path)
model = AutoModelForCausalLM.from_pretrained(model_path)

Loading model from /data1/malto/unlearning_llm/models/semeval25-unlearning-model-1B-model


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  8.88it/s]
