In [1]:
import argparse
import os, sys
import time
import datetime
import numpy as np
from tqdm import tqdm_notebook as tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from Resnet20model import Resnet20

In [2]:
#############################################
# sanity check for the correctness of Resnet20 architecture
dummy_input = torch.ones((1,3,32,32))
dummy_instance = Resnet20()
print(dummy_instance.forward(dummy_input))
total_params = 1
for p in dummy_instance.parameters():
    temp = 1
    for s in p.size():
        temp *= s
    total_params += temp
    #print(p.size())
print(f"Total Number of Parameters: {total_params}")
#############################################

tensor([[-0.1929, -0.0139, -0.0794,  0.0647,  0.3185, -0.0417,  0.4083,  0.5119,
         -0.0675, -0.3102]], grad_fn=<AddmmBackward0>)
Total Number of Parameters: 273067


In [3]:
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10 as torchCIFAR10

mean = (0.4914, 0.4822, 0.4465)
std = (0.2023, 0.1994, 0.2010)
rndimg_mean = (136.289, 129.273,122.668)
rndimg_std  = (73.788, 73.672, 76.487)

# specify preprocessing function
transform_train = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std),
                                     transforms.RandomCrop(size=32, padding=4),
                                     transforms.RandomHorizontalFlip()])
tensor_transform_train = transforms.Compose([transforms.Normalize(rndimg_mean, rndimg_std),
                                     transforms.RandomCrop(size=32, padding=4),
                                     transforms.RandomHorizontalFlip()])

transform_val = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])

transform_test = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])

transform_visulaizer = transforms.Compose([transforms.ToTensor()])

In [4]:
# Custom dataset class for tensor datasets. (Allows transforms)

class OODTensorDataset:
    def __init__(self, tensor_data, transforms=None):
        self.tensors = tensor_data
        self.transforms = transforms
    
    def __getitem__(self, index):   
        x = self.tensors[index]
        if self.transforms:
            x = self.transforms(x)
        return x, -1                 # returing -1 for OOD labels
    
    def __len__(self):
        return self.tensors.size()[0]


##### Set up Data loaders

In [5]:
from tools.dataset import CIFAR10
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR100

DATA_ROOT = "./data"
RANDOM_IMGS = "./data/300K_random_images.npy"
TRAIN_BATCH_SIZE = 256
VAL_BATCH_SIZE = 100
OOD_BATCH_SIZE = 2

# construct dataset
train_set_in = CIFAR10(
    root=DATA_ROOT, 
    mode='train', 
    download=True,
    transform= transform_train
)
random_images_data = torch.permute(torch.from_numpy(np.load(RANDOM_IMGS)), (0,3,1,2)).float()  
rand_img_set = OODTensorDataset(
    random_images_data,
    tensor_transform_train
)
val_set = CIFAR10(
    root=DATA_ROOT, 
    mode='val', 
    download=True,
    transform= transform_val
)
test_set = CIFAR10(
    root=DATA_ROOT,
    mode='test',
    download=True,
    transform=transform_test
)
torch_test_set = torchCIFAR10(
    root=DATA_ROOT,
    train=False,
    transform=transform_test,
    download=True
)
torch_vis_set = torchCIFAR10(
    root=DATA_ROOT,
    train=False,
    transform=transform_visulaizer,
    download=True
)


# construct dataloader
train_loader_in = DataLoader(
    train_set_in, 
    batch_size= TRAIN_BATCH_SIZE,  
    shuffle= True,   
    num_workers=4
)
train_loader_ood = DataLoader(
    rand_img_set,
    batch_size = OOD_BATCH_SIZE,
    shuffle = True
)
val_loader = DataLoader(
    val_set, 
    batch_size= VAL_BATCH_SIZE,  
    shuffle= False,  
    num_workers=4
)
test_loader = DataLoader(
    test_set
)
torch_test_loader = DataLoader(
    torch_test_set
)
torch_visualizer_loader = DataLoader(
    torch_vis_set
)

Using downloaded and verified file: ./data\cifar10_trainval_F22.zip
Extracting ./data\cifar10_trainval_F22.zip to ./data
Files already downloaded and verified
Using downloaded and verified file: ./data\cifar10_trainval_F22.zip
Extracting ./data\cifar10_trainval_F22.zip to ./data
Files already downloaded and verified
Using downloaded and verified file: ./data\cifar10_test_F22.zip
Extracting ./data\cifar10_test_F22.zip to ./data
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [6]:
def estimate_mean_and_std(npdata, N=1000):
    pavg    = np.array([0.0, 0.0, 0.0])
    pavg_sq = np.array([0.0, 0.0, 0.0])
    for img in npdata[0:N]:
        scale = 1/(img.shape[0]*img.shape[1])
        pavg += np.sum(img*scale,(0,1))*(1/N)
        pavg_sq += np.sum((img**2)*scale,(0,1))*(1/N)

    std = np.sqrt(pavg_sq-(pavg**2))
    print(pavg)
    print(std)
#estimate_mean_and_std(np.load(RANDOM_IMGS).astype("int32"))

In [28]:
# specify the device for computation
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [38]:
net = Resnet20().to(device)
#net.load_state_dict(torch.load("saved_model/resnet20.pth")["state_dict"])

In [39]:
import torch.nn as nn
import torch.optim as optim

WARMUP_LR = .01
MOMENTUM = 0.9
REG = 5e-4
optimizer = optim.SGD(net.parameters(), lr=WARMUP_LR, momentum=MOMENTUM, weight_decay=REG)

EPOCHS = 150
CHECKPOINT_FOLDER = "./saved_model"
INITAL_LR = .1
DECAY_EPOCHS = 60
DECAY = .1

In [None]:
def train():
    avg_loss = 0
    best_val_acc = 0
    current_learning_rate = WARMUP_LR
    train_loader_ood.dataset.offset = np.random.randint(len(train_loader_ood.dataset))
    #Learning rate scheduler
    for i in range(EPOCHS):
        if i == 5:
            current_learning_rate = INITAL_LR
        if i % DECAY_EPOCHS == 0 and i != 0:
            current_learning_rate = current_learning_rate * DECAY
            for param_group in optimizer.param_groups:
                param_group['lr'] = current_learning_rate
            #print("Current learning rate has decayed to %f" %current_learning_rate)
    
        net.train()
        
        print(f"Epoch: {i}/{EPOCHS}")
        for in_set, out_set in zip(train_loader_in, train_loader_ood):
            
            data = torch.cat((in_set[0], out_set[0]), 0)
            target = in_set[1].type(torch.LongTensor)
            data, target = data.to(device), target.to(device)
            
            # Forward Pass
            x = net.forward(data)
            
            loss = F.cross_entropy(x[:len(in_set[0])], target)
            # cross-entropy from softmax distribution to uniform distribution
            loss += 0.5 * -(x[len(in_set[0]):].mean(1) - torch.logsumexp(x[len(in_set[0]):], dim=1)).mean()
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            avg_loss = .8 * avg_loss + float(loss) * 0.2
       
        print(f"avg training loss: {avg_loss}")
        
        
        net.eval()

        total_examples = 0
        correct_examples = 0

        val_loss = 0

        with torch.no_grad():
            for batch_idx, (inputs, targets) in enumerate(val_loader):
                # copy inputs to device
                inputs = inputs.to(device)
                targets = targets.type(torch.LongTensor)
                targets = targets.to(device)

                # compute the output and loss
                output = net.forward(inputs)
                loss = F.cross_entropy(output, targets)
                val_loss += loss

                # count the number of correctly predicted samples in the current batch
                total_examples += targets.size()[0]
                batch_preds = torch.argmax(output, dim=1)
                correct_examples += (batch_preds == targets).int().sum().item()

        avg_loss = val_loss / len(val_loader)
        avg_acc = correct_examples / total_examples
        print("Validation loss: %.4f, Validation accuracy: %.4f" % (avg_loss, avg_acc))
        
        #save the model checkpoint
        if avg_acc > best_val_acc:
            best_val_acc = avg_acc
            if not os.path.exists(CHECKPOINT_FOLDER):
               os.makedirs(CHECKPOINT_FOLDER)
            print("Saving ...")
            state = {'state_dict': cnn.state_dict(),
                    'epoch': i,
                    'lr': current_learning_rate}
            torch.save(state, os.path.join(CHECKPOINT_FOLDER, 'OE_resnet20.pth'))

    
train()

Epoch: 0/150
avg training loss: 2.8940900718910263
Validation loss: 1.7796, Validation accuracy: 0.3586
Saving ...
Epoch: 1/150
avg training loss: 2.738651990890503
Validation loss: 1.5484, Validation accuracy: 0.4634
Saving ...
Epoch: 2/150
avg training loss: 2.5535314083099365
Validation loss: 1.5653, Validation accuracy: 0.4942
Saving ...
Epoch: 3/150
avg training loss: 2.4773104190826416
Validation loss: 1.2715, Validation accuracy: 0.5618
Saving ...
Epoch: 4/150
avg training loss: 2.38065242767334
Validation loss: 1.2027, Validation accuracy: 0.5826
Saving ...
Epoch: 5/150
avg training loss: 2.333941698074341
Validation loss: 1.1613, Validation accuracy: 0.6106
Saving ...
Epoch: 6/150
avg training loss: 2.30107045173645
Validation loss: 1.0690, Validation accuracy: 0.6346
Saving ...
Epoch: 7/150
avg training loss: 2.242527484893799
Validation loss: 1.0273, Validation accuracy: 0.6630
Saving ...
Epoch: 8/150
avg training loss: 2.228848934173584
Validation loss: 1.0129, Validation a