In [7]:
%load_ext autoreload
%autoreload 2

import os

print("original dir: ", os.getcwd())

if os.getcwd().endswith("NewMethod"):
    new_path = "../"
    os.chdir(new_path)
    print("changed dir: ", os.getcwd())
    
import torch
import torch.nn
from torch import optim
from torch.nn import CrossEntropyLoss
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

from tqdm import tqdm


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
original dir:  d:\我的\大学\3春\学业\创新实践\repo\Nonlinear-Erasure-Code\src


In [8]:
import datetime

TASK_CONFIG = {
    "TASK": "CIFAR10",  # ARGS
    "DATE": datetime.datetime.now().strftime("%Y_%m_%d"),
    "MODEL": "LeNet9",
}

读取数据集

In [14]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# 设置数据转换
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)

# 设置数据集（训练集与测试集合）

"""
MNIST:
image: (1, 28, 28), label: (0-9)

FashionMNIST:
image: (1, 28, 28), label: (0-9)

CIFAR10:
image: (3, 32, 32), label: (0-9)
"""

print(f"当前任务为 {TASK_CONFIG['TASK']}")

# ARGS

# train_dataset = datasets.MNIST(
#     root="./data", train=True, download=True, transform=transform
# )
# train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
# test_dataset = datasets.MNIST(
#     root="./data", train=False, download=True, transform=transform
# )
# test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# train_dataset = datasets.FashionMNIST(
#     root="./data", train=True, download=True, transform=transform
# )
# train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
# test_dataset = datasets.FashionMNIST(
#     root="./data", train=False, download=True, transform=transform
# )
# test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

train_dataset = datasets.CIFAR10(
    root="./data", train=True, download=True, transform=transform,
)
test_dataset = datasets.CIFAR10(
    root="./data", train=False, download=True, transform=transform,
)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

print("Data is ready!")

当前任务为 CIFAR10
Files already downloaded and verified
Files already downloaded and verified
Data is ready!


设置部分参数

In [15]:
# ARGS
K = 2
R = 1
N = K + R
original_data_shape = tuple(train_dataset[0][0].shape)
num_classes = 10
print(f"K: {K}")
print(f"R: {R}")
print(f"N: {N}")
print(f"data_shape: {original_data_shape}")
print(f"num_classes: {num_classes}")

K: 2
R: 1
N: 3
data_shape: (3, 32, 32)
num_classes: 10


定义 base model

In [16]:
import torch

from base_model.MyModel1 import MyModel1
from base_model.LeNet5 import LeNet5
from base_model.LeNet9 import LeNet9

# 引入 base model, 该model将在后续全部过程中使用
# ResNet
assert TASK_CONFIG["MODEL"] == "LeNet9"
model = LeNet9(input_dim=original_data_shape, num_classes=num_classes)

fc_input_size: 4096


In [17]:
# 读取模型
base_model_path = (
    f"./base_model/{TASK_CONFIG['MODEL']}/{TASK_CONFIG['TASK']}/model.pth"
)
print(f"base_model_path: {base_model_path}")

model.load_state_dict(torch.load(base_model_path, map_location=device))
conv_segment = model.get_conv_segment()
fc_segment = model.get_fc_segment()
model.to(device)
model.eval()

print("Model is ready!")

base_model_path: ./base_model/LeNet9/CIFAR10/model.pth
Model is ready!


验证 base model 准确率

In [18]:
# 测试循环
model.eval()  # 设置模型为评估模式

correct = 0
total = 0
with torch.no_grad():  # 在评估过程中不计算梯度
    for data, target in train_loader:
        # 将数据移动到设备上
        data, target = data.to(device), target.to(device)

        output = model(data)
        _, predicted = torch.max(output.data, 1)
        total += target.size(0)
        correct += (predicted == target).sum().item()

print(f"训练集-> 总量: {total}, 正确数量: {correct}, 准确率: {100 * correct / total}%")

correct = 0
total = 0
with torch.no_grad():  # 在评估过程中不计算梯度
    for data, target in test_loader:
        # 将数据移动到设备上
        data, target = data.to(device), target.to(device)

        output = model(data)
        _, predicted = torch.max(output.data, 1)
        total += target.size(0)
        correct += (predicted == target).sum().item()

print(f"测试集-> 总量: {total}, 正确数量: {correct}, 准确率: {100 * correct / total}%")

训练集-> 总量: 50000, 正确数量: 43124, 准确率: 86.248%
测试集-> 总量: 10000, 正确数量: 7432, 准确率: 74.32%


验证 model 和 conv_segment, fc_segment 的输出是否一致

In [19]:
x = torch.randn(1, *original_data_shape).to(device)
model.to(device)
y = model(x)
print(y.data)

z = conv_segment(x)
z = z.flatten(1)
z = fc_segment(z)
print(z.data)
print(torch.allclose(y, z))

tensor([[ 1.0358,  1.7525, -2.2849, -2.7033, -2.7184, -4.0310, -0.6132, -3.9575,
          0.8465,  0.6749]], device='cuda:0')
tensor([[ 1.0358,  1.7525, -2.2849, -2.7033, -2.7184, -4.0310, -0.6132, -3.9575,
          0.8465,  0.6749]], device='cuda:0')
True


设置另一部分参数

In [20]:
from util.util import cal_input_shape


conv_output_shape = model.calculate_conv_output(input_dim=original_data_shape)
print(f"conv_output_shape: {conv_output_shape}")
assert conv_output_shape[2] % K == 0

split_conv_output_shape = (
    conv_output_shape[0],
    conv_output_shape[1],
    conv_output_shape[2] // K,
)
print(f"split_conv_output_shape: {split_conv_output_shape}")

conv_segment.to('cpu')
conv_segment.train()
split_data_range = cal_input_shape(
    model=conv_segment,
    original_input_shape=original_data_shape,
    original_output_shape=conv_output_shape,
    split_num=K,
)
print(f"split_data_range: {split_data_range}")

# print(conv_segment)
print(
    f"split_conv_output_data_shape from split_data_shape: {[tuple(conv_segment(torch.randn(1, _[0], _[1], _[3] - _[2])).shape) for _ in split_data_range]}"
)

split_data_shapes = [
    (
        _[0],
        _[1],
        _[3] - _[2],
    )
    for _ in split_data_range
]
print(f"split_data_shapes: {split_data_shapes}")

split_data_shape = split_data_shapes[0]
print(f"choose the first one as the split_data_shape: {split_data_shape}")

conv_output_shape: (256, 4, 4)
split_conv_output_shape: (256, 4, 2)
split_data_range: [(3, 32, 0, 27), (3, 32, 8, 32)]
split_conv_output_data_shape from split_data_shape: [(1, 256, 4, 2), (1, 256, 4, 2)]
split_data_shapes: [(3, 32, 27), (3, 32, 24)]
choose the first one as the split_data_shape: (3, 32, 27)


验证分割后的输入，能够恰好恢复出原始输出

In [21]:
x = torch.randn(1, *original_data_shape).to(device)
conv_segment.to(device)
y = conv_segment(x)
print(f"y.shape: {y.shape}")

x_split = [x[:, :, :, _[2]:_[3]] for _ in split_data_range]
y_split = [conv_segment(_x) for _x in x_split]
print(f"y_split.shape: {[tuple(_y.shape) for _y in y_split]}")

y_hat = torch.cat(y_split, dim=3)
print(f"y_hat.shape: {y_hat.shape}")

# |A-B| <= atol + rtol * |B|
print(f"y和y_hat是否相等: {torch.allclose(y_hat, y, rtol=1e-08, atol=1e-05)}")

diff = torch.abs(y_hat - y)
epsilon = 0.0001
print(f"y和y_hat是否相等: {torch.all(diff <= epsilon)}")
# print(torch.allclose(y_split[0], y[:, :, :, 0:5]))
# print(torch.allclose(y_split[1], y[:, :, :, 5:10]))
# print(torch.allclose(y_split[2], y[:, :, :, 10:15]))
# print(torch.allclose(y_split[3], y[:, :, :, 15:20]))

# print(y[0][0][0] == y_hat[0][0][0])
# print(y[0][0][0])
# print(y_hat[0][0][0])
# y = x
# y_split = x_split
# for layer in conv_segment:
#     print(layer)
#     y = layer(y)
#     y_split = [layer(_y) for _y in y_split]
#     print(f"y.shape: {y.shape}")
#     print(f"y_split.shape: {[tuple(_.shape) for _ in y_split]}")
#     print(y[0][0][0])
#     print(y_split[0][0][0][0])
#     print(y_split[1][0][0][0])
#     print(y_split[2][0][0][0])
#     print(y_split[3][0][0][0])

y.shape: torch.Size([1, 256, 4, 4])
y_split.shape: [(1, 256, 4, 2), (1, 256, 4, 2)]
y_hat.shape: torch.Size([1, 256, 4, 4])
y和y_hat是否相等: False
y和y_hat是否相等: False


定义 Encoder Decoder

In [22]:
from encoder.mlp_encoder import MLPEncoder
from encoder.conv_encoder import CatChannelConvEncoder, CatBatchSizeConvEncoder
from decoder.mlp_decoder import MLPDecoder
from decoder.conv_decoder import CatChannelConvDecoder, CatBatchSizeConvDecoder

print(f"split_data_shape: {split_data_shape}")
print(f"split_conv_output_shape: {split_conv_output_shape}")

# ARGS

# encoder = MLPEncoder(num_in=K, num_out=R, in_dim=split_data_shape)
# decoder = MLPDecoder(num_in=N, num_out=K, in_dim=split_conv_output_shape)

encoder = CatChannelConvEncoder(num_in=K, num_out=R, in_dim=split_data_shape)
decoder = CatChannelConvDecoder(
    num_in=N, num_out=K, in_dim=split_conv_output_shape
)

# encoder = CatBatchSizeConvEncoder(num_in=K, num_out=R, in_dim=split_data_shape)
# decoder = CatBatchSizeConvDecoder(num_in=N, num_out=K, in_dim=split_conv_output_data_shape)

split_data_shape: (3, 32, 27)
split_conv_output_shape: (256, 4, 2)


In [23]:
# print(torch.cuda.memory_summary())

In [24]:
def getModelSize(model):
    param_size = 0
    param_sum = 0
    for param in model.parameters():
        param_size += param.nelement() * param.element_size()
        param_sum += param.nelement()
    buffer_size = 0
    buffer_sum = 0
    for buffer in model.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()
        buffer_sum += buffer.nelement()
    all_size = (param_size + buffer_size) / 1024 / 1024
    print("模型总大小为：{:.3f}MB".format(all_size))
    return (param_size, param_sum, buffer_size, buffer_sum, all_size)

getModelSize(encoder)
getModelSize(decoder)
print()

模型总大小为：0.284MB
模型总大小为：2.321MB



训练 Encoder Decoder

In [29]:
epoch_num = 10  # ARGS
print(f"epoch_num: {epoch_num}")

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from tqdm import tqdm

print(f"Train dataset: {len(train_dataset)}")
print("image size: ", train_dataset[0][0].size())

# 定义损失函数
criterion = nn.MSELoss()

# optimizer_encoder = optim.SGD(encoder.parameters(), lr=1e-3, momentum=0.8, weight_decay=1e-5)
# optimizer_decoder = optim.SGD(decoder.parameters(), lr=1e-3, momentum=0.8, weight_decay=1e-5)

optimizer_encoder = optim.Adam(encoder.parameters(), lr=1e-4, weight_decay=1e-6)
optimizer_decoder = optim.Adam(decoder.parameters(), lr=1e-4, weight_decay=1e-6)

model.to(device)
conv_segment.to(device)
fc_segment.to(device)
encoder.to(device)
decoder.to(device)

model.eval()
conv_segment.eval()
fc_segment.eval()
encoder.train()
decoder.train()

model.eval()
for m in model.modules():
    if isinstance(m, nn.BatchNorm2d):
        m.track_running_stats = False

loss_list = [[] for _ in range(epoch_num)]

for epoch in range(epoch_num):
    train_loader_tqdm = tqdm(
        train_loader,
        desc=f"Epoch {epoch+1}/{epoch_num}",
        bar_format="{l_bar}{bar:20}{r_bar}",
    )
    correct = 0
    correct_truth = 0
    total = 0
    for images, labels in train_loader_tqdm:
        images = images.to(device)
        labels = labels.to(device)

        # split image tensor(64, 3, 32, 32) -> [tensor(64, 3, 32, 8) * K]
        images_list = []
        for _1, _2, start, end in split_data_range:
            images_list.append(images[:, :, :, start:end].clone())

        pad = (0, 3, 0, 0)
        images_list[-1] = F.pad(images_list[-1], pad, "constant", value=0)

        ground_truth = conv_segment(images)
        ground_truth = ground_truth.view(ground_truth.size(0), -1)

        # forward
        images_list += encoder(images_list)
        output_list = []
        for i in range(N):
            output = conv_segment(images_list[i])
            output_list.append(output)
        # losed_output_list = lose_something(output_list, self.lose_device_index)
        decoded_output_list = decoder(output_list)
        output = torch.cat(decoded_output_list, dim=3)
        output = output.view(output.size(0), -1)

        loss = criterion(output, ground_truth)

        loss_list[epoch].append(loss.item())

        # backward
        optimizer_encoder.zero_grad()
        optimizer_decoder.zero_grad()
        loss.backward()
        optimizer_encoder.step()
        optimizer_decoder.step()

        # calculate accuracy
        _, predicted = torch.max(fc_segment(output).data, 1)
        _, predicted_truth = torch.max(fc_segment(ground_truth.data), 1)
        # print(predicted)
        # print(predicted_truth)
        # print(labels)
        correct += (predicted == labels).sum().item()
        correct_truth += (predicted_truth == labels).sum().item()
        total += labels.size(0)

        train_loader_tqdm.set_postfix(loss=loss.item())

    print(f"Original Accuracy: {100 * correct_truth / total}%, Train Accuracy: {100 * correct / total}%")


epoch_num: 10
Train dataset: 50000
image size:  torch.Size([3, 32, 32])


Epoch 1/10: 100%|████████████████████| 391/391 [02:55<00:00,  2.22it/s, loss=0.0339]


Original Accuracy: 86.248%, Train Accuracy: 62.364%


Epoch 2/10: 100%|████████████████████| 391/391 [02:55<00:00,  2.22it/s, loss=0.0302]


Original Accuracy: 86.248%, Train Accuracy: 62.866%


Epoch 3/10: 100%|████████████████████| 391/391 [02:56<00:00,  2.22it/s, loss=0.03]  


Original Accuracy: 86.248%, Train Accuracy: 63.35%


Epoch 4/10: 100%|████████████████████| 391/391 [02:56<00:00,  2.22it/s, loss=0.0289]


Original Accuracy: 86.248%, Train Accuracy: 63.922%


Epoch 5/10: 100%|████████████████████| 391/391 [02:56<00:00,  2.22it/s, loss=0.0326]


Original Accuracy: 86.248%, Train Accuracy: 64.51%


Epoch 6/10: 100%|████████████████████| 391/391 [02:56<00:00,  2.21it/s, loss=0.0303]


Original Accuracy: 86.248%, Train Accuracy: 64.952%


Epoch 7/10: 100%|████████████████████| 391/391 [02:56<00:00,  2.22it/s, loss=0.0294]


Original Accuracy: 86.248%, Train Accuracy: 65.378%


Epoch 8/10: 100%|████████████████████| 391/391 [02:56<00:00,  2.22it/s, loss=0.0281]


Original Accuracy: 86.248%, Train Accuracy: 65.666%


Epoch 9/10: 100%|████████████████████| 391/391 [02:56<00:00,  2.21it/s, loss=0.0322]


Original Accuracy: 86.248%, Train Accuracy: 66.032%


Epoch 10/10: 100%|████████████████████| 391/391 [02:55<00:00,  2.22it/s, loss=0.0292]

Original Accuracy: 86.248%, Train Accuracy: 66.348%





Evaluation

In [30]:
def lose_something(output_list, lose_num):
    if lose_num == 0:
        return output_list
    
    lose_index = torch.randperm(len(output_list))[:lose_num]
    losed_output_list = []

    for i in range(len(output_list)):

        if i in lose_index:

            losed_output_list.append(torch.zeros_like(output_list[i]))
        else:

            losed_output_list.append(output_list[i])
    return losed_output_list

In [31]:
import torch

from tqdm import tqdm

from dataset.image_dataset import ImageDataset
from util.split_data import split_vector

conv_segment.to(device)
fc_segment.to(device)
model.to(device)
encoder.to(device)
decoder.to(device)

conv_segment.eval()
fc_segment.eval()
model.eval()
encoder.eval()
decoder.eval()

def evaluation(loader, loss_num):
    original_correct = 0
    merge_correct = 0
    correct = 0
    total = 0
    with torch.no_grad():
        loader_tqdm = tqdm(
            loader,
            desc=f"Evaluating...",
            bar_format="{l_bar}{bar:20}{r_bar}",
        )
        for images, labels in loader_tqdm:
            images = images.to(device)
            labels = labels.to(device)
            
            # split image tensor(64, 3, 32, 32) -> [tensor(64, 3, 32, 8) * K]
            images_list = []
            for _1, _2, start, end in split_data_range:
                images_list.append(images[:, :, :, start:end].clone())
        
            pad = (0, 3, 0, 0)
            images_list[-1] = F.pad(images_list[-1], pad, "constant", value=0)

            _, predicted = torch.max(model(images).data, 1)
            original_correct += (predicted == labels).sum().item()

            output = conv_segment(images)
            output = output.view(output.size(0), -1)
            output = fc_segment(output)
            _, predicted = torch.max(output.data, 1)
            merge_correct += (predicted == labels).sum().item()

            imageDataset_list = [
                ImageDataset(images) for images in images_list + encoder(images_list)
            ]
            output_list = []
            for i in range(N):
                imageDataset = imageDataset_list[i]
                output = conv_segment(imageDataset.images)
                output_list.append(output)
            losed_output_list = lose_something(output_list, loss_num)
            decoded_output_list = decoder(losed_output_list)
            output = torch.cat(decoded_output_list, dim=3)
            output = output.view(output.size(0), -1)

            _, predicted = torch.max(fc_segment(output).data, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

    print(f"样本总数: {total}")
    print(
        f"原始模型(model) -> 预测正确数: {original_correct}, 预测准确率: {(100 * original_correct / total):.2f}%"
    )
    print(
        f"原始模型(conv+fc) -> 预测正确数: {merge_correct}, 预测准确率: {(100 * merge_correct / total):.2f}%"
    )
    print(
        f"使用Encoder和Decoder -> 预测正确数: {correct}, 预测准确率: {(100 * correct / total):.2f}%"
    )

In [32]:
# 训练集
print("训练")
for i in range(N + 1):
    print(f"loss_num: {i}")
    evaluation(train_loader, i)


训练
loss_num: 0


Evaluating...: 100%|████████████████████| 391/391 [01:11<00:00,  5.43it/s]


样本总数: 50000
原始模型(model) -> 预测正确数: 43124, 预测准确率: 86.25%
原始模型(conv+fc) -> 预测正确数: 43124, 预测准确率: 86.25%
使用Encoder和Decoder -> 预测正确数: 33135, 预测准确率: 66.27%
loss_num: 1


Evaluating...: 100%|████████████████████| 391/391 [01:11<00:00,  5.48it/s]


样本总数: 50000
原始模型(model) -> 预测正确数: 43124, 预测准确率: 86.25%
原始模型(conv+fc) -> 预测正确数: 43124, 预测准确率: 86.25%
使用Encoder和Decoder -> 预测正确数: 28507, 预测准确率: 57.01%
loss_num: 2


Evaluating...:  87%|█████████████████▍  | 340/391 [01:02<00:09,  5.43it/s]


KeyboardInterrupt: 

In [None]:
# 测试集
print("测试")
for i in range(N + 1):
    print(f"loss_num: {i}")
    evaluation(test_loader, i)

loss_num: 0


Evaluating...: 100%|████████████████████| 157/157 [00:13<00:00, 11.90it/s]


样本总数: 10000
原始模型(model) -> 预测正确数: 7432, 预测准确率: 74.32%
原始模型(conv+fc) -> 预测正确数: 7432, 预测准确率: 74.32%
使用Encoder和Decoder -> 预测正确数: 7439, 预测准确率: 74.39%
loss_num: 1


Evaluating...: 100%|████████████████████| 157/157 [00:13<00:00, 11.71it/s]


样本总数: 10000
原始模型(model) -> 预测正确数: 7432, 预测准确率: 74.32%
原始模型(conv+fc) -> 预测正确数: 7432, 预测准确率: 74.32%
使用Encoder和Decoder -> 预测正确数: 7129, 预测准确率: 71.29%
loss_num: 2


Evaluating...: 100%|████████████████████| 157/157 [00:13<00:00, 11.65it/s]


样本总数: 10000
原始模型(model) -> 预测正确数: 7432, 预测准确率: 74.32%
原始模型(conv+fc) -> 预测正确数: 7432, 预测准确率: 74.32%
使用Encoder和Decoder -> 预测正确数: 6631, 预测准确率: 66.31%
loss_num: 3


Evaluating...: 100%|████████████████████| 157/157 [00:13<00:00, 11.62it/s]


样本总数: 10000
原始模型(model) -> 预测正确数: 7432, 预测准确率: 74.32%
原始模型(conv+fc) -> 预测正确数: 7432, 预测准确率: 74.32%
使用Encoder和Decoder -> 预测正确数: 6055, 预测准确率: 60.55%
loss_num: 4


Evaluating...: 100%|████████████████████| 157/157 [00:13<00:00, 11.68it/s]


样本总数: 10000
原始模型(model) -> 预测正确数: 7432, 预测准确率: 74.32%
原始模型(conv+fc) -> 预测正确数: 7432, 预测准确率: 74.32%
使用Encoder和Decoder -> 预测正确数: 5064, 预测准确率: 50.64%
loss_num: 5


Evaluating...: 100%|████████████████████| 157/157 [00:13<00:00, 11.74it/s]


样本总数: 10000
原始模型(model) -> 预测正确数: 7432, 预测准确率: 74.32%
原始模型(conv+fc) -> 预测正确数: 7432, 预测准确率: 74.32%
使用Encoder和Decoder -> 预测正确数: 3385, 预测准确率: 33.85%
loss_num: 6


Evaluating...: 100%|████████████████████| 157/157 [00:13<00:00, 11.75it/s]

样本总数: 10000
原始模型(model) -> 预测正确数: 7432, 预测准确率: 74.32%
原始模型(conv+fc) -> 预测正确数: 7432, 预测准确率: 74.32%
使用Encoder和Decoder -> 预测正确数: 1000, 预测准确率: 10.00%



