In [2]:
import numpy as np
import matplotlib.pyplot as plt
from tqdm import trange, tqdm
import torch
from torch import nn
from torchvision.datasets import CIFAR10 
from torchvision import transforms
from models.resnext import ResNeXt29_4x64d

In [None]:
if torch.cuda.is_available():
    print(torch.cuda.get_device_name(0))

NVIDIA A100-SXM4-40GB


In [3]:
model = resnest50()
model.fc = nn.Linear(in_features=model.fc.in_features, out_features=10)

In [5]:
model.load_state_dict(torch.load("model_history/resnest_baseline.pt"))

<All keys matched successfully>

In [6]:
model.eval()

ResNet(
  (conv1): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  )
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): SplAtConv2d(
        (conv): Conv2d

In [17]:
baseline_data = np.load("model_history/resnext_2_worker_fp16.npz", allow_pickle=True)

In [18]:
baseline_data["loss"]

array([2.14070058, 1.53067776, 1.29189843, 1.16748973, 1.08763071,
       1.03469546, 1.00616927, 0.98995403, 0.97422721, 0.95867732,
       0.93864907, 0.93228199, 0.92666757, 0.92226353, 0.91506842,
       0.90121291, 0.90425939, 0.90047718, 0.89804242, 0.88843695,
       0.88560818, 0.8846927 , 0.88052943, 0.884715  , 0.88223018,
       0.8779879 , 0.87230854, 0.86613067, 0.86905197, 0.87051475,
       0.86893748, 0.85721281, 0.85546744, 0.86011983, 0.84744767,
       0.85180915, 0.85266313, 0.84710795, 0.8449118 , 0.84302348,
       0.8425541 , 0.84347692, 0.83443035, 0.83720869, 0.83906438,
       0.82762782, 0.83679529, 0.82946548, 0.82972193, 0.82535576,
       0.83572099, 0.83373728, 0.83606296, 0.83041573, 0.83182715,
       0.82722206, 0.82820432, 0.82855812, 0.83061915, 0.82978878,
       0.83138553, 0.82741861, 0.82593848, 0.82792137, 0.82581878,
       0.82429438, 0.82552342, 0.82441319, 0.81894006, 0.82261146,
       0.82287922, 0.81848722, 0.81934425, 0.82532804, 0.81161

In [19]:
baseline_data["acc"]

array([25.886, 42.304, 52.616, 57.898, 60.882, 62.792, 63.814, 64.498,
       65.232, 65.578, 66.404, 66.504, 66.778, 66.836, 67.422, 67.62 ,
       67.524, 68.048, 67.858, 68.33 , 68.542, 68.324, 68.658, 68.294,
       68.424, 68.676, 69.052, 69.19 , 69.104, 68.88 , 69.124, 69.51 ,
       69.758, 69.494, 69.954, 69.874, 69.88 , 69.902, 69.82 , 70.034,
       70.194, 70.166, 70.384, 70.312, 70.232, 70.472, 70.308, 70.79 ,
       70.526, 70.748, 70.456, 70.308, 70.362, 70.906, 70.656, 70.934,
       70.712, 70.478, 70.606, 70.554, 70.682, 70.912, 70.636, 70.72 ,
       70.764, 70.774, 70.834, 70.77 , 70.982, 70.768, 70.792, 71.222,
       70.984, 70.75 , 71.336, 71.136, 71.202, 71.43 , 71.434, 71.486,
       71.634, 71.498, 71.616, 71.566, 71.656, 71.746, 71.588, 71.638,
       71.83 , 71.768, 71.798, 71.688, 71.76 , 71.95 , 71.68 , 71.826,
       71.506, 71.634, 71.644, 71.606, 72.01 , 71.746, 72.052, 71.734,
       71.762, 71.53 , 71.716, 71.668, 71.852, 71.74 , 71.886, 71.66 ,
      

In [20]:
baseline_data["test_acc"]

array([21.78, 24.45, 44.65, 49.19, 51.94, 49.35, 57.9 , 56.75, 47.65,
       61.43, 63.83, 49.85, 64.75, 65.14, 58.77, 59.34, 51.65, 63.88,
       55.7 , 67.95, 59.28, 58.31, 62.46, 57.16, 63.98, 62.13, 60.01,
       51.97, 57.85, 34.67, 64.81, 59.08, 59.51, 61.35, 10.  , 53.55,
       53.07, 52.85, 56.56, 49.56, 37.44, 47.93, 53.02, 37.65, 48.59,
       42.99, 53.96, 59.87, 51.03, 53.25, 56.73, 58.09, 51.19, 54.47,
       30.45, 57.53, 65.36, 57.53, 63.32, 64.03, 53.01, 60.94, 53.32,
       61.69, 64.94, 64.78, 58.81, 44.29, 59.65, 59.08, 51.59, 56.78,
       55.73, 53.31, 55.55, 49.67, 48.84, 51.44, 57.06, 64.5 , 57.87,
       47.41, 54.83, 44.47, 44.07, 50.09, 55.28, 59.11, 34.25, 24.47,
       47.69, 59.23, 60.16, 51.64, 42.68, 64.98, 48.9 , 62.63, 61.03,
       40.57, 45.59, 18.78, 57.28, 59.66, 56.44, 61.1 , 57.99, 65.49,
       53.74, 69.34, 50.45, 49.14, 62.82, 66.31, 54.23, 65.32, 61.62,
       24.2 , 59.68, 45.71, 66.96, 60.43, 54.77, 56.24, 64.93, 63.67,
       29.92, 50.29,

In [3]:
half = np.load("model_history/resnext_2_worker_checkpoint.npz", allow_pickle=True)
half["loss"]

array([2.16599903, 1.50904563, 1.25347727, 1.14159202, 1.08257909])

In [4]:
half["acc"]

array([26.798, 43.736, 54.56 , 58.808, 61.132])

In [6]:
half["test_acc"]

array([32.03, 40.99, 48.22, 51.04, 32.29])