In [1]:
!git clone https://github.com/LucStrater/Knowledge_Distillation_AD.git

Cloning into 'Knowledge_Distillation_AD'...
remote: Enumerating objects: 76, done.[K
remote: Counting objects: 100% (46/46), done.[K
remote: Compressing objects: 100% (23/23), done.[K
remote: Total 76 (delta 29), reused 25 (delta 20), pack-reused 30[K
Receiving objects: 100% (76/76), 1.83 MiB | 8.43 MiB/s, done.
Resolving deltas: 100% (32/32), done.


In [2]:
import sys
sys.path.insert(0,'/content/Knowledge_Distillation_AD')

##Train model

In [3]:
from test import *
from utils.utils import *
from dataloader import *
from pathlib import Path
from torch.autograd import Variable
import pickle
from test_functions import detection_test
from loss_functions import *
from models.network import *

In [4]:
def train(config):
    direction_loss_only = config["direction_loss_only"]
    normal_class = config["normal_class"]
    learning_rate = float(config['learning_rate'])
    num_epochs = config["num_epochs"]
    lamda = config['lamda']
    continue_train = config['continue_train']
    last_checkpoint = config['last_checkpoint']

    checkpoint_path = "./outputs/{}/{}/checkpoints/".format(config['experiment_name'], config['dataset_name'])

    # create directory
    Path(checkpoint_path).mkdir(parents=True, exist_ok=True)

    train_dataloader, test_dataloader = load_data(config)
    if continue_train:
        vgg, model = get_networks(config, load_checkpoint=True)
    else:
        vgg, model = get_networks(config)

    # Criteria And Optimizers
    if direction_loss_only:
        criterion = DirectionOnlyLoss()
    else:
        criterion = MseDirectionLoss(lamda)

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    if continue_train:
        optimizer.load_state_dict(
            torch.load('{}Opt_{}_epoch_{}.pth'.format(checkpoint_path, normal_class, last_checkpoint)))

    losses = []
    roc_aucs = []
    if continue_train:
        with open('{}Auc_{}_epoch_{}.pickle'.format(checkpoint_path, normal_class, last_checkpoint), 'rb') as f:
            roc_aucs = pickle.load(f)

    for epoch in range(num_epochs + 1):
        model.train()
        epoch_loss = 0
        for data in train_dataloader:
            X = data[0]
            if X.shape[1] == 1:
                X = X.repeat(1, 3, 1, 1)
            X = Variable(X).cuda()

            output_pred = model.forward(X)
            output_real = vgg(X)

            total_loss = criterion(output_pred, output_real)

            # Add loss to the list
            epoch_loss += total_loss.item()
            losses.append(total_loss.item())

            # Clear the previous gradients
            optimizer.zero_grad()
            # Compute gradients
            total_loss.backward()
            # Adjust weights
            optimizer.step()

        print('epoch [{}/{}], loss:{:.4f}'.format(epoch + 1, num_epochs, epoch_loss))
        if epoch % 10 == 0:
            roc_auc = detection_test(model, vgg, test_dataloader, config)
            roc_aucs.append(roc_auc)
            print("RocAUC at epoch {}:".format(epoch), roc_auc)

        if epoch % 50 == 0:
            torch.save(model.state_dict(),
                       '{}Cloner_{}_epoch_{}.pth'.format(checkpoint_path, normal_class, epoch))
            torch.save(optimizer.state_dict(),
                       '{}Opt_{}_epoch_{}.pth'.format(checkpoint_path, normal_class, epoch))
            with open('{}Auc_{}_epoch_{}.pickle'.format(checkpoint_path, normal_class, epoch),
                      'wb') as f:
                pickle.dump(roc_aucs, f)

In [5]:
config = get_config('/content/Knowledge_Distillation_AD/configs/config.yaml')
train(config)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./Dataset/CIFAR10/train/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:05<00:00, 29604028.37it/s]


Extracting ./Dataset/CIFAR10/train/cifar-10-python.tar.gz to ./Dataset/CIFAR10/train
Cifar10 DataLoader Called...
All Train Data:  (50000, 32, 32, 3)
Normal Train Data:  (5000, 32, 32, 3)
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./Dataset/CIFAR10/test/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:05<00:00, 29822176.87it/s]


Extracting ./Dataset/CIFAR10/test/cifar-10-python.tar.gz to ./Dataset/CIFAR10/test
Test Train Data: (10000, 32, 32, 3)


Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth
100%|██████████| 528M/528M [00:04<00:00, 118MB/s]


layer : 0 Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
layer : 1 BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
layer : 2 ReLU(inplace=True)
layer : 3 Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
layer : 4 ReLU(inplace=True)
layer : 5 MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
layer : 6 Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
layer : 7 BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
layer : 8 ReLU(inplace=True)
layer : 9 Conv2d(16, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
layer : 10 ReLU(inplace=True)
layer : 11 MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
layer : 12 Conv2d(128, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
layer : 13 BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

In [7]:
from utils.utils import get_config
from dataloader import load_data, load_localization_data
from test_functions import detection_test, localization_test
from models.network import get_networks

config = get_config('/content/Knowledge_Distillation_AD/configs/config.yaml')
vgg, model = get_networks(config, load_checkpoint=True)

# Localization test
if config['localization_test']:
    test_dataloader, ground_truth = load_localization_data(config)
    roc_auc = localization_test(model=model, vgg=vgg, test_dataloader=test_dataloader, ground_truth=ground_truth,
                                config=config)

# Detection test
else:
    _, test_dataloader = load_data(config)
    roc_auc = detection_test(model=model, vgg=vgg, test_dataloader=test_dataloader, config=config)
last_checkpoint = config['last_checkpoint']
print("RocAUC after {} epoch:".format(last_checkpoint), roc_auc)

layer : 0 Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
layer : 1 BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
layer : 2 ReLU(inplace=True)
layer : 3 Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
layer : 4 ReLU(inplace=True)
layer : 5 MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
layer : 6 Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
layer : 7 BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
layer : 8 ReLU(inplace=True)
layer : 9 Conv2d(16, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
layer : 10 ReLU(inplace=True)
layer : 11 MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
layer : 12 Conv2d(128, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
layer : 13 BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

FileNotFoundError: ignored