In [None]:
import matplotlib.pyplot as plt
import numpy as np
from sklearn.linear_model import LinearRegression

In [None]:
measurement = {
    'c24': np.load('../npy/time_cifar100_resnet18_r24_10_601_10.npy', allow_pickle=True).item(),
    'c32': np.load('../npy/time_cifar100_resnet18_r32_10_571_10.npy', allow_pickle=True).item(),
    'cx24': np.load('../npy/time_cifar100_resnet18_r24_20_2561_20_xla.npy', allow_pickle=True).item(),
    'cx32': np.load('../npy/time_cifar100_resnet18_r32_20_1461_20_xla.npy', allow_pickle=True).item(),
    'ia160': np.load('../npy/time_imagenet_resnet18_r160_10_341_10_amp.npy', allow_pickle=True).item(),
    'ia224': np.load('../npy/time_imagenet_resnet18_r224_10_161_10_amp.npy', allow_pickle=True).item(),
    'ia288': np.load('../npy/time_imagenet_resnet18_r288_10_141_10_amp.npy', allow_pickle=True).item(),
    'iax160': np.load('../npy/time_imagenet_resnet18_r160_20_1281_20_amp_xla.npy', allow_pickle=True).item(),
    'iax224': np.load('../npy/time_imagenet_resnet18_r224_20_621_20_amp_xla.npy', allow_pickle=True).item(),
    'iax288': np.load('../npy/time_imagenet_resnet18_r288_20_301_20_amp_xla.npy', allow_pickle=True).item(),
}

In [None]:
for key, value in measurement['c24'].items():
    print(key, end=', ')

In [None]:
for key, value in measurement.items():
    validation_tuple = (
        value['batch_size'][0],
        value['batch_size'][-1] + 1,
        value['batch_size'][1] - value['batch_size'][0],
    )
    if validation_tuple != value['sss']:
        raise ValueError(
            f'for the file "{key}", '
            f'"validation_tuple" {validation_tuple} and "sss" {value["sss"]} are mismatched'
        )
    print(key, value['sss'])

In [None]:
reg_model = {}

for key, value in measurement.items():
    reg_model[key] = LinearRegression().fit(
        np.array(value['batch_size']).reshape(-1, 1),
        value['avg_train_time'],
    )
    print(key, reg_model[key].intercept_, reg_model[key].coef_)

In [None]:
type(reg_model['c24'].coef_[0])

In [None]:
prediction = {}

for key, value in reg_model.items():
    prediction[key] = {
        'batch_size': np.arange(1, measurement[key]['batch_size'][-1] + 1),
    }
    prediction[key]['batch_time'] = value.predict(prediction[key]['batch_size'].reshape(-1, 1))
    if 'c' in key: # cifar
        prediction[key]['epoch_time'] = (
            prediction[key]['batch_time'] * np.ceil(50000 / prediction[key]['batch_size'])
        )
    elif 'i' in key: # imagenet
        prediction[key]['epoch_time'] = (
            prediction[key]['batch_time'] * np.ceil(1281167 / prediction[key]['batch_size'])
        )
    else:
        prediction[key]['epoch_time'] = None

In [None]:
DPI = 72 # [72, 150, 240, 300]

In [None]:
plt.figure(dpi=DPI)
for key, value in measurement.items():
    if 'c' in key:
        plt.plot(value['batch_size'], value['avg_train_time'], label=key + ', measurement')
for key, value in prediction.items():
    if 'c' in key:
        plt.plot(value['batch_size'], value['batch_time'], '--', label=key + ', prediction')
plt.title('Training CIFAR-100 on ResNet-18')
plt.xlabel('Batch Size')
plt.ylabel('Training Time for a Batch (sec)')
plt.legend()
plt.show()

In [None]:
plt.figure(dpi=DPI)
for key, value in measurement.items():
    if 'i' in key:
        plt.plot(value['batch_size'], value['avg_train_time'], label=key + ', measurement')
for key, value in prediction.items():
    if 'i' in key:
        plt.plot(value['batch_size'], value['batch_time'], '--', label=key + ', prediction')
plt.title('Training ImageNet on ResNet-18')
plt.xlabel('Batch Size')
plt.ylabel('Training Time for a Batch (sec)')
plt.legend()
plt.show()

In [None]:
plt.figure(dpi=DPI)
for key, value in prediction.items():
    if 'c' in key:
        plt.plot(value['batch_size'], value['epoch_time'], label=key + ', prediction')
plt.title('Training CIFAR-100 on ResNet-18')
plt.xlabel('Batch Size')
plt.ylabel('Training Time for an Epoch (sec)')
plt.legend()
plt.show()

In [None]:
plt.figure(dpi=DPI)
for key, value in prediction.items():
    if 'i' in key:
        plt.plot(value['batch_size'], value['epoch_time'], label=key + ', prediction')
plt.title('Training ImageNet on ResNet-18')
plt.xlabel('Batch Size')
plt.ylabel('Training Time for an Epoch (sec)')
plt.legend()
plt.show()

In [None]:
min_time_index = np.argmin(prediction['iax160']['epoch_time'])
print('prediction training time (batch_size, epoch_time)')
print(
    'min_time',
    prediction['iax160']['batch_size'][min_time_index],
    prediction['iax160']['epoch_time'][min_time_index],
)
print(
    'last_time',
    prediction['iax160']['batch_size'][-1],
    prediction['iax160']['epoch_time'][-1],
)