In [None]:
#imports
import numpy as np
import torch
import torch.nn as nn
import pickle
import torchvision
from typing import Union, List, Dict, Any, cast
from torchvision.transforms import transforms
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import Caltech256
from torchvision.models import VGG

In [None]:
# class definition
class VGGNormModel(torchvision.models.VGG):
    
    pass

# function that init layer
def make_layers(cfg: List[Union[str, int]], batch_norm: bool = False, norm_layer = None) -> nn.Sequential:
    layers: List[nn.Module] = []
    in_channels = 3
    for v in cfg:
        if v == 'M':
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        else:
            v = cast(int, v)
            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
            # v is the output channel
            if batch_norm:
                if norm_layer is None:
                    raise Error("Please specify a norm layer")
                # @group if want to use this, please refer to the higher order function
                # in the next block
                layers += [conv2d, norm_layer(v//2, v)(), nn.ReLU(inplace=True)]
            else:
                layers += [conv2d, nn.ReLU(inplace=True)]
            in_channels = v
    return nn.Sequential(*layers)

def make_vgg(arch: str, cfg: str, batch_norm: bool, pretrained: bool, progress: bool, norm_layer=None, num_classes = None, **kwargs: Any) -> VGG:
    cfgs: Dict[str, List[Union[str, int]]] = {
        'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
        'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
        'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
        'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
    }
    if pretrained:
        kwargs['init_weights'] = False
    model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm, norm_layer=norm_layer), num_classes = num_classes, **kwargs)
    if pretrained:
        raise NotImplementedError()
        state_dict = load_state_dict_from_url(model_urls[arch],
                                              progress=progress)
        model.load_state_dict(state_dict)
    return model

def vgg11_gn(pretrained: bool = False, progress: bool = True, norm_layer = None, num_classes = None, **kwargs: Any) -> VGG:
    r"""
    Makes the group norm version of VGG11
    VGG 11-layer model (configuration "A") with batch normalization
    `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_.
    The required minimum input size of the model is 32x32.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    assert num_classes is not None, "give a number of class in accordance to dataset"
    return make_vgg('vgg11_bn', 'A', True, pretrained, progress, norm_layer = norm_layer, num_classes = num_classes, **kwargs)



In [None]:
# group norm exp
def get_group_norm_layer(in_channel, out_channel):
    def fun():
        return nn.GroupNorm(in_channel, out_channel)
    return fun


input = torch.randn(20, 6, 10, 10)
# Separate 6 channels into 3 groups
m = nn.GroupNorm(3, 6)
# Separate 6 channels into 6 groups (equivalent with InstanceNorm)
#m = nn.GroupNorm(6, 6)
# Put all 6 channels into a single group (equivalent with LayerNorm)
#m = nn.GroupNorm(1, 6)
# Activating the module
output = m(input)
print(output.shape)



In [None]:
# dataset loading code
class GreyscaleToRGBTransform(object):    
    def __call__(self, image):  
        if image.shape[0] == 1:
            return transforms.Lambda(lambda x: x.repeat(3, 1, 1))(image)
        return image
    
def get_torchvision_dataset(batch_size):
    # define transforms
    train_transform = transforms.Compose(
        [transforms.Resize((224, 224)),
         #transforms.Lambda(lambda x: x.repeat(3, 1, 1))  if x.shape[0] == 1  else NoneTransform(),                
         transforms.ToTensor(),
         GreyscaleToRGBTransform(),
         transforms.Normalize(mean=[0.485, 0.456, 0.406],
                              std=[0.229, 0.224, 0.225])])
    val_transform = transforms.Compose(
        [transforms.Resize((224, 224)),
         transforms.ToTensor(),
         transforms.Normalize(mean=[0.485, 0.456, 0.406],
                              std=[0.229, 0.224, 0.225])])

    # download link is broken
    dataset = Caltech256(root="../data", download=False, transform=train_transform)
    #print(dataset)
    train_set, val_set = torch.utils.data.random_split(dataset, [24487, 6122])
    train_set = torch.utils.data.DataLoader(train_set, batch_size=batch_size)
    val_set = torch.utils.data.DataLoader(train_set, batch_size=batch_size)
    return train_set, val_set


In [None]:
def train(args, model, device, train_loader, optimizer, epoch):
    model = model.train()
    model = model.to(device)
    for batch_idx, (data, target) in enumerate(train_loader):
        target = torch.as_tensor(target) # caltech256 target is int
        data, target = data.to(device), target.to(device)        
        optimizer.zero_grad()
        output = model(data)
        #print(output.shape)
        #print(target.shape)
        loss = torch.nn.CrossEntropyLoss()(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % args["log_interval"] == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
            if args["dry_run"]:
                break

In [None]:
def run_tests(args):
    
    datasets = [get_torchvision_dataset(args["batch_size"])]
    device = torch.device(args["device"])
    num_classess = [257]
    for num_classes, (train_set, val_set) in zip(num_classess, datasets):
        models = [vgg11_gn(norm_layer = get_group_norm_layer, num_classes = num_classes)]
        for model in models:
            # TODO: reload model paramter
            optimizer = torch.optim.Adam(model.parameters(), lr=args["lr"])
            trained_model = train(args, model, args["device"], train_set, optimizer, args["epoch"])
            


In [None]:
def main():
    # in theory load from cmd, but ... jupyter
    args = dict()
    args["device"] = "cuda" 
    args["lr"] = 1e-6 # learning rate
    args["epoch"] = 10
    args["batch_size"] = 2
    args["log_interval"] = 100
    args["dry_run"] = False
    run_tests(args)

In [None]:
main()