In [1]:
import os 
import time
import warnings

import numpy as np
from sklearn.cluster import KMeans

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as T
from torch.utils.data import DataLoader
from torchsummary import summary

warnings.filterwarnings("ignore")

In [2]:
EPOCHS = 20
BATCH_SIZE = 128
LEARNING_RATE = 0.001

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Load data

In [3]:
train_transforms = T.Compose([
    T.RandomCrop(32, padding=4),
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    T.Normalize(mean=[0.4914, 0.4822, 0.4465],
                std=[0.2023, 0.1994, 0.2010])
])

val_transforms = T.Compose([
    T.ToTensor(),
    T.Normalize(mean=[0.4914, 0.4822, 0.4465],
                std=[0.2023, 0.1994, 0.2010])
])

In [4]:
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transforms)
val_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=val_transforms)

train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=8)
val_loader = DataLoader(dataset=val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=8)

Files already downloaded and verified
Files already downloaded and verified


# Init model

In [5]:
class CNN(nn.Module):
    def __init__(self, num_classes=10):
        super(CNN, self).__init__()
        self.block1 = self._make_conv_block(in_ch=3,   out_ch=64)
        self.block2 = self._make_conv_block(in_ch=64,  out_ch=128)
        self.block3 = self._make_conv_block(in_ch=128, out_ch=256)
        self.block4 = self._make_conv_block(in_ch=256, out_ch=512)

        self.classifier = nn.Sequential(
            nn.Linear(512 * 2 * 2, 2048),  # из 512x2x2 -> 2048
            nn.ReLU(inplace=True),
            nn.Linear(2048, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, num_classes)
        )

    def _make_conv_block(self, in_ch, out_ch):
        block = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),

            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),

            nn.MaxPool2d(kernel_size=2)
        )
        return block

    def forward(self, x):
        x = self.block1(x) 
        x = self.block2(x)  
        x = self.block3(x)  
        x = self.block4(x) 

        x = x.view(x.size(0), -1)  
        x = self.classifier(x)
        return x

model = CNN(num_classes=10)
model.load_state_dict(torch.load('./model.pt', weights_only=True))

_ = summary(model, input_size=(BATCH_SIZE, 3, 32, 32), device=DEVICE, depth=4)

Layer (type:depth-idx)                   Param #
├─Sequential: 1-1                        --
|    └─Conv2d: 2-1                       1,728
|    └─BatchNorm2d: 2-2                  128
|    └─ReLU: 2-3                         --
|    └─Conv2d: 2-4                       36,864
|    └─BatchNorm2d: 2-5                  128
|    └─ReLU: 2-6                         --
|    └─MaxPool2d: 2-7                    --
├─Sequential: 1-2                        --
|    └─Conv2d: 2-8                       73,728
|    └─BatchNorm2d: 2-9                  256
|    └─ReLU: 2-10                        --
|    └─Conv2d: 2-11                      147,456
|    └─BatchNorm2d: 2-12                 256
|    └─ReLU: 2-13                        --
|    └─MaxPool2d: 2-14                   --
├─Sequential: 1-3                        --
|    └─Conv2d: 2-15                      294,912
|    └─BatchNorm2d: 2-16                 512
|    └─ReLU: 2-17                        --
|    └─Conv2d: 2-18                      589,

In [6]:
def cluster_and_save(model, num_clusters=16, outfile="compressed_model.npz"):
    compressed_dict = {}
    device = next(model.parameters()).device 

    with torch.no_grad():
        for name, param in model.named_parameters():
            w_np = param.cpu().numpy() 
            shape_ = w_np.shape
            w_flat = w_np.ravel().reshape(-1, 1)  

            effective_clusters = min(num_clusters, w_flat.shape[0])

            kmeans = KMeans(n_clusters=effective_clusters, n_init=5, random_state=42)
            kmeans.fit(w_flat)

            centroids = kmeans.cluster_centers_.flatten() 
            labels = kmeans.labels_  

            compressed_dict[f"{name}_shape"] = shape_
            compressed_dict[f"{name}_centroids"] = centroids.astype(np.float32)
            compressed_dict[f"{name}_labels"] = labels.astype(np.int32)
    np.savez_compressed(outfile, **compressed_dict)

In [7]:
cluster_and_save(model, 16)
os.path.getsize("compressed_model.npz")

7347355

In [8]:
def load_clustered_model(model_class, infile="compressed_model.npz", num_classes=10):
    data = np.load(infile)
    model = model_class(num_classes=num_classes)

    state_dict = model.state_dict()  

    for name, param in model.named_parameters():
        shape_ = data[f"{name}_shape"]
        centroids = data[f"{name}_centroids"]
        labels = data[f"{name}_labels"]

        w_rec = centroids[labels] 
        w_rec = w_rec.reshape(shape_)

        state_dict[name] = torch.from_numpy(w_rec)

    model.load_state_dict(state_dict)
    return model

In [9]:
model = load_clustered_model(CNN, infile="compressed_model.npz")

# Eval

In [10]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

params = count_parameters(model)
print("Params:", params)

Params: 10992074


In [11]:
inp = torch.randn(1, 3, 32, 32)

num_samples = 100
start_time = time.time()
for _ in range(num_samples):
    output = model(inp)
end_time = time.time()

infer_time = ((end_time - start_time) / num_samples) * 1000
print(f'CPU Avg inference time: {infer_time:.4f} ms')

CPU Avg inference time: 2.5893 ms


In [12]:
num_samples = 100

model.to(DEVICE)
start_time = time.time()
for _ in range(num_samples):
    output = model(inp.to(DEVICE))
end_time = time.time()

infer_time = ((end_time - start_time) / num_samples) * 1000
print(f'GPU Avg inference time: {infer_time:.4f} ms')

GPU Avg inference time: 1.1630 ms


In [13]:
correct = 0
total = 0
with torch.no_grad():
    for X_, y_ in val_loader:
        X_, y_ = X_.to(DEVICE), y_.to(DEVICE)
        outputs = model(X_)
        _, predicted = torch.max(outputs, 1)
        correct += (predicted == y_).sum().item()
        total += y_.size(0)

accuracy = 100.0 * correct / total
print("Accuracy:", accuracy)

Accuracy: 86.84
