In [None]:
# header files
import numpy as np
import torch
import torch.nn as nn
import torchvision

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# define transforms
train_transforms = torchvision.transforms.Compose([torchvision.transforms.RandomCrop(32, padding=4),
                                       torchvision.transforms.RandomHorizontalFlip(),
                                       torchvision.transforms.ToTensor(),
                                       torchvision.transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])])

val_transforms = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
                                       torchvision.transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])])

In [None]:
# dataset
train_dataset = torchvision.datasets.CIFAR10(root="/content/drive/My Drive/CIFAR10_train", train=True, transform=train_transforms, download=True)
val_dataset = torchvision.datasets.CIFAR10("/content/drive/My Drive/CIFAR10_val", train=False, transform=val_transforms, download=True)

In [None]:
# dataloader
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=16)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=128, shuffle=False, num_workers=16)

In [None]:
# header files
import math
import torch
import torch.nn as nn


class ARMA2d(nn.Module):
    def __init__(self, in_channels, out_channels, w_kernel_size=3, w_padding_mode='zeros', w_padding=0, w_stride=1, w_dilation=1, w_groups=1, bias=False, a_kernel_size=3, a_padding_mode='circular', a_padding=0, a_stride=1, a_dilation=1):
        """
            Initialization of 2D-ARMA layer.
        """
        super(ARMA2d, self).__init__()
        self.moving_average = nn.Conv2d(in_channels, out_channels, w_kernel_size, padding=w_padding, stride=w_stride, dilation=w_dilation, groups=w_groups, bias=bias)
        self.autoregressive = AutoRegressive2d(out_channels, a_kernel_size, padding=a_padding, padding_mode=a_padding_mode, stride=a_stride, dilation=a_dilation)

    def forward(self, x):
        """
            Computation of 2D-ARMA layer.
        """
        x = self.moving_average(x)
        x = self.autoregressive(x)
        return x


class AutoRegressive2d(nn.Module):
    def __init__(self, channels, kernel_size=3, padding=0, padding_mode='circular', stride=1, dilation=1):
        """
            Initialization of 2D-AutoRegressive layer.
        """
        super(AutoRegressive2d, self).__init__()

        if padding_mode == "circular":
            self.a = AutoRegressive_circular(channels, kernel_size, padding, stride, dilation)
        elif padding_mode == "reflect":
            self.a = AutoRegressive_reflect(channels, kernel_size, padding, stride, dilation)
        else: 
            raise NotImplementedError

    def forward(self, x):
        """
            Computation of 2D-AutoRegressive layer.
        """
        x = self.a(x)
        return x

   
class AutoRegressive_circular(nn.Module):
    def __init__(self, channels, kernel_size=3, padding=0, stride=1, dilation=1):
        """
            Initialization of a 2D-AutoRegressive layer (with circular padding).
        """
        super(AutoRegressive_circular, self).__init__()
        self.alpha = nn.Parameter(torch.Tensor(channels, kernel_size//2, 4))
        self.set_parameters()

    def set_parameters(self):
        """
            Initialization of the learnable parameters.
        """
        nn.init.zeros_(self.alpha)

    def forward(self, x):
        """
            Computation of the 2D-AutoRegressive layer (with circular padding). 
        """    
        x = autoregressive_circular(x, self.alpha)
        return x


def autoregressive_circular(x, alpha):
    """
        Computation of a 2D-AutoRegressive layer (with circular padding).
    """

    if  x.size()[-2] < alpha.size()[1] * 2 + 1 or x.size()[-1] < alpha.size()[1] * 2 + 1:
        return x

    # There're 4 chunks, each chunk is [T, P, 1]
    alpha = alpha.tanh() / math.sqrt(2)
    chunks = torch.chunk(alpha, alpha.size()[-1], -1)

    # size: [T, P, 1]
    A_x_left  = (chunks[0]*math.cos(-math.pi/4)-chunks[1]*math.sin(-math.pi/4))
    A_x_right = (chunks[0]*math.sin(-math.pi/4)+chunks[1]*math.cos(-math.pi/4))
    A_y_left  = (chunks[2]*math.cos(-math.pi/4)-chunks[3]*math.sin(-math.pi/4))
    A_y_right = (chunks[2]*math.sin(-math.pi/4)+chunks[3]*math.cos(-math.pi/4))

    # zero padding + circulant shift: 
    # [A_x_left 1 A_x_right] -> [1 A_x_right 0 0 ... 0 A_x_left]
    # size: [T, P, 3]->[T, P, I1] or [T, P, I2]
    A_x = torch.cat((torch.ones(chunks[0].size(), device=alpha.device), A_x_right, torch.zeros(chunks[0].size()[0], chunks[0].size()[1], x.size()[-2] - 3, device=alpha.device), A_x_left), -1)
    A_y = torch.cat((torch.ones(chunks[2].size(), device = alpha.device), A_y_right, torch.zeros(chunks[2].size()[0], chunks[2].size()[1], x.size()[-1] - 3, device=alpha.device), A_y_left), -1)

    # size: [T, P, I1] + [T, P, I2] -> [T, P, I1, I2]
    A = torch.einsum('tzi,tzj->tzij',(A_x, A_y))

    # Complex Division: FFT/FFT -> irFFT
    A_s = torch.chunk(A, A.size()[1], 1)
    for i in range(A.size()[1]):
        x = ar_circular.apply(x, torch.squeeze(A_s[i], 1))
    return x


def complex_division(x, A, trans_deno=False):
    a, b = torch.chunk(x, 2, -1)
    c, d = torch.chunk(A, 2, -1)

    if trans_deno: 
            # [a bj] / [c -dj] -> [ac-bd/(c^2+d^2) (bc+ad)/(c^2+d^2)j]
        res_l = (a * c - b * d) / (c * c + d * d)
        res_r = (b * c + a * d) / (c * c + d * d)
    else:   # [a bj] / [c  dj] -> [ac+bd/(c^2+d^2) (bc-ad)/(c^2+d^2)j]
        res_l = (a * c + b * d) / (c * c + d * d)
        res_r = (b * c - a * d) / (c * c + d * d)
    res = torch.cat((res_l, res_r), -1)
    return res


def complex_multiplication(x, A, trans_deno=False):
    a, b = torch.chunk(x, 2, -1)
    c, d = torch.chunk(A, 2, -1)

    if trans_deno:
            # [a bj]*[c -dj] -> [ac+bd (bc-ad)j]
        res_l = a * c + b * d
        res_r = b * c - a * d
    else:   # [a bj]*[c  dj] -> [ac-bd (ad+bc)j]
        res_l = a * c - b * d
        res_r = b * c + a * d
    res = torch.cat((res_l,res_r),  -1)
    return res

    
class ar_circular(torch.autograd.Function):

    # x size: [M, T, I1, I2]
    # a size:[T, I1, I2]
    @staticmethod
    def forward(ctx, x, a):
        X = torch.rfft(x, 2, onesided=False)  # size:[M, T, I1, I2, 2]
        A = torch.rfft(a, 2, onesided=False)  # size:[T, I1, I2, 2]
        Y = complex_division(X, A)  # size:[M, T, I1, I2, 2]
        y = torch.irfft(Y, 2, onesided=False)  # size:[M, T, I1, I2]

        ctx.save_for_backward(A, Y)
        return y

    @staticmethod
    def backward(ctx, grad_y):
        """
            {grad_a} * a^T    = - grad_y  * y^T
            [T, I1, I2]   * [T, I1, I2] = [M, T, I1, I2] * [M, T, I1, I2]
            a^T    * {grad_x}     = grad_y
            [T, I1, I2] * [M, T, I1, I2]   = [M, T, I1, I2]
            intermediate = grad_y / a^T
        """
        A, Y = ctx.saved_tensors
        grad_x = grad_a = None  

        grad_Y = torch.rfft(grad_y, 2, onesided = False)
        intermediate = complex_division(grad_Y, A, trans_deno = True)  # size: [M, T, I1, I2]
        grad_x = torch.irfft(intermediate, 2, onesided=False)

        intermediate = - complex_multiplication(intermediate, Y, trans_deno = True) # size: [M, T, I1, I2]
        grad_a = torch.irfft(intermediate.sum(0), 2, onesided = False)  # size:[T, I1, I2]
        return grad_x, grad_a

In [None]:
# header files
import torch
import torch.nn as nn
import torch.utils.model_zoo as model_zoo


# required functions
__all__ = ['resnet18', 'resnet34', 'resnet50', 'resnet101']

model_with_arma_files = {
    'resnet18': '.pth',
    'resnet34': '.pth',
    'resnet50': '.pth',
    'resnet101': '.pth',
    'resnet152': '.pth',
}

model_without_arma_files = {
    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}


def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1, arma=False):
    """
        in_planes: number of input channels
        out_planes: number of output channels
        stride: the stride value
        groups: the number of groups
        dilation: the gap between kernel cells
        arma: True, then arma layer applied, otherwise conv layer
    """
    if arma:
      return ARMA2d(in_planes, out_planes, w_kernel_size=3, w_padding=dilation, w_stride=stride, w_groups=groups, w_dilation=dilation, bias=False)
    else:
      return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, groups=groups, bias=False, dilation=dilation)

def conv1x1(in_planes, out_planes, stride=1, arma=False):
    """
        in_planes: number of input channels
        out_planes: number of output channels
        stride: the stride value
        arma: True, then arma layer applied, otherwise conv layer
    """
    if arma:
      return ARMA2d(in_planes, out_planes, w_kernel_size=1, w_stride=stride, w_padding=0, bias=False)
    else:
      return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class BasicBlock(nn.Module):
    expansion = 1
    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64, dilation=1, norm_layer=None, arma=False):
        super(BasicBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if groups != 1 or base_width != 64:
            raise ValueError('BasicBlock only supports groups=1 and base_width=64')
        if dilation > 1:
            raise NotImplementedError("Dilation>1 not supported in BasicBlock")

        self.conv1 = conv3x3(inplanes, planes, stride, arma=arma)
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes, arma=arma)
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        if self.downsample is not None:
            identity = self.downsample(x)
        out += identity
        out = self.relu(out)
        return out


class Bottleneck(nn.Module):
    expansion = 4
    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64, dilation=1, norm_layer=None, arma=False):
        super(Bottleneck, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        width = int(planes*(base_width/64.))*groups
        
        self.conv1 = conv1x1(inplanes, width, arma=arma)
        self.bn1 = norm_layer(width)
        self.conv2 = conv3x3(width, width, stride, groups, dilation, arma=arma)
        self.bn2 = norm_layer(width)
        self.conv3 = conv1x1(width, planes*4, arma=arma)
        self.bn3 = norm_layer(planes*4)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)
        out = self.conv3(out)
        out = self.bn3(out)
        if self.downsample is not None:
            identity = self.downsample(x)
        out += identity
        out = self.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, layers, num_classes=10, zero_init_residual=False, groups=1, width_per_group=64, replace_stride_with_dilation=None, norm_layer=None, arma=False):
        super(ResNet, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer

        self.inplanes = 64
        self.dilation = 1
        if replace_stride_with_dilation is None:
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError("replace_stride_with_dilation should be None or a 3-element tuple, got {}".format(replace_stride_with_dilation))
        self.groups = groups
        self.base_width = width_per_group
        if(arma):
            self.conv1 = ARMA2d(3, self.inplanes, w_kernel_size=3, w_stride=1, w_padding=1, bias=False)
        else:
            self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = norm_layer(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.layer1 = self._make_layer(block, 64, layers[0], arma=arma)
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0], arma=arma)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1], arma=arma)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2], arma=arma)
        self.avgpool = nn.AdaptiveAvgPool2d((4, 4))
        self.fc = nn.Linear(512*16*block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

    def _make_layer(self, block, planes, blocks, stride=1, dilate=False, arma=False):
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
        if stride != 1 or self.inplanes != planes*block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes*block.expansion, stride, arma=arma),
                norm_layer(planes*block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer, arma=arma))
        self.inplanes = planes*block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes, groups=self.groups, base_width=self.base_width, dilation=self.dilation, norm_layer=norm_layer, arma=arma))
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x


def _resnet(arch, block, layers, arma=False, pretrained_with_arma=False, **kwargs):
    # load resnet
    model = ResNet(block, layers, arma=arma, **kwargs)

    # if pretrained with arma
    if pretrained_with_arma and arma:
        model.load_state_dict(model_with_arma_files[arch])

    # if pretrained without arma
    #if arma == False:
    #    model.load_state_dict(model_zoo.load_url(model_without_arma_files[arch]))
    return model


def resnet18(arma=False, pretrained_with_arma=False, **kwargs):
    r"""ResNet-34 model from
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], arma=arma, pretrained_with_arma=pretrained_with_arma, **kwargs)


def resnet34(arma=False, pretrained_with_arma=False, **kwargs):
    r"""ResNet-34 model from
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], arma=arma, pretrained_with_arma=pretrained_with_arma, **kwargs)


def resnet50(arma=False, pretrained_with_arma=False, **kwargs):
    r"""ResNet-50 model from
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], arma=arma, pretrained_with_arma=pretrained_with_arma, **kwargs)


def resnet101(arma=False, pretrained_with_arma=False, **kwargs):
    r"""ResNet-101 model from
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], arma=arma, pretrained_with_arma=pretrained_with_arma, **kwargs)

In [None]:
# model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = resnet50(arma=False)
model.to(device)
print(model)

In [None]:
# Cross-Entropy loss with Label Smoothing
class CrossEntropyLabelSmoothingLoss(nn.Module):
    
    def __init__(self, smoothing=0.0):
        super(CrossEntropyLabelSmoothingLoss, self).__init__()
        self.smoothing = smoothing
    
    def forward(self, pred, target):
        log_prob = torch.nn.functional.log_softmax(pred, dim=-1)
        weight = input.new_ones(pred.size()) * (self.smoothing/(pred.size(-1)-1.))
        weight.scatter_(-1, target.unsqueeze(-1), (1.-self.smoothing))
        loss = (-weight * log_prob).sum(dim=-1).mean()
        return loss

In [None]:
# define loss
criterion = CrossEntropyLabelSmoothingLoss(0.0)

In [None]:
# optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=0.05, momentum=0.95, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=150, gamma=0.1)

In [None]:
train_losses = []
train_acc = []
val_losses = []
val_acc = []
best_metric = -1
best_metric_epoch = -1

# train and validate
for epoch in range(0, 500):
    
    # train
    model.train()
    training_loss = 0.0
    total = 0
    correct = 0
    for i, (input, target) in enumerate(train_loader):
        
        input = input.to(device)
        target = target.to(device)

        optimizer.zero_grad()
        output = model(input)
        loss = criterion(output, target)
        
        loss.backward()
        optimizer.step()
        
        training_loss = training_loss + loss.item()
        _, predicted = output.max(1)
        total += target.size(0)
        correct += predicted.eq(target).sum().item()
        
    training_loss = training_loss / float(len(train_loader))
    training_accuracy = str(100.0 * (float(correct) / float(total)))
    train_losses.append(training_loss)
    train_acc.append(training_accuracy)
    
    # validate
    model.eval()
    valid_loss = 0.0
    total = 0
    correct = 0
    for i, (input, target) in enumerate(val_loader):
        
        with torch.no_grad():
            input = input.to(device)
            target = target.to(device)

            output = model(input)
            loss = criterion(output, target)
            _, predicted = output.max(1)
            total += target.size(0)
            correct += predicted.eq(target).sum().item()
            
        valid_loss = valid_loss + loss.item()
    valid_loss = valid_loss / float(len(val_loader))
    valid_accuracy = str(100.0 * (float(correct) / float(total)))
    val_losses.append(valid_loss)
    val_acc.append(valid_accuracy)
    scheduler.step()


    # store best model
    if(float(valid_accuracy)>best_metric and epoch>=10):
      best_metric = float(valid_accuracy)
      best_metric_epoch = epoch
      torch.save(model.state_dict(), "/content/drive/My Drive/best_model_resnet50_cifar10.pth")
    
    print()
    print("Epoch" + str(epoch) + ":")
    print("Training Accuracy: " + str(training_accuracy) + "    Validation Accuracy: " + str(valid_accuracy))
    print("Training Loss: " + str(training_loss) + "    Validation Loss: " + str(valid_loss))
    print("Best Validation Accuracy: " + str(best_metric))
    print()

In [None]:
import matplotlib.pyplot as plt

e = []
for index in range(0, 500):
  e.append(index)

In [None]:
plt.plot(e, train_acc)

In [None]:
plt.plot(e, val_acc)