In [1]:
import sys
import os
import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.image as mping
import torch
import torch.nn as nn
import torchvision.models as models
import torch.utils.data as data
from torchvision import transforms

from data_loader import get_loader
from data_loader import get_encoder_loader
from model import *

%matplotlib inline
%config InlineBackend.figure_format = 'retina'

### set high parameters

In [2]:
batch_size = 128
num_epochs = 200
extract_size = 2048
class_size = 61


# 图片格式转化
transform_train = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])
transform_vaild = transforms.Compose([
    transforms.Resize([224,224]),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

### load encoded datas

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# exract the images to embedding tensor
encoder = Encoder()
encoder = encoder.to(device)

# data_loader = get_encoder_loader(transform_train, encoder, device, mode='train', batch_size=batch_size)
# data_loader.dataset.save_to("./bottle_neck/resnet152_train.h")
data_loader = get_encoder_loader(transform_train, encoder, device, mode='train', batch_size=batch_size, file='./bottle_neck/resnet152_train.h')

### init model

In [4]:
# set the total number of training steps per epoch
total_step = int(len(data_loader.dataset)/batch_size)

classify_model = MultiClassify(extract_size, class_size)
classify_model = classify_model.to(device)

criterion = nn.CrossEntropyLoss().cuda() if torch.cuda.is_available() else nn.CrossEntropyLoss()

# with RMSprop to slow the desent gradient progress
optimizer = torch.optim.SGD(classify_model.parameters(), lr=0.001)

### Load the best trained model,yet!

In [7]:
classify_model.load_state_dict(torch.load('./models/class_multi_model_last.pkl'))

### time to train model

In [11]:
classify_model = classify_model.train()
best_acc = 0

for epoch in range(1, num_epochs+1):
    start = time.time()
    for i_step in range(1, total_step+1):
        
        # Ramdomly get samples
        indices = data_loader.dataset.get_train_indices()
        new_sampler = data.sampler.SubsetRandomSampler(indices=indices)
        data_loader.batch_sampler.sampler = new_sampler
        
        embeds, targets = next(iter(data_loader))
        
        embeds = embeds.squeeze(1)
        targets = targets.type(torch.LongTensor).to(device)
        
        classify_model.zero_grad()
        
        outputs = classify_model(embeds)
        
        loss = criterion(outputs, targets.view(-1))
        
        loss.backward()
        
        optimizer.step()
        
        if i_step%20 == 0:
            # calculate the status
            predict_result = outputs.argmax(1)
            accuracy = torch.sum(predict_result == targets).item() / batch_size * 100
            best_acc = accuracy if accuracy > best_acc else best_acc
            
            stats = 'Epoch [%d/%d], Step [%d/%d], Loss: %.4f, Accuracy: %.2f%%, Best_acc: %.2f%%\
            ' % (epoch, num_epochs, i_step, total_step, loss.item(), accuracy, best_acc)
            print('\r' + stats, end='')
            sys.stdout.flush()
    print('\n Epoch {}, spent time:{:.2f}s'.format(epoch, time.time()-start))       
    if epoch%50 == 0:
        torch.save(classify_model.state_dict(), os.path.join('./models', 'class_multi_model_%d.pkl' % epoch))
torch.save(classify_model.state_dict(), os.path.join('./models', 'class_multi_model_last.pkl'))

Epoch [1/200], Step [240/255], Loss: 0.3731, Accuracy: 89.06%, Best_acc: 89.84%            
 Epoch 1, spent time:5.88s
Epoch [2/200], Step [240/255], Loss: 0.4057, Accuracy: 85.94%, Best_acc: 91.41%            
 Epoch 2, spent time:5.96s
Epoch [3/200], Step [240/255], Loss: 0.4038, Accuracy: 82.03%, Best_acc: 91.41%            
 Epoch 3, spent time:5.90s
Epoch [4/200], Step [240/255], Loss: 0.4624, Accuracy: 86.72%, Best_acc: 91.41%            
 Epoch 4, spent time:5.88s
Epoch [5/200], Step [240/255], Loss: 0.4584, Accuracy: 81.25%, Best_acc: 91.41%            
 Epoch 5, spent time:5.89s
Epoch [6/200], Step [240/255], Loss: 0.3620, Accuracy: 89.84%, Best_acc: 91.41%            
 Epoch 6, spent time:5.89s
Epoch [7/200], Step [240/255], Loss: 0.4471, Accuracy: 82.03%, Best_acc: 91.41%            
 Epoch 7, spent time:5.85s
Epoch [8/200], Step [240/255], Loss: 0.4367, Accuracy: 82.03%, Best_acc: 91.41%            
 Epoch 8, spent time:5.81s
Epoch [9/200], Step [240/255], Loss: 0.4756, Acc

Epoch [69/200], Step [240/255], Loss: 0.4013, Accuracy: 84.38%, Best_acc: 92.97%            
 Epoch 69, spent time:5.82s
Epoch [70/200], Step [240/255], Loss: 0.4597, Accuracy: 84.38%, Best_acc: 92.97%            
 Epoch 70, spent time:5.86s
Epoch [71/200], Step [240/255], Loss: 0.4558, Accuracy: 83.59%, Best_acc: 92.97%            
 Epoch 71, spent time:5.82s
Epoch [72/200], Step [240/255], Loss: 0.4772, Accuracy: 80.47%, Best_acc: 92.97%            
 Epoch 72, spent time:5.78s
Epoch [73/200], Step [240/255], Loss: 0.4228, Accuracy: 88.28%, Best_acc: 93.75%            
 Epoch 73, spent time:6.01s
Epoch [74/200], Step [240/255], Loss: 0.3976, Accuracy: 83.59%, Best_acc: 93.75%            
 Epoch 74, spent time:5.94s
Epoch [75/200], Step [240/255], Loss: 0.3613, Accuracy: 87.50%, Best_acc: 93.75%            
 Epoch 75, spent time:5.89s
Epoch [76/200], Step [240/255], Loss: 0.4790, Accuracy: 83.59%, Best_acc: 93.75%            
 Epoch 76, spent time:5.92s
Epoch [77/200], Step [240/255], 

Epoch [136/200], Step [240/255], Loss: 0.4249, Accuracy: 86.72%, Best_acc: 94.53%            
 Epoch 136, spent time:6.12s
Epoch [137/200], Step [240/255], Loss: 0.4283, Accuracy: 89.06%, Best_acc: 94.53%            
 Epoch 137, spent time:6.18s
Epoch [138/200], Step [240/255], Loss: 0.3769, Accuracy: 89.06%, Best_acc: 94.53%            
 Epoch 138, spent time:6.15s
Epoch [139/200], Step [240/255], Loss: 0.2840, Accuracy: 89.06%, Best_acc: 94.53%            
 Epoch 139, spent time:6.16s
Epoch [140/200], Step [240/255], Loss: 0.3137, Accuracy: 90.62%, Best_acc: 94.53%            
 Epoch 140, spent time:6.05s
Epoch [141/200], Step [240/255], Loss: 0.5513, Accuracy: 79.69%, Best_acc: 94.53%            
 Epoch 141, spent time:6.04s
Epoch [142/200], Step [240/255], Loss: 0.4182, Accuracy: 86.72%, Best_acc: 94.53%            
 Epoch 142, spent time:6.10s
Epoch [143/200], Step [240/255], Loss: 0.3999, Accuracy: 83.59%, Best_acc: 94.53%            
 Epoch 143, spent time:6.13s
Epoch [144/200],

### How good is the model?

In [5]:
# valid_data_loader = get_encoder_loader(transform_vaild, encoder, device, mode='valid', batch_size=batch_size)
# valid_data_loader.dataset.save_to("./bottle_neck/resnet152_valid_new.h")
valid_data_loader = get_encoder_loader(transform_vaild, encoder, device, mode='valid', batch_size=batch_size, file="./bottle_neck/resnet152_valid.h")

In [12]:
classify_model = classify_model.eval()
predict = []
for embed, _ in valid_data_loader.dataset:
    p = classify_model(embed).argmax().item()
    predict.append(p)
    
df_refer = valid_data_loader.dataset.refer
df_refer["predict"] = predict
df_refer['correct'] = df_refer.predict == df_refer.disease_class
accuracy = (df_refer.correct == True).sum()/len(df_refer)
print('The final accuracy is %.2f%%.' % (accuracy*100))

The final accuracy is 64.65%.
