In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ConvNet2_mnist(nn.Module):
    def __init__(self, in_channels, init='random'):
        super(ConvNet2_mnist, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, 32, kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)
        self.linear = nn.Linear(576, 10)
        
        if init == 'gaussian':
            self._gaussian_initialization()
        
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.avg_pool2d(x, 2)
        x = x.view(x.size(0), -1)
        x = self.linear(x)
        return x
    
param_num = 0
model = ConvNet2_mnist(1)
for name, layer in model.named_parameters():
    print(name, layer)

In [11]:
import torch
model_pth = "/scratch/yufan/nn_diff/gen_param_2_gen_param_2/"
model = torch.load(model_pth + "seed_200.pth", map_location="cpu")
data = torch.Tensor([])
for key, value in model.items():
    print(key, "-", value.shape)
    data = torch.cat((data, value.flatten()), 0)
len(data)

conv1.weight - torch.Size([32, 1, 3, 3])
conv1.bias - torch.Size([32])
conv2.weight - torch.Size([64, 32, 3, 3])
conv2.bias - torch.Size([64])
linear.weight - torch.Size([10, 576])
linear.bias - torch.Size([10])


24586

In [13]:
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision import datasets
import torch.nn.functional as F

class ConvNet2_mnist(nn.Module):
    def __init__(self, in_channels, init='random'):
        super(ConvNet2_mnist, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, 32, kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)
        self.linear = nn.Linear(576, 10)
        
        if init == 'gaussian':
            self._gaussian_initialization()
        
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.avg_pool2d(x, 2)
        x = x.view(x.size(0), -1)
        x = self.linear(x)
        return x
    
    def _gaussian_initialization(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight, mean=0, std=0.01)
                nn.init.normal_(m.bias, mean=0, std=0.01)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, mean=0, std=0.01)
                nn.init.normal_(m.bias, mean=0, std=0.01)
                
def test(model, test_loader):
    total_correct = 0
    model.eval()
    for images, labels in test_loader:
        images, labels = images.cuda(), labels.cuda()
        outputs = model(images)
        pred_labels = torch.argmax(outputs, dim=1)

        matches = pred_labels.eq(labels).float()
        correct = matches.sum().item()

        total_correct += correct

    accuracy = total_correct / len(test_loader.dataset)
    return accuracy

test_transform = transforms.Compose([
                                    transforms.ToTensor(),
                                    transforms.Normalize((0.1307,), (0.3081,))
                                ])

test_dataset = datasets.MNIST(root='/scratch/datasets/mnist', train=False, transform=test_transform, download=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=1000, shuffle=False)

In [15]:
from copy import deepcopy
net = ConvNet2_mnist(1)
_data = deepcopy(data)
for name, layer in net.named_parameters():
    print(name, layer.shape)
    layer_param = _data[:layer.numel()]
    print(layer_param.shape)
    layer.data = layer_param.reshape(layer.shape)
    _data = _data[layer.numel():]
net = net.cuda()
    
test(net, test_loader)

conv1.weight torch.Size([32, 1, 3, 3])
torch.Size([288])
conv1.bias torch.Size([32])
torch.Size([32])
conv2.weight torch.Size([64, 32, 3, 3])
torch.Size([18432])
conv2.bias torch.Size([64])
torch.Size([64])
linear.weight torch.Size([10, 576])
torch.Size([5760])
linear.bias torch.Size([10])
torch.Size([10])


0.9667