## Week 3 Assignment: Implement a Quadratic Layer

In this week's programming exercise, you will build a custom quadratic layer which computes y = ax2 + bx + c. Similar to the ungraded lab, this layer will be plugged into a model that will be trained on the MNIST dataset. Let's get started!

## Imports

In [None]:
import math
import numpy as np
import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
from torchvision.datasets import MNIST
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device: ", device)

device:  cuda


## Define the quadratic layer (TODO)

Implement a simple quadratic layer. It has 3 state variables: a, b and c. The computation returned is ax^2 + bx +c. Make sure it can also accept an activation function.

In [3]:
# ax**2 + bx + c
class SimpleQuadratic(nn.Module):
    def __init__(self, in_features, out_features, bias=True, activation=None, device=None, dtype=None):
        super(SimpleQuadratic, self).__init__()
        
        factory_kwargs = {'device': device, 'dtype': dtype}
        
        # Input
        self.in_features = in_features
        
        # Bias
        self.bias = bias
        
        self.apply_activation = False
        
        # Activation
        if activation is not None:
            self.apply_activation = True
            self.activation = getattr(nn.functional, activation)
        
        # Weight
        self.weight_a = Parameter(torch.empty((out_features, in_features), **factory_kwargs))
        self.weight_b = Parameter(torch.empty((out_features, in_features), **factory_kwargs))
        self.weight_c = Parameter(torch.zeros((out_features, in_features), **factory_kwargs))
        
        # Bias
        if bias:
            self.bias_a = Parameter(torch.empty(out_features, **factory_kwargs))
            self.bias_b = Parameter(torch.empty(out_features, **factory_kwargs))
            self.bias_c = Parameter(torch.empty(out_features, **factory_kwargs))
        else:
            self.register_parameter('bias', None)
        
        # Weight and Biase initialization
        self._reset_parameters()
    
    def forward(self, input):
        x, y = input.shape
        if y != self.in_features:
            print(f'Wrong Input Features. Please use tensor with {self.in_features} Input Features')
            return 0
        
        # output = input.matmul(self.weight.t())
        output = torch.matmul(torch.square(input), self.weight_a.t()) + torch.matmul(input, self.weight_b.t())
        
        if self.bias is not None:
            output += self.bias_a
            output += self.bias_b
            output += self.bias_c
        ret = output
        
        if self.apply_activation:
            return self.activation(ret)
        
        return ret
    
    def _reset_parameters(self):
        torch.nn.init.kaiming_uniform_(self.weight_a, a=math.sqrt(5))
        torch.nn.init.kaiming_uniform_(self.weight_b, a=math.sqrt(5))
        
        if self.bias is not None:
            fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight_a)
            bound = 1 / math.sqrt(fan_in)
            torch.nn.init.uniform_(self.bias_a, -bound, bound)
        
        if self.bias is not None:
            fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight_b)
            bound = 1 / math.sqrt(fan_in)
            torch.nn.init.uniform_(self.bias_b, -bound, bound)
        
        if self.bias is not None:
            fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight_c)
            bound = 1 / math.sqrt(fan_in)
            torch.nn.init.uniform_(self.bias_c, -bound, bound)

## Prepare the Dataset

In [4]:
# Image Transform
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5,), std=(0.5,))
])

In [5]:
# Load Dataset
train_data = MNIST(root='./', train=True, download=True, transform=transform)
test_data = MNIST(root='./', train=False, download=True, transform=transform)

  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [6]:
# DataLoader
train_loader = DataLoader(dataset=train_data,
                          batch_size=64,
                          shuffle=True,
                          num_workers=2,
                          pin_memory=True)
val_loader = DataLoader(dataset=test_data,
                        batch_size=64,
                        shuffle=True,
                        num_workers=2,
                        pin_memory=True)

## Train the Model

In [7]:
# Build the Model
model = nn.Sequential(
    SimpleQuadratic(in_features=784, out_features=128, activation="relu"),
    nn.Dropout(0.2),
    # SimpleQuadratic(in_features=128, out_features=10),
    nn.Linear(in_features=128, out_features=10),
    nn.LogSoftmax(dim=1),
)
model.to(device)

Sequential(
  (0): SimpleQuadratic()
  (1): Dropout(p=0.2, inplace=False)
  (2): Linear(in_features=128, out_features=10, bias=True)
  (3): LogSoftmax(dim=1)
)

In [8]:
# Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()

In [9]:
# Train the Model
EPOCHS = 5

model.train()

for epoch in range(EPOCHS):
    running_loss = 0
    correct = 0
    
    for data in train_loader:
        images, labels = data
        images = images.view(images.shape[0], -1)
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        
        output = model(images)
        loss = criterion(output, labels)
        loss.backward()
        optimizer.step()
        
        pred = output.data.max(1, keepdim=True)[1]
        correct += pred.eq(labels.data.view_as(pred)).cpu().sum()
        
        running_loss += loss.item()
    
    print(f"Epoch: {epoch}, loss: {running_loss/len(train_loader)}, accuracy: {correct/len(train_loader.dataset)}")


# Evaluate Trained Model
running_loss = 0
correct = 0
    
model.eval()
for data in val_loader:
    images, labels = data
    images = images.view(images.shape[0], -1)
    images, labels = images.to(device), labels.to(device)

    output = model(images)
    loss = criterion(output, labels)

    pred = output.data.max(1, keepdim=True)[1]
    correct += pred.eq(labels.data.view_as(pred)).cpu().sum()

    running_loss += loss.item()

print(f"\nValidation - loss: {running_loss/len(val_loader)}, accuracy: {correct/len(val_loader.dataset)}")

Epoch: 0, loss: 0.4934732016882917, accuracy: 0.8494166731834412
Epoch: 1, loss: 0.2807143531612623, accuracy: 0.9145833253860474
Epoch: 2, loss: 0.23896014608585758, accuracy: 0.9272500276565552
Epoch: 3, loss: 0.21361781856906947, accuracy: 0.9348499774932861
Epoch: 4, loss: 0.19977718043619636, accuracy: 0.9381833076477051

Validation - loss: 0.1441410919710709, accuracy: 0.9562000036239624
