# Adjust the architecutre of pre-train model to our model 

In [1]:
import os
import json
import logging
import warnings
from datetime import datetime
from pathlib import Path
import shutil

import torch
import torch.nn as nn
from transformers import (
    ViTImageProcessor, 
    ViTForImageClassification,
    ViTConfig
)

from tqdm.auto import tqdm
import numpy as np

# Suppress warnings for cleaner output
warnings.filterwarnings('ignore')

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('model_adjustment.log'),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

class ModelAdjustmentConfig:
    """Configuration class for model architecture adjustment"""
    
    def __init__(self):
        # Model paths
        self.base_model_path = '/Volumes/KODAK/folder 02/Brest_cancer_prediction/model/raw_model'
        self.output_model_path = '/Volumes/KODAK/folder 02/Brest_cancer_prediction/model/fine_tuning_model'
        self.adjusted_model_name = 'breast_cancer_vit_adjusted'
        
        # Task-specific parameters
        self.num_classes = 2  # Binary classification: cancer/no cancer
        self.task_name = 'breast_cancer_detection'
        self.class_names = ['normal', 'cancer']
        
        # Model configuration
        self.dropout_rate = 0.1
        self.label_smoothing = 0.1
        self.hidden_dropout_prob = 0.1
        self.attention_probs_dropout_prob = 0.1
        
        # Device configuration
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        logger.info(f"Configuration initialized for {self.task_name}")
        logger.info(f"Target classes: {self.num_classes} ({', '.join(self.class_names)})")
        logger.info(f"Device: {self.device}")

class ViTModelAdjuster:
    """Class for adjusting ViT model architecture for specific tasks"""
    
    def __init__(self, config: ModelAdjustmentConfig):
        self.config = config
        self.processor = None
        self.original_model = None
        self.adjusted_model = None
        self.model_config = None
        
    def create_output_directory(self):
        """Create output directory structure"""
        output_path = Path(self.config.output_model_path)
        output_path.mkdir(parents=True, exist_ok=True)
        
        adjusted_model_path = output_path / self.config.adjusted_model_name
        adjusted_model_path.mkdir(parents=True, exist_ok=True)
        
        logger.info(f"Created output directory: {adjusted_model_path}")
        return str(adjusted_model_path)
    
    def load_original_model(self):
        """Load the original pre-trained model"""
        try:
            logger.info("Loading original ViT model and processor...")
            
            # Check if local model exists
            if os.path.exists(self.config.base_model_path):
                logger.info(f"Loading from local path: {self.config.base_model_path}")
                
                # Load processor
                self.processor = ViTImageProcessor.from_pretrained(self.config.base_model_path)
                
                # Load model configuration
                self.model_config = ViTConfig.from_pretrained(self.config.base_model_path)
                
                # Load original model
                self.original_model = ViTForImageClassification.from_pretrained(
                    self.config.base_model_path
                )
                
            else:
                logger.info("Local model not found. Loading from Hugging Face...")
                
                # Load processor
                self.processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
                
                # Load model configuration
                self.model_config = ViTConfig.from_pretrained('google/vit-base-patch16-224')
                
                # Load original model
                self.original_model = ViTForImageClassification.from_pretrained(
                    'google/vit-base-patch16-224'
                )
            
            logger.info(f"Original model loaded successfully")
            logger.info(f"Original number of classes: {self.model_config.num_labels}")
            logger.info(f"Model architecture: {self.model_config.architectures}")
            
        except Exception as e:
            logger.error(f"Error loading original model: {str(e)}")
            raise
    
    def print_model_info(self, model, title="Model Information"):
        """Print detailed model information"""
        logger.info(f"\n{'='*50}")
        logger.info(f"{title}")
        logger.info(f"{'='*50}")
        
        # Model configuration
        config = model.config
        logger.info(f"Architecture: {config.architectures}")
        logger.info(f"Number of labels: {config.num_labels}")
        logger.info(f"Hidden size: {config.hidden_size}")
        logger.info(f"Number of attention heads: {config.num_attention_heads}")
        logger.info(f"Number of layers: {config.num_hidden_layers}")
        logger.info(f"Image size: {config.image_size}")
        logger.info(f"Patch size: {config.patch_size}")
        
        # Classifier layer info
        if hasattr(model, 'classifier'):
            classifier = model.classifier
            logger.info(f"Classifier input features: {classifier.in_features}")
            logger.info(f"Classifier output features: {classifier.out_features}")
        
        # Model parameters
        total_params = sum(p.numel() for p in model.parameters())
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        
        logger.info(f"Total parameters: {total_params:,}")
        logger.info(f"Trainable parameters: {trainable_params:,}")
        logger.info(f"{'='*50}\n")
    
    def adjust_model_architecture(self):
        """Adjust model architecture for the target task"""
        try:
            logger.info("Adjusting model architecture for breast cancer classification...")
            
            # Create new configuration with adjusted parameters
            new_config = ViTConfig.from_pretrained(
                self.config.base_model_path if os.path.exists(self.config.base_model_path) 
                else 'google/vit-base-patch16-224'
            )
            
            # Update configuration for the new task
            new_config.num_labels = self.config.num_classes
            new_config.id2label = {i: label for i, label in enumerate(self.config.class_names)}
            new_config.label2id = {label: i for i, label in enumerate(self.config.class_names)}
            new_config.problem_type = "single_label_classification"
            
            # Update dropout rates for better regularization
            new_config.hidden_dropout_prob = self.config.hidden_dropout_prob
            new_config.attention_probs_dropout_prob = self.config.attention_probs_dropout_prob
            
            # Add task-specific metadata
            new_config.task_specific_params = {
                "task_name": self.config.task_name,
                "num_classes": self.config.num_classes,
                "class_names": self.config.class_names,
                "adjustment_date": datetime.now().isoformat(),
                "base_model": "google/vit-base-patch16-224"
            }
            
            logger.info("Creating adjusted model with new configuration...")
            
            # Create new model with adjusted configuration
            self.adjusted_model = ViTForImageClassification(new_config)
            
            # Copy weights from original model (except classifier layer)
            self.copy_pretrained_weights()
            
            # Initialize new classifier layer
            self.initialize_classifier_layer()
            
            logger.info("Model architecture adjusted successfully!")
            
        except Exception as e:
            logger.error(f"Error adjusting model architecture: {str(e)}")
            raise
    
    def copy_pretrained_weights(self):
        """Copy weights from original model to adjusted model (except classifier)"""
        logger.info("Copying pre-trained weights...")
        
        # Get state dictionaries
        original_state_dict = self.original_model.state_dict()
        adjusted_state_dict = self.adjusted_model.state_dict()
        
        # Copy all weights except classifier
        copied_layers = []
        skipped_layers = []
        
        for name, param in original_state_dict.items():
            if name in adjusted_state_dict and not name.startswith('classifier'):
                if param.shape == adjusted_state_dict[name].shape:
                    adjusted_state_dict[name].copy_(param)
                    copied_layers.append(name)
                else:
                    skipped_layers.append(f"{name} (shape mismatch)")
            elif name.startswith('classifier'):
                skipped_layers.append(f"{name} (classifier layer)")
            else:
                skipped_layers.append(f"{name} (not found in target)")
        
        # Load the updated state dict
        self.adjusted_model.load_state_dict(adjusted_state_dict)
        
        logger.info(f"Copied weights for {len(copied_layers)} layers")
        logger.info(f"Skipped {len(skipped_layers)} layers:")
        for layer in skipped_layers[:5]:  # Show first 5 skipped layers
            logger.info(f"  - {layer}")
        if len(skipped_layers) > 5:
            logger.info(f"  ... and {len(skipped_layers) - 5} more")
    
    def initialize_classifier_layer(self):
        """Initialize the new classifier layer with proper weights"""
        logger.info("Initializing new classifier layer...")
        
        # Get the classifier layer
        classifier = self.adjusted_model.classifier
        
        # Initialize with Xavier/Glorot initialization
        nn.init.xavier_uniform_(classifier.weight)
        nn.init.constant_(classifier.bias, 0)
        
        logger.info(f"Classifier initialized - Input: {classifier.in_features}, Output: {classifier.out_features}")
    
    def validate_adjusted_model(self):
        """Validate the adjusted model"""
        logger.info("Validating adjusted model...")
        
        try:
            # Create dummy input
            dummy_input = torch.randn(1, 3, 224, 224)
            
            # Test forward pass
            self.adjusted_model.eval()
            with torch.no_grad():
                outputs = self.adjusted_model(pixel_values=dummy_input)
                logits = outputs.logits
            
            # Validate output shape
            expected_shape = (1, self.config.num_classes)
            actual_shape = logits.shape
            
            if actual_shape == expected_shape:
                logger.info(f"✓ Model validation successful! Output shape: {actual_shape}")
                
                # Test probability distribution
                probabilities = torch.softmax(logits, dim=-1)
                logger.info(f"✓ Sample output probabilities: {probabilities.squeeze().tolist()}")
                
                return True
            else:
                logger.error(f"✗ Model validation failed! Expected shape: {expected_shape}, Got: {actual_shape}")
                return False
                
        except Exception as e:
            logger.error(f"✗ Model validation failed with error: {str(e)}")
            return False
    
    def save_adjusted_model(self):
        """Save the adjusted model and processor"""
        try:
            # Create output directory
            output_dir = self.create_output_directory()
            
            logger.info(f"Saving adjusted model to: {output_dir}")
            
            # Save model
            self.adjusted_model.save_pretrained(output_dir)
            logger.info("✓ Model saved successfully")
            
            # Save processor
            self.processor.save_pretrained(output_dir)
            logger.info("✓ Processor saved successfully")
            
            # Save adjustment metadata
            metadata = {
                "adjustment_info": {
                    "original_model": "google/vit-base-patch16-224",
                    "task": self.config.task_name,
                    "num_classes": self.config.num_classes,
                    "class_names": self.config.class_names,
                    "adjustment_date": datetime.now().isoformat()
                },
                "model_config": {
                    "hidden_size": self.adjusted_model.config.hidden_size,
                    "num_attention_heads": self.adjusted_model.config.num_attention_heads,
                    "num_hidden_layers": self.adjusted_model.config.num_hidden_layers,
                    "image_size": self.adjusted_model.config.image_size,
                    "patch_size": self.adjusted_model.config.patch_size,
                    "num_labels": self.adjusted_model.config.num_labels
                },
                "usage_instructions": {
                    "loading": "Use ViTForImageClassification.from_pretrained() to load this model",
                    "processor": "Use ViTImageProcessor.from_pretrained() to load the processor",
                    "input_size": "224x224 RGB images",
                    "output": f"{self.config.num_classes} class probabilities"
                }
            }
            
            metadata_path = os.path.join(output_dir, 'adjustment_metadata.json')
            with open(metadata_path, 'w') as f:
                json.dump(metadata, f, indent=2)
            
            logger.info(f"✓ Metadata saved to: {metadata_path}")
            
            # Create README file
            readme_content = self.create_readme_content(output_dir)
            readme_path = os.path.join(output_dir, 'README.md')
            with open(readme_path, 'w') as f:
                f.write(readme_content)
            
            logger.info(f"✓ README created: {readme_path}")
            
            return output_dir
            
        except Exception as e:
            logger.error(f"Error saving adjusted model: {str(e)}")
            raise
    
    def create_readme_content(self, model_path):
        """Create README content for the adjusted model"""
        readme = f"""# Breast Cancer Detection ViT Model (Adjusted)

## Model Information
- **Base Model**: google/vit-base-patch16-224
- **Task**: Binary classification for breast cancer detection
- **Classes**: {self.config.num_classes} ({', '.join(self.config.class_names)})
- **Adjustment Date**: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}

## Model Architecture
- **Input Size**: 224x224 RGB images
- **Hidden Size**: {self.adjusted_model.config.hidden_size}
- **Attention Heads**: {self.adjusted_model.config.num_attention_heads}
- **Layers**: {self.adjusted_model.config.num_hidden_layers}
- **Parameters**: {sum(p.numel() for p in self.adjusted_model.parameters()):,}

## Usage

### Loading the Model
```python
from transformers import ViTForImageClassification, ViTImageProcessor

# Load model and processor
model = ViTForImageClassification.from_pretrained('{model_path}')
processor = ViTImageProcessor.from_pretrained('{model_path}')
```

### Inference Example
```python
import torch
from PIL import Image

# Load and preprocess image
image = Image.open('path/to/mri_image.jpg')
inputs = processor(images=image, return_tensors="pt")

# Make prediction
with torch.no_grad():
    outputs = model(**inputs)
    predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)

# Get predicted class
predicted_class_id = predictions.argmax().item()
predicted_class = {self.config.class_names}[predicted_class_id]
confidence = predictions[0][predicted_class_id].item()

print(f"Prediction: {{predicted_class}} (Confidence: {{confidence:.4f}})")
```

## Fine-tuning Ready
This model is ready for fine-tuning on your breast cancer dataset. The classifier layer has been properly initialized for binary classification.

## Notes
- The model uses the original ViT-Base architecture with an adjusted classifier layer
- All pre-trained weights from ImageNet are preserved except for the final classification layer
- The classifier layer has been randomly initialized and needs fine-tuning on your dataset
"""
        return readme
    
    def run_adjustment_process(self):
        """Run the complete model adjustment process"""
        try:
            logger.info("Starting ViT model adjustment process...")
            
            # Step 1: Load original model
            with tqdm(total=5, desc="Model Adjustment Progress") as pbar:
                self.load_original_model()
                pbar.set_description("Original model loaded")
                pbar.update(1)
                
                # Print original model info
                self.print_model_info(self.original_model, "Original Model Information")
                
                # Step 2: Adjust architecture
                self.adjust_model_architecture()
                pbar.set_description("Architecture adjusted")
                pbar.update(1)
                
                # Print adjusted model info
                self.print_model_info(self.adjusted_model, "Adjusted Model Information")
                
                # Step 3: Validate adjusted model
                validation_success = self.validate_adjusted_model()
                if not validation_success:
                    raise Exception("Model validation failed!")
                pbar.set_description("Model validated")
                pbar.update(1)
                
                # Step 4: Save adjusted model
                output_path = self.save_adjusted_model()
                pbar.set_description("Model saved")
                pbar.update(1)
                
                # Step 5: Final verification
                self.verify_saved_model(output_path)
                pbar.set_description("Verification complete")
                pbar.update(1)
            
            logger.info("✅ Model adjustment completed successfully!")
            logger.info(f"Adjusted model saved to: {output_path}")
            
            return output_path
            
        except Exception as e:
            logger.error(f"❌ Model adjustment failed: {str(e)}")
            raise
    
    def verify_saved_model(self, model_path):
        """Verify that the saved model can be loaded correctly"""
        try:
            logger.info("Verifying saved model...")
            
            # Load the saved model
            loaded_model = ViTForImageClassification.from_pretrained(model_path)
            loaded_processor = ViTImageProcessor.from_pretrained(model_path)
            
            # Test with dummy input
            dummy_input = torch.randn(1, 3, 224, 224)
            
            loaded_model.eval()
            with torch.no_grad():
                outputs = loaded_model(pixel_values=dummy_input)
                logits = outputs.logits
            
            if logits.shape == (1, self.config.num_classes):
                logger.info("✅ Saved model verification successful!")
                return True
            else:
                logger.error("❌ Saved model verification failed!")
                return False
                
        except Exception as e:
            logger.error(f"❌ Error verifying saved model: {str(e)}")
            return False

def main():
    """Main execution function"""
    logger.info("🚀 Starting ViT Model Adjustment for Breast Cancer Detection")
    
    try:
        # Initialize configuration
        config = ModelAdjustmentConfig()
        
        # Initialize model adjuster
        adjuster = ViTModelAdjuster(config)
        
        # Run adjustment process
        output_path = adjuster.run_adjustment_process()
        
        logger.info("🎉 Process completed successfully!")
        logger.info(f"📁 Adjusted model location: {output_path}")
        logger.info("📋 Next steps:")
        logger.info("   1. Use this adjusted model for fine-tuning on your dataset")
        logger.info("   2. The model is ready for training with your breast cancer data")
        logger.info("   3. Check the README.md file for usage instructions")
        
        return output_path
        
    except Exception as e:
        logger.error(f"💥 Process failed: {str(e)}")
        raise

# For Jupyter Notebook usage
if __name__ == "__main__":
    print("🔬 ViT Model Adjustment for Breast Cancer Detection")
    print("=" * 60)
    print("This script will adjust the ViT model architecture for binary classification.")
    print("The model will be ready for fine-tuning on your breast cancer dataset.")
    print("=" * 60)
    
    # Run the adjustment process
    try:
        output_path = main()
        print(f"\n✅ SUCCESS: Model adjusted and saved to: {output_path}")
    except Exception as e:
        print(f"\n❌ ERROR: {str(e)}")
        print("Please check the logs for more details.")

2025-07-01 08:47:16,605 - INFO - 🚀 Starting ViT Model Adjustment for Breast Cancer Detection
2025-07-01 08:47:16,606 - INFO - Configuration initialized for breast_cancer_detection
2025-07-01 08:47:16,607 - INFO - Target classes: 2 (normal, cancer)
2025-07-01 08:47:16,607 - INFO - Device: cpu
2025-07-01 08:47:16,607 - INFO - Starting ViT model adjustment process...


🔬 ViT Model Adjustment for Breast Cancer Detection
This script will adjust the ViT model architecture for binary classification.
The model will be ready for fine-tuning on your breast cancer dataset.


Model Adjustment Progress:   0%|          | 0/5 [00:00<?, ?it/s]

2025-07-01 08:47:16,628 - INFO - Loading original ViT model and processor...
2025-07-01 08:47:16,637 - INFO - Loading from local path: /Volumes/KODAK/folder 02/Brest_cancer_prediction/model/raw_model
2025-07-01 08:47:16,884 - INFO - Original model loaded successfully
2025-07-01 08:47:16,884 - INFO - Original number of classes: 1000
2025-07-01 08:47:16,885 - INFO - Model architecture: ['ViTForImageClassification']
2025-07-01 08:47:16,886 - INFO - 
2025-07-01 08:47:16,886 - INFO - Original Model Information
2025-07-01 08:47:16,888 - INFO - Architecture: ['ViTForImageClassification']
2025-07-01 08:47:16,888 - INFO - Number of labels: 1000
2025-07-01 08:47:16,889 - INFO - Hidden size: 768
2025-07-01 08:47:16,889 - INFO - Number of attention heads: 12
2025-07-01 08:47:16,890 - INFO - Number of layers: 12
2025-07-01 08:47:16,890 - INFO - Image size: 224
2025-07-01 08:47:16,891 - INFO - Patch size: 16
2025-07-01 08:47:16,891 - INFO - Classifier input features: 768
2025-07-01 08:47:16,892 - IN


✅ SUCCESS: Model adjusted and saved to: /Volumes/KODAK/folder 02/Brest_cancer_prediction/model/fine_tuning_model/breast_cancer_vit_adjusted
