<a href="https://colab.research.google.com/github/ayagup/stablediffusion/blob/main/hello_tpu.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.runtime as xr
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np



In [2]:
# Simple neural network model
class SimpleNet(nn.Module):
    def __init__(self, input_size=784, hidden_size=128, num_classes=10):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, num_classes)
        self.dropout = nn.Dropout(0.2)

    def forward(self, x):
        x = x.view(x.size(0), -1)  # Flatten
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        x = self.dropout(x)
        x = self.fc3(x)
        return x


In [3]:

# Create synthetic dataset
def create_synthetic_data(num_samples=1000, input_size=784, num_classes=10):
    X = torch.randn(num_samples, input_size)
    y = torch.randint(0, num_classes, (num_samples,))
    return X, y


In [4]:

# Fixed training function
def train_model():
    print("Creating synthetic dataset...")
    X_train, y_train = create_synthetic_data(num_samples=1000)
    X_test, y_test = create_synthetic_data(num_samples=200)

    # Create datasets and dataloaders
    train_dataset = TensorDataset(X_train, y_train)
    test_dataset = TensorDataset(X_test, y_test)

    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

    # Move data to TPU using parallel loader
    device = xm.xla_device()
    train_loader = pl.ParallelLoader(train_loader, [device])
    test_loader = pl.ParallelLoader(test_loader, [device])

    print("Initializing model...")
    model = SimpleNet().to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    print("Starting training...")
    num_epochs = 3

    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0
        epoch_correct = 0
        epoch_total = 0
        batch_count = 0

        for batch_idx, (data, target) in enumerate(train_loader.per_device_loader(device)):
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()

            # Use XLA optimizer step
            xm.optimizer_step(optimizer)

            # Accumulate statistics
            epoch_loss += loss.item()
            _, predicted = torch.max(output.data, 1)
            batch_total = target.size(0)
            batch_correct = (predicted == target).sum().item()

            epoch_total += batch_total
            epoch_correct += batch_correct
            batch_count += 1

            if batch_idx % 10 == 0:
                batch_acc = 100.0 * batch_correct / batch_total if batch_total > 0 else 0.0
                print(f'Epoch {epoch+1}/{num_epochs}, Batch {batch_idx}, '
                      f'Loss: {loss.item():.4f}, Batch Acc: {batch_acc:.2f}%')

        # Print epoch results
        avg_loss = epoch_loss / batch_count if batch_count > 0 else 0.0
        accuracy = 100.0 * epoch_correct / epoch_total if epoch_total > 0 else 0.0
        print(f'Epoch {epoch+1} completed - Avg Loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%')

        # Evaluation every epoch
        model.eval()
        test_correct = 0
        test_total = 0
        test_loss = 0
        test_batches = 0

        with torch.no_grad():
            for data, target in test_loader.per_device_loader(device):
                output = model(data)
                test_loss += criterion(output, target).item()
                _, predicted = torch.max(output.data, 1)
                test_total += target.size(0)
                test_correct += (predicted == target).sum().item()
                test_batches += 1

        if test_total > 0:
            test_accuracy = 100.0 * test_correct / test_total
            avg_test_loss = test_loss / test_batches if test_batches > 0 else 0.0
            print(f'Test - Loss: {avg_test_loss:.4f}, Accuracy: {test_accuracy:.2f}%')

        print("-" * 50)

    print("Training completed!")
    return model


In [5]:

# Test basic TPU operations
def test_tpu_operations():
    print("=== Testing Basic TPU Operations ===")

    device = xm.xla_device()
    print(f"Using device: {device}")

    # Basic tensor operations
    x = torch.randn(3, 3, device=device)
    y = torch.randn(3, 3, device=device)

    print(f"Created tensors on TPU")

    # Basic operations
    z = x + y
    print(f"Addition successful, result shape: {z.shape}")

    # Matrix multiplication
    a = torch.randn(4, 5, device=device)
    b = torch.randn(5, 3, device=device)
    c = torch.matmul(a, b)
    print(f"Matrix multiplication successful, result shape: {c.shape}")

    # Neural network layer test
    linear = nn.Linear(10, 5).to(device)
    input_tensor = torch.randn(8, 10, device=device)
    output = linear(input_tensor)
    print(f"Neural network layer successful, output shape: {output.shape}")

    print("✓ All TPU operations work correctly!")
    print("=" * 50)


In [6]:
test_tpu_operations()

=== Testing Basic TPU Operations ===
Using device: xla:0
Created tensors on TPU
Addition successful, result shape: torch.Size([3, 3])
Matrix multiplication successful, result shape: torch.Size([4, 3])
Neural network layer successful, output shape: torch.Size([8, 5])
✓ All TPU operations work correctly!


  device = xm.xla_device()


In [7]:
trained_model = train_model()

Creating synthetic dataset...
Initializing model...
Starting training...


  device = xm.xla_device()


Epoch 1/3, Batch 0, Loss: 2.3174, Batch Acc: 12.50%
Epoch 1/3, Batch 10, Loss: 2.2902, Batch Acc: 15.62%
Epoch 1/3, Batch 20, Loss: 2.2995, Batch Acc: 12.50%
Epoch 1/3, Batch 30, Loss: 2.3156, Batch Acc: 12.50%
Epoch 1 completed - Avg Loss: 2.3108, Accuracy: 10.40%
Test - Loss: 2.3155, Accuracy: 7.50%
--------------------------------------------------
Epoch 2 completed - Avg Loss: 0.0000, Accuracy: 0.00%
--------------------------------------------------
Epoch 3 completed - Avg Loss: 0.0000, Accuracy: 0.00%
--------------------------------------------------
Training completed!


In [8]:
device = xm.xla_device()
test_input = torch.randn(1, 784, device=device)
with torch.no_grad():
    prediction = trained_model(test_input)
    predicted_class = torch.argmax(prediction, dim=1)
    print(f"Sample prediction: class {predicted_class.item()}")

Sample prediction: class 6


  device = xm.xla_device()


In [None]:

# Main execution
if __name__ == "__main__":
    try:
        # Test basic operations first
        test_tpu_operations()

        # Train the model
        print("\n=== Starting Model Training ===")
        trained_model = train_model()

        print("\n✓ PyTorch TPU example completed successfully!")

        # Final test with the trained model
        device = xm.xla_device()
        test_input = torch.randn(1, 784, device=device)
        with torch.no_grad():
            prediction = trained_model(test_input)
            predicted_class = torch.argmax(prediction, dim=1)
            print(f"Sample prediction: class {predicted_class.item()}")

    except Exception as e:
        print(f"Error: {e}")
        import traceback
        traceback.print_exc()