In [4]:
!pip install cloud-tpu-client==0.10 torch==2.0.0 torchvision==0.15.1 https://storage.googleapis.com/tpu-pytorch/wheels/colab/torch_xla-2.0-cp39-cp39-linux_x86_64.whl


Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting torch-xla==2.0
  Using cached https://storage.googleapis.com/tpu-pytorch/wheels/colab/torch_xla-2.0-cp39-cp39-linux_x86_64.whl (115.7 MB)


In [5]:
!pip install timm

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting timm
  Downloading timm-0.6.13-py3-none-any.whl (549 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m549.1/549.1 kB[0m [31m10.9 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub
  Downloading huggingface_hub-0.14.1-py3-none-any.whl (224 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m224.5/224.5 kB[0m [31m21.0 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: huggingface-hub, timm
Successfully installed huggingface-hub-0.14.1 timm-0.6.13


In [6]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [7]:
import os
import timm
import time
import torch
import torch_xla.core.xla_model as xm

In [8]:
model_names = [
    'efficientnet_b0',
    'efficientnet_b1',
    'efficientnet_b2',
    'efficientnet_b3',
    'efficientnet_b4',
    'mobilenetv2_100',
    'nasnetalarge',
    'resnet18',
    'resnet26',
    'resnet34',
    'resnet50',
    'resnet101',
    'resnet152',
    'vgg11',
    'vgg13',
    'vgg16',
    'vgg19',
    'xception'
]

batch_sizes = [16, 32, 64]
num_inference_steps = [32]
data_precisions = {
    "float16": torch.float16,
    "bfloat16": torch.bfloat16,
    "float32": torch.float32,
    "float64": torch.float64
}

In [25]:
cur_path = os.getcwd()
dir_path = os.path.join(cur_path, "drive/MyDrive/hml_project/tpu_results_pytorch")
if not os.path.exists(dir_path):
  os.mkdir(dir_path)

In [26]:
device = xm.xla_device()
print("Using device:", device)

Using device: xla:1


In [27]:
num_inference_steps = [10]

In [35]:
for model_name in model_names:
  for precision_name, precision in zip(data_precisions.keys(), data_precisions.values()):
    for batch_size in batch_sizes:
      for num_inference_step in num_inference_steps:

        output_dir = f"{model_name}_{precision_name}_{batch_size}_{num_inference_step}"
        output_path = os.path.join(dir_path, output_dir)
        if not os.path.exists(output_path):
          os.mkdir(output_path)
        # else:
        #   continue
        print(output_dir)
        
        model = timm.create_model(model_name, pretrained=True)
        model.to(device)
        
        if precision == torch.float16:
          model = model.half()
        elif precision == torch.bfloat16:
          model = model.to(dtype = torch.bfloat16)
        elif precision == torch.float32:
          model = model
        elif precision == torch.float64:
          model = model.double()

        total_inference_time = 0
        total_dev_to_host_time = 0
        total_host_to_dev_time = 0
        total_time = 0

        for step in range(num_inference_step):
            # generate random input tensor
            # inputs_shape = MODEL_INPUT_SHAPES[model_name]
            
            # inputs = torch.randn(batch_size, inputs_shape[2], inputs_shape[1], inputs_shape[0]).to(self.device)
            inputs = torch.randn(batch_size, 3, 224, 224, dtype = precision)
            
            # move the input tensor to the device
            try:
                start_time_devt = time.time()
                inputs = torch.Tensor.to(inputs, device)
                end_time_devt = time.time()

            except (AttributeError, RuntimeError, TypeError):
                print(f"There was a problem with the input tensor or device while passing input tensor to the device: {device}.")

            # run the model on the input tensor to get the output
            try:
                # record the start time of the inference
                start_time_inf = time.time()

                output = model(inputs)
                
                # record the end time of the inference
                end_time_inf = time.time()

                output = torch.Tensor.to(output, "cpu")

                end_time = time.time()

            except (ValueError, IndexError):
                print("There was a problem with the input data while passing it to the model to compute the output.")

            # calculate the inference time in seconds
            dev_to_host_time = end_time_devt - start_time_devt
            inference_time = end_time_inf - start_time_inf
            host_to_dev_time = end_time - end_time_inf
            
            total_inference_time += inference_time
            total_dev_to_host_time += dev_to_host_time
            total_host_to_dev_time += host_to_dev_time
            total_time += inference_time + dev_to_host_time + host_to_dev_time

        # calculate average inference time over all steps
        avg_inference_time = total_inference_time / num_inference_step
        avg_dev_to_host_time = total_dev_to_host_time / num_inference_step
        avg_host_to_dev_time = total_host_to_dev_time / num_inference_step
        avg_time = total_time / num_inference_step

        # Print the results
        fpath = os.path.join(output_path, 'mylog.txt')
        f = open(fpath, 'w')
        f.write(f"Model: {model_name}\n")
        f.write(f"Batch size: {batch_size}\n")
        f.write(f"Precision type: {precision_name}\n")
        f.write(f"Inference time: {avg_inference_time}\n")
        f.write(f"Host-to-device communication time: {avg_host_to_dev_time}\n")
        f.write(f"Device-to-host communication time: {avg_dev_to_host_time}\n")
        f.write(f"Total time: {avg_time}\n")
        f.close()
        break
      break
    break
  break

efficientnet_b0_float16_16_10
