<a href="https://colab.research.google.com/github/Sairam954/QuantizationAwareTrainingPCM/blob/master/QATPCM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
pip install git+https://github.com/Xilinx/brevitas.git

Collecting git+https://github.com/Xilinx/brevitas.git
  Cloning https://github.com/Xilinx/brevitas.git to /tmp/pip-req-build-wsyhbjtw
  Running command git clone -q https://github.com/Xilinx/brevitas.git /tmp/pip-req-build-wsyhbjtw
Collecting docrep
  Downloading https://files.pythonhosted.org/packages/dd/4a/ac09d6e07713e22baa4ab4e6f422d25e53425f3dc042616387dfbc272504/docrep-0.2.7.tar.gz
Building wheels for collected packages: Brevitas, docrep
  Building wheel for Brevitas (setup.py) ... [?25l[?25hdone
  Created wheel for Brevitas: filename=Brevitas-0.2.0a0-cp36-cp36m-linux_x86_64.whl size=1748363 sha256=879092cd9e31e8b8fcc77e1cd031c15ba0bac0091352199234b526a92d5946a1
  Stored in directory: /tmp/pip-ephem-wheel-cache-igs3wal0/wheels/7b/ba/1b/b3bebdeb51db39fc118c4d60ef8556d8a9ab0f1bfda8767a3d
  Building wheel for docrep (setup.py) ... [?25l[?25hdone
  Created wheel for docrep: filename=docrep-0.2.7-cp36-none-any.whl size=23003 sha256=7192537d9a6e10ac52deb9b33c8ec33aeea694c2ab9868240

In [0]:
import torchvision
import torchvision.transforms as transforms
import torch
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.quantization

In [0]:
from torch.nn import Module
import torch.nn.functional as F
import brevitas.nn as qnn
from brevitas.core.quant import QuantType

class QuantLeNet(Module):
    def __init__(self):
        super(QuantLeNet, self).__init__()
        self.conv1 = qnn.QuantConv2d(1, 6, 5, 
                                     weight_quant_type=QuantType.FP, 
                                     weight_bit_width=4)
        self.relu1 = qnn.QuantReLU(quant_type=QuantType.FP, bit_width=4, max_val=6)
        self.conv2 = qnn.QuantConv2d(6, 16, 5, 
                                     weight_quant_type=QuantType.FP, 
                                     weight_bit_width=4)
        self.relu2 = qnn.QuantReLU(quant_type=QuantType.FP, bit_width=4, max_val=6)
        self.fc1   = qnn.QuantLinear(256, 120, bias=True, 
                                     weight_quant_type=QuantType.FP, 
                                     weight_bit_width=4)
        self.relu3 = qnn.QuantReLU(quant_type=QuantType.FP, bit_width=4, max_val=6)
        self.fc2   = qnn.QuantLinear(120, 84, bias=True, 
                                     weight_quant_type=QuantType.FP, 
                                     weight_bit_width=4)
        self.relu4 = qnn.QuantReLU(quant_type=QuantType.FP, bit_width=4, max_val=6)
        self.fc3   = qnn.QuantLinear(84, 10, bias=False, 
                                     weight_quant_type=QuantType.FP, 
                                     weight_bit_width=4)

    def forward(self, x):
        out = self.relu1(self.conv1(x))
        out = F.max_pool2d(out, 2)
        out = self.relu2(self.conv2(out))
        out = F.max_pool2d(out, 2)
        out = out.view(out.size(0), -1)
        out = self.relu3(self.fc1(out))
        out = self.relu4(self.fc2(out))
        out = self.fc3(out)
        return out
    def printcheck(self):
        print(self.conv1.shape)
        print(self.conv2.shape)
        


    
      


In [0]:
def evaluation(dataloader,model):
    total, correct = 0, 0
    for data in dataloader:
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        _, pred = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (pred == labels).sum().item()
    return 100 * correct / total


In [0]:
def updateWeightsWithTimePCM(model,t):
        vmin = 0.1
        vmax =0.15
        t0= 1e-3
        k = 0.2
        g0 = 0.5e-06
        #weight and resistance drift PCM model
        minweight = -1
        maxweight= 1
        minmaxweightdiff= maxweight-minweight
        minconductance = 0.554e-06
        maxconductance = 4.762e-06
        minmaxconductancedratio = minconductance/maxconductance
        newmodelstate = model.state_dict()

        for name,currentweight in model.named_parameters():
            with torch.no_grad(): 
            
              v0=torch.from_numpy(np.random.uniform(vmin,vmax,currentweight.shape))
              v_of_w= v0 + k*np.log((g0/maxconductance)/(((currentweight-maxweight)/minmaxweightdiff)*(1-minmaxconductancedratio)+1))
              newweight = currentweight + (currentweight+(minmaxweightdiff/(1-minmaxconductancedratio)- maxweight))*((t/t0)**(-v_of_w)-1)
              newmodelstate[name] = newweight
        return newmodelstate

In [0]:
batch_size = 128
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)
testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)

In [0]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

intmodel = QuantLeNet()
loss_fn = nn.CrossEntropyLoss()
opt = optim.Adam(intmodel.parameters())

cpu


In [0]:
%%time
max_epochs = 5


for epoch in range(max_epochs):

    for i, data in enumerate(trainloader, 0):

        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        opt.zero_grad()

        outputs = intmodel(inputs)
        loss = loss_fn(outputs, labels)
        loss.backward()
        opt.step()
        #net.apply(clamp)
        
    print('Epoch: %d/%d' % (epoch, max_epochs))





Epoch: 0/5
Epoch: 1/5
Epoch: 2/5
Epoch: 3/5
Epoch: 4/5
CPU times: user 1min 25s, sys: 2.12 s, total: 1min 27s
Wall time: 1min 28s


In [0]:
print('Test acc: %0.2f, Train acc: %0.2f' % (evaluation(testloader,intmodel), evaluation(trainloader,intmodel)))

Test acc: 98.73, Train acc: 98.76


In [0]:
driftedmodel = QuantLeNet()
driftedmodelstatedict = updateWeightsWithTimePCM(intmodel,1e-03)
driftedmodel.load_state_dict(driftedmodelstatedict)


<All keys matched successfully>

In [0]:
print(driftedmodel.state_dict().keys())
print(intmodel.state_dict()['fc3.weight'])
print("====================================================")
print(driftedmodel.state_dict()['fc3.weight'])

odict_keys(['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight'])
tensor([[ 5.7502e-02, -4.7653e-02, -8.4562e-02, -2.5373e-02,  5.7231e-02,
         -2.3991e-01, -5.8743e-02,  1.3909e-01,  3.4454e-02, -1.2096e-02,
         -3.5502e-02, -1.2316e-01, -6.2097e-03, -7.2736e-02,  7.9141e-02,
          1.6031e-02,  1.3602e-01, -1.0817e-01, -1.3844e-01,  2.8183e-02,
         -1.0384e-01, -2.7578e-01,  1.7877e-01, -1.6349e-01, -1.2095e-01,
          2.6431e-02, -1.1178e-01, -2.9428e-02,  6.2005e-02, -1.5256e-01,
          6.1777e-02, -7.3393e-02, -1.9656e-01,  6.7498e-02, -1.2652e-02,
          1.1367e-01, -3.6797e-02, -2.3679e-02,  1.5724e-01,  1.0304e-01,
         -1.5883e-01, -7.2645e-02, -9.6742e-02,  6.9922e-02,  2.8612e-02,
          7.6805e-02,  5.4769e-02,  5.9486e-02, -3.2537e-02,  1.4885e-01,
         -4.3388e-02,  2.4472e-02, -4.7109e-02,  2.8444e-02,  2.4384e-03,
         -1.3648e-01,  1.1871e-02, -6.9804e-02,

In [0]:
print('Test acc: %0.2f, Train acc: %0.2f' % (evaluation(testloader,driftedmodel), evaluation(trainloader,driftedmodel)))

Test acc: 98.73, Train acc: 98.76


In [0]:
print('Test acc: %0.2f, Train acc: %0.2f' % (evaluation(testloader,intmodel), evaluation(trainloader,intmodel)))

Test acc: 98.92, Train acc: 99.69
