In [1]:
import os
import torch
import numpy as np

# Check versions for compatibility
print("torch:", torch.__version__)
print("numpy:", np.__version__)

# Check for AMD GPU (ROCm)
if torch.cuda.is_available():
    print("CUDA/ROCm device count:", torch.cuda.device_count())
    print("Current device:", torch.cuda.current_device())
    print("Device name:", torch.cuda.get_device_name(0))
else:
    print("No CUDA/ROCm device detected.")

torch: 2.3.0a0+git96dd291
numpy: 1.24.4
CUDA/ROCm device count: 1
Current device: 0
Device name: AMD Instinct MI250X/MI250


In [2]:
from datasets import load_from_disk

processed_dir = "/home/aac/project-hyperion/data/processed"
wikitext2_path = os.path.join(processed_dir, "wikitext2_tokenized")
wikitext2 = load_from_disk(wikitext2_path)

print("Loaded tokenized WikiText-2 splits:", wikitext2)
print("Sample input_ids:", wikitext2["train"][0]["input_ids"][:20])
print("Sample attention_mask:", wikitext2["train"][0]["attention_mask"][:20])

Loaded tokenized WikiText-2 splits: DatasetDict({
    test: Dataset({
        features: ['input_ids', 'attention_mask'],
        num_rows: 2891
    })
    train: Dataset({
        features: ['input_ids', 'attention_mask'],
        num_rows: 23767
    })
    validation: Dataset({
        features: ['input_ids', 'attention_mask'],
        num_rows: 2461
    })
})
Sample input_ids: [796, 569, 18354, 7496, 17740, 6711, 796, 220, 198, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256]
Sample attention_mask: [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]


In [3]:
import torch

cifar10_train = torch.load(os.path.join(processed_dir, "cifar10_train.pt"))
cifar10_test = torch.load(os.path.join(processed_dir, "cifar10_test.pt"))

print("CIFAR-10 train samples:", len(cifar10_train))
print("CIFAR-10 test samples:", len(cifar10_test))
img, lbl = cifar10_train[0]
print("Sample image shape:", img.shape, "Label:", lbl)

CIFAR-10 train samples: 50000
CIFAR-10 test samples: 10000
Sample image shape: torch.Size([3, 32, 32]) Label: 6


In [4]:
from torch.utils.data import Dataset, DataLoader

# For WikiText-2, wrap Hugging Face dataset in PyTorch Dataset for DataLoader
class WikiText2TorchDataset(Dataset):
    def __init__(self, hf_dataset, split="train"):
        self.data = hf_dataset[split]
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        item = self.data[idx]
        input_ids = torch.tensor(item["input_ids"], dtype=torch.long)
        attention_mask = torch.tensor(item["attention_mask"], dtype=torch.long)
        return input_ids, attention_mask

# For CIFAR-10, data is already a list of (img, label) tuples
class CIFAR10TorchDataset(Dataset):
    def __init__(self, data):
        self.data = data
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        img, label = self.data[idx]
        return img, label

# Instantiate datasets
wikitext2_train_ds = WikiText2TorchDataset(wikitext2, split="train")
cifar10_train_ds = CIFAR10TorchDataset(cifar10_train)

# DataLoaders
BATCH_SIZE = 32

wikitext2_loader = DataLoader(wikitext2_train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
cifar10_loader = DataLoader(cifar10_train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)

print("DataLoaders ready.")

DataLoaders ready.


In [5]:
import torch.nn as nn

class SimpleTransformerLM(nn.Module):
    def __init__(self, vocab_size, emb_dim=256, n_heads=4, n_layers=2, max_len=128):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, emb_dim)
        encoder_layer = nn.TransformerEncoderLayer(d_model=emb_dim, nhead=n_heads)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
        self.fc = nn.Linear(emb_dim, vocab_size)
        self.max_len = max_len

    def forward(self, input_ids, attention_mask=None):
        x = self.embedding(input_ids)  # (batch, seq, emb)
        x = x.permute(1, 0, 2)        # (seq, batch, emb)
        x = self.transformer(x)        # (seq, batch, emb)
        x = x.permute(1, 0, 2)        # (batch, seq, emb)
        logits = self.fc(x)            # (batch, seq, vocab)
        return logits

# Get vocab size from tokenizer config
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("gpt2", use_fast=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
vocab_size = tokenizer.vocab_size

model_lm = SimpleTransformerLM(vocab_size).to("cuda" if torch.cuda.is_available() else "cpu")
print("SimpleTransformerLM initialized.")



SimpleTransformerLM initialized.


In [6]:
import torchvision.models as models

# Use a small ResNet for CIFAR-10 (ResNet-18)
model_cifar = models.resnet18(num_classes=10)
model_cifar = model_cifar.to("cuda" if torch.cuda.is_available() else "cpu")
print("ResNet-18 for CIFAR-10 initialized.")

ResNet-18 for CIFAR-10 initialized.


In [7]:
def train_language_model(model, dataloader, device, epochs=1):
    model.train()
    optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4)
    loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
    for epoch in range(epochs):
        total_loss = 0
        for input_ids, attention_mask in dataloader:
            input_ids = input_ids.to(device)
            # Shift targets for language modeling
            targets = input_ids[:, 1:].contiguous()
            inputs = input_ids[:, :-1].contiguous()
            outputs = model(inputs)
            logits = outputs[:, :, :]
            logits = logits.reshape(-1, logits.size(-1))
            targets = targets.reshape(-1)
            loss = loss_fn(logits, targets)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch+1}: Avg Loss = {total_loss / len(dataloader):.4f}")

def train_cifar_model(model, dataloader, device, epochs=1):
    model.train()
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
    loss_fn = nn.CrossEntropyLoss()
    for epoch in range(epochs):
        total_loss = 0
        total_correct = 0
        total = 0
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = loss_fn(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            _, preds = outputs.max(1)
            total_correct += preds.eq(labels).sum().item()
            total += labels.size(0)
        print(f"Epoch {epoch+1}: Loss = {total_loss / len(dataloader):.4f}, Acc = {100 * total_correct / total:.2f}%")

In [10]:
device = "cuda" if torch.cuda.is_available() else "cpu"

print("Training language model on WikiText-2 (25 epoch, sanity check)...")
train_language_model(model_lm, wikitext2_loader, device, epochs=25)

print("Training ResNet-18 on CIFAR-10 (25 epoch, sanity check)...")
train_cifar_model(model_cifar, cifar10_loader, device, epochs=25)

Training language model on WikiText-2 (25 epoch, sanity check)...
Epoch 1: Avg Loss = 5.6753
Epoch 2: Avg Loss = 5.3459
Epoch 3: Avg Loss = 5.0855
Epoch 4: Avg Loss = 4.8599
Epoch 5: Avg Loss = 4.6623
Epoch 6: Avg Loss = 4.4856
Epoch 7: Avg Loss = 4.3231
Epoch 8: Avg Loss = 4.1774
Epoch 9: Avg Loss = 4.0440
Epoch 10: Avg Loss = 3.9218
Epoch 11: Avg Loss = 3.8100
Epoch 12: Avg Loss = 3.7063
Epoch 13: Avg Loss = 3.6117
Epoch 14: Avg Loss = 3.5219
Epoch 15: Avg Loss = 3.4402
Epoch 16: Avg Loss = 3.3657
Epoch 17: Avg Loss = 3.2932
Epoch 18: Avg Loss = 3.2273
Epoch 19: Avg Loss = 3.1675
Epoch 20: Avg Loss = 3.1089
Epoch 21: Avg Loss = 3.0561
Epoch 22: Avg Loss = 3.0041
Epoch 23: Avg Loss = 2.9584
Epoch 24: Avg Loss = 2.9118
Epoch 25: Avg Loss = 2.8712
Training ResNet-18 on CIFAR-10 (25 epoch, sanity check)...
Epoch 1: Loss = 1.0457, Acc = 63.47%
Epoch 2: Loss = 0.8336, Acc = 71.07%
Epoch 3: Loss = 0.6990, Acc = 75.95%
Epoch 4: Loss = 0.5893, Acc = 79.56%
Epoch 5: Loss = 0.4909, Acc = 83.03%

In [11]:
# Save models
torch.save(model_lm.state_dict(), os.path.join(processed_dir, "simple_transformer_lm.pt"))
torch.save(model_cifar.state_dict(), os.path.join(processed_dir, "resnet18_cifar10.pt"))
print("Models saved after 25 epochs.")

Models saved after 25 epochs.
