# Local self-attention (practice)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from torch.nn.modules.linear import Linear
import torch.nn as nn
from torchsummary import summary
import math

### unfold

In [None]:
# unfold = nn.Unfold(kernel_size=(2, 3))
B, C, H, W = 2, 4, 5, 6
K = 3

x = torch.randn(B, C, H, W)
out = x.unfold(2, K, 1)
print(out.shape)

## LocalSelfAttention module

In [None]:
class LocalSelfAttention(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, padding=0, bias=False):
        super(LocalSelfAttention, self).__init__()
        # in-class implementation
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.padding = padding
        
        # define projection layers
        # (step 1) query projection layer (linear layer with a sliding window)
        
        # (step 2) key projection layer (linear layer with a sliding window)
        
        # (step 3) value projection layer (linear layer with a sliding window)
        
        assert out_channels % 2 == 0
        
        # define learnable relative positional encodings
        # (step 4) positional encoding in height direction
        
            # self.rel = nn.Parameter(torch.randn(out_channels, 1, 1, 1, 1), requires_grad=True)
        # (step 5) positional encoding in width direction
        
        # Hint: initialize the positional encoding with normal distribution. torch.randn
        
    def forward(self, x):
        # in-class implementation
        batch, channels, height, width = x.size()
        
        # define pad
        # q_out : (B, C, H, W)
        # k_out : (B, C, H + padding, W + padding)
        # v_out : (B, C, H + padding, W + padding)

        # k_out: (B, C, H, W, K, K)
        # k_out[:, :, :, :, 0, 0]
        
        # k_out_h: (B, C//2, H, W, K, K)
        # self.rel_h: (C//2, 1, 1, K, 1), self.rel_w: (C//2, 1, 1, 1, K)
        
        # k_out = k_out + self.rel
        # (B, C, H, W, K, K)
        
        # reshape k_out : (B, C, H, W, K*K)
        # reshape v_out : (B, C, H, W, K*K)
        
        # reshape q_out : (B, C, H, W, 1)
        # (B, C, H, W, K*K)
        # scaling
        # v_out: (B, C, H, W, K*K)
        # out: (B, C, H, W, K*K)
        
        # out.shape == (B, C, H, W)
        return out

In [None]:
B, C, H, W = 2, 4, 5, 6
layer = LocalSelfAttention(C, C**2, kernel_size=3, padding=1)
x = torch.randn((B, C, H, W))
out = layer(x)

## Simple Nets

### Block

In [None]:
class SimpleBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(SimpleBlock, self).__init__()
        # in-class implementation: attn -> residual connection -> relu -> output
        
        
    def forward(self, x):
        # in-class implementation
        # if stride > 1, using avg_pool2d
        
        return out

### Simple Nets

In [None]:
class SimpleNet(nn.Module):
    def __init__(self, in_channels, num_classes):
        super(SimpleNet, self).__init__()
        self.stem = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU()
        )
        self.net = nn.Sequential(
            SimpleBlock(16, 16, stride=2),
            SimpleBlock(16, 16, stride=2),
            SimpleBlock(16, 16, stride=2),
            SimpleBlock(16, 32, stride=2) # 1/16
        )
        self.classifier = nn.Linear(32, num_classes)
        
    def forward(self, x):
        # (B, 3, 224, 224)
        out = self.stem(x)
        out = self.net(out) # (B, 32, 14, 14)
        out = F.avg_pool2d(out, 14)
        out = out.view(out.size(0), -1) # (B, 32)
        out = self.classifier(out)
        
        return out

### Load Train, Test Dataset and Define Train, Test Loader

In [None]:
# Loading and normalizing CIFAR-10
import torchvision
import torchvision.transforms as transforms

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Resize(224),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                      download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                         shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                      download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=2,
                                        shuffle=False, num_workers=2)
print(len(trainset))
print(len(testset))
classes = ('plane', 'car', 'bird', 'cat', 'deer', 
           'dog', 'frog', 'horse', 'ship', 'truck')

### For Visualization

In [None]:
# display some images

import matplotlib.pyplot as plt
import numpy as np

def imshow(img):
    img = img / 2 + 0.5  # Unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()
    
    
# get some random training images
dataiter = iter(trainloader)
images, labels = dataiter.next()

# show images
imshow(torchvision.utils.make_grid(images))

# print labels
print(' '.join('%5s' % classes[labels[j]] for j in range(4)))

### Define model

In [None]:
device = "cuda:1"
x = torch.randn(2,3,224,224).to(device)
model = SimpleNet(3, 10).to(device)
output = model(x)
print(output.shape)

### Define Loss and optimizer

In [None]:
import torch.optim as optim

# define a loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

### Training

In [None]:
### Train the network
print('Start Training ')
for epoch in range(2):
    running_loss = 0.0
    for i, data in enumerate(trainloader):
        # Fill this loop
        # get the inputs
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
        # print statistics
        if i % 50 == 0:
            print('[%d, %5d] loss: %.6f' %
                 (epoch + 1, i + 1, running_loss / 500))
            running_loss = 0.0

print('Finished Training')

In [None]:
dataiter = iter(testloader)
images, labels = dataiter.next()

#print images
imshow(torchvision.utils.make_grid(images))
print('GrondTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(2)))

In [None]:
images = images.to(device)
outputs = model(images)

_, predicted = torch.max(outputs, 1)

print('Predicted: ', ' '.join('%5s' % classes[predicted[j]] for j in range(2))) 