In [None]:
!git clone https://github.com/LucStrater/Knowledge_Distillation_AD.git
%cd /content/Knowledge_Distillation_AD

Cloning into 'Knowledge_Distillation_AD'...
remote: Enumerating objects: 119, done.[K
remote: Counting objects: 100% (89/89), done.[K
remote: Compressing objects: 100% (49/49), done.[K
remote: Total 119 (delta 42), reused 64 (delta 29), pack-reused 30[K
Receiving objects: 100% (119/119), 19.87 MiB | 5.60 MiB/s, done.
Resolving deltas: 100% (45/45), done.
/content/Knowledge_Distillation_AD


In [None]:
!pip install transformers

In [None]:
from transformers import ViTImageProcessor, ViTForImageClassification
from PIL import Image
import requests
import torch

url = 'https://www.cs.toronto.edu/~kriz/cifar-10-sample/dog10.png'
image = Image.open(requests.get(url, stream=True).raw)

# Initialize the ViTImageProcessor
processor = ViTImageProcessor.from_pretrained('nateraw/vit-base-patch16-224-cifar10')
model = ViTForImageClassification.from_pretrained('nateraw/vit-base-patch16-224-cifar10', output_hidden_states=True)

# Preprocess the image with ViTImageProcessor
inputs = processor(images=image, return_tensors="pt")

# Pass the inputs through the ViT model
with torch.no_grad():
    outputs = model(**inputs)

# Get the predicted class label
preds = outputs.logits.argmax(dim=1)

classes = [
    'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'
]
predicted_class = classes[preds[0].item()]

print(predicted_class)

In [None]:
print(outputs.hidden_states[12].shape)

In [None]:
model

In [None]:
config = {}

# Data parameters
config["experiment_name"] = 'local_equal_net'
config["dataset_name"] = 'cifar10'
config["last_checkpoint"] = 200

# Training parameters
config["num_epochs"] = 1 # put 201 if you want to train from scratch
config["batch_size"] = 64
config["learning_rate"] = 1e-3
config["mvtec_img_size"] = 128
config["normal_class"] = 3
config["lamda"] = 0.01
config["pretrain"] = True # True =use pre-trained vgg as source network --- False =use random initialize
config["use_bias"] = False # True =using bias term in neural network layer
config["equal_network_size"] = False # True =using equal network size for cloner and source network --- False =smaller network for cloner
config["direction_loss_only"] = False
config["continue_train"] = True

# Test parameters
config["localization_test"] = False # True =For Localization Test --- False =For Detection
config["localization_method"] = 'gbp' # gradients , smooth_grad , gbp

In [None]:
from utils.dataloader import load_data
from torch.autograd import Variable
from utils.loss_functions import MseDirectionLoss
from transformers import ViTImageProcessor, ViTForImageClassification
from PIL import Image
import requests
import torch

In [None]:
# Function for setting the seed
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

In [None]:
def test(model, vit_pretrained, processor_pretrained, test_dataloader, config):
    target_class = config["normal_class"]

    similarity_loss = torch.nn.CosineSimilarity()
    label_score = []
    model.eval()
    for data in test_dataloader:
        X, Y = data
        if X.shape[1] == 1:
            X = X.repeat(1, 3, 1, 1)
        X = Variable(X).cuda()

        # preprocess inputs for pretrained model
        inputs_pretrained = processor_pretrained(images=X, return_tensors="pt")

        # inference pretrained model
        with torch.no_grad():
            output_real = model(**inputs_pretrained).hidden_states

        output_pred = model.forward(X)

        y_pred_1, y_pred_2, y_pred_3 = output_pred[6], output_pred[9], output_pred[12]
        y_1, y_2, y_3 = output_real[6], output_real[9], output_real[12]

        abs_loss_1 = torch.mean((y_pred_1 - y_1) ** 2, dim=(1, 2, 3))
        loss_1 = 1 - similarity_loss(y_pred_1.view(y_pred_1.shape[0], -1), y_1.view(y_1.shape[0], -1))
        abs_loss_2 = torch.mean((y_pred_2 - y_2) ** 2, dim=(1, 2, 3))
        loss_2 = 1 - similarity_loss(y_pred_2.view(y_pred_2.shape[0], -1), y_2.view(y_2.shape[0], -1))
        abs_loss_3 = torch.mean((y_pred_3 - y_3) ** 2, dim=(1, 2, 3))
        loss_3 = 1 - similarity_loss(y_pred_3.view(y_pred_3.shape[0], -1), y_3.view(y_3.shape[0], -1))
        total_loss = loss_1 + loss_2 + loss_3 + config['lamda'] * (abs_loss_1 + abs_loss_2 + abs_loss_3)

        label_score += list(zip(Y.cpu().data.numpy().tolist(), total_loss.cpu().data.numpy().tolist()))

    labels, scores = zip(*label_score)
    labels = np.array(labels)
    indx1 = labels == target_class
    indx2 = labels != target_class
    labels[indx1] = 1
    labels[indx2] = 0
    scores = np.array(scores)
    fpr, tpr, thresholds = roc_curve(labels, scores, pos_label=0)
    roc_auc = auc(fpr, tpr)
    roc_auc = round(roc_auc, 4)
    return roc_auc

In [None]:
def train(config):
    # data prep
    train_dataloader, test_dataloader = load_data(config)

    # define model here above optimizer!

    # criteria / optimizers
    criterion = MseDirectionLoss(config["lamda"])
    optimizer = torch.optim.Adam(model.parameters(), lr=config["learning_rate"])

    # get pretrained ViT
    processor_pretrained = ViTImageProcessor.from_pretrained('nateraw/vit-base-patch16-224-cifar10')
    vit_pretrained = ViTForImageClassification.from_pretrained('nateraw/vit-base-patch16-224-cifar10', output_hidden_states=True)

    # init logging
    losses = []
    roc_aucs = []

    for epoch in range(config["num_epochs"] + 1):
        model.train()
        epoch_loss = 0
        for data in train_dataloader:
            X = data[0]
            if X.shape[1] == 1:
                X = X.repeat(1, 3, 1, 1)
            X = Variable(X).cuda()

            # preprocess inputs for pretrained model
            inputs_pretrained = processor_pretrained(images=X, return_tensors="pt")

            # inference pretrained model
            with torch.no_grad():
                output_real = model(**inputs_pretrained).hidden_states

            output_pred = model.forward(X)

            total_loss = criterion(output_pred, output_real)

            # logging
            epoch_loss += total_loss.item()
            losses.append(total_loss.item())

            # standard pytorch
            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

        print('epoch [{}/{}], loss:{:.4f}'.format(epoch + 1, num_epochs, epoch_loss))
        if epoch % 10 == 0:
            roc_auc = test(model, vit_pretrained, processor_pretrained, test_dataloader, config)
            roc_aucs.append(roc_auc)
            print("RocAUC at epoch {}:".format(epoch), roc_auc)

    return model, vit_pretrained, processor_pretrained

In [None]:
set_seed(42)
model, vit_pretrained, processor_pretrained = train(config)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./Dataset/CIFAR10/train/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:01<00:00, 92568537.52it/s] 


Extracting ./Dataset/CIFAR10/train/cifar-10-python.tar.gz to ./Dataset/CIFAR10/train
Cifar10 DataLoader Called...
All Train Data:  (50000, 32, 32, 3)
Normal Train Data:  (5000, 32, 32, 3)
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./Dataset/CIFAR10/test/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:02<00:00, 81135462.02it/s]


Extracting ./Dataset/CIFAR10/test/cifar-10-python.tar.gz to ./Dataset/CIFAR10/test
Test Train Data: (10000, 32, 32, 3)


Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth
100%|██████████| 528M/528M [00:04<00:00, 125MB/s]


layer : 0 Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
layer : 1 BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
layer : 2 ReLU(inplace=True)
layer : 3 Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
layer : 4 ReLU(inplace=True)
layer : 5 MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
layer : 6 Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
layer : 7 BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
layer : 8 ReLU(inplace=True)
layer : 9 Conv2d(16, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
layer : 10 ReLU(inplace=True)
layer : 11 MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
layer : 12 Conv2d(128, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
layer : 13 BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

KeyboardInterrupt: ignored

In [None]:
# Detection test
_, test_dataloader = load_data(config)
roc_auc = test(model, vit_pretrained, processor_pretrained, test_dataloader, config)
last_checkpoint = config['last_checkpoint']
print("RocAUC after {} epoch:".format(last_checkpoint), roc_auc)

Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth
100%|██████████| 528M/528M [00:05<00:00, 107MB/s]


layer : 0 Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
layer : 1 BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
layer : 2 ReLU(inplace=True)
layer : 3 Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
layer : 4 ReLU(inplace=True)
layer : 5 MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
layer : 6 Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
layer : 7 BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
layer : 8 ReLU(inplace=True)
layer : 9 Conv2d(16, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
layer : 10 ReLU(inplace=True)
layer : 11 MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
layer : 12 Conv2d(128, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
layer : 13 BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

100%|██████████| 170498071/170498071 [00:08<00:00, 20973895.55it/s]


Extracting ./Dataset/CIFAR10/train/cifar-10-python.tar.gz to ./Dataset/CIFAR10/train
Cifar10 DataLoader Called...
All Train Data:  (50000, 32, 32, 3)
Normal Train Data:  (5000, 32, 32, 3)
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./Dataset/CIFAR10/test/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:08<00:00, 20880714.93it/s]


Extracting ./Dataset/CIFAR10/test/cifar-10-python.tar.gz to ./Dataset/CIFAR10/test
Test Train Data: (10000, 32, 32, 3)
RocAUC after 200 epoch: 0.7703
