<a href="https://colab.research.google.com/github/MahdiNouraie/CNN-FashionMNIST/blob/main/Alexnet_FashionMNIST_8Cores.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
!pip install cloud-tpu-client==0.10 torch==1.11.0 https://storage.googleapis.com/tpu-pytorch/wheels/colab/torch_xla-1.11-cp37-cp37m-linux_x86_64.whl

In [None]:
import torch
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torchvision
from torchvision import datasets
import torchvision.transforms as transforms
import torch_xla.distributed.parallel_loader as pl
import time

In [None]:
assert os.environ['COLAB_TPU_ADDR'], 'Make sure to select TPU from Edit > Notebook settings > Hardware accelerator'

In [None]:
def map_fn(index, flags):
  # Sets a common random seed - both for initialization and ensuring graph is the same
  torch.manual_seed(flags['seed'])
  # Acquires the (unique) Cloud TPU core corresponding to this process's index
  device = xm.xla_device()  
  ## Dataloader construction
  # Creates the transform for the raw Torchvision data
  # See https://pytorch.org/docs/stable/torchvision/models.html for normalization
  # Pre-trained TorchVision models expect RGB (3 x H x W) images
  # H and W should be >= 224
  # Loaded into [0, 1] and normalized as follows:
  normalize = transforms.Normalize((0.5,) , (0.5,) , )
  to_rgb = transforms.Lambda(lambda image: image.convert('RGB'))
  resize = transforms.Resize((224, 224))
  my_transform = transforms.Compose([resize, to_rgb, transforms.ToTensor(), normalize])

  # Downloads train and test datasets
  # Note: master goes first and downloads the dataset only once (xm.rendezvous)
  #   all the other workers wait for the master to be done downloading.

  if not xm.is_master_ordinal():
    xm.rendezvous('download_only_once')

  train_dataset = datasets.FashionMNIST(
    '~/ .pytorch/F_MNIST_data' ,
    train=True,
    download=True,
    transform=my_transform)

  test_dataset = datasets.FashionMNIST(
    '~/ .pytorch/F_MNIST_data' ,
    train=False,
    download=True,
    transform=my_transform)
  
  if xm.is_master_ordinal():
    xm.rendezvous('download_only_once')
  
  # Creates the (distributed) train sampler, which let this process only access
  # its portion of the training dataset.
  train_sampler = torch.utils.data.distributed.DistributedSampler(
    train_dataset,
    num_replicas=xm.xrt_world_size(),
    rank=xm.get_ordinal(),
    shuffle=True)
  
  test_sampler = torch.utils.data.distributed.DistributedSampler(
    test_dataset,
    num_replicas=xm.xrt_world_size(),
    rank=xm.get_ordinal(),
    shuffle=False)
  
  # Creates dataloaders, which load data in batches
  # Note: test loader is not shuffled or sampled
  train_loader = torch.utils.data.DataLoader(
      train_dataset,
      batch_size=flags['batch_size'],
      sampler=train_sampler,
      num_workers=flags['num_workers'],
      drop_last=True)

  test_loader = torch.utils.data.DataLoader(
      test_dataset,
      batch_size=flags['batch_size'],
      sampler=test_sampler,
      shuffle=False,
      num_workers=flags['num_workers'],
      drop_last=True)
  

  ## Network, optimizer, and loss function creation

  # Creates AlexNet for 10 classes
  # Note: each process has its own identical copy of the model
  #  Even though each model is created independently, they're also
  #  created in the same way.
  net = torchvision.models.alexnet(num_classes=10).to(device).train()

  loss_fn = torch.nn.CrossEntropyLoss()
  optimizer = torch.optim.Adam(net.parameters())


  ## Trains
  train_start = time.time()
  for epoch in range(flags['num_epochs']):
    para_train_loader = pl.ParallelLoader(train_loader, [device]).per_device_loader(device)
    for batch_num, batch in enumerate(para_train_loader):
      data, targets = batch 

      # Acquires the network's best guesses at each class
      output = net(data)

      # Computes loss
      loss = loss_fn(output, targets)

      # Updates model
      optimizer.zero_grad()
      loss.backward()

      # Note: optimizer_step uses the implicit Cloud TPU context to
      #  coordinate and synchronize gradient updates across processes.
      #  This means that each process's network has the same weights after
      #  this is called.
      # Warning: this coordination requires the actions performed in each 
      #  process are the same. In more technical terms, the graph that
      #  PyTorch/XLA generates must be the same across processes. 
      xm.optimizer_step(optimizer)  # Note: barrier=True not needed when using ParallelLoader 

  elapsed_train_time = time.time() - train_start
  print("Process", index, "finished training. Train time was:", elapsed_train_time) 


  ## Evaluation
  # Sets net to eval and no grad context 
  net.eval()
  eval_start = time.time()
  with torch.no_grad():
    num_correct = 0
    total_guesses = 0

    para_train_loader = pl.ParallelLoader(test_loader, [device]).per_device_loader(device)
    for batch_num, batch in enumerate(para_train_loader):
      data, targets = batch

      # Acquires the network's best guesses at each class
      output = net(data)
      best_guesses = torch.argmax(output, 1)

      # Updates running statistics
      num_correct += torch.eq(targets, best_guesses).sum().item()
      total_guesses += flags['batch_size']
  
  elapsed_eval_time = time.time() - eval_start
  print("Process", index, "finished evaluation. Evaluation time was:", elapsed_eval_time)
  print("Process", index, "guessed", num_correct, "of", total_guesses, "correctly for", num_correct/total_guesses * 100, "% accuracy.")

In [None]:
# Configures training (and evaluation) parameters
flags = {}
flags['batch_size'] = 32
flags['num_workers'] = 8
flags['num_epochs'] = 5
flags['seed'] = 1234

xmp.spawn(map_fn, args=(flags,), nprocs=8, start_method='fork')

  cpuset_checked))
  cpuset_checked))
  cpuset_checked))
  cpuset_checked))
  cpuset_checked))
  cpuset_checked))
  cpuset_checked))
  cpuset_checked))


Process 6 finished training. Train time was: 1708.569236755371
Process 4 finished training. Train time was: 1710.6655972003937
Process 7 finished training. Train time was: 1711.5464255809784
Process 5 finished training. Train time was: 1718.607072353363
Process 2 finished training. Train time was: 1713.4141731262207
Process 1 finished training. Train time was: 1708.0153498649597
Process 3 finished training. Train time was: 1708.4740352630615
Process 0 finished training. Train time was: 1721.9642391204834
Process 6 finished evaluation. Evaluation time was: 60.750757694244385
Process 6 guessed 1137 of 1248 correctly for 91.10576923076923 % accuracy.
Process 5 finished evaluation. Evaluation time was: 62.32980298995972
Process 5 guessed 1115 of 1248 correctly for 89.34294871794873 % accuracy.
Process 1 finished evaluation. Evaluation time was: 62.50452494621277
Process 1 guessed 1109 of 1248 correctly for 88.86217948717949 % accuracy.
Process 2 finished evaluation. Evaluation time was: 64