In [65]:
import torch

import matplotlib.pyplot as plt

from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10

from utils.model_util import *
from utils.data_util import *
from utils.lib_util import *
from utils.train_util import *

plt.style.use('default')

device = 'cuda'

In [66]:
train_data = CIFAR10(
    root='./data/',
    train=True,
    transform=torchvision.transforms.ToTensor(),
    download=False)
test_data = CIFAR10(
    root='./data/',
    train=False,
    transform=torchvision.transforms.ToTensor(),
    download=False)
train_dataloader = DataLoader(
    dataset=train_data,
    batch_size=160,
    shuffle=True,
    pin_memory=True,
    num_workers=4)
test_dataloader = DataLoader(
    dataset=test_data,
    batch_size=200,
    shuffle=False,
    pin_memory=True,
    num_workers=4)

In [67]:
epochs = 2
model = torchvision.models.resnet18(weights=None, num_classes=10).to(device).train()
optimizer = torch.optim.Adam(model.parameters())
loss_func = torch.nn.CrossEntropyLoss().to(device)
for epoch in range(epochs):
    for data, target in test_dataloader:
        optimizer.zero_grad()
        output = model(data.to(device))
        loss = loss_func(output, target.to(device))
        loss.backward()
        optimizer.step()

In [69]:
acc = eval_model(model, test_dataloader, device)
print(acc)
print(output.shape)
print(model)

0.48019999265670776
torch.Size([200, 10])
ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats

In [87]:
parameters = deepcopy(model.state_dict())
print(len(parameters))
for i, key in enumerate(parameters):
    if parameters[key].shape == torch.Size([]):
        print(i, key, parameters[key].shape)

122
5 bn1.num_batches_tracked torch.Size([])
11 layer1.0.bn1.num_batches_tracked torch.Size([])
17 layer1.0.bn2.num_batches_tracked torch.Size([])
23 layer1.1.bn1.num_batches_tracked torch.Size([])
29 layer1.1.bn2.num_batches_tracked torch.Size([])
35 layer2.0.bn1.num_batches_tracked torch.Size([])
41 layer2.0.bn2.num_batches_tracked torch.Size([])
47 layer2.0.downsample.1.num_batches_tracked torch.Size([])
53 layer2.1.bn1.num_batches_tracked torch.Size([])
59 layer2.1.bn2.num_batches_tracked torch.Size([])
65 layer3.0.bn1.num_batches_tracked torch.Size([])
71 layer3.0.bn2.num_batches_tracked torch.Size([])
77 layer3.0.downsample.1.num_batches_tracked torch.Size([])
83 layer3.1.bn1.num_batches_tracked torch.Size([])
89 layer3.1.bn2.num_batches_tracked torch.Size([])
95 layer4.0.bn1.num_batches_tracked torch.Size([])
101 layer4.0.bn2.num_batches_tracked torch.Size([])
107 layer4.0.downsample.1.num_batches_tracked torch.Size([])
113 layer4.1.bn1.num_batches_tracked torch.Size([])
119 lay