In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
import time

# Define a simple linear model
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(1000, 500)
        self.fc2 = nn.Linear(500, 1)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Function to train the model and measure time
def train_model(device, epochs=10):
    # Create the model and move to device
    model = SimpleNN().to(device)
    criterion = nn.MSELoss()
    optimizer = optim.SGD(model.parameters(), lr=0.01)

    # Generate dummy data
    x = torch.randn(10000, 1000).to(device)  # 10,000 samples with 1,000 features
    y = torch.randn(10000, 1).to(device)

    start_time = time.time()
    
    # Training loop
    for epoch in range(epochs):
        optimizer.zero_grad()
        outputs = model(x)
        loss = criterion(outputs, y)
        loss.backward()
        optimizer.step()

    elapsed_time = time.time() - start_time
    print(f"Training completed on {device} in {elapsed_time:.2f} seconds")
    return elapsed_time

# Compare CPU and GPU performance
if __name__ == "__main__":
    # Check if GPU is available
    gpu_available = torch.cuda.is_available()
    
    print("Training on CPU...")
    cpu_time = train_model(device=torch.device("cpu"))
    
    if gpu_available:
        print("\nTraining on GPU...")
        gpu_time = train_model(device=torch.device("cuda"))
        print(f"\nSpeedup (GPU vs CPU): {cpu_time / gpu_time:.2f}x")
    else:
        print("\nGPU is not available for comparison.")


Training on CPU...
Training completed on cpu in 0.94 seconds

Training on GPU...


RuntimeError: CUDA error: CUDA-capable device(s) is/are busy or unavailable
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
