In [8]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import sys
import torch
import torch.nn as nn
from torch.optim import lr_scheduler
from torchvision.datasets import ImageFolder
from torchvision import transforms as transforms
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
from mmdet.models.backbones.yololite_logo import YoloLiteLogo
from mmdet.models.backbones.yololite import YoloLite

class cls_model(nn.Module):
    def __init__(self):
        super(cls_model, self).__init__()
        self.backbone = YoloLiteLogo()
        #self.backbone = YoloLite()
        self.fc1 = nn.Linear(1024 * 4, 1024)
        #self.fc1 = nn.Linear(25600, 1024)
        self.fc2 = nn.Linear(1024, 5)
        self.dp = nn.Dropout(p = 0.5)
        
    def forward(self, x):
        x = self.backbone(x)[-1]
        x = x.view(x.shape[0], -1)
        x = self.fc2(self.dp(self.fc1(x)))
        return x

# params
lr = 0.05
batch_size = 128
epochs = 1000
checkpoint_step = 50
# transform
transform = transforms.Compose([
    transforms.Resize((320, 320)), # 缩放到 96 * 96 大小
    transforms.ToTensor()
])
# dataset
dataset = ImageFolder(root = '/data/zhaozhiyuan/ImageNet-5', transform = transform)
train_size = int(len(dataset) * 0.7)
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])
train_loader = DataLoader(dataset = train_dataset, batch_size = batch_size, num_workers = 4, shuffle = True)
test_loader = DataLoader(dataset = test_dataset, batch_size = batch_size, num_workers = 4, shuffle = True)
# model
model = cls_model()
model.train()
model.cuda()
# loss
creterion = nn.CrossEntropyLoss()
# optimizer
optimizer = torch.optim.SGD(params = model.parameters(), lr = lr)
scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[60], gamma=0.1)

if not os.path.exists('/data/zhaozhiyuan/mmdetection_checkpoints/yololite_logo'):
    os.mkdir('/data/zhaozhiyuan/mmdetection_checkpoints/yololite_logo')
#train
for epoch in range(epochs):
    loss_sum = 0
    for X, y in train_loader:
        X, y = X.cuda(), y.cuda()
        y_hat = model(X)
        loss = creterion(y_hat, y)
        loss_sum += float(loss.detach().cpu())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    scheduler.step()
    # test
    model.eval()
    n, correct = 0, 0
    for X, y in test_loader:
        X, y = X.cuda(), y.cuda()
        y_hat = model(X)
        y_pred = torch.argmax(y_hat, dim = 1)
        correct += int(torch.sum(y_pred == y))
        n += y.shape[0]
    lr = optimizer.param_groups[0]['lr']
    print('epoch {}, test_acc {:.6f}, loss {:.9f}, lr {:.3f}'.format(epoch, correct / n, loss_sum / (len(train_loader) * batch_size), lr))
    model.train()
    if epoch % checkpoint_step == 0 and epoch != 0:
        torch.save(model, '/data/zhaozhiyuan/mmdetection_checkpoints/yololite_logo/model_{}.pkl'.format(epoch))

epoch 0, test_acc 0.251504, loss 0.012514437, lr 0.050
epoch 1, test_acc 0.251504, loss 0.012416666, lr 0.050
epoch 2, test_acc 0.251504, loss 0.012381208, lr 0.050
epoch 3, test_acc 0.251504, loss 0.012366347, lr 0.050
epoch 4, test_acc 0.251504, loss 0.012358800, lr 0.050
epoch 5, test_acc 0.251504, loss 0.012371713, lr 0.050
epoch 6, test_acc 0.251504, loss 0.012357952, lr 0.050
epoch 7, test_acc 0.251504, loss 0.012358654, lr 0.050
epoch 8, test_acc 0.251504, loss 0.012351892, lr 0.050
epoch 9, test_acc 0.251504, loss 0.012364510, lr 0.050
epoch 10, test_acc 0.251504, loss 0.012360363, lr 0.050
epoch 11, test_acc 0.251504, loss 0.012345033, lr 0.050
epoch 12, test_acc 0.251504, loss 0.012348635, lr 0.050
epoch 13, test_acc 0.251504, loss 0.012351416, lr 0.050
epoch 14, test_acc 0.251504, loss 0.012341646, lr 0.050
epoch 15, test_acc 0.252597, loss 0.012332086, lr 0.050
epoch 16, test_acc 0.251504, loss 0.012314558, lr 0.050
epoch 17, test_acc 0.261345, loss 0.012295325, lr 0.050
ep

  "type " + obj.__name__ + ". It won't be checked "


epoch 51, test_acc 0.782395, loss 0.004819695, lr 0.050
epoch 52, test_acc 0.751777, loss 0.004650043, lr 0.050
epoch 53, test_acc 0.702570, loss 0.005405847, lr 0.050
epoch 54, test_acc 0.750683, loss 0.005452016, lr 0.050
epoch 55, test_acc 0.746856, loss 0.004654499, lr 0.050
epoch 56, test_acc 0.682340, loss 0.004107485, lr 0.050
epoch 57, test_acc 0.780208, loss 0.003953539, lr 0.050
epoch 58, test_acc 0.734828, loss 0.003656791, lr 0.050
epoch 59, test_acc 0.716785, loss 0.003473797, lr 0.005
epoch 60, test_acc 0.810279, loss 0.002353216, lr 0.005
epoch 61, test_acc 0.809185, loss 0.002133037, lr 0.005
epoch 62, test_acc 0.806998, loss 0.002005464, lr 0.005
epoch 63, test_acc 0.804265, loss 0.001989276, lr 0.005
epoch 64, test_acc 0.806998, loss 0.001919769, lr 0.005
epoch 65, test_acc 0.806452, loss 0.001853482, lr 0.005
epoch 66, test_acc 0.807545, loss 0.001806680, lr 0.005
epoch 67, test_acc 0.802078, loss 0.001757102, lr 0.005
epoch 68, test_acc 0.804811, loss 0.001726719, l

epoch 196, test_acc 0.805905, loss 0.000036790, lr 0.005
epoch 197, test_acc 0.805358, loss 0.000037944, lr 0.005
epoch 198, test_acc 0.805358, loss 0.000034651, lr 0.005
epoch 199, test_acc 0.805905, loss 0.000035980, lr 0.005
epoch 200, test_acc 0.803718, loss 0.000034825, lr 0.005
epoch 201, test_acc 0.803718, loss 0.000035777, lr 0.005
epoch 202, test_acc 0.804811, loss 0.000031059, lr 0.005
epoch 203, test_acc 0.803718, loss 0.000033515, lr 0.005
epoch 204, test_acc 0.802624, loss 0.000030780, lr 0.005
epoch 205, test_acc 0.802078, loss 0.000030737, lr 0.005
epoch 206, test_acc 0.802078, loss 0.000028994, lr 0.005
epoch 207, test_acc 0.804265, loss 0.000032258, lr 0.005
epoch 208, test_acc 0.805905, loss 0.000029038, lr 0.005
epoch 209, test_acc 0.806998, loss 0.000028913, lr 0.005
epoch 210, test_acc 0.804265, loss 0.000029197, lr 0.005
epoch 211, test_acc 0.806452, loss 0.000027242, lr 0.005
epoch 212, test_acc 0.804265, loss 0.000027044, lr 0.005
epoch 213, test_acc 0.806452, l

epoch 340, test_acc 0.804811, loss 0.000006077, lr 0.005
epoch 341, test_acc 0.805358, loss 0.000005977, lr 0.005
epoch 342, test_acc 0.804811, loss 0.000006249, lr 0.005
epoch 343, test_acc 0.804265, loss 0.000005937, lr 0.005
epoch 344, test_acc 0.804265, loss 0.000005652, lr 0.005
epoch 345, test_acc 0.804811, loss 0.000006802, lr 0.005
epoch 346, test_acc 0.803718, loss 0.000006940, lr 0.005
epoch 347, test_acc 0.803171, loss 0.000006798, lr 0.005
epoch 348, test_acc 0.804265, loss 0.000005956, lr 0.005
epoch 349, test_acc 0.805358, loss 0.000005907, lr 0.005
epoch 350, test_acc 0.804811, loss 0.000005902, lr 0.005
epoch 351, test_acc 0.804811, loss 0.000006033, lr 0.005
epoch 352, test_acc 0.804265, loss 0.000005725, lr 0.005
epoch 353, test_acc 0.804265, loss 0.000005285, lr 0.005
epoch 354, test_acc 0.804811, loss 0.000006021, lr 0.005
epoch 355, test_acc 0.804811, loss 0.000004676, lr 0.005
epoch 356, test_acc 0.803718, loss 0.000005058, lr 0.005
epoch 357, test_acc 0.803718, l

epoch 484, test_acc 0.802078, loss 0.000002845, lr 0.005
epoch 485, test_acc 0.804265, loss 0.000002857, lr 0.005
epoch 486, test_acc 0.803718, loss 0.000003018, lr 0.005
epoch 487, test_acc 0.804811, loss 0.000002926, lr 0.005
epoch 488, test_acc 0.804811, loss 0.000002672, lr 0.005
epoch 489, test_acc 0.803171, loss 0.000002847, lr 0.005
epoch 490, test_acc 0.803718, loss 0.000002768, lr 0.005
epoch 491, test_acc 0.803718, loss 0.000002642, lr 0.005
epoch 492, test_acc 0.805358, loss 0.000002748, lr 0.005
epoch 493, test_acc 0.804811, loss 0.000002625, lr 0.005
epoch 494, test_acc 0.804811, loss 0.000002870, lr 0.005
epoch 495, test_acc 0.805358, loss 0.000002913, lr 0.005
epoch 496, test_acc 0.803718, loss 0.000002519, lr 0.005
epoch 497, test_acc 0.804265, loss 0.000003282, lr 0.005
epoch 498, test_acc 0.804811, loss 0.000002738, lr 0.005
epoch 499, test_acc 0.800984, loss 0.000003375, lr 0.005
epoch 500, test_acc 0.803718, loss 0.000002606, lr 0.005
epoch 501, test_acc 0.803718, l

epoch 628, test_acc 0.804811, loss 0.000001610, lr 0.005
epoch 629, test_acc 0.804265, loss 0.000001700, lr 0.005
epoch 630, test_acc 0.804265, loss 0.000001846, lr 0.005
epoch 631, test_acc 0.804265, loss 0.000001579, lr 0.005
epoch 632, test_acc 0.803718, loss 0.000001722, lr 0.005
epoch 633, test_acc 0.803171, loss 0.000001700, lr 0.005
epoch 634, test_acc 0.803718, loss 0.000001636, lr 0.005
epoch 635, test_acc 0.803171, loss 0.000001693, lr 0.005
epoch 636, test_acc 0.804265, loss 0.000002274, lr 0.005
epoch 637, test_acc 0.804265, loss 0.000001850, lr 0.005
epoch 638, test_acc 0.803171, loss 0.000001749, lr 0.005
epoch 639, test_acc 0.803171, loss 0.000002138, lr 0.005
epoch 640, test_acc 0.803718, loss 0.000001588, lr 0.005
epoch 647, test_acc 0.803171, loss 0.000002052, lr 0.005
epoch 648, test_acc 0.803171, loss 0.000001675, lr 0.005
epoch 649, test_acc 0.804265, loss 0.000001776, lr 0.005
epoch 650, test_acc 0.804811, loss 0.000001621, lr 0.005
epoch 651, test_acc 0.802624, l

epoch 778, test_acc 0.802624, loss 0.000001160, lr 0.005
epoch 779, test_acc 0.803171, loss 0.000001195, lr 0.005
epoch 780, test_acc 0.802624, loss 0.000001300, lr 0.005
epoch 781, test_acc 0.802624, loss 0.000001147, lr 0.005
epoch 782, test_acc 0.803171, loss 0.000001153, lr 0.005
epoch 783, test_acc 0.803171, loss 0.000001175, lr 0.005
epoch 784, test_acc 0.804811, loss 0.000001564, lr 0.005
epoch 785, test_acc 0.804811, loss 0.000001103, lr 0.005
epoch 786, test_acc 0.804265, loss 0.000001250, lr 0.005
epoch 787, test_acc 0.803171, loss 0.000001094, lr 0.005
epoch 788, test_acc 0.804265, loss 0.000001198, lr 0.005
epoch 789, test_acc 0.804265, loss 0.000001227, lr 0.005
epoch 790, test_acc 0.803718, loss 0.000001014, lr 0.005
epoch 791, test_acc 0.804265, loss 0.000001158, lr 0.005
epoch 792, test_acc 0.804265, loss 0.000001166, lr 0.005
epoch 793, test_acc 0.803718, loss 0.000001193, lr 0.005
epoch 794, test_acc 0.803171, loss 0.000001294, lr 0.005
epoch 795, test_acc 0.803171, l

epoch 922, test_acc 0.803171, loss 0.000000926, lr 0.005
epoch 923, test_acc 0.803171, loss 0.000001096, lr 0.005
epoch 924, test_acc 0.802624, loss 0.000000904, lr 0.005
epoch 925, test_acc 0.803171, loss 0.000000871, lr 0.005
epoch 926, test_acc 0.803718, loss 0.000000996, lr 0.005
epoch 927, test_acc 0.803718, loss 0.000000900, lr 0.005
epoch 928, test_acc 0.803171, loss 0.000001039, lr 0.005
epoch 929, test_acc 0.803718, loss 0.000000876, lr 0.005
epoch 930, test_acc 0.803171, loss 0.000000944, lr 0.005
epoch 931, test_acc 0.803171, loss 0.000000745, lr 0.005
epoch 932, test_acc 0.803718, loss 0.000000918, lr 0.005
epoch 933, test_acc 0.801531, loss 0.000000912, lr 0.005
epoch 934, test_acc 0.803171, loss 0.000001200, lr 0.005
epoch 935, test_acc 0.803171, loss 0.000000943, lr 0.005
epoch 936, test_acc 0.803171, loss 0.000000884, lr 0.005
epoch 937, test_acc 0.802624, loss 0.000000881, lr 0.005
epoch 938, test_acc 0.803718, loss 0.000000906, lr 0.005
epoch 939, test_acc 0.803171, l

In [9]:
model = cls_model()
model.train()
model = torch.load('/data/zhaozhiyuan/mmdetection_checkpoints/yololite_logo/model_950.pkl')
print(type(model))
model.cuda()
print(type(model.state_dict()))
torch.save(model.backbone.state_dict(), '/data/zhaozhiyuan/mmdetection_checkpoints/yololite_logo/pretrained_on_imagenet.pkl')

<class '__main__.cls_model'>
<class 'collections.OrderedDict'>
