<a href="https://colab.research.google.com/github/OneFineStarstuff/OneFineStarstuff/blob/main/Efficient_Training_with_Gradient_Checkpointing.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
from transformers import AutoModelForCausalLM, AutoTokenizer

class WrappedModel(nn.Module):
    def __init__(self, model):
        super(WrappedModel, self).__init__()
        self.model = model

    def forward(self, input_ids):
        attention_mask = (input_ids != tokenizer.pad_token_id).to(dtype=torch.long) if tokenizer.pad_token_id is not None else torch.ones_like(input_ids)
        # Use checkpointing on memory-intensive operations
        outputs = checkpoint(self.model_forward, input_ids, attention_mask)
        return outputs

    def model_forward(self, input_ids, attention_mask):
        return self.model(input_ids=input_ids, attention_mask=attention_mask).logits

# Initialize student model and tokenizer
student_model = AutoModelForCausalLM.from_pretrained("distilgpt2")
tokenizer = AutoTokenizer.from_pretrained("distilgpt2")

wrapped_model = WrappedModel(student_model)

# Create a dummy input for testing
dummy_input = tokenizer("This is a dummy input", return_tensors="pt")["input_ids"]

# Test the model with checkpointing
outputs = wrapped_model(dummy_input)
print(outputs)