In [1]:
'''
This notebook trains different models on input data and and saves SWAG model in checkpoint directory to be used. This model is
then used in calibration notebook to generate entropy
'''

'\nThis notebook trains different models on input data and and saves SWAG model in checkpoint directory to be used. This model is\nthen used in calibration notebook to generate entropy\n'

In [1]:
import argparse
import os, sys
import time
import tabulate

import torch
import torch.nn.functional as F
import torchvision
import numpy as np
import os

In [2]:
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

In [3]:

from swag import data, models, utils, losses
from swag.posteriors import SWAG


In [4]:
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
import imageio
import glob


In [5]:
from sklearn.model_selection import train_test_split

In [6]:
from PIL import Image
# import skimage.io as io


## Model init:

In [None]:
'''
Initializing model, We can either use custom defined model or predefined models in the folder swag/models, in this case
we are using PreResNet56 model'''

In [7]:
num_classes = 2
batch_size = 16

In [8]:
device = None

use_cuda = torch.cuda.is_available()

if use_cuda:
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
# SWAG Checkpoints are saved in the ./checkpoints directory
print(r"Preparing directory './checkpoints'")
os.makedirs(r'./checkpoints', exist_ok=True)
with open(os.path.join(r'./checkpoints', "command.sh"), "w") as f:
    f.write(" ".join(sys.argv))
    f.write("\n")

seed=1
torch.backends.cudnn.benchmark = True
torch.manual_seed(seed)

if use_cuda:
    torch.cuda.manual_seed(seed)

Preparing directory './checkpoints'


In [10]:
# Custom model can be defined here. Change linear layer depening on the size of input image used.

class VGG16(nn.Module):
    def __init__(self, num_classes=2):
        super(VGG16, self).__init__()
        self.conv1_1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, padding=1)
        self.conv1_2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1)

        self.conv2_1 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1)
        self.conv2_2 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1)

        self.conv3_1 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1)
        self.conv3_2 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1)
        self.conv3_3 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1)

        self.conv4_1 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding=1)
        self.conv4_2 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1)
        self.conv4_3 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1)

        self.conv5_1 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1)
        self.conv5_2 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1)
        self.conv5_3 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1)

        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # 3,128,128
        self.fc1 = nn.Linear(8192, 4096)

#         self.fc1 = nn.Linear(2048, 4096)
        
#         self.fc1 = nn.Linear(262144, 4096)
#         self.fc1 = nn.Linear(32768, 4096)
        self.fc2 = nn.Linear(4096, 2048)
        self.fc3 = nn.Linear(2048, num_classes)

    def forward(self, x):
        x = F.relu(self.conv1_1(x))
        x = F.relu(self.conv1_2(x))
        x = self.maxpool(x)
        x = F.relu(self.conv2_1(x))
        x = F.relu(self.conv2_2(x))
        x = self.maxpool(x)
        x = F.relu(self.conv3_1(x))
        x = F.relu(self.conv3_2(x))
        x = F.relu(self.conv3_3(x))
        x = self.maxpool(x)
        x = F.relu(self.conv4_1(x))
        x = F.relu(self.conv4_2(x))
        x = F.relu(self.conv4_3(x))
        x = self.maxpool(x)
        x = F.relu(self.conv5_1(x))
        x = F.relu(self.conv5_2(x))
        x = F.relu(self.conv5_3(x))
        x = self.maxpool(x)
        print(x.shape)
        x = x.reshape(x.shape[0], -1)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, 0.5) #dropout was included to combat overfitting
        x = F.relu(self.fc2(x))
        x = F.dropout(x, 0.5)
        x = self.fc3(x)
        return x

class ConvNetSimple:
    base = VGG16
    args = list()
    kwargs = {}

    transform_test = transforms.Compose(
        [
            transforms.Resize(128),
            # transforms.Resize(224),

            transforms.ToTensor(),
            # transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
            transforms.Normalize((0.5394, 0.5354, 0.5504), (0.3623, 0.3620, 0.3465)),

            # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            # transforms.Normalize((0.45242316, 0.45249584, 0.46897713), (0.21943445, 0.22656967, 0.22850613))
        ]


    )

    # transform_train = transforms.Compose(
    #     [
    #         transforms.RandomHorizontalFlip(),
    #         transforms.Resize(256),
    #         transforms.RandomCrop(256, padding=4),
    #         transforms.Normalize((0.5394, 0.5354, 0.5504), (0.3623, 0.3620, 0.3465)),
    #         # transforms.Normalize((0.4376821 , 0.4437697 , 0.47280442), (0.19803012, 0.20101562, 0.19703614))
    #     ]
    # )


In [None]:
# model_cfg  = ConvNetSimple
# print("Preparing model")
# model = model_cfg.base(*model_cfg.args, num_classes=num_classes, **model_cfg.kwargs)
# model.to(device)

In [None]:
# To test the custom model for parameters
# summary(model, (3, 128, 128))

In [9]:
model_cfg = getattr(models, 'PreResNet56')
# model_cfg = getattr(models, 'VGG16')

print("Preparing model")
# print(*model_cfg.args)
model = model_cfg.base(*model_cfg.args, num_classes=num_classes, **model_cfg.kwargs)

model.to(device)


Preparing model
<class 'swag.models.preresnet.Bottleneck'>


PreResNet(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv1): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn3): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(16, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(16, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
    )
    (1): Bottleneck(
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv1): Conv2d(64, 16, kernel_size=(1, 1), stride=(1, 1

### Dataset/Dataloader Preparation

In [12]:
def load_pngs():
    good, bad = [], []
    for im_path in glob.glob("./data/bottle/good/*.png"):
        im = Image.open(im_path)
        good.append(im)
    for im_path in glob.glob("./data/toothbrush/train/good/*.png"):
        im = Image.open(im_path)
        bad.append(im)
    return good, bad

In [13]:
good_imgs, bad_imgs = load_pngs()

In [15]:
# p=transforms.Compose(
#     [transforms.Resize(size=128)]
# )

In [16]:
# p(good_imgs[-1])

In [16]:
images = [i for i in good_imgs] +[i for i in bad_imgs]
labels = [1 for i in good_imgs] +[0 for i in bad_imgs]

In [17]:
len(bad_imgs)

60

In [18]:
len(images)

151

In [19]:
# images = []
# labels = []
# for i in good_imgs:
#     images.append(i)
#     labels += [1] 

# for i in bad_imgs:
#     images.append(i)
#     labels += [0]

In [20]:
len(images), len(labels)

(151, 151)

In [21]:
# transform = transforms.Compose([
#     transforms.ToTensor()
# ])

In [22]:
# img_tr = transform(images[0])

# mean, std = img_tr.mean([1,2]), img_tr.std([1,2])


In [23]:
# print(labels[30])
# p(images[30])


In [24]:
class MvTecDataset(torch.utils.data.Dataset):

    def __init__(self, imgs, labels, transform):
        # self.imgs = imgs.astype(np.float32)
        self.imgs = imgs
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.transform(self.imgs[idx]), self.labels[idx]

In [25]:
X_train, X_test, y_train, y_test = train_test_split(images, labels, test_size=0.3, shuffle=True)
loaders = {
    "train": torch.utils.data.DataLoader(MvTecDataset(X_train, y_train, model_cfg.transform_test), batch_size=batch_size, shuffle=True, drop_last=True),    
    "test": torch.utils.data.DataLoader(MvTecDataset(X_test, y_test, model_cfg.transform_test), batch_size=batch_size, shuffle=True, drop_last=True)
}
print(len(X_train), len(X_test))

105 46


In [26]:
# dataset = MvTecDataset(images, labels, model_cfg.transform_test)
# train_size = int(0.8 * len(dataset))
# test_size = len(dataset) - train_size
# train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])
# loaders = {
#     "train": torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True),    
#     "test": torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
# }
# print(train_size, test_size)

## SWAG pipeline

In [27]:
ch_dir='./checkpoints'

# no of epochs to run
epochs=20
# checkpoint save frequency
save_freq=20
# test set evaluation frequency
eval_freq=5
# learning rate parameters
lr_init=0.1
momentum=0.9
wd=1e-4

# use SWA procedure
swa=True
# epoch to start saving snapshots for SWAG
swa_start=10
swa_lr=0.01
swa_c_epochs=1

# use covariance matrix in SWAG prodecure
cov_mat=True
# no of models to create
max_num_models=20
swa_resume=None
loss="CE"
seed=1
resume=None
no_schedule=False

In [28]:
if cov_mat:
    no_cov_mat = False
else:
    no_cov_mat = True
if swa:
    print("SWAG training")
    # SWAG initialization
    swag_model = SWAG(
        model_cfg.base,
        no_cov_mat=no_cov_mat,
        max_num_models=max_num_models,
        *model_cfg.args,
        num_classes=num_classes,
        **model_cfg.kwargs
    )
    swag_model.to(device)
else:
    print("SGD training")



SWAG training
<class 'swag.models.preresnet.Bottleneck'>


In [29]:
def schedule(epoch):
    swa_lr=0.02
    lr_init=0.1
    swa_start=10
    swa=True

    t = (epoch) / (swa_start if swa else epochs)
    lr_ratio = swa_lr / lr_init if swa else 0.01
    if t <= 0.5:
        factor = 1.0
    elif t <= 0.9:
        factor = 1.0 - (1.0 - lr_ratio) * (t - 0.5) / 0.4
    else:
        factor = lr_ratio
    return lr_init * factor


In [30]:
# use a slightly modified loss function that allows input of model
if loss == "CE":
    criterion = losses.cross_entropy
    # criterion = F.cross_entropy
elif loss == "adv_CE":
    criterion = losses.adversarial_cross_entropy

optimizer = torch.optim.SGD(
    model.parameters(), lr=lr_init, momentum=momentum, weight_decay=wd
)

In [31]:
# if resume is not None:
#     print("Resume training from %s" % args.resume)
#     checkpoint = torch.load(args.resume)
#     start_epoch = checkpoint["epoch"]
#     model.load_state_dict(checkpoint["state_dict"])
#     optimizer.load_state_dict(checkpoint["optimizer"])


In [32]:
# if swa and swa_resume is not None:
#     checkpoint = torch.load(swa_resume)
#     swag_model = SWAG(
#         model_cfg.base,
#         no_cov_mat=no_cov_mat,
#         max_num_models=max_num_models,
#         loading=True,
#         *model_cfg.args,
#         num_classes=num_classes,
#         **model_cfg.kwargs
#     )
#     swag_model.to(device)
#     swag_model.load_state_dict(checkpoint["state_dict"])


In [33]:
start_epoch = 0

columns = ["ep", "lr", "tr_loss", "tr_acc", "te_loss", "te_acc", "time", "mem_usage"]
if swa:
    columns = columns[:-2] + ["swa_te_loss", "swa_te_acc"] + columns[-2:]
    swag_res = {"loss": None, "accuracy": None}


In [34]:
utils.save_checkpoint(
    ch_dir,
    start_epoch,
    state_dict=model.state_dict(),
    optimizer=optimizer.state_dict(),
)

In [36]:

sgd_ens_preds = None
sgd_targets = None
n_ensembled = 0.0

for epoch in range(start_epoch, epochs):
    # print(epoch)
    time_ep = time.time()

    if not no_schedule:
        lr = schedule(epoch)
        utils.adjust_learning_rate(optimizer, lr)
    else:
        lr = lr_init

    if (swa and (epoch + 1) > swa_start) and cov_mat:
        # print('in if')
        train_res = utils.train_epoch(loaders["train"], model, criterion, optimizer, cuda=use_cuda)
    else:
        # print('in else')
        train_res = utils.train_epoch(loaders["train"], model, criterion, optimizer, cuda=use_cuda)
        # print('in else after')


    if (
        epoch == 0
        or epoch % eval_freq == eval_freq - 1
        or epoch == epochs - 1
    ):  
#         print('test_res 312')

        test_res = utils.eval(loaders["test"], model, criterion, cuda=use_cuda)
    else:
        test_res = {"loss": None, "accuracy": None}
#         print('test_res 317')

    if (
        swa
        and (epoch + 1) > swa_start
        and (epoch + 1 - swa_start) % swa_c_epochs == 0
    ):
        # sgd_preds, sgd_targets = utils.predictions(loaders["test"], model)
        sgd_res = utils.predict(loaders["test"], model)
        sgd_preds = sgd_res["predictions"]
        sgd_targets = sgd_res["targets"]
        print("updating sgd_ens")
        if sgd_ens_preds is None:
            sgd_ens_preds = sgd_preds.copy()
        else:
            # TODO: rewrite in a numerically stable way
            sgd_ens_preds = sgd_ens_preds * n_ensembled / (
                n_ensembled + 1
            ) + sgd_preds / (n_ensembled + 1)
        n_ensembled += 1
        swag_model.collect_model(model)
        if (
            epoch == 0
            or epoch % eval_freq == eval_freq - 1
            or epoch == epochs - 1
        ):
            swag_model.sample(0.0)
            utils.bn_update(loaders["train"], swag_model)
            swag_res = utils.eval(loaders["test"], swag_model, criterion)
        else:
            swag_res = {"loss": None, "accuracy": None}

    if (epoch + 1) % save_freq == 0:
        utils.save_checkpoint(
            ch_dir,
            epoch + 1,
            state_dict=model.state_dict(),
            optimizer=optimizer.state_dict(),
        )
        if swa:
            # Save SWAG Weights
            utils.save_checkpoint(
                ch_dir, epoch + 1, name="swag", state_dict=swag_model.state_dict()
            )

    time_ep = time.time() - time_ep
    
    if use_cuda:
        memory_usage = torch.cuda.memory_allocated() / (1024.0 ** 3)
    else:
        memory_usage = None 
    values = [
        epoch + 1,
        lr,
        train_res["loss"],
        train_res["accuracy"],
        test_res["loss"],
        test_res["accuracy"],
        time_ep,
        memory_usage,
    ]
    if swa:
        values = values[:-2] + [swag_res["loss"], swag_res["accuracy"]] + values[-2:]
    table = tabulate.tabulate([values], columns, tablefmt="simple", floatfmt="8.4f")
    if epoch % 40 == 0:
        table = table.split("\n")
        table = "\n".join([table[1]] + table)
    else:
        table = table.split("\n")[2]
    print(table)

if epochs % save_freq != 0:
#     Checkpoint save
    utils.save_checkpoint(
        ch_dir,
        epochs,
        state_dict=model.state_dict(),
        optimizer=optimizer.state_dict(),
    )
    if swa and epochs > swa_start:
        utils.save_checkpoint(
            a, epochs, name="swag", state_dict=swag_model.state_dict()
        )

if swa:
    np.savez(
        os.path.join(ch_dir, "sgd_ens_preds.npz"),
        predictions=sgd_ens_preds,
        targets=sgd_targets,
    )


In [None]:
swag_model.state_dict()