# テーマE：量子化と枝刈りの同時実行時のビット幅とスパーシティの最適比


## モジュールの読み込み

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.nn.init as init
import numpy as np
import matplotlib.pyplot as plt
import torchvision
import torchvision.transforms as transforms
from torchvision import datasets
from torch.utils.data import DataLoader
from torch.autograd import Function

## MNISTのデータセット/精度評価関数の作成

In [2]:
# 実行デバイスの設定
device = 'cuda:2'

# 普通のtransform
transform_normal = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# テストデータには普通のtransformを使ってください
transform_for_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform_normal) # モデルの学習に使うデータセット
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform_for_test) # モデルの評価に使うデータセット
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)

def compute_accuracy(model, test_loader, device='cuda:0'):
    model.eval()  # 評価モード
    model.to(device)
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            outputs = model(images.to(device))
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels.to(device)).sum().item()
    accuracy = 100 * correct / total
    print(f'Accuracy: {accuracy:.2f}%')
    return accuracy

def train(model, lr=0.05, epochs=5, device='cuda:0'):
    # 損失関数と最適化手法の定義
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=lr)
    model.to(device)
    for epoch in range(epochs):
        loss_sum = 0
        for images, labels in train_loader:
            # モデルの予測
            outputs = model(images.to(device))

            # 損失の計算
            loss = criterion(outputs, labels.to(device))
            loss_sum += loss.item()

            # 勾配の初期化
            optimizer.zero_grad()

            # バックプロパゲーション
            loss.backward()

            # オプティマイザの更新
            optimizer.step()

        # 損失を表示
        print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss_sum/len(train_loader):.4f}')
    return model




## 通常モデルの学習

In [4]:
class SimpleModel(nn.Module):
    def __init__(self): # モデルのセットアップ
        super(SimpleModel, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x): # モデルが行う処理
        x = x.view(-1, 28 * 28)  # 28x28の画像を１次元に変換
        x = self.fc1(x) 
        x = nn.ReLU()(x) # 活性化関数
        x = self.fc2(x) 
        x = nn.ReLU()(x) # 活性化関数
        x = self.fc3(x) 
        return x

# モデルのインスタンスを作成
model = SimpleModel().to(device)

model = train(model, lr=0.1, epochs=5, device=device)

Epoch [1/5], Loss: 0.4596
Epoch [2/5], Loss: 0.1787
Epoch [3/5], Loss: 0.1324
Epoch [4/5], Loss: 0.1033
Epoch [5/5], Loss: 0.0859


精度の確認

In [5]:
accuracy = compute_accuracy(model, test_loader, device=device)

Accuracy: 95.40%


## スカラー量子化（一様対称量子化）の実行

###  プロセス：量子化層に変換-->量子化認識学習

ここでは簡便に量子化パラメータをmin-maxスケーリングで決定する
対称量子化なので、行列Xの最大値と最小値の差の２分の1を$p$-bitの数値範囲の最大値$q_{max}$でわる

$q_{max} = 2^{(p-1)} - 1$

$s = \frac{max(X) - min(X)}{2q_{max}}$

$X_{q} = s * \text{clip}(\text{round}(\frac{X}{s}), -q_{max}, q_{max})$

In [6]:

class SymQuantSTE(Function):
    @staticmethod
    def forward(ctx, input: torch.Tensor, scale: torch.Tensor, num_bits: int):
        if num_bits == 1:
            s = scale.abs()
            output = s * torch.sgn(input)
        else:
            s = scale.abs().clamp_min(1e-8)
            qmax = 2 ** (num_bits - 1) - 1
            q = torch.clamp(torch.round(input / s), -qmax, qmax)
            output = q * s

        # backward用に保存
        ctx.save_for_backward(input, s)
        ctx.num_bits = num_bits
        return output

    @staticmethod
    def backward(ctx, grad_output):
        input, s = ctx.saved_tensors   # forwardでsaveしたものを正しく取り出す
        num_bits = ctx.num_bits
        if num_bits == 1:
            grad_input = torch.clamp(grad_output, -1, 1)
        else:
            qmax = 2 ** (num_bits - 1) - 1
            mask = (input.abs() <= qmax * s).to(grad_output.dtype)
            grad_input = grad_output * mask

        return grad_input, None, None




class SymQuantLinear(nn.Linear):
    def __init__(self, in_features, out_features, bias=True, weight_bits=8, act_bits=None):
        super().__init__(in_features, out_features, bias)
        self.weight_bits = weight_bits
        self.act_bits = act_bits

    def forward(self, input):
        # weight のスケール
        if self.weight_bits == 1:
            weight_scale = self.weight.abs().sum() / self.weight.numel()
        else:
            qmax_w = 2 ** (self.weight_bits - 1) - 1
            weight_scale = (self.weight.max() - self.weight.min()) / (2 * qmax_w)

        # activation のスケール
        if self.act_bits is not None:
            if self.act_bits == 1:
                act_scale = input.abs().sum() / input.numel()
            else:
                qmax_a = 2 ** (self.act_bits - 1) - 1
                act_scale = (input.max() - input.min()) / (2 * qmax_a)
            input = SymQuantSTE.apply(input, act_scale, self.act_bits)

        # quantized weight
        w_q = SymQuantSTE.apply(self.weight, weight_scale, self.weight_bits)

        return F.linear(input, w_q, self.bias)



def replace_linear_with_quantizedlinear(module, weight_bits=8, act_bits=None):
    for name, child in module.named_children():
        # すでに QuantizedLinear ならスキップ
        if isinstance(child, SymQuantLinear):
            continue
        if isinstance(child, nn.Linear):
            qlinear = SymQuantLinear(
                child.in_features,
                child.out_features,
                bias=(child.bias is not None),
                weight_bits=weight_bits,
                act_bits=act_bits
            )
            # 重みとバイアスをコピー
            qlinear.weight.data.copy_(child.weight.data)
            if child.bias is not None:
                qlinear.bias.data.copy_(child.bias.data)
            setattr(module, name, qlinear)
        else:
            replace_linear_with_quantizedlinear(child, weight_bits, act_bits)
    return module

# モデルのインスタンスを作成
model = SimpleModel().to(device)
# 通常学習
print('warming up by no-quantized training...')
model = train(model, lr=0.1, epochs=5, device=device)
# Linear層をQuantizedLinearに置換
model_q = replace_linear_with_quantizedlinear(model, weight_bits=4, act_bits=4)
print('quantization aware training...')
model_q = train(model_q, lr=1e-3, epochs=5, device=device)
accuracy = compute_accuracy(model_q, test_loader)

warming up by no-quantized training...
Epoch [1/5], Loss: 0.4444
Epoch [2/5], Loss: 0.1771
Epoch [3/5], Loss: 0.1291
Epoch [4/5], Loss: 0.1031
Epoch [5/5], Loss: 0.0871
quantization aware training...
Epoch [1/5], Loss: 0.0786
Epoch [2/5], Loss: 0.0643
Epoch [3/5], Loss: 0.0602
Epoch [4/5], Loss: 0.0582
Epoch [5/5], Loss: 0.0575
Accuracy: 97.43%


## 枝刈りの実行

### 実行プロセス：学習過程で重みの絶対値をスコアとして、スコアが小さいtop-kを0にする(枝刈り)

In [7]:
class PrunedLinear(nn.Linear):
    def __init__(self, in_features, out_features, bias=True, prune_ratio=0.2):
        """
        prune_ratio: 枝刈りする割合 (0.2 なら全重みの 20% を 0 にする)
        """
        super().__init__(in_features, out_features, bias)
        assert 0.0 <= prune_ratio < 1.0
        self.prune_ratio = prune_ratio

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        # weight をコピー
        W = self.weight

        if self.prune_ratio > 0:
            # |W| を平坦化
            flat = W.abs().flatten()
            k = int(self.prune_ratio * flat.numel())
            if k > 0:
                # top-k 小さい要素の閾値を求める
                threshold = torch.kthvalue(flat, k).values
                mask = (W.abs() > threshold).float()
                W = W * mask  # 枝刈りした重み
        # 通常の線形変換
        return F.linear(input, W, self.bias)
    
def replace_linear_with_prunedlinear(module, prune_ratio=0.2):
    """
    モデル内の nn.Linear を PrunedLinear に置き換える
    prune_ratio: 枝刈りする割合 (0.2 なら全重みの20%をゼロ化)
    """
    for name, child in module.named_children():
        if isinstance(child, PrunedLinear):
            continue
        if isinstance(child, nn.Linear):
            plinear = PrunedLinear(
                child.in_features,
                child.out_features,
                bias=(child.bias is not None),
                prune_ratio=prune_ratio
            )
            # 重みとバイアスをコピー
            plinear.weight.data.copy_(child.weight.data)
            if child.bias is not None:
                plinear.bias.data.copy_(child.bias.data)
            setattr(module, name, plinear)
        else:
            replace_linear_with_prunedlinear(child, prune_ratio=prune_ratio)
    return module


# モデルのインスタンスを作成
model = SimpleModel().to(device)
# 通常学習
print('warming up by no-quantized training...')
model = train(model, lr=0.1, epochs=5, device=device)
# Linear層をQuantizedLinearに置換
model_p = replace_linear_with_prunedlinear(model, prune_ratio=0.8)
print('pruning...')
model_p = train(model_p, lr=1e-3, epochs=5, device=device)
accuracy = compute_accuracy(model_p, test_loader)


warming up by no-quantized training...
Epoch [1/5], Loss: 0.4493
Epoch [2/5], Loss: 0.1750
Epoch [3/5], Loss: 0.1285
Epoch [4/5], Loss: 0.1021
Epoch [5/5], Loss: 0.0830
pruning...
Epoch [1/5], Loss: 0.2021
Epoch [2/5], Loss: 0.1361
Epoch [3/5], Loss: 0.1230
Epoch [4/5], Loss: 0.1138
Epoch [5/5], Loss: 0.1078
Accuracy: 96.65%


## 枝刈りと量子化の同時実行

In [None]:


class SymQuantSTE(Function):
    @staticmethod
    def forward(ctx, input: torch.Tensor, scale: torch.Tensor, num_bits: int):
        if num_bits == 1:
            s = scale.abs()
            output = s * torch.sign(input)
        else:
            s = scale.abs().clamp_min(1e-8)
            qmax = 2 ** (num_bits - 1) - 1
            q = torch.clamp(torch.round(input / s), -qmax, qmax)
            output = q * s

        # backward用に保存
        ctx.save_for_backward(input, s)
        ctx.num_bits = num_bits
        return output

    @staticmethod
    def backward(ctx, grad_output):
        input, s = ctx.saved_tensors
        num_bits = ctx.num_bits
        if num_bits == 1:
            grad_input = torch.clamp(grad_output, -1, 1)
        else:
            qmax = 2 ** (num_bits - 1) - 1
            mask = (input.abs() <= qmax * s).to(grad_output.dtype)
            grad_input = grad_output * mask
        return grad_input, None, None


class PrunedQuantLinear(nn.Linear):
    """
    枝刈り + 量子化付き Linear 層
    """
    def __init__(self, in_features, out_features, bias=True,
                 weight_bits=8, act_bits=None, prune_ratio=0.0):
        super().__init__(in_features, out_features, bias)
        self.weight_bits = weight_bits
        self.act_bits = act_bits
        self.prune_ratio = prune_ratio  # 0.0～1.0の割合

    def forward(self, input):
        # --- 枝刈り ---
        if self.prune_ratio > 0:
            # 重みの絶対値に基づいて閾値を決める
            k = int(self.weight.numel() * self.prune_ratio)
            if k > 0:
                threshold = torch.topk(self.weight.abs().flatten(), k, largest=False).values.max()
                mask = (self.weight.abs() > threshold).to(self.weight.dtype)
                pruned_weight = self.weight * mask
            else:
                pruned_weight = self.weight
        else:
            pruned_weight = self.weight

        # --- weight のスケール ---
        if self.weight_bits == 1:
            weight_scale = pruned_weight.abs().sum() / pruned_weight.numel()
        else:
            qmax_w = 2 ** (self.weight_bits - 1) - 1
            weight_scale = (pruned_weight.max() - pruned_weight.min()) / (2 * qmax_w)

        # --- activation のスケール ---
        if self.act_bits is not None:
            if self.act_bits == 1:
                act_scale = input.abs().sum() / input.numel()
            else:
                qmax_a = 2 ** (self.act_bits - 1) - 1
                act_scale = (input.max() - input.min()) / (2 * qmax_a)
            input = SymQuantSTE.apply(input, act_scale, self.act_bits)

        # --- quantized weight ---
        w_q = SymQuantSTE.apply(pruned_weight, weight_scale, self.weight_bits)

        return F.linear(input, w_q, self.bias)


def replace_linear_with_prunedquantlinear(module, weight_bits=8, act_bits=None, prune_ratio=0.2):
    """
    モデル内の nn.Linear を PrunedQuantLinear に置き換える
    weight_bits: 重みの量子化ビット数
    act_bits: 活性化の量子化ビット数 (Noneなら非量子化)
    prune_ratio: 枝刈り割合 (0.2 なら全重みの20%をゼロ化)
    """
    for name, child in module.named_children():
        # すでに PrunedQuantLinear ならスキップ
        if isinstance(child, PrunedQuantLinear):
            continue
        if isinstance(child, nn.Linear):
            pqlinear = PrunedQuantLinear(
                child.in_features,
                child.out_features,
                bias=(child.bias is not None),
                weight_bits=weight_bits,
                act_bits=act_bits,
                prune_ratio=prune_ratio
            )
            # 重みとバイアスをコピー
            pqlinear.weight.data.copy_(child.weight.data)
            if child.bias is not None:
                pqlinear.bias.data.copy_(child.bias.data)
            setattr(module, name, pqlinear)
        else:
            # 再帰的に探索
            replace_linear_with_prunedquantlinear(child, weight_bits=weight_bits, act_bits=act_bits, prune_ratio=prune_ratio)
    return module


# モデルのインスタンスを作成
model = SimpleModel().to(device)
# 通常学習
print('warming up by no-quantized training...')
model = train(model, lr=0.1, epochs=5, device=device)
# Linear層をQuantizedLinearに置換
model_pq = replace_linear_with_prunedquantlinear(model, weight_bits=1, act_bits=8, prune_ratio=0.5)
print('pruning...')
model_pq = train(model_pq, lr=1e-3, epochs=5, device=device)
accuracy = compute_accuracy(model_pq, test_loader)

## 課題
### ・アクティベーションのビット幅・重みのビット幅をそれぞれ変化させた時の精度変化を観察し、感度の違いを評価する
### ・低ビットでもより高精度を達成するためには量子化パラメータをどのように設定するといいかを考察・実験してみる