In [1]:
import os
import torch
import torch.nn as nn
import torchvision.models as models
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
from datasets import load_from_disk
import numpy as np

# Paths
processed_dir = "/home/aac/project-hyperion/data/processed"
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

Using device: cuda


In [2]:
# Load tokenized WikiText-2
wikitext2 = load_from_disk(os.path.join(processed_dir, "wikitext2_tokenized"))
class WikiText2TorchDataset(torch.utils.data.Dataset):
    def __init__(self, hf_dataset, split="train"):
        self.data = hf_dataset[split]
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        item = self.data[idx]
        input_ids = torch.tensor(item["input_ids"], dtype=torch.long)
        attention_mask = torch.tensor(item["attention_mask"], dtype=torch.long)
        return input_ids, attention_mask

wikitext2_train_ds = WikiText2TorchDataset(wikitext2, split="train")
wikitext2_loader = DataLoader(wikitext2_train_ds, batch_size=32, shuffle=True, num_workers=2)

# Load CIFAR-10
cifar10_train = torch.load(os.path.join(processed_dir, "cifar10_train.pt"))
class CIFAR10TorchDataset(torch.utils.data.Dataset):
    def __init__(self, data):
        self.data = data
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        img, label = self.data[idx]
        return img, label

cifar10_train_ds = CIFAR10TorchDataset(cifar10_train)
cifar10_loader = DataLoader(cifar10_train_ds, batch_size=32, shuffle=True, num_workers=2)

# Load models
tokenizer = AutoTokenizer.from_pretrained("gpt2", use_fast=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
vocab_size = tokenizer.vocab_size

class SimpleTransformerLM(nn.Module):
    def __init__(self, vocab_size, emb_dim=256, n_heads=4, n_layers=2, max_len=128):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, emb_dim)
        encoder_layer = nn.TransformerEncoderLayer(d_model=emb_dim, nhead=n_heads)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
        self.fc = nn.Linear(emb_dim, vocab_size)
        self.max_len = max_len
    def forward(self, input_ids, attention_mask=None):
        x = self.embedding(input_ids)
        x = x.permute(1, 0, 2)
        x = self.transformer(x)
        x = x.permute(1, 0, 2)
        logits = self.fc(x)
        return logits

model_lm = SimpleTransformerLM(vocab_size).to(device)
model_lm.load_state_dict(torch.load(os.path.join(processed_dir, "simple_transformer_lm.pt"), map_location=device))

model_cifar = models.resnet18(num_classes=10).to(device)
model_cifar.load_state_dict(torch.load(os.path.join(processed_dir, "resnet18_cifar10.pt"), map_location=device))

print("Models loaded.")



Models loaded.


In [3]:
def print_memory(prefix=""):
    if torch.cuda.is_available():
        allocated = torch.cuda.memory_allocated() / 1024**2
        reserved = torch.cuda.memory_reserved() / 1024**2
        print(f"{prefix} CUDA memory allocated: {allocated:.2f} MB, reserved: {reserved:.2f} MB")
    else:
        print(f"{prefix} CUDA not available; memory stats not shown.")

# Baseline for language model
print("Baseline memory usage: Language model")
print_memory("Before forward pass:")
input_ids, attention_mask = next(iter(wikitext2_loader))
input_ids = input_ids.to(device)
with torch.no_grad():
    outputs = model_lm(input_ids[:, :-1])
print_memory("After forward pass:")

# Baseline for vision model
print("\nBaseline memory usage: CIFAR-10 model")
print_memory("Before forward pass:")
images, labels = next(iter(cifar10_loader))
images = images.to(device)
with torch.no_grad():
    outputs = model_cifar(images)
print_memory("After forward pass:")

Baseline memory usage: Language model
Before forward pass: CUDA memory allocated: 152.93 MB, reserved: 230.00 MB
After forward pass: CUDA memory allocated: 932.96 MB, reserved: 1042.00 MB

Baseline memory usage: CIFAR-10 model
Before forward pass: CUDA memory allocated: 932.96 MB, reserved: 1042.00 MB


  attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)


After forward pass: CUDA memory allocated: 153.34 MB, reserved: 958.00 MB


In [5]:
from torch.utils.checkpoint import checkpoint_sequential

class SimpleTransformerLM(nn.Module):
    def __init__(self, vocab_size, emb_dim=256, n_heads=4, n_layers=2, max_len=128, use_checkpoint=False):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, emb_dim)
        encoder_layer = nn.TransformerEncoderLayer(d_model=emb_dim, nhead=n_heads)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
        self.fc = nn.Linear(emb_dim, vocab_size)
        self.max_len = max_len
        self.use_checkpoint = use_checkpoint

    def forward(self, input_ids, attention_mask=None):
        x = self.embedding(input_ids)
        x = x.permute(1, 0, 2)  # (seq, batch, emb)
        if self.use_checkpoint:
            # Use checkpointing on transformer layers
            x = checkpoint_sequential(self.transformer.layers, len(self.transformer.layers), x)
        else:
            x = self.transformer(x)
        x = x.permute(1, 0, 2)
        logits = self.fc(x)
        return logits

# Instantiate with checkpointing enabled
model_lm_ckpt = SimpleTransformerLM(
    vocab_size,
    emb_dim=256,
    n_heads=4,
    n_layers=2,
    max_len=128,
    use_checkpoint=True
).to(device)
model_lm_ckpt.load_state_dict(model_lm.state_dict())
print("Checkpointed language model ready.")

Checkpointed language model ready.


In [6]:
def checkpoint_resnet_blocks(model):
    def forward_with_ckpt(x):
        x = model.conv1(x)
        x = model.bn1(x)
        x = model.relu(x)
        x = model.maxpool(x)
        layers = [model.layer1, model.layer2, model.layer3, model.layer4]
        x = checkpoint_sequential(layers, len(layers), x)
        x = model.avgpool(x)
        x = torch.flatten(x, 1)
        x = model.fc(x)
        return x
    model.forward = forward_with_ckpt
    return model

# Apply checkpointing to your already-loaded model_cifar
model_cifar_ckpt = checkpoint_resnet_blocks(model_cifar)
print("Checkpointed ResNet-18 ready.")

Checkpointed ResNet-18 ready.


In [7]:
# For language model with checkpointing
print("Memory usage: Language model with activation checkpointing")
print_memory("Before forward pass:")
input_ids, attention_mask = next(iter(wikitext2_loader))
input_ids = input_ids.to(device)
with torch.no_grad():
    outputs = model_lm_ckpt(input_ids[:, :-1])
print_memory("After forward pass:")

# For ResNet-18 with checkpointing
print("\nMemory usage: ResNet-18 with activation checkpointing")
print_memory("Before forward pass:")
images, labels = next(iter(cifar10_loader))
images = images.to(device)
with torch.no_grad():
    outputs = model_cifar_ckpt(images)
print_memory("After forward pass:")

Memory usage: Language model with activation checkpointing
Before forward pass: CUDA memory allocated: 370.23 MB, reserved: 962.00 MB
After forward pass: CUDA memory allocated: 1150.23 MB, reserved: 1742.00 MB

Memory usage: ResNet-18 with activation checkpointing
Before forward pass: CUDA memory allocated: 1150.23 MB, reserved: 1742.00 MB




After forward pass: CUDA memory allocated: 370.23 MB, reserved: 1744.00 MB


In [8]:
def train_one_batch(model, dataloader, device):
    model.train()
    optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4)
    loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
    input_ids, attention_mask = next(iter(dataloader))
    input_ids = input_ids.to(device)
    targets = input_ids[:, 1:].contiguous()
    inputs = input_ids[:, :-1].contiguous()
    optimizer.zero_grad()
    outputs = model(inputs)
    logits = outputs.reshape(-1, outputs.size(-1))
    targets = targets.reshape(-1)
    loss = loss_fn(logits, targets)
    loss.backward()
    optimizer.step()
    return loss.item()

# Clear memory
if torch.cuda.is_available():
    torch.cuda.empty_cache()

print("Training one batch (no checkpointing):")
print_memory("Before:")
loss = train_one_batch(model_lm, wikitext2_loader, device)
print_memory("After:")
print(f"Loss: {loss:.4f}")

if torch.cuda.is_available():
    torch.cuda.empty_cache()

print("\nTraining one batch (with checkpointing):")
print_memory("Before:")
loss = train_one_batch(model_lm_ckpt, wikitext2_loader, device)
print_memory("After:")
print(f"Loss: {loss:.4f}")

Training one batch (no checkpointing):
Before: CUDA memory allocated: 370.23 MB, reserved: 962.00 MB
After: CUDA memory allocated: 478.68 MB, reserved: 4090.00 MB
Loss: 2.8528

Training one batch (with checkpointing):
Before: CUDA memory allocated: 478.68 MB, reserved: 964.00 MB
After: CUDA memory allocated: 587.06 MB, reserved: 4092.00 MB
Loss: 2.7724
