In [4]:
import torch
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
import torchvision.models as models
import torchvision.utils as utils

In [2]:
if torch.cuda.is_available():
    device = torch.device("cuda:0")
    print("Running on the GPU")
else:
    print("No Cuda")

Running on the GPU


In [3]:
# parameters:
batch_size = 64
learning_rate = 0.001
momentum = 0.5
epoch = 10

In [5]:
def get_MNIST(transform):
    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Lambda(lambda x:x.repeat(3,1,1)),
         transform]
    )
    train_dataset = datasets.MNIST(root=r'D:\\InternWorkspace_wsr\\dataset',train=True,download=True,transform=transform)
    test_dataset = datasets.MNIST(root=r'D:\\InternWorkspace_wsr\\dataset',train=False,download=True,transform=transform)
    return train_dataset,test_dataset

In [6]:
def get_dataloader(train_dataset,test_dataset):
    # dataloader
    train_dataloader = DataLoader(dataset=train_dataset,
                                batch_size = batch_size,
                                shuffle=True)
    test_dataloader = DataLoader(dataset=test_dataset,
                                batch_size = 1,
                                shuffle=True)
    return train_dataloader,test_dataloader

In [7]:
# ResNet18
resnet18 = models.resnet18()
resnet18_weights = torch.load('./weights/resnet18-f37072fd.pth')
resnet18.load_state_dict(resnet18_weights)
resnet18.eval()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [8]:
# # 第一种方法：直接修改
# modify parameter:
# fc output(原)：1000 -> 10
resnet18_fc_features = resnet18.fc.in_features
resnet18.fc = torch.nn.modules.Linear(resnet18_fc_features,10)

In [9]:
resnet18

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [10]:
# 第二种方法：冻结之前层的参数，微调最后一层
for param in resnet18.parameters():
    param.requires_grad = False
# 解冻fc层
for param in resnet18.fc.parameters():
    print(param)
    param.requires_grad = True

Parameter containing:
tensor([[-0.0402, -0.0309,  0.0116,  ..., -0.0378, -0.0379,  0.0116],
        [-0.0386, -0.0397, -0.0121,  ...,  0.0410,  0.0208,  0.0428],
        [-0.0116, -0.0414, -0.0394,  ..., -0.0209,  0.0426,  0.0216],
        ...,
        [-0.0144,  0.0095,  0.0232,  ...,  0.0061, -0.0282,  0.0196],
        [-0.0380,  0.0353,  0.0054,  ...,  0.0231,  0.0108, -0.0369],
        [-0.0280,  0.0317, -0.0424,  ...,  0.0056, -0.0126, -0.0155]])
Parameter containing:
tensor([ 0.0105, -0.0392,  0.0408,  0.0058, -0.0313,  0.0023,  0.0305, -0.0157,
        -0.0292, -0.0186])


In [11]:
# preprcossed:
resnet18_trans = models.ResNet18_Weights.DEFAULT.transforms()
train_dataset,test_dataset = get_MNIST(resnet18_trans)
train_dataloader,test_dataloader = get_dataloader(train_dataset,test_dataset)

In [12]:
# Finetune AlexNet
# train:
def finetune(model,train_dataloader,test_dataloader):
    model.train()
    model = model.to(device)
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(),lr=0.001)

    epoch = 5
    loss_lists = []
    for e in tqdm(range(epoch)):
        train_loss = 0
        total = 0
        correct = 0
        for _ ,data in enumerate(train_dataloader):
            optimizer.zero_grad()
            inputs,targets = data
            # to GPU
            inputs,targets = inputs.to(device),targets.to(device)
            # train
            outputs = model(inputs)
            loss = criterion(outputs,targets)
            # loss
            loss.backward()
            optimizer.step()
            
            # compute:
            train_loss += loss.item()
            _,predicted = torch.max(outputs.data,dim=1)
            total+=inputs.shape[0]
            correct+=(targets==predicted).sum().item()
            
            
        # save loss:
        loss_lists.append(loss.cpu().detach().numpy())
        acc = correct/total*100
        print("epoch {}: loss {} acc {}".format(e,loss,acc))

    # predict:
    correct = 0
    total = 0
    model.eval()
    with torch.no_grad():
        for data in tqdm(test_dataloader):
            inputs,targets = data
            
            inputs,targets = inputs.to(device),targets.to(device)
            outputs = model(inputs)
            
            _,predicted = torch.max(outputs.data,dim=1)
            total += inputs.shape[0]
            correct+=(targets==predicted).sum().item()
            
    acc = correct/total*100
    print("Acc in test dataset: {}".format(acc))

    # save model
    torch.save(model,'./weights/checkpoints/Resnet18_ft.pth')

    
    # show:
    fig = plt.figure()
    plt.plot(np.arange(0,len(loss_lists),1),loss_lists)
    plt.show()
    plt.close(fig)

In [None]:
finetune(resnet18,train_dataloader,test_dataloader)

In [None]:
def predict(model):
    model = model.to(device)
    # predict:
    correct = 0
    total = 0
    model.eval()
    with torch.no_grad():
        for data in tqdm(test_dataloader):
            inputs,targets = data
            
            inputs,targets = inputs.to(device),targets.to(device)
            outputs = model(inputs)
            
            _,predicted = torch.max(outputs.data,dim=1)
            total += inputs.shape[0]
            correct+=(targets==predicted).sum().item()
            
    acc = correct/total*100
    print("Acc in test dataset: {}".format(acc))

In [None]:
model_ft = torch.load('./weights/checkpoints/Resnet18_ft.pth')
predict(model_ft)