In [None]:
# workspace/notebooks/gpu/simple_cnn_test.ipynb

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import time

print("=== Simple CNN Test (No BatchNorm Issues) ===")

# Set environment variables to fix ROCm issues
import os
os.environ['MIOPEN_DISABLE_CACHE'] = '1'
os.environ['PYTORCH_HIP_ALLOC_CONF'] = 'max_split_size_mb:512'

class SimpleCNN(nn.Module):
    """A simple CNN without BatchNorm for ROCm compatibility"""
    def __init__(self, num_classes=10):
        super(SimpleCNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(256 * 4 * 4, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(512, num_classes)
        )
    
    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

if torch.cuda.is_available():
    device = torch.device('cuda:0')
    print(f"Using device: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
    
    # Load CIFAR-10
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    
    # Test different batch sizes
    for batch_size in [128, 256, 512, 1024, 2048, 4096]:
        print(f"\n{'='*50}")
        print(f"Testing batch size: {batch_size}")
        print(f"{'='*50}")
        
        try:
            trainset = torchvision.datasets.CIFAR10(
                root='./data', train=True, download=True, transform=transform
            )
            trainloader = DataLoader(
                trainset, batch_size=batch_size, shuffle=True, num_workers=0  # 0 workers for stability
            )
            
            # Create model
            model = SimpleCNN(num_classes=10).to(device)
            
            criterion = nn.CrossEntropyLoss()
            optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
            
            # Get one batch
            data_iter = iter(trainloader)
            inputs, labels = next(data_iter)
            inputs, labels = inputs.to(device), labels.to(device)
            
            # Warmup
            with torch.no_grad():
                _ = model(inputs)
            
            # Benchmark
            torch.cuda.synchronize()
            start = time.time()
            
            for _ in range(10):  # 10 iterations
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
            
            torch.cuda.synchronize()
            total_time = time.time() - start
            
            samples_per_sec = (batch_size * 10) / total_time
            memory_used = torch.cuda.memory_allocated() / 1e9
            
            print(f"✓ Success!")
            print(f"  Time for 10 batches: {total_time:.3f}s")
            print(f"  Throughput: {samples_per_sec:.0f} samples/sec")
            print(f"  Memory: {memory_used:.2f} GB")
            print(f"  Peak memory: {torch.cuda.max_memory_allocated() / 1e9:.2f} GB")
            
            # Cleanup
            del model, inputs, labels
            torch.cuda.empty_cache()
            
        except RuntimeError as e:
            print(f"✗ Failed at batch size {batch_size}: {str(e)[:100]}...")
            torch.cuda.empty_cache()
            continue
        except Exception as e:
            print(f"✗ Error at batch size {batch_size}: {type(e).__name__}")
            torch.cuda.empty_cache()
            continue
    
    print(f"\n{'='*60}")
    print(f"Maximum stable batch size found!")
    print(f"Your 68GB VRAM can handle massive batches!")
    print(f"{'='*60}")
    
else:
    print("No GPU available!")

=== Simple CNN Test (No BatchNorm Issues) ===
Using device: AMD Radeon Graphics
VRAM: 68.72 GB

Testing batch size: 128
