In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import json
import math
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
from heper_dataset import get_dataloaders_cifar10,UnNormalize
import matplotlib.pyplot as plt
from helper_evaluation import set_all_seeds, set_deterministic, compute_confusion_matrix
from helper_train import train_model
from helper_plotting import plot_training_loss, plot_accuracy, show_examples, plot_confusion_matrix


In [12]:
RANDOM_SEED = 123
BATCH_SIZE = 256
NUM_EPOCHS = 50
DEVICE = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')

In [13]:
set_all_seeds(RANDOM_SEED)
#set_deterministic()

In [14]:
##########################
### CIFAR-10 DATASET
##########################


train_transforms = torchvision.transforms.Compose([
    torchvision.transforms.Resize((120, 120)),
    torchvision.transforms.RandomCrop((110, 110)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                                      ])

test_transforms = torchvision.transforms.Compose([
    torchvision.transforms.Resize((120, 120)),        
    torchvision.transforms.CenterCrop((110, 110)),            
    torchvision.transforms.ToTensor(),                
    torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])


train_loader, valid_loader, test_loader = get_dataloaders_cifar10(
    batch_size=BATCH_SIZE,
    validation_fraction=0.1,
    train_transforms=train_transforms,
    test_transforms=test_transforms,
    num_workers=2)

# Checking the dataset
for images, labels in train_loader:  
    print('Image batch dimensions:', images.shape)
    print('Image label dimensions:', labels.shape)
    print('Class labels of 10 examples:', labels[:10])
    break

Files already downloaded and verified
Image batch dimensions: torch.Size([256, 3, 110, 110])
Image label dimensions: torch.Size([256])
Class labels of 10 examples: tensor([4, 7, 4, 6, 2, 6, 9, 7, 3, 0])


In [15]:
class SABlock(nn.Module):
    layer_idx = -1
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, bias=False, downsample=False, structure=[]):
        super(SABlock, self).__init__()
        if SABlock.layer_idx < len(structure):
            SABlock.layer_idx += 1
        channels = structure[SABlock.layer_idx][:-1]
        side = structure[SABlock.layer_idx][-1]
        
        self.scales = [None, 2, 4, 7]
        
        self.stride = stride
        

        self.downsample = None if downsample == False else \
                          nn.Sequential(nn.Conv2d(inplanes, planes * SABlock.expansion, kernel_size=1, stride=1, bias=bias),
                                        nn.BatchNorm2d(planes * SABlock.expansion))

        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=bias)
        self.bn1 = nn.BatchNorm2d(planes)

        # kernel size == 1 if featuremap size == 1
        self.conv2 = nn.ModuleList([nn.Conv2d(planes, channels[i], kernel_size=3 if side / 2**i > 1 else 1, stride=1, padding=1 if side / 2**i > 1 else 0, bias=bias) if channels[i] > 0 else \
                                    None for i in range(len(self.scales))])
        self.bn2 = nn.ModuleList([nn.BatchNorm2d(channels[i]) if channels[i] > 0 else \
                                  None for i in range(len(self.scales))])

        self.conv3 = nn.Conv2d(sum(channels), planes * SABlock.expansion, kernel_size=1, bias=bias)
        self.bn3 = nn.BatchNorm2d(planes * SABlock.expansion)
        

    def forward(self, x):
        x = F.max_pool2d(x, self.stride, self.stride) if self.stride > 1 else x
    
        residual = self.downsample(x) if self.downsample != None else x

        out1 = self.conv1(x)
        out1 = F.relu(self.bn1(out1))
      
        out2_list = []
        size = [out1.size(2), out1.size(3)]
        for i in range(len(self.scales)):
            out2_i = out1 # copy
            if self.scales[i] != None:
                out2_i = F.max_pool2d(out2_i, self.scales[i], self.scales[i])
            if self.conv2[i] != None:
                out2_i = self.conv2[i](out2_i)
            if self.scales[i] != None:
                # nearest mode is not suitable for upsampling on non-integer multiples 
                mode = 'nearest' if size[0] % out2_i.shape[2] == 0 and size[1] % out2_i.shape[3] == 0 else 'bilinear'
                out2_i = F.interpolate(out2_i, size=size, mode=mode)
            if self.bn2[i] != None:
                out2_i = self.bn2[i](out2_i)
                out2_list.append(out2_i)
        out2 = torch.cat(out2_list, 1)
        out2 = F.relu(out2)

        out3 = self.conv3(out2)
        out3 = self.bn3(out3)
        out3 += residual
        out3 = F.relu(out3)

        return out3

In [16]:
class ScaleNet(nn.Module):

    def __init__(self, block, layers, structure, num_classes=1000):
        super(ScaleNet, self).__init__()

        self.inplanes = 64
        self.structure = structure

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)

        self.layer1 = self._make_layer(block, 64, layers[0], stride=1)
        self.layer2 = self._make_layer(block, 128, layers[1], stride=1)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=1)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            annels = structure[SABlock.layer_idx][:-1]

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = True if stride != 1 or self.inplanes != planes * block.expansion else False
        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample=downsample, structure=self.structure))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes, downsample=False, structure=self.structure))

        return nn.Sequential(*layers)


    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = F.relu(x)

        x = F.max_pool2d(x, 3, 2, 1)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = F.adaptive_avg_pool2d(x, 1)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x


In [17]:

def scalenet50(structure_path, ckpt=None, **kwargs):
    layer = [3, 4, 6, 3]
    structure = json.loads(open(structure_path).read())
    model = ScaleNet(SABlock, layer, structure, **kwargs)

    # pretrained
    if ckpt != None:
        state_dict = torch.load(ckpt, map_location='cpu')
        model.load_state_dict(state_dict)

    return model


In [18]:
def scalenet101(structure_path, ckpt=None, **kwargs):
    layer = [3, 4, 23, 3]
    structure = json.loads(open(structure_path).read())
    model = ScaleNet(SABlock, layer, structure, **kwargs)

    # pretrained
    if ckpt != None:
        state_dict = torch.load(ckpt, map_location='cpu')
        model.load_state_dict(state_dict)

    return model


In [19]:
def scalenet152(structure_path, ckpt=None, **kwargs):
    layer = [3, 8, 36, 3]
    structure = json.loads(open(structure_path).read())
    model = ScaleNet(SABlock, layer, structure, **kwargs)

    # pretrained
    if ckpt != None:
        state_dict = torch.load(ckpt, map_location='cpu')
        model.load_state_dict(state_dict)

    return model

In [20]:
def get_num_correct(preds, labels):
    return preds.argmax(dim=1).eq(labels).sum().item()

In [21]:
@torch.no_grad()
def get_all_preds(model, loader):
    all_preds = torch.tensor([])
    for batch in loader:
        images, labels = batch

        preds = model(images)
        all_preds = torch.cat(
            (all_preds, preds)
            ,dim=0
        )
    return all_preds

In [22]:
model = scalenet50(structure_path='../structures/scalenet50.json')

In [None]:
model = model.to(DEVICE)

optimizer = torch.optim.SGD(model.parameters(), momentum=0.9, lr=0.1)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                       factor=0.1,
                                                       mode='max',
                                                       verbose=True)

minibatch_loss_list, train_acc_list, valid_acc_list = train_model(
    model=model,
    num_epochs=NUM_EPOCHS,
    train_loader=train_loader,
    valid_loader=valid_loader,
    test_loader=test_loader,
    optimizer=optimizer,
    device=DEVICE,
    scheduler=scheduler,
    scheduler_on='valid_acc',
    logging_interval=100)

plot_training_loss(minibatch_loss_list=minibatch_loss_list,
                   num_epochs=NUM_EPOCHS,
                   iter_per_epoch=len(train_loader),
                   results_dir=None,
                   averaging_iterations=200)
plt.show()

plot_accuracy(train_acc_list=train_acc_list,
              valid_acc_list=valid_acc_list,
              results_dir=None)
plt.ylim([60, 100])
plt.show()

