In [1]:
import torch

torch.manual_seed = 420
torch.backends.cudnn.deterministic = True

In [2]:
from torchvision import transforms
from torchvision.datasets import FashionMNIST
from torch.utils.data import DataLoader, random_split

transform_train = transforms.Compose([
  transforms.RandomRotation(180),
  transforms.RandomHorizontalFlip(),
  transforms.RandomVerticalFlip(),
  transforms.ToTensor(), 
  transforms.Normalize((0.5,), (0.5,)),
  ])
transform_test = transforms.Compose([
  transforms.ToTensor(), 
  transforms.Normalize((0.5,), (0.5,))
  ])

# Load raw datasets without transforms first
dataset_raw = FashionMNIST(root="./data", train=True, download=True, transform=None)
dataset_test = FashionMNIST(root="./data", train=False, download=True, transform=transform_test)

train_size = int(0.8 * len(dataset_raw))
val_size = len(dataset_raw) - train_size
train_indices, val_indices = random_split(range(len(dataset_raw)), [train_size, val_size])

class TransformedSubset: 
  def __init__(self, dataset, indices, transform=None):
    self.dataset = dataset
    self.indices = indices
    self.transform = transform

  def __getitem__(self, idx): 
    image, label = self.dataset[self.indices[idx]]
    if self.transform:
      image = self.transform(image)
    return image, label
  
  def __len__(self):
    return len(self.indices)

dataset_train = TransformedSubset(dataset_raw, train_indices.indices, transform_train)
dataset_val = TransformedSubset(dataset_raw, val_indices.indices, transform_test)

# Update your DataLoader creation with optimizations
train_loader = DataLoader(dataset_train, batch_size=128, shuffle=True)
val_loader = DataLoader(dataset_val, batch_size=128, shuffle=False)
test_loader = DataLoader(dataset_test, batch_size=128, shuffle=False)

In [3]:
# Tensorboard
from torch.utils.tensorboard import SummaryWriter

# Initialize TensorBoard writer
writer = SummaryWriter(log_dir="./logs")

In [4]:
class CNN(torch.nn.Module):
  """
  A Convolutional Neural Network for image classification in PyTorch
  """
  def __init__(self):
    super(CNN, self).__init__()
    self.model = torch.nn.Sequential(
      torch.nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1), # 28x28x1 -> 28x28x32
      torch.nn.ReLU(),
      torch.nn.MaxPool2d(kernel_size=2, stride=2), # 28x28x32 -> 14x14x32
      torch.nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1), # 14x14x32 -> 14x14x64
      torch.nn.ReLU(),
      torch.nn.MaxPool2d(kernel_size=2, stride=2), # 14x14x64 -> 7x7x64
      torch.nn.Flatten(), # 7x7x64 -> 3136
      torch.nn.Linear(3136, 128),
      torch.nn.ReLU(),
      torch.nn.Linear(128, 10)
    )

  def forward(self, x):
    return self.model(x)
  
cnn_model = CNN()

In [5]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Move model to device
cnn_model = cnn_model.to(device)

Using device: cuda


In [7]:
# Optimizer and Loss Function
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(cnn_model.model.parameters(), lr=0.001)

In [8]:
from torch.profiler import profile, record_function, ProfilerActivity, schedule

# Validation function
def validate_model(model, data_loader, loss_fn, device):
    model.eval()  # Set model to evaluation mode
    total_loss = 0.0
    correct = 0
    total = 0
    
    with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
                record_shapes=True,
                profile_memory=True,
                with_stack=True,
                schedule=schedule(wait=0, warmup=0, active=999999),
                on_trace_ready=torch.profiler.tensorboard_trace_handler("./logs"),
                ) as prof:
        with torch.no_grad():  # Disable gradient computation for validation
            for images, labels in data_loader:
                # Move data to device
                images, labels = images.to(device), labels.to(device)

                outputs = model(images)
                loss = loss_fn(outputs, labels)
                prof.step()  # Step the profiler
                total_loss += loss.item()
                
                # Calculate accuracy
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        avg_loss = total_loss / len(data_loader)
        accuracy = 100 * correct / total

    print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

    return avg_loss, accuracy

In [9]:
import os

def save_best_model(model, epoch, val_loss, best_val_loss, out_dir):
  # Check if out_dir exists and create it of it does not
  os.makedirs(out_dir, exist_ok=True)

  torch.save(model.state_dict(), f"{out_dir}/model_epoch_{epoch+1}_{val_loss:.4f}.pth")
  if val_loss < best_val_loss:
    best_val_loss = val_loss
    torch.save(model.state_dict(), f"{out_dir}/model_best.pth")
  return best_val_loss

In [10]:
from torch.utils import benchmark

# Training loop 

epochs = 5
best_val_loss = float("inf")

for epoch in range(epochs):
    cnn_model.train()  # Set model to training mode
    running_loss = 0.0
    for batch_idx, (images, labels) in enumerate(train_loader):
        # Move data to device
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = cnn_model(images)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    
    avg_loss = running_loss / len(train_loader)
    val_loss, val_accuracy = validate_model(cnn_model, val_loader, loss_fn, device)
    print(f"Epoch {epoch+1}/{epochs} completed, Average Loss: {avg_loss:.4f}, Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.2f}%")
    best_val_loss = save_best_model(cnn_model, epoch, val_loss, best_val_loss, out_dir="./trained/CNN_basic")

    writer.add_scalar("Loss/Train", avg_loss, epoch+1)
    writer.add_scalar("Loss/Val", val_loss, epoch+1)
    writer.add_scalar("Accuracy/Val", val_accuracy, epoch+1)

writer.flush()

  warn("Profiler won't be using warmup, this can skew profiler results")


-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          ProfilerStep*         0.00%       0.000us         0.00%       0.000us       0.000us        4.665s      3043.32%        4.665s      49.102ms           0 b           0 b           0 b           0 

In [None]:
# Final test on test set
test_loss, test_accuracy = validate_model(cnn_model, test_loader, loss_fn, device)
print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%")