In [1]:
# You only need to run this section once.

%load_ext autoreload
%autoreload 2

import os

print("original dir: ", os.getcwd())

if os.getcwd().endswith("NewMethod"):
    new_path = "../"
    os.chdir(new_path)
    print("changed dir: ", os.getcwd())
    
import torch

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

original dir:  e:\Nonlinear-Erasure-Code\src\NewMethod
changed dir:  e:\Nonlinear-Erasure-Code\src


In [3]:
import datetime

TASK_CONFIG = {
    "TASK": "MNIST",
    "DATE": datetime.datetime.now().strftime("%Y_%m_%d"),
    "MODEL": "MyModel1",
}

读取数据集

In [4]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# 设置数据转换
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)

# 设置数据集（训练集与测试集合）

"""
MNIST:
image: (1, 28, 28), class: 10
"""

# MNIST
dataset_name = TASK_CONFIG["TASK"]
train_dataset = datasets.MNIST(
    root="./data", train=True, download=True, transform=transform
)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_dataset = datasets.MNIST(
    root="./data", train=False, download=True, transform=transform
)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

print("Data is ready!")

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data\MNIST\raw\train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 10592724.11it/s]


Extracting ./data\MNIST\raw\train-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data\MNIST\raw\train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 508902.94it/s]


Extracting ./data\MNIST\raw\train-labels-idx1-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data\MNIST\raw\t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 3790476.63it/s]


Extracting ./data\MNIST\raw\t10k-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 4350429.04it/s]

Extracting ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw

Data is ready!





设置部分参数

In [12]:
import datetime

K = 4
R = 0
N = K + R
original_data_shape = tuple(train_dataset[0][0].shape)
num_classes = 10
print(f"K: {K}")
print(f"R: {R}")
print(f"N: {N}")
print(f"data_shape: {original_data_shape}")
print(f"num_classes: {num_classes}")

base_model_path = (
    f"./base_model/{TASK_CONFIG['MODEL']}/{TASK_CONFIG['TASK']}/2024_03_05/model.pth"
)
encoder_path = f"./encoder/MLP/{TASK_CONFIG['TASK']}/{TASK_CONFIG['DATE']}/encoder-task_{TASK_CONFIG['TASK']}-basemodel_{TASK_CONFIG['MODEL']}-in_{K}-out_{R}.pth"
decoder_path = f"./decoder/MLP/{TASK_CONFIG['TASK']}/{TASK_CONFIG['DATE']}/decoder-task_{TASK_CONFIG['TASK']}-basemodel_{TASK_CONFIG['MODEL']}-in_{N}-out_{K}.pth"
save_dir = f"./save/{TASK_CONFIG['TASK']}/{TASK_CONFIG['MODEL']}/{TASK_CONFIG['DATE']}/"
print(f"base_model_path: {base_model_path}")
print(f"encoder_path: {encoder_path}")
print(f"decoder_path: {decoder_path}")
print(f"save_dir: {save_dir}")

epoch_num = 4
print(f"epoch_num: {epoch_num}")

K: 4
R: 0
N: 4
data_shape: (1, 28, 28)
num_classes: 10
base_model_path: ./base_model/MyModel1/MNIST/2024_03_05/model.pth
encoder_path: ./encoder/MLP/MNIST/2024_03_05/encoder-task_MNIST-basemodel_MyModel1-in_4-out_0.pth
decoder_path: ./decoder/MLP/MNIST/2024_03_05/decoder-task_MNIST-basemodel_MyModel1-in_4-out_4.pth
save_dir: ./save/MNIST/MyModel1/2024_03_05/
epoch_num: 4


定义 base model

In [13]:
import torch

from base_model.MyModel1 import MyModel1

# 引入 base model, 该model将在后续全部过程中使用
# ResNet
assert TASK_CONFIG["MODEL"] == "MyModel1"
model = MyModel1(input_dim=original_data_shape, num_classes=num_classes)

# 读取模型
model.load_state_dict(torch.load(base_model_path, map_location=device))

conv_segment = model.get_conv_segment()
fc_segment = model.get_fc_segment()

x = torch.randn(1, *original_data_shape)

print(model(x).data)

y = conv_segment(x)
y = y.view(y.size(0), -1)
y = fc_segment(y)
print(y.data)

tensor([[ 0.4511,  2.8249,  4.0434, -2.8124,  2.7863, -0.4604,  0.9184,  1.5557,
          3.1938, -2.1021]])
tensor([[ 0.4511,  2.8249,  4.0434, -2.8124,  2.7863, -0.4604,  0.9184,  1.5557,
          3.1938, -2.1021]])


设置另一部分参数

In [38]:
from util.util import cal_input_shape


conv_output_shape = model.calculate_conv_output(input_dim=original_data_shape)
print(f"conv_output_shape: {conv_output_shape}")
assert conv_output_shape[2] % K == 0

split_conv_output_shape = (
    conv_output_shape[0],
    conv_output_shape[1],
    conv_output_shape[2] // K,
)
print(f"split_conv_output_shape: {split_conv_output_shape}")

split_data_range = cal_input_shape(
    model=conv_segment,
    original_input_shape=original_data_shape,
    original_output_shape=conv_output_shape,
    split_num=K,
)
print(f"split_data_range: {split_data_range}")

# print(conv_segment)
print(
    f"split_conv_output_data_shape from split_data_shape: {[tuple(conv_segment(torch.randn(1, _[0], _[1], _[3] - _[2])).shape) for _ in split_data_range]}"
)

split_data_shape = split_data_range[0]

conv_output_shape: (64, 20, 20)
split_conv_output_shape: (64, 20, 5)
{'type': 'ReLU', 'layer': ReLU(inplace=True), 'input_shape': (64, 20, 20), 'output_shape': (64, 20, 20)}
(64, 20, 0, 5)
(64, 20, 0, 5)
--------------------------------------------------
{'type': 'BatchNorm2d', 'layer': BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False), 'input_shape': (64, 20, 20), 'output_shape': (64, 20, 20)}
(64, 20, 0, 5)
(64, 20, 0, 5)
--------------------------------------------------
{'type': 'Conv2d', 'layer': Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1)), 'input_shape': (32, 22, 22), 'output_shape': (64, 20, 20)}
(64, 20, 0, 5)
(32, 22, 0, 7)
--------------------------------------------------
{'type': 'ReLU', 'layer': ReLU(inplace=True), 'input_shape': (32, 22, 22), 'output_shape': (32, 22, 22)}
(32, 22, 0, 7)
(32, 22, 0, 7)
--------------------------------------------------
{'type': 'BatchNorm2d', 'layer': BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True

定义 Encoder Decoder

In [42]:
from encoder.mlp_encoder import MLPEncoder
from encoder.conv_encoder import CatChannelConvEncoder, CatBatchSizeConvEncoder
from decoder.mlp_decoder import MLPDecoder
from decoder.conv_decoder import CatChannelConvDecoder, CatBatchSizeConvDecoder

# encoder = MLPEncoder(num_in=K, num_out=R, in_dim=split_data_shape)
# decoder = MLPDecoder(num_in=N, num_out=K, in_dim=split_conv_output_shape)

encoder = CatChannelConvEncoder(num_in=K, num_out=R, in_dim=split_data_shape)
decoder = CatChannelConvDecoder(
    num_in=N, num_out=K, in_dim=split_conv_output_data_shape
)

# encoder = CatBatchSizeConvEncoder(num_in=K, num_out=R, in_dim=split_data_shape)
# decoder = CatBatchSizeConvDecoder(num_in=N, num_out=K, in_dim=split_conv_output_data_shape)

训练 Encoder Decoder

In [43]:
import torch
import torch.nn as nn
import torch.optim as optim

from tqdm import tqdm
import datetime
import os

from dataset.image_dataset import ImageDataset
from util.split_data import split_vector

print(f"Train dataset: {len(train_dataset)}")
print("image size: ", train_dataset[0][0].size())

# 定义损失函数
criterion = nn.MSELoss()
criterion2 = nn.CrossEntropyLoss()
optimizer_encoder = optim.SGD(encoder.parameters(), lr=0.001, momentum=0.9)
optimizer_decoder = optim.SGD(decoder.parameters(), lr=0.001, momentum=0.9)

conv_segment.to(device)
fc_segment.to(device)
encoder.to(device)
decoder.to(device)

conv_segment.eval()
fc_segment.eval()
encoder.train()
decoder.train()

model.eval()
for m in model.modules():
    if isinstance(m, nn.BatchNorm2d):
        m.track_running_stats = False

loss_list = [[] for _ in range(epoch_num)]

for epoch in range(epoch_num):
    train_loader_tqdm = tqdm(
        train_loader,
        desc=f"Epoch {epoch+1}/{epoch_num}",
        bar_format="{l_bar}{bar:20}{r_bar}",
    )
    correct = 0
    correct_truth = 0
    total = 0
    for images, labels in train_loader_tqdm:
        # split image tensor(64, 3, 32, 32) -> [tensor(64, 3, 32, 8) * K]
        images_list = []
        for _1, _2, start, end in split_data_range:
            images_list.append(images[:, :, :, start:end].clone())

        images_list = [images.to(device) for images in images_list]
        labels = labels.to(device)

        images = images.to(device)
        ground_truth = conv_segment(images)
        ground_truth = ground_truth.view(ground_truth.size(0), -1)

        # forward
        images_list += encoder(images_list)
        output_list = []
        for i in range(N):
            output = conv_segment(images_list[i])
            output_list.append(output)
        # losed_output_list = lose_something(output_list, self.lose_device_index)
        decoded_output_list = decoder(output_list)
        output = torch.cat(decoded_output_list, dim=3)
        output = output.view(output.size(0), -1)

        loss = criterion(output, ground_truth)
        # loss = criterion2(fc_segment(output), fc_segment(ground_truth))
        loss_list[epoch].append(loss.item())

        # backward
        optimizer_encoder.zero_grad()
        optimizer_decoder.zero_grad()
        loss.backward()
        optimizer_encoder.step()
        optimizer_decoder.step()

        # calculate accuracy
        _, predicted = torch.max(fc_segment(output).data, 1)
        _, predicted_truth = torch.max(fc_segment(ground_truth.data), 1)
        # print(predicted)
        # print(predicted_truth)
        # print(labels)
        correct += (predicted == labels).sum().item()
        correct_truth += (predicted_truth == labels).sum().item()
        total += labels.size(0)

        train_loader_tqdm.set_postfix(loss=loss.item())

    print(f"Epoch: {epoch+1}, Loss: {loss.item()}")
    print(f"Train Accuracy: {100 * correct / total}%")
    print(f"Original Accuracy: {100 * correct_truth / total}%")
    # 27%
    # 10%

Train dataset: 60000
image size:  torch.Size([1, 28, 28])


Epoch 1/4:  44%|████████▊           | 412/938 [00:58<01:14,  7.02it/s, loss=1.6] 


KeyboardInterrupt: 

保存模型

In [None]:
import os


os.makedirs(os.path.dirname(encoder_path), exist_ok=True)
os.makedirs(os.path.dirname(decoder_path), exist_ok=True)

torch.save(encoder.state_dict(), encoder_path)
torch.save(decoder.state_dict(), decoder_path)

In [None]:
print(*loss_list)

path = save_dir + "loss2.txt"
os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, "w") as f:
    for loss in loss_list:
        f.write(" ".join(map(str, loss)) + "\n")

In [None]:
import os

with open("loss.txt", "r") as f:
    loss_list = [list(map(float, line.split())) for line in f.readlines()]

In [None]:
import matplotlib.pyplot as plt

y = [e for l in loss_list for e in l]
print(f"记录的loss数量: {len(y)}")
print(f"最后一个loss: {y[-1]}")

plt.figure()
plt.plot(y)
plt.show()

在测试集上评估 Encoder Decoder

In [None]:
# 加载模型
encoder = MLPEncoder(num_in=K, num_out=R, in_dim=split_data_shape)
decoder = MLPDecoder(num_in=N, num_out=K, in_dim=split_conv_output_shape)

encoder.load_state_dict(torch.load(encoder_path))
decoder.load_state_dict(torch.load(decoder_path))

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

from tqdm import tqdm
import datetime
import os

from dataset.image_dataset import ImageDataset
from util.split_data import split_vector

print(f"Test dataset: {len(test_dataset)}")
print("image size: ", train_dataset[0][0].size())

conv_segment.to(device)
fc_segment.to(device)
model.to(device)
encoder.to(device)
decoder.to(device)

conv_segment.eval()
fc_segment.eval()
model.eval()
encoder.eval()
decoder.eval()

original_correct = 0
merge_correct = 0
correct = 0
total = 0
with torch.no_grad():
    test_loader_tqdm = tqdm(
        test_loader,
        desc=f"Test",
        bar_format="{l_bar}{bar:20}{r_bar}",
    )
    for images, labels in test_loader_tqdm:

        # split image tensor(64, 3, 32, 32) -> [tensor(64, 3, 32, 8) * K]
        images_list = []
        for start, end in split_vector(
            L=original_data_shape[2], k=K, l=split_data_shape[2]
        ):
            images_list.append(images[:, :, :, start:end].clone())

        images_list = [images.to(device) for images in images_list]
        labels = labels.to(device)

        images = images.to(device)
        ground_truth = conv_segment(images)

        _, predicted = torch.max(model(images).data, 1)
        original_correct += (predicted == labels).sum().item()

        output = conv_segment(images)
        output = output.view(output.size(0), -1)
        output = fc_segment(output)
        _, predicted = torch.max(output.data, 1)
        merge_correct += (predicted == labels).sum().item()

        imageDataset_list = [
            ImageDataset(images) for images in images_list + encoder(images_list)
        ]
        output_list = []
        for i in range(N):
            imageDataset = imageDataset_list[i]
            output = conv_segment(imageDataset.images)
            output_list.append(output)
        # losed_output_list = lose_something(output_list, self.lose_device_index)
        decoded_output_list = decoder(output_list)
        output = torch.cat(decoded_output_list, dim=3)
        output = output.view(output.size(0), -1)

        _, predicted = torch.max(fc_segment(output).data, 1)
        correct += (predicted == labels).sum().item()
        total += labels.size(0)

print(f"测试集总数: {total}")
print(
    f"原始模型 -> 预测正确数: {original_correct}, 预测准确率: {100 * original_correct / total}%"
)
print(
    f"合并模型(conv+fc) -> 预测正确数: {merge_correct}, 预测准确率: {100 * merge_correct / total}%"
)
print(
    f"使用Encoder和Decoder -> 预测正确数: {correct}, 预测准确率: {100 * correct / total}%"
)

训练 base model

In [9]:
import torch
import torch.nn as nn
import torch.optim as optim

from tqdm import tqdm
import datetime
import os


print(f"Train dataset: {len(train_dataset)}")
print(f"Test dataset: {len(test_dataset)}")
print("image size: ", train_dataset[0][0].size())

# 部署到设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
model.to(device)

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0)

t0 = datetime.datetime.now()
# 训练循环
model.train()  # 设置模型为训练模式
num_epochs = 10  # 迭代次数
for epoch in range(num_epochs):
    # 使用 tqdm 包装训练数据加载器
    train_loader_tqdm = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
    for data, target in train_loader_tqdm:
        # 将数据移动到设备上
        data, target = data.to(device), target.to(device)

        # target 变为 one-hot 编码
        target = (
            torch.zeros(target.size(0), 10)
            .to(device)
            .scatter_(1, target.view(-1, 1), 1)
        )

        # 正向传播
        output = model(data)
        loss = criterion(output, target)

        # 反向传播和优化
        optimizer.zero_grad()  # 清除之前的梯度
        loss.backward()  # 反向传播
        optimizer.step()  # 更新权重

        # 更新进度条的描述
        train_loader_tqdm.set_postfix(loss=loss.item())
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}")
t1 = datetime.datetime.now()

# 测试循环
model.eval()  # 设置模型为评估模式

correct = 0
total = 0
with torch.no_grad():  # 在评估过程中不计算梯度
    for data, target in train_loader:
        # 将数据移动到设备上
        data, target = data.to(device), target.to(device)

        output = model(data)
        _, predicted = torch.max(output.data, 1)
        total += target.size(0)
        correct += (predicted == target).sum().item()

print(f"训练集-> 总量: {total}, 正确数量: {correct}, 准确率: {100 * correct / total}%")
t2 = datetime.datetime.now()

correct = 0
total = 0
with torch.no_grad():  # 在评估过程中不计算梯度
    for data, target in test_loader:
        # 将数据移动到设备上
        data, target = data.to(device), target.to(device)

        output = model(data)
        _, predicted = torch.max(output.data, 1)
        total += target.size(0)
        correct += (predicted == target).sum().item()

print(f"测试集-> 总量: {total}, 正确数量: {correct}, 准确率: {100 * correct / total}%")
t3 = datetime.datetime.now()

print(f"训练时间: {t1 - t0}")
print(f"训练集评估时间: {t2 - t1}")
print(f"测试集评估时间: {t3 - t2}")

# 保存模型
now = datetime.datetime.now()
date = now.strftime("%Y_%m_%d")
filepath = f"base_model/{TASK_CONFIG['MODEL']}/{TASK_CONFIG['TASK']}/{TASK_CONFIG['DATE']}/model.pth"
dirpath = os.path.dirname(filepath)
if not os.path.exists(dirpath):
    os.makedirs(dirpath)
torch.save(model.state_dict(), filepath)

# 读取模型
# model = LeNet5(input_dim=(1, 28, 28), num_classes=10)
# model.load_state_dict(torch.load(filepath))

Train dataset: 60000
Test dataset: 10000
image size:  torch.Size([1, 28, 28])
Device: cpu


Epoch 1/10: 100%|██████████| 938/938 [00:39<00:00, 24.00it/s, loss=0.0128] 


Epoch 1/10, Loss: 0.012757232412695885


Epoch 2/10: 100%|██████████| 938/938 [00:41<00:00, 22.78it/s, loss=0.0142] 


Epoch 2/10, Loss: 0.014194248244166374


Epoch 3/10: 100%|██████████| 938/938 [00:41<00:00, 22.42it/s, loss=0.0218] 


Epoch 3/10, Loss: 0.021787161007523537


Epoch 4/10: 100%|██████████| 938/938 [00:40<00:00, 23.21it/s, loss=0.00286]


Epoch 4/10, Loss: 0.002860612003132701


Epoch 5/10: 100%|██████████| 938/938 [00:40<00:00, 23.15it/s, loss=0.0123]  


Epoch 5/10, Loss: 0.012319878675043583


Epoch 6/10: 100%|██████████| 938/938 [00:42<00:00, 21.84it/s, loss=0.0264]  


Epoch 6/10, Loss: 0.02640596404671669


Epoch 7/10: 100%|██████████| 938/938 [00:42<00:00, 22.22it/s, loss=0.00395] 


Epoch 7/10, Loss: 0.003954804968088865


Epoch 8/10: 100%|██████████| 938/938 [00:37<00:00, 25.24it/s, loss=0.00779] 


Epoch 8/10, Loss: 0.007793494500219822


Epoch 9/10: 100%|██████████| 938/938 [00:39<00:00, 23.48it/s, loss=0.00105] 


Epoch 9/10, Loss: 0.0010497723706066608


Epoch 10/10: 100%|██████████| 938/938 [00:41<00:00, 22.44it/s, loss=0.00354] 


Epoch 10/10, Loss: 0.003543086349964142
训练集-> 总量: 60000, 正确数量: 59930, 准确率: 99.88333333333334%
测试集-> 总量: 10000, 正确数量: 9917, 准确率: 99.17%
训练时间: 0:06:47.128733
训练集评估时间: 0:00:13.520933
测试集评估时间: 0:00:02.320267
