In [None]:
import torch

import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms

import numpy as np

class BigModel(nn.Module):
    def __init__(self):
        super(BigModel, self).__init__()
        self.fc1 = nn.Linear(784, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear()
    
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

train_dataset= datasets.MNIST('./data', train=True, download = True, transform = transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = 64, shuffle = True)

def train(model, dataloader, criterion, optimizer, device='cpu', num_epochs=10 ):
    model.train()
    model.to(device)
    
    for epoch in range(num_epochs):
        running_loss = 0.0
        for batch_idx, (inputs, targets) in enumerate(dataloader):
            inputs, targets = inputs.to(device), targets.to(device)

            # 前向传播
            outputs = model(inputs.view(inputs.size(0), -1))
            loss = criterion(outputs, targets)

            # 反向传播

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        print(f"Epoch {epoch+1}, Loss: {running_loss / len(dataloader)}")
    return model

big_model = BigModel()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(big_model.parameters(), lr=1e-3)
big_model = train(big_model, train_loader, criterion, optimizer, device='cuda', num_epochs=2)


torch.save(big_model.state_dict(), 'big_model.pth')

# prune

def prune_network(model, pruning_rate=0.5, method='global'):
    for name, param in model.named_parameters():
            if 'weight' in name:
                tensor = param.data.cpu().numpy()
                if method == 'global':
                    threshold = np.percentile(abs(tensor), pruning_rate * 100)
                else:
                    threshold = np.percentile(abs(tensor), pruning_rate * 100, axis=1, keepdims=True)
                mask = abs(tensor) > threshold
                param.data = torch.FloatTensor(tensor* mask.astype(float)).to(param.device)
    
big_model.load_state_dict(torch.load('big_model.pth'))
prune_network(big_model, pruning_rate=0.5, method='global')

torch.save(big_model.state_dict(), 'big_model_pruned.pth')

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(big_model.parameters(), lr=1e-4)

fintuned_model = train(big_model, train_loader, criterion, optimizer, device='cuda', num_epochs=10)


torch.save(fintuned_model.state_dict(), 'big_model_pruned_finetuned.pth')

