## 移除最后conv最后一层网络

把 resnet最后一层再拼接回去，重新开始训练 ，finetune


In [1]:
import sys
import os
import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.image as mping
import torch
import torch.nn as nn
import torchvision.models as models
import torch.utils.data as data
from torchvision import transforms

from data_loader import *
from model import *

%matplotlib inline
%config InlineBackend.figure_format = 'retina'

### set high parameters

In [3]:
batch_size = 256
num_epochs = 50
extract_size = 2048
class_size = 61


# 图片格式转化
transform_train = transforms.Compose([
    transforms.Resize([224,224]),
#     transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])
transform_vaild = transforms.Compose([
    transforms.Resize([224,224]),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

### load encoded datas

In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# exract the images to embedding tensor
# remove the last conv layer
encoder = EncoderClearCut()
encoder = encoder.to(device)

fileFold = './bottle_neck/resnet152_train_clear_remove_last_conv'
data_loader = get_encoder_loader_fold(transform_train, encoder, device, fileFold, load=False, mode='train', batch_size=batch_size)

validFileFold = './bottle_neck/resnet152_valid_clear_remove_last_conv'
valid_data_loader = get_encoder_loader_fold(transform_vaild, encoder, device, validFileFold, load=False, mode='valid', batch_size=batch_size)

encoding 4981/4982->99.98%, spent_time:7:51.123.00

### Define the classify network

In [26]:
class ConvClassify(nn.Module):
    """Get the last conv layer of resnet to train"""
    def __init__(self, in_features, class_size):
        super(ConvClassify, self).__init__()
        self.drop = nn.Dropout2d(0.4)
        self.conv = nn.Conv2d(1024, 2048, 3, stride=2)
        self.pool = nn.AvgPool2d(6)
        self.bn = nn.BatchNorm1d(2048)
        self.fc = nn.Linear(in_features, class_size)
    
    def forward(self, features):
        y = self.drop(features)
        y = self.conv(y)
        y = self.pool(y)
        y = y.view(y.size(0), -1)
        y = self.bn(y)
        y = self.fc(y)
        return y

In [27]:
def valid_class_acc(classify_model, valid_data_loader):
    classify_model = classify_model.eval()
    predict = []
    total = len(valid_data_loader.dataset)
    for idx in range(total):
        embed, _ = valid_data_loader.dataset[idx]
        p = classify_model(embed).argmax().item()
        predict.append(p)
        print('\r %d / %d' % (idx, total), end='')
        sys.stdout.flush()
    
    df_refer = valid_data_loader.dataset.refer
    df_refer["predict"] = predict
    df_refer['correct'] = df_refer.predict == df_refer.disease_class
    accuracy = (df_refer.correct == True).sum()/len(df_refer)
    return accuracy*100

### init model

拼接resnet 最后一个 conv layer，用来训练

In [28]:
# set the total number of training steps per epoch
total_step = int(len(data_loader.dataset)/batch_size)

classify_model = ConvClassify(extract_size, class_size)
classify_model = classify_model.to(device)

criterion = nn.CrossEntropyLoss().cuda() if torch.cuda.is_available() else nn.CrossEntropyLoss()

# with RMSprop to slow the desent gradient progress
optimizer = torch.optim.Adam(classify_model.parameters(), lr=0.001)

### Load the best trained model,yet!

In [8]:
classify_model.load_state_dict(torch.load('./models/class_train_last_conv_layer_last.pkl'))

### time to train model

In [9]:
optimizer = torch.optim.SGD(classify_model.parameters(), lr=0.0001)

In [None]:
best_acc = 0

for epoch in range(6, num_epochs+6):
    classify_model = classify_model.train()
    start = time.time()
    for i_step in range(1, total_step+1):
        
        # Ramdomly get samples
        indices = data_loader.dataset.get_train_indices()
        new_sampler = data.sampler.SubsetRandomSampler(indices=indices)
        data_loader.batch_sampler.sampler = new_sampler
        
        
        embeds, targets = next(iter(data_loader))
        
        embeds = embeds.squeeze(1)
        targets = targets.type(torch.LongTensor).to(device)
        
        classify_model.zero_grad()
        
        outputs = classify_model(embeds)
        
        loss = criterion(outputs, targets.view(-1))
        
        loss.backward()
        
        optimizer.step()
        
        if i_step%2 == 0:
            # calculate the status
            predict_result = outputs.argmax(1)
            accuracy = torch.sum(predict_result == targets).item() / batch_size * 100
            best_acc = accuracy if accuracy > best_acc else best_acc
            stats = 'Epoch [%d/%d], Step [%d/%d], Loss: %.4f, Accuracy: %.2f%%, Best_acc: %.2f%%\
            ' % (epoch, num_epochs, i_step, total_step, loss.item(), accuracy, best_acc)
            print('\r' + stats, end='')
            sys.stdout.flush()
    
    valid_acc = valid_class_acc(classify_model, valid_data_loader)
    print('\n Epoch {}, spent time:{:.2f}s, valid: {:.2f}%'.format(epoch, time.time()-start, valid_acc))       
    if epoch%5 == 0:
        torch.save(classify_model.state_dict(), os.path.join('./models', 'class_train_last_conv_layer_%d.pkl' % epoch))
# torch.save(classify_model.state_dict(), os.path.join('./models', 'class_train_last_conv_layer_last.pkl'))

 4981 / 4982, Step [126/127], Loss: 0.9728, Accuracy: 64.45%, Best_acc: 69.92%            
 Epoch 6, spent time:319.92s, valid: 69.93%
 4981 / 4982, Step [126/127], Loss: 1.0193, Accuracy: 67.19%, Best_acc: 73.05%            
 Epoch 7, spent time:326.24s, valid: 70.75%
 4981 / 4982, Step [126/127], Loss: 0.7152, Accuracy: 72.66%, Best_acc: 73.83%            
 Epoch 8, spent time:309.71s, valid: 71.44%
 4981 / 4982, Step [126/127], Loss: 0.8822, Accuracy: 69.92%, Best_acc: 76.17%            
 Epoch 9, spent time:320.46s, valid: 72.48%
 4981 / 4982], Step [126/127], Loss: 0.8656, Accuracy: 71.48%, Best_acc: 76.17%            
 Epoch 10, spent time:309.17s, valid: 71.56%
 4981 / 4982], Step [126/127], Loss: 0.9334, Accuracy: 67.97%, Best_acc: 76.17%            
 Epoch 11, spent time:317.99s, valid: 72.72%
 4981 / 4982], Step [126/127], Loss: 0.8210, Accuracy: 72.27%, Best_acc: 76.56%            
 Epoch 12, spent time:363.16s, valid: 72.58%
 4981 / 4982], Step [126/127], Loss: 0.7148, Accu

### How good is the model?

In [12]:
!mkdir ./bottle_neck/resnet152_valid_clear_remove_last_conv

In [8]:
fileFold = './bottle_neck/resnet152_valid_clear_remove_last_conv'
valid_data_loader = get_encoder_loader_fold(transform_vaild, encoder, device, fileFold, load=True, mode='valid', batch_size=batch_size)

In [41]:
classify_model.load_state_dict(torch.load('./models/class_train_last_conv_layer_last.pkl'))

In [None]:
print('The final accuracy is %.2f' % valid_class_acc(classify_model, valid_data_loader))

## With 100 epoch the result is the best,yet.

In [23]:
!mv ./models/class_multify_rm1layer_100.pkl ./models/class_multify_rm1layer_good.pkl