In [1]:
import os
import torch
import torch.nn as nn
import torchvision.models as models
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
from datasets import load_from_disk
import numpy as np

# Paths
processed_dir = "/home/aac/project-hyperion/data/processed"
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

Using device: cuda


In [2]:
# Load WikiText-2
wikitext2 = load_from_disk(os.path.join(processed_dir, "wikitext2_tokenized"))
class WikiText2TorchDataset(torch.utils.data.Dataset):
    def __init__(self, hf_dataset, split="train"):
        self.data = hf_dataset[split]
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        item = self.data[idx]
        input_ids = torch.tensor(item["input_ids"], dtype=torch.long)
        attention_mask = torch.tensor(item["attention_mask"], dtype=torch.long)
        return input_ids, attention_mask

wikitext2_train_ds = WikiText2TorchDataset(wikitext2, split="train")
wikitext2_loader = DataLoader(wikitext2_train_ds, batch_size=32, shuffle=True, num_workers=2)

# Load CIFAR-10
cifar10_train = torch.load(os.path.join(processed_dir, "cifar10_train.pt"))
class CIFAR10TorchDataset(torch.utils.data.Dataset):
    def __init__(self, data):
        self.data = data
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        img, label = self.data[idx]
        return img, label

cifar10_train_ds = CIFAR10TorchDataset(cifar10_train)
cifar10_loader = DataLoader(cifar10_train_ds, batch_size=32, shuffle=True, num_workers=2)

# Load tokenizer and get vocab size
tokenizer = AutoTokenizer.from_pretrained("gpt2", use_fast=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
vocab_size = tokenizer.vocab_size

# Define language model
class SimpleTransformerLM(nn.Module):
    def __init__(self, vocab_size, emb_dim=256, n_heads=4, n_layers=2, max_len=128):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, emb_dim)
        encoder_layer = nn.TransformerEncoderLayer(d_model=emb_dim, nhead=n_heads)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
        self.fc = nn.Linear(emb_dim, vocab_size)
        self.max_len = max_len
    def forward(self, input_ids, attention_mask=None):
        x = self.embedding(input_ids)
        x = x.permute(1, 0, 2)
        x = self.transformer(x)
        x = x.permute(1, 0, 2)
        logits = self.fc(x)
        return logits

model_lm = SimpleTransformerLM(vocab_size).to(device)
model_cifar = models.resnet18(num_classes=10).to(device)
print("Models initialized.")




Models initialized.


In [13]:
from torch.cuda.amp import autocast, GradScaler

def train_language_model_amp(model, dataloader, device, epochs=3):
    model.train()
    optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4)
    loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
    scaler = GradScaler()
    for epoch in range(epochs):
        total_loss = 0
        for input_ids, attention_mask in dataloader:
            input_ids = input_ids.to(device)
            targets = input_ids[:, 1:].contiguous()
            inputs = input_ids[:, :-1].contiguous()
            optimizer.zero_grad()
            with autocast():  # No device_type argument for ROCm/CUDA
                outputs = model(inputs)
                logits = outputs.reshape(-1, outputs.size(-1))
                targets_flat = targets.reshape(-1)
                loss = loss_fn(logits, targets_flat)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            total_loss += loss.item()
        avg_loss = total_loss / len(dataloader)
        print(f"Epoch {epoch+1}: Avg Loss = {avg_loss:.4f}")

def train_cifar_model_amp(model, dataloader, device, epochs=3):
    model.train()
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
    loss_fn = nn.CrossEntropyLoss()
    scaler = GradScaler()
    for epoch in range(epochs):
        total_loss = 0
        total_correct = 0
        total = 0
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            with autocast():
                outputs = model(images)
                loss = loss_fn(outputs, labels)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            total_loss += loss.item()
            _, preds = outputs.max(1)
            total_correct += preds.eq(labels).sum().item()
            total += labels.size(0)
        avg_loss = total_loss / len(dataloader)
        acc = 100 * total_correct / total
        print(f"Epoch {epoch+1}: Loss = {avg_loss:.4f}, Acc = {acc:.2f}%")

In [15]:
import warnings
warnings.filterwarnings("ignore")

In [16]:
# Train language model with AMP
print("Training language model on WikiText-2 with AMP (25 epochs)...")
train_language_model_amp(model_lm, wikitext2_loader, device, epochs=25)

# Train ResNet-18 with AMP
print("Training ResNet-18 on CIFAR-10 with AMP (25 epochs)...")
train_cifar_model_amp(model_cifar, cifar10_loader, device, epochs=25)

Training language model on WikiText-2 with AMP (25 epochs)...
Epoch 1: Avg Loss = 5.1700
Epoch 2: Avg Loss = 4.9256
Epoch 3: Avg Loss = 4.7180
Epoch 4: Avg Loss = 4.5350
Epoch 5: Avg Loss = 4.3713
Epoch 6: Avg Loss = 4.2223
Epoch 7: Avg Loss = 4.0838
Epoch 8: Avg Loss = 3.9572
Epoch 9: Avg Loss = 3.8424
Epoch 10: Avg Loss = 3.7347
Epoch 11: Avg Loss = 3.6362
Epoch 12: Avg Loss = 3.5438
Epoch 13: Avg Loss = 3.4593
Epoch 14: Avg Loss = 3.3817
Epoch 15: Avg Loss = 3.3098
Epoch 16: Avg Loss = 3.2431
Epoch 17: Avg Loss = 3.1796
Epoch 18: Avg Loss = 3.1202
Epoch 19: Avg Loss = 3.0667
Epoch 20: Avg Loss = 3.0152
Epoch 21: Avg Loss = 2.9660
Epoch 22: Avg Loss = 2.9223
Epoch 23: Avg Loss = 2.8776
Epoch 24: Avg Loss = 2.8380
Epoch 25: Avg Loss = 2.7995
Training ResNet-18 on CIFAR-10 with AMP (25 epochs)...
Epoch 1: Loss = 0.7140, Acc = 75.54%
Epoch 2: Loss = 0.6040, Acc = 79.14%
Epoch 3: Loss = 0.5048, Acc = 82.61%
Epoch 4: Loss = 0.4257, Acc = 85.37%
Epoch 5: Loss = 0.3431, Acc = 88.11%
Epoch 6

In [17]:
# Save AMP-trained models and scaler state for reproducibility
from torch.cuda.amp import GradScaler

scaler_lm = GradScaler()
scaler_cifar = GradScaler()

torch.save({
    "model_state_dict": model_lm.state_dict(),
    "scaler_state_dict": scaler_lm.state_dict()
}, os.path.join(processed_dir, "simple_transformer_lm_amp.pt"))

torch.save({
    "model_state_dict": model_cifar.state_dict(),
    "scaler_state_dict": scaler_cifar.state_dict()
}, os.path.join(processed_dir, "resnet18_cifar10_amp.pt"))

print("AMP-trained models and scaler state saved.")

AMP-trained models and scaler state saved.


In [21]:
import time
from torch.cuda.amp import autocast, GradScaler

def profile_amp_training(model, dataloader, device, n_batches=10):
    model.eval()
    scaler = GradScaler()
    optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4)
    start = time.time()
    torch.cuda.reset_peak_memory_stats()
    for i, (input_ids, attention_mask) in enumerate(dataloader):
        if i >= n_batches:
            break
        input_ids = input_ids.to(device)
        targets = input_ids[:, 1:].contiguous()
        inputs = input_ids[:, :-1].contiguous()
        optimizer.zero_grad()
        with autocast():  # <-- FIXED: no device_type argument
            outputs = model(inputs)
            logits = outputs.reshape(-1, outputs.size(-1))
            targets_flat = targets.reshape(-1)
            loss = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)(logits, targets_flat)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
    torch.cuda.synchronize()
    end = time.time()
    print(f"AMP training: {n_batches} batches in {end-start:.2f} sec")
    print(f"Peak memory used: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB")

print("Profiling AMP training (language model, 10 batches)...")
profile_amp_training(model_lm, wikitext2_loader, device, n_batches=10)

Profiling AMP training (language model, 10 batches)...
AMP training: 10 batches in 0.64 sec
Peak memory used: 3465.58 MB
