In [1]:
import torch
import torch.nn as nn
import torchvision.models as models
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F
        
from tqdm import tqdm

!pip install torchviz

Collecting torchviz
  Downloading torchviz-0.0.2.tar.gz (4.9 kB)
  Preparing metadata (setup.py) ... [?25ldone
Building wheels for collected packages: torchviz
  Building wheel for torchviz (setup.py) ... [?25ldone
[?25h  Created wheel for torchviz: filename=torchviz-0.0.2-py3-none-any.whl size=4131 sha256=5c653786d9557f67b4b82eb399c45aa694ec80d512349524e06682be5271580a
  Stored in directory: /root/.cache/pip/wheels/4c/97/88/a02973217949e0db0c9f4346d154085f4725f99c4f15a87094
Successfully built torchviz
Installing collected packages: torchviz
Successfully installed torchviz-0.0.2


In [2]:
!nvidia-smi

Mon Apr 22 15:45:43 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.129.03             Driver Version: 535.129.03   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla T4                       Off | 00000000:00:04.0 Off |                    0 |
| N/A   44C    P8               9W /  70W |      0MiB / 15360MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   1  Tesla T4                       Off | 00000000:00:05.0 Off |  

# Wandb

In [3]:
!pip install wandb
!wandb login # Rajouter la clé

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


# Custom model

In [4]:
class Conv2D_QuantFunction_Binary(torch.autograd.Function):
    
    @staticmethod
    def forward(ctx, X, weight, Wp, Wn, threshold, stride, padding):
            
        # Max abs weight
        max_abs_weight = torch.max(torch.abs(weight)).cuda()
        normalized_weight = weight / max_abs_weight
        normalized_weight = (weight / max_abs_weight).cuda()
        
        # binary weight 
        threshold = threshold.to(normalized_weight.device)
        mask = torch.where(normalized_weight > threshold, 1, normalized_weight)
        mask = torch.where(mask < threshold, -1, mask).cuda()
        
        quantized_weight = torch.where(mask == 1, Wp, mask)
        quantized_weight = torch.where(mask == -1, Wn, mask).cuda()
        
        # sauvegarde des variables pour le backward
        ctx.save_for_backward(X, 
                              quantized_weight,
                              Wp.clone(),
                              Wn.clone(),
                              mask,
                              stride,
                              padding)
        
        return F.conv2d(X, quantized_weight, stride=stride.int().item(), padding=padding.int().item())

    @staticmethod
    def backward(ctx, grad_out):
        
        # récupération des variables sauvegardées
        X, quantized_weight, Wp, Wn, mask, stride, padding = ctx.saved_tensors
        
        # calcul gradient poids quantifiés et outpu
        grad_input = torch.nn.grad.conv2d_input(X.shape, quantized_weight, grad_out, stride=stride, padding=padding)
        grad_weight = torch.nn.grad.conv2d_weight(X, quantized_weight.shape, grad_out, stride=stride, padding=padding)
        
        # scale le gradient
        grad_weight = torch.where(mask == 1, Wp * grad_weight, grad_weight)
        grad_weight = torch.where(mask == -1, Wn * grad_weight, grad_weight) 
        
        # calcul gradient Wp et Wn
        grad_Wp = torch.sum(torch.where(mask == 1, grad_weight, torch.zeros_like(grad_weight))).unsqueeze(0)
        grad_Wn = torch.sum(torch.where(mask == -1, grad_weight, torch.zeros_like(grad_weight))).unsqueeze(0)
        
        return grad_input, grad_weight, grad_Wp, grad_Wn, grad_Wn, None, None # None car pas de gradient pour le threshold

class Conv2D_QuantFunction_Ternary(torch.autograd.Function):
    
    @staticmethod
    def forward(ctx, X, weight, Wp, Wn, threshold, stride, padding):
        
        # Max abs weight
        max_abs_weight = torch.max(torch.abs(weight))
        normalized_weight = (weight / max_abs_weight)
        
        # Ternary weight
        threshold = threshold.to(normalized_weight.device)
        mask = torch.where((normalized_weight > -threshold) & (normalized_weight <= threshold), 0, normalized_weight)
        mask = torch.where(mask > threshold, 1, mask)
        mask = torch.where(mask < -threshold, -1, mask)
        
        quantized_weight = torch.where(mask == 0, 0, mask)
        quantized_weight = torch.where(mask == 1, Wp, mask)
        quantized_weight = torch.where(mask == -1, Wn, mask)

        # sauvegarde des variables pour le backward
        ctx.save_for_backward(X, 
                              quantized_weight,
                              Wp.clone(),
                              Wn.clone(),
                              mask,
                              stride,
                              padding)
        return F.conv2d(X, quantized_weight, stride=stride.int().item(), padding=padding.int().item())

    @staticmethod
    def backward(ctx, grad_out):
        
        # récupération des variables sauvegardées
        X, quantized_weight, Wp, Wn, mask, stride, padding = ctx.saved_tensors
        
        # calcul gradient poids quantifiés et outpu
        grad_input = torch.nn.grad.conv2d_input(X.shape, quantized_weight, grad_out, stride=stride, padding=padding)
        grad_weight = torch.nn.grad.conv2d_weight(X, quantized_weight.shape, grad_out, stride=stride, padding=padding)
        
        # scale le gradient
        grad_weight = torch.where(mask == 1, Wp * grad_weight, grad_weight)
        grad_weight = torch.where(mask == -1, Wn * grad_weight, grad_weight) 
        
        # calcul gradient Wp et Wn
        grad_Wp = torch.sum(torch.where(mask == 1, grad_weight, torch.zeros_like(grad_weight))).unsqueeze(0)
        grad_Wn = torch.sum(torch.where(mask == -1, grad_weight, torch.zeros_like(grad_weight))).unsqueeze(0)
        
        size_Wp = torch.sum(torch.where(mask == 1, 1, 0))
        size_Wn = torch.sum(torch.where(mask == -1, 1, 0))
        
        grad_Wp = grad_Wp / size_Wp
        grad_Wn = grad_Wn / size_Wn
        
        return grad_input, grad_weight, grad_Wp, grad_Wn, None, None, None # None car pas de gradient pour le threshold
        
class Conv2D_QuantModule(nn.Module):
    
    def __init__(self, ternary, in_channels, out_channels, kernel_size, stride, padding, threshold=0.05, Wp_init=1, Wn_init=-1):
        super().__init__()
        # initialisation des poids et threshold
        self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size))
        nn.init.xavier_normal_(self.weight)
        
        # Initialiser Wp avec la moyenne des poids positifs et Wn avec la moyenne des poids négatifs
        self.Wn = nn.Parameter(torch.tensor([Wn_init]).float())
        self.Wp = nn.Parameter(torch.tensor([Wp_init]).float())
        
        # on doit manuellement passer tensor sur cuda, TODO rjaouter device dans Conv2D_QuantModule
        self.threshold = torch.tensor([threshold], requires_grad=False).cuda()
        self.stride = torch.tensor([stride]).float().cuda()
        self.padding = torch.tensor([padding]).float().cuda()

        self.function = Conv2D_QuantFunction_Ternary() if ternary else Conv2D_QuantFunction_Binary()
            
    def forward(self, X):
        output =  self.function.apply(X, self.weight, self.Wp, self.Wn, self.threshold, self.stride, self.padding)
        return output

In [5]:
import torch.nn as nn
import torch.nn.functional as F

class Residual(nn.Module):
    def __init__(self, ternary, quant, in_channels, out_channels, increase_dim=False, first=False):
        super(Residual, self).__init__()
        self.first = first
        self.increase_dim = increase_dim

        if increase_dim:
            out_channels = in_channels * 2
            stride1 = 2
        else:
            out_channels = in_channels
            stride1 = 1

        self.bn1 = nn.BatchNorm2d(in_channels)
        self.conv1 = Conv2D_QuantModule(ternary, in_channels, out_channels, kernel_size=3, stride=stride1, padding=1) if quant else nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.conv2 = Conv2D_QuantModule(ternary, out_channels, out_channels, kernel_size=3, stride=1, padding=1) if quant else nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)

        if increase_dim:
            self.avgpool = nn.AvgPool2d(kernel_size=2, stride=2)
        
        self.in_channels = in_channels
        
    def forward(self, x):
        if not self.first:
            x = F.relu(self.bn1(x))
        out = self.conv1(x)
        out = F.relu(self.bn2(out))
        out = self.conv2(out)

        if self.increase_dim:
            x = self.avgpool(x)
            x = F.pad(x, (0, 0, 0, 0, self.in_channels // 2, self.in_channels // 2))

        out += x
        return out

class Model(nn.Module):
    def __init__(self, n, quant, ternary):
        super(Model, self).__init__()
        self.n = n
        
        # jamais quantifier le premier conv 
        self.conv0 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn0 = nn.BatchNorm2d(16)
        
        self.res1 = self._make_layer(ternary, quant, 16, 16, n, first=True)
        self.res2 = self._make_layer(ternary, quant, 16, 32, n, increase_dim=True)
        self.res3 = self._make_layer(ternary, quant, 32, 64, n, increase_dim=True)
        self.bnlast = nn.BatchNorm2d(64)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(64, 10)

    def _make_layer(self, ternary, quant, input_channels, output_channels, num_blocks, increase_dim=False, first=False):
        layers = []
        layers.append(Residual(ternary, quant, input_channels, output_channels, increase_dim=increase_dim, first=first))
        for _ in range(1, num_blocks):
            layers.append(Residual(ternary, quant, output_channels, output_channels))
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv0(x)
        x = F.relu(self.bn0(x))
        
        x = self.res1(x)
        x = self.res2(x)
        x = self.res3(x)
        x = F.relu(self.bnlast(x))
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

In [6]:
# Fonction pour charger et prétraiter les données CIFAR-10
def load_cifar10(batch_size):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),  # Redimensionner les images pour ResNet-18
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalisation des images
    ])

    trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)

    testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)

    return trainloader, testloader

# Fonction pour entraîner le modèle
def train(model, trainloader, testloader, criterion, optimizer, num_epochs, device='cuda', quant=False):
    
    device = next(model.parameters()).device
    
    # Learning rate decay 
    scheduler = StepLR(optimizer, step_size=20, gamma=0.1)  
    
    for epoch in range(num_epochs):
        running_loss = 0.0
        
        model.train()
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data[0].to(device), data[1].to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            if i % 100 == 99:    # Afficher la perte tous les 100 mini-lots
                print('[%d, %5d] loss: %.3f' %
                      (epoch + 1, i + 1, running_loss / 100))
                wandb.log({"Training Loss": running_loss / 100})  # Enregistrer la perte
                running_loss = 0.0
                
                if quant:
                    #wandb.log({"Wn": model.module.layer1.layer[0].conv1.Wn.item(), "Wp": model.module.layer1.layer[0].conv1.Wp.item()})
                    wandb.log({"l1.0.conv1.Wn": model.module.res1[0].conv1.Wn.item(), "l1.0.conv1.Wp": model.module.res1[0].conv1.Wp.item()})
                    wandb.log({"l3.2.conv2.Wn": model.module.res3[1].conv2.Wn.item(), "l3.2.conv2.Wp": model.module.res3[1].conv2.Wp.item()})
        
        model.eval()
        train_accuracy = evaluate_accuracy(model, trainloader)
        test_accuracy = evaluate_accuracy(model, testloader)
        wandb.log({"Train accuracy": train_accuracy, "Test accuracy": test_accuracy})
        
        scheduler.step()  
        
    print('Finished Training')

# Fonction pour calculer l'accuracy
def evaluate_accuracy(model, testloader):
    # Get the device where the model resides
    device = next(model.parameters()).device

    # Iterate through the test data and perform inference
    total_correct = 0
    total_samples = 0

    with torch.no_grad():
        for inputs, labels in testloader:
            # Move data to the same device as the model
            inputs = inputs.to(device)
            labels = labels.to(device)

            # Perform inference
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)

            # Update counts
            total_samples += labels.size(0)
            total_correct += (predicted == labels).sum().item()

    # Calculate accuracy
    accuracy = total_correct / total_samples

    print('Accuracy on the test set: {:.2f}%'.format(accuracy * 100))
    return accuracy

In [None]:
import wandb
from torch.optim.lr_scheduler import StepLR

torch.cuda.empty_cache()

config = {
    "batch_size": 128,
    "num_epochs": 20,
    "learning_rate": 0.1,
    "momentum": 0.9,
    "n": 5,  ,             # 3 -> Resnet20, 5 -> Resnet32, 7 -> Resnet42
    "quant": True,         # False if we need to train without quantization
    "ternary": True        # only if config.quant = True, True for ternary quantization, False for binary quantization
}

run = wandb.init(
    # Set the project where this run will be logged
    project="quantization-aware-training",
    group="quant-Resnet-training",
    config=config
)

# Charger les données CIFAR-10
trainloader, testloader = load_cifar10(batch_size=config["batch_size"])

# Charger le modèle ResNet-18 
model = Model(n=config["n"], quant=config["quant"], ternary=config["ternary"])

if config["quant"]:
    weights = torch.load('/kaggle/input/resnet32/pytorch/acc80/1/resnet_32_accuracy_80.pth', map_location='cuda')

    # ajouter les Wn, Wp
    for k,v in model.state_dict().items():
        if "Wn" in k or "Wp" in k:
            weights[k] = v

    # Load the weights into the model
    model.load_state_dict(weights)

# Activer la surveillance WandB pour le modèle
wandb.watch(model, log='all', log_freq=10)

# Définir la fonction de perte et l'optimiseur
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=config["learning_rate"], momentum=config["momentum"], weight_decay=0.0002)

# Mettre sur plusieurs GPU
if torch.cuda.device_count() == 1:
    print("Using", torch.cuda.device_count(), "GPUs!")
    model.cuda()
elif torch.cuda.device_count() == 2:
    print("Using", torch.cuda.device_count(), "GPUs!")
    model.to('cuda:0')
    model = nn.DataParallel(model)  # Utiliser DataParallel pour utiliser plusieurs GPU
else:
    print("Using cpu")
    
# Entraîner le modèle
model.eval()
train_accuracy = evaluate_accuracy(model, trainloader)
test_accuracy = evaluate_accuracy(model, testloader)
wandb.log({"Train accuracy": train_accuracy, "Test accuracy": test_accuracy})

train(model, trainloader, testloader, criterion, optimizer, num_epochs=config["num_epochs"], quant=config["quant"])
wandb.finish()

[34m[1mwandb[0m: Currently logged in as: [33md-zhu66[0m. Use [1m`wandb login --relogin`[0m to force relogin


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:02<00:00, 71761632.37it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified
Using 2 GPUs!
Accuracy on the test set: 10.00%
Accuracy on the test set: 10.00%
[1,   100] loss: 2.137
[1,   200] loss: 1.659
[1,   300] loss: 1.472
Accuracy on the test set: 36.81%
Accuracy on the test set: 36.56%
[2,   100] loss: 1.267
[2,   200] loss: 1.201
[2,   300] loss: 1.135
Accuracy on the test set: 59.28%
Accuracy on the test set: 58.24%
[3,   100] loss: 0.986
[3,   200] loss: 0.965
[3,   300] loss: 0.942
Accuracy on the test set: 62.51%
Accuracy on the test set: 61.39%
[4,   100] loss: 0.877
[4,   200] loss: 0.882
[4,   300] loss: 0.866
Accuracy on the test set: 53.87%
Accuracy on the test set: 52.88%
[5,   100] loss: 0.828
[5,   200] loss: 0.826
[5,   300] loss: 0.812
Accuracy on the test set: 64.72%
Accuracy on the test set: 63.52%


In [None]:
train_accuracy = evaluate_accuracy(quant_resnet18, trainloader)
test_accuracy = evaluate_accuracy(quant_resnet18, testloader)

In [None]:
torch.save(quant_resnet18.state_dict(), '/kaggle/working/resnet_32_accuracy_80.pth')