<a href="https://colab.research.google.com/github/achanhon/AdversarialModel/blob/master/Untitled20.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

##quantification
L'objectif de ce notebook est d'essayer d'effectuer un apprentissage float32 propice à être caster en int8.
Pour cela, une loss supplémentaire sur les poids (forcés à être près d'un entier) et l'utilisation d'une double activation relu + "cyclique".

Ce test préliminaire est réalisé sur CIFAR10.

### activation
Il convient d'abord de selectionner une activation continue qui a le même comportement en float et en uint8

In [1]:
import torch
import torchvision
import matplotlib.pyplot as plt

def cycle(x):
    assert x.dtype in [torch.half, torch.float,torch.uint8]
    if x.dtype != torch.uint8:
        x = x % 256
    return x - 2 * (x-128)*(x>128).to(dtype=x.dtype)

x = (torch.arange(1000)-500)
y = cycle(x.float())
z = cycle(x.to(dtype=torch.uint8)).float()
assert all(y==z)

y = cycle(x)
plt.figure(figsize=(8, 6))
plt.plot(x, y)
plt.grid(True)
plt.show()

AssertionError: ignored

### baseline
construisons maintenant un petit réseau qui pourrait absorber des poids (en partie) uint8

In [None]:
def channelPool(x):
    x1,x2 = x[:,::2,:,:],x[:,1::2,:,:]
    return torch.max(x1,x2)

def funconcat(x1,x2):
    n,c,h,w = x1.shape
    x = torch.zeros(n,2*c,h,w).to(dtype=x1.dtype)
    x[:,::2,:,:],x[:,1::2,:,:] = x1,x2
    return x

class Funblock(torch.nn.Module):
    def __init__(self,size):
        super(Funblock, self).__init__()
        self.conv = torch.nn.Conv2d(size, 2*size,kernel_size=3,padding=1,group=8)

    def forward(self,x):
        x1 = cycle(self.conv(x))
        x1 = channelPool(x1)
        return funconcat(x1,x)

In [None]:
class Baseline(torch.nn.Module):
    def __init__(self):
        super(Baseline, self).__init__()
        # very few operation: can be done in float
        self.proj1 = torch.nn.Conv2d(3, 16,kernel_size=3,padding=1)
        self.proj2 = torch.nn.Conv2d(16, 64,kernel_size=2,stride=2)
        self.proj3 = torch.nn.Conv2d(64, 32,kernel_size=1)

        # could be uint8
        self.c1 = Funblock(32)
        self.c2 = Funblock(64)
        self.c3 = Funblock(128)
        self.c4 = Funblock(256)
        self.c5 = Funblock(512)

        # very few operation: can be done in float
        self.f1 = torch.nn.Linear(512,1024)
        self.f2 = torch.nn.Linear(1024,2048)
        self.f3 = torch.nn.Linear(2048,2048)
        self.f4 = torch.nn.Linear(4096,10)

    def forward(self, x):
        p = torch.nn.functional.leaky_relu(self.proj1(x.float()))
        p = torch.nn.functional.leaky_relu(self.proj2(p))
        p = cycle(self.proj3(p)).to(dtype=x.dtype)

        x = cycle(self.c1(x))
        x = cycle(self.c2(x))
        x = torch.nn.functional.max_pool2d(x,kernel_size=2,stride=2)
        x = cycle(self.c3(x))
        x = cycle(self.c4(x))
        x = cycle(self.c5(x))
        x = torch.nn.functional.max_pool2d(x,kernel_size=8,stride=8)

        x = x[:,:,0,0].float()

        x = torch.nn.functional.leaky_relu(self.f1(x))
        x = torch.nn.functional.leaky_relu(self.f2(x))
        x = torch.cat([x,torch.nn.functional.leaky_relu(self.f3(x))],dim=1)
        return self.f4(x)

net = Baseline()
with torch.no_grad():
    print(net(torch.rand(2,3,32,32)).shape)