In [2]:
root = './train_data'

In [3]:
import torch
import torchvision
import time
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
from torch.utils.tensorboard import SummaryWriter

In [6]:
def get_device():
    ''' Get device (if GPU is available, use GPU) '''
    return 'cuda' if torch.cuda.is_available() else 'cpu'

Setup Hyper-parameters 

In [7]:
writer = SummaryWriter('./log')

In [8]:
config = {
    'n_epochs': 100,                # maximum number of epochs
    'batch_size': 128,               # mini-batch size for dataloader
    'optimizer': 'SGD',              # optimization algorithm (optimizer in torch.optim)
    'optim_hparas': {                # hyper-parameters for the optimizer (depends on which optimizer you are using)
        'lr': 0.1,                 # learning rate of SGD
        'momentum': 0.4              # momentum for SGD
    },
    'save_path': './models_ReLU3/model.pth'  # your model will be saved here
}

In [9]:
transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize(40),
    torchvision.transforms.RandomResizedCrop(32, scale=(0.64, 1.0), ratio=(1.0, 1.0)),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize([0.4914, 0.4822,0.4465], [0.2023,0.1994,0.2010])
])


cifar10 = torchvision.datasets.CIFAR10(root= root ,
                            train=True,
                            download=False,
                            transform=transform)

cifar10_test = torchvision.datasets.CIFAR10(root= root ,
                            train=False,
                            download=False,
                            transform=transform)

In [10]:
train_loader= torch.utils.data.DataLoader(cifar10, batch_size=config['batch_size'],shuffle=True)

test_loader= torch.utils.data.DataLoader(cifar10_test, batch_size=config['batch_size'], shuffle=True)

In [11]:
class MLP(torch.nn.Module):
    def __init__(self, num_i, num_h1,num_h2, num_h3,num_o):
        super(MLP, self).__init__()
        
        self.linear1 = torch.nn.Linear(num_i, num_h1) #输入层到第一层隐藏层的线性转换
        self.linear2 = torch.nn.Linear(num_h1, num_h2)
        self.linear3 = torch.nn.Linear(num_h2, num_h3)
        self.linear4 = torch.nn.Linear(num_h3, num_o)
        self.relu = torch.nn.ReLU()
               
    def forward(self, x):
        x = self.linear1(x)
        x = self.relu(x)
        
        x = self.linear2(x)
        x = self.relu(x)
        
        x =self.linear3(x)
        x = self.relu(x)
        
        x = self.linear4(x) 
        return x

In [12]:
def train(model, train_loader,device, epoch, epochs):
    loss_func = torch.nn. CrossEntropyLoss()
    # optimizer = torch.optim.Adam(model.parameters())
   
    # Setup optimizer
    optimizer = getattr(torch.optim, config['optimizer'])(
        model.parameters(), **config['optim_hparas'])
    
    epoch = epoch
    epochs =epochs
    sum_loss =0
    train_correct = 0
    
    for data in train_loader:
        inputs, labels= data
        inputs = torch.flatten(inputs, start_dim=1)
        inputs , labels= inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_func(outputs, labels)
        loss.backward()
        optimizer.step()
     
        _, id = torch.max(outputs.data ,1)
        sum_loss += loss.data
        train_correct += torch.sum(id == labels.data)
               
    print('[%d/%d] loss:%.3f, correct:%.3f%%, time:%s' %
        (epoch+1 , epochs , sum_loss/len(train_loader),
        100 * train_correct / len(cifar10),
        time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())))
    
    writer.add_scalar('Loss/train', sum_loss/len(train_loader), epoch)
    writer.add_scalar('Accuracy/train', 100 * train_correct / len(cifar10), epoch)
        

In [13]:
def test(model, test_loader,device,epoch):
    test_correct = 0
    
    epoch = epoch
    model.eval()
    with torch.no_grad():
        for data in test_loader:
            inputs, labels =data 
            inputs, labels = inputs.to(device),labels.to(device)
            inputs = torch.flatten(inputs, start_dim=1)
            outputs = model(inputs)
            
            _,id = torch.max(outputs.data, 1)
            test_correct += torch.sum(id == labels.data)
            
        print(f'Accuracy on test set:{100*test_correct / len(cifar10_test):.3f}%')
     
    writer.add_scalar('Accuracy/validation', 100*test_correct / len(cifar10_test), epoch)
    return test_correct

In [14]:
def main(model, train_loader, test_loader, device, n_epochs, save_path):
    max_test_correct = 0
    
    epochs =  n_epochs
    save_path = save_path
    
    for epoch in range(epochs):
        train(model, train_loader,device, epoch ,epochs)
    
        test_correct = test(model, test_loader, device, epoch)
        test_accuracy = 100*test_correct / len(cifar10_test)
        test_accuracy_str = "{:.3f}".format(test_accuracy)
    
        if test_correct > max_test_correct :      #选取相较之前在测试集上表现较好的模型进行保存
            max_test_correct = test_correct
            save_path_with_accuracy = save_path[:-4]+'_'+ test_accuracy_str + '%'+save_path[-4:]
            torch.save(model.state_dict(),save_path_with_accuracy)
    

In [15]:
device = get_device()
model = MLP(3072, 1024, 128,64, 10).to(device)

In [None]:
main(model, train_loader,test_loader, device, config['n_epochs'],config['save_path'])

In [17]:
trained_model = MLP(3072, 1024, 128,64, 10).to(device)
trained_model.load_state_dict(torch.load('./model_pth/model_58.980%.pth'))  

training_labels = []
model_pre_train = []

testing_labels = []
model_pre_test = []

test_correct = 0
with torch.no_grad():
    for data in train_loader:
        inputs, labels =data
        inputs, labels = inputs.to(device), labels.to(device)
        inputs = torch.flatten(inputs, start_dim=1)
        outputs = trained_model(inputs)
        _, id = torch.max(outputs.data ,1)
        training_labels = training_labels + labels.tolist()
        model_pre_train = model_pre_train + id.tolist()
        
    for data in test_loader:
        inputs, labels =data 
        inputs, labels = inputs.to(device),labels.to(device)
        inputs = torch.flatten(inputs, start_dim=1)
        outputs = trained_model(inputs)        
        _,id = torch.max(outputs.data, 1)          
        testing_labels = testing_labels + labels.tolist()
        model_pre_test = model_pre_test + id.tolist()
        
        test_correct += torch.sum(id == labels.data)
            
    print(f'Accuracy on test set:{100*test_correct / len(cifar10_test):.3f}%')

Accuracy on test set:59.250%


In [None]:
##可视化在训练数据和验证数据上的混淆矩阵
train_confm = confusion_matrix(training_labels, model_pre_train)
val_confm = confusion_matrix(testing_labels, model_pre_test)

plt.figure(figsize=(12,5))
plt.subplot(1,2,1)
sns.heatmap(train_confm, square=True, annot=True, fmt='d',
           cbar=False, cmap="YlGnBu")

plt.xlabel("Real label")              #真实的标签
plt.ylabel("Predicted label")               #预测的标签
plt.title("Confusion matrix (training dataset)")            #混淆矩阵（训练集）
plt.subplot(1,2,2)
sns.heatmap(val_confm, square=True, annot=True, fmt='d',
           cbar=False, cmap="YlGnBu")

plt.xlabel("Real label")                     #真实的标签
plt.ylabel("Predicted label")                          #预测的标签
plt.title("Confusion matrix (validation set)")         #混淆矩阵（验证集）
plt.tight_layout()
plt.show()

In [18]:
print(trained_model)

MLP(
  (linear1): Linear(in_features=3072, out_features=1024, bias=True)
  (linear2): Linear(in_features=1024, out_features=128, bias=True)
  (linear3): Linear(in_features=128, out_features=64, bias=True)
  (linear4): Linear(in_features=64, out_features=10, bias=True)
  (relu): ReLU()
)
