In [1]:
import torch 
import numpy as np
import transformers, datasets, accelerate, tensorboard, evaluate
from models import Student
import torch.nn as nn
import torch.optim as optim
from util import *
from loss import LossCalulcator
from pretrained_kd import *
from datasets import Array3D, ClassLabel, Features, load_dataset, Image
from matplotlib import pyplot
from tqdm import tqdm
from transformers import AdamW, ViTFeatureExtractor, ViTModel


In [22]:
# define hyperparameters
temperature = 10
distillw = 0.1
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
epochs = 100
lr = 0.1
lr_stepsize = 20
batch_size = 1000
test_batch_size = 100

In [3]:
# Load model directly
from transformers import AutoImageProcessor, AutoModelForImageClassification

processor = AutoImageProcessor.from_pretrained("aaraki/vit-base-patch16-224-in21k-finetuned-cifar10")
teacher = AutoModelForImageClassification.from_pretrained("aaraki/vit-base-patch16-224-in21k-finetuned-cifar10")
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")



In [4]:
student = Student(3, 32, 10, 0.2)
student.load_state_dict(torch.load('model/cifar10_github/epoch_99.bin'))
# model.eval()

<All keys matched successfully>

In [6]:
dataset = load_dataset("cifar10")
split = dataset['train'].train_test_split(test_size=(5000.0/50000))
dataset['splitted_train'] = split['train']
dataset['validation'] = split['test']
transform = torchvision.transforms.ToTensor()

def preprocess_images(examples):
    # get batch of images
    images = examples['img']
    examples['img'] = [np.array(image) for image in examples['img']]
    # convert to list of NumPy arrays of shape (C, H, W)
    images = [np.array(image, dtype=np.uint8) for image in images]
    images = [np.moveaxis(image, source=-1, destination=0) for image in images]
    # preprocess and add pixel_values
    inputs = feature_extractor(images=images)
    examples['pixel_values'] = inputs['pixel_values']
    return examples

features = Features({
    'label': ClassLabel(names=['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']),
    # 'img': Image(decode=True, id=None),
    # could probably change img to int for faster inference
    'img': Array3D(dtype="float32", shape=(3,32,32)),
    'pixel_values': Array3D(dtype="float32", shape=(3, 224, 224)), })

# preprocessed_train = dataset['train'].map(preprocess_images, batched=True, features=features)
preprocessed_val = dataset['validation'].map(preprocess_images, batched=True, features=features)
preprocessed_test = dataset['test'].map(preprocess_images, batched=True, features=features)
preprocessed_splitted_train = dataset['splitted_train'].map(preprocess_images, batched=True, features=features)

# set format to PyTorch
# preprocessed_train.set_format('torch', columns=['img', 'pixel_values', 'label'])
preprocessed_val.set_format('torch', columns=['img', 'pixel_values', 'label'])
preprocessed_test.set_format('torch', columns=['img', 'pixel_values', 'label'])
preprocessed_splitted_train.set_format('torch', columns=['img', 'pixel_values', 'label'])

Downloading readme:   0%|          | 0.00/5.16k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/120M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/23.9M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/50000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/10000 [00:00<?, ? examples/s]

Map:   0%|          | 0/5000 [00:00<?, ? examples/s]

Map:   0%|          | 0/10000 [00:00<?, ? examples/s]

Map:   0%|          | 0/45000 [00:00<?, ? examples/s]

In [23]:
preprocessed_dataloaders = {}
preprocessed_dataloaders['splitted_train'] = torch.utils.data.DataLoader(preprocessed_splitted_train, batch_size=batch_size, shuffle=True)
preprocessed_dataloaders['validation'] = torch.utils.data.DataLoader(preprocessed_val, batch_size=batch_size)
preprocessed_dataloaders['test'] = torch.utils.data.DataLoader(preprocessed_test, batch_size=test_batch_size)

In [8]:
# preprocessed_dataloaders = construct_dataloaders((preprocessed_train, preprocessed_test, preprocessed_splitted_train, preprocessed_val), batch_size, shuffle_train=True)

In [35]:
import time
teacher_input = processor(images=dataset['train']['img'][0], return_tensors="pt")
start = time.time()
with torch.no_grad():
    teacher_output = teacher(**teacher_input)
print(time.time()-start)
print(teacher_output.logits)

0.12718605995178223
tensor([[ 3.3682, -0.3160, -0.2798, -0.5006, -0.5529, -0.5625, -0.6144, -0.4671,
          0.2807, -0.3066]])


In [36]:
import time
student_input = next(iter(preprocessed_dataloaders['test']))  
start = time.time()
student_output = student(student_input['img'])
print("student: ", time.time()-start)
print(student_output)

{'label': tensor([3, 8, 8, 0, 6, 6, 1, 6, 3, 1, 0, 9, 5, 7, 9, 8, 5, 7, 8, 6, 7, 0, 4, 9,
        5, 2, 4, 0, 9, 6, 6, 5, 4, 5, 9, 2, 4, 1, 9, 5, 4, 6, 5, 6, 0, 9, 3, 9,
        7, 6, 9, 8, 0, 3, 8, 8, 7, 7, 4, 6, 7, 3, 6, 3, 6, 2, 1, 2, 3, 7, 2, 6,
        8, 8, 0, 2, 9, 3, 3, 8, 8, 1, 1, 7, 2, 5, 2, 7, 8, 9, 0, 3, 8, 6, 4, 6,
        6, 0, 0, 7]), 'img': tensor([[[[158., 112.,  49.,  ...,  41., 161., 116.],
          [ 41., 160., 111.,  ..., 109.,  44., 149.],
          [107.,  45., 150.,  ..., 116.,  85.,  33.],
          ...,
          [133., 122., 133.,  ..., 132., 103.,  57.],
          [183., 183., 175.,  ..., 226., 220., 191.],
          [188., 164., 135.,  ..., 165., 154., 157.]],

         [[147., 137., 155.,  ..., 104.,  81.,  46.],
          [188., 191., 189.,  ..., 188., 199., 171.],
          [164., 170., 142.,  ..., 201., 192., 160.],
          ...,
          [ 87.,  80.,  77.,  ...,  61.,  57.,  46.],
          [ 52.,  54.,  38.,  ..., 120., 105.,  63.],
          [117.

In [24]:
loss_calculator = LossCalulcator(temperature, distillw)
optimizer = optim.AdamW(
    student.parameters(),
    lr=lr)
scheduler = optim.lr_scheduler.StepLR(optimizer = optimizer,
                                    step_size = lr_stepsize)

In [25]:
train(student = student,
    teacher = teacher,
    dataloader = preprocessed_dataloaders['splitted_train'],
    val_dataloader = preprocessed_dataloaders['validation'],
    optimizer = optimizer,
    scheduler = scheduler,
    loss_calculator = loss_calculator,
    epochs = epochs,
    device = device,
    )

  0%|          | 0/100 [07:14<?, ?it/s]


KeyboardInterrupt: 

In [None]:
measure_accuracy(student, preprocessed_dataloaders['test'], device)