In [1]:
import os
import pickle
import copy
import torch
import torch.nn as nn
import pickle
import cv2

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

from torch import optim
from torch.autograd import Variable
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision import models

In [2]:
class RES34(nn.Module):
    def __init__(self):
        super(RES34, self).__init__()
        self.num_cls = 11
        self.base = models.resnet34(pretrained=True)
        self.base.fc = nn.Linear(self.base.fc.in_features, self.num_cls)
    def forward(self, x):
        x = self.base(x)
        return x
    
    def save(self, name='res34'):
        name += '.pth'
        torch.save(self.state_dict(), name)

    def loadIfExist(self, name='res34'):
        name += '.pth'
        fileList = os.listdir("./")
        # print(fileList)
        if name in fileList:
            self.load_state_dict(torch.load(name))
            print("the latest model has been load")


In [3]:
ROOT = 'E:\\abag\\Course\\Third semester 2\\project\\anime_dataset\\'


color_list = ['black',  'red', 'orange', 'yellow', 'green', 'cyan', 'blue', 'purple', 'pink', 'brown', 'gray']
COLOR2LABEL = dict(zip(color_list, range(len(color_list))))
LABEL2COLOR = dict(zip(range(len(color_list)), color_list))


class Anime_Dataset(Dataset):
    
    def __init__(self, train=True):
        
        dir_paths = [os.path.join(ROOT, s) for s in color_list]
        self.paths = []
        for dir_path in dir_paths:
            files = os.listdir(dir_path)
            if train:
                file_paths = [os.path.join(dir_path,file) for file in files if file.endswith('.jpg')][:-200]
                self.paths.extend(file_paths)
            else:
                file_paths = [os.path.join(dir_path,file) for file in files if file.endswith('.jpg')][-200:]
                self.paths.extend(file_paths)
        
    
    def __getitem__(self, index):
        image = cv2.imread(self.paths[index])
        image = cv2.resize(image, (64, 64))  # shape=[128, 128, 3]
        image = image.transpose([2, 0, 1])  # shape=[3, 128, 128]
        image = torch.tensor(image, dtype=torch.float32)
        image = image / 128. - 1
        
        color = self.paths[index].split(os.sep)[-2]
        label = torch.tensor(COLOR2LABEL[color], dtype=torch.int64)
        return image, label

    def __len__(self):
        return len(self.paths)

In [4]:
def calc_acc(labels, output):
    output = torch.argmax(output, dim=1)
    labels = labels.to(torch.int64)
    return torch.mean(torch.eq(output, labels).to(torch.float32)).item()

In [5]:
batch_size = 32

train_dataset = Anime_Dataset()
test_dataset = Anime_Dataset(False)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, drop_last=True)

FileNotFoundError: [WinError 3] 系统找不到指定的路径。: 'E:\\abag\\Course\\Third semester 2\\project\\anime_dataset\\purple'

In [6]:
lr = 0.0001

model = RES34()
model.loadIfExist()
if torch.cuda.is_available():
    model = model.cuda()
# print(model)
optim = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
loss_fn = torch.nn.CrossEntropyLoss()

the latest model has been load
the latest model has been load


In [18]:
epochs = 3

for e in range(epochs):
    for batch_images, batch_labels in train_loader:
        x = batch_images.cuda()
        y = batch_labels.cuda()
        output = model(x)
        loss = loss_fn(output, y)
        print('loss %5.4f' % loss.item(), end='\r')
        optim.zero_grad()
        loss.backward()
        optim.step()

    model.train(False)
    with torch.no_grad():
        outputs = []
        labels = []
        for batch_images, batch_labels in test_loader:
            output = model(batch_images.cuda())
            outputs.append(output.data.cpu())
            labels.append(batch_labels)
        acc = calc_acc(torch.cat(labels, dim=0), torch.cat(outputs, dim=0))
    print('epoch %3d, acc %5.4f' % (e + 1, acc))
    model.train()

epoch   1, acc 0.8372
epoch   2, acc 0.8621
epoch   3, acc 0.8792


In [20]:
model.save()

In [12]:
PATH = 'E:\\abag\\Course\\Third semester 2\\project\\anime_face'

class Anime_verify_Dataset(Dataset):
    
    def __init__(self):
        
        files = os.listdir(PATH)
        self.paths = [os.path.join(PATH,image) for image in files if image.endswith('.jpg')]
    
    def __getitem__(self, index):
        path = self.paths[index]
        image = cv2.imread(self.paths[index])
        image = cv2.resize(image, (64, 64))  # shape=[128, 128, 3]
        image = image.transpose([2, 0, 1])  # shape=[3, 128, 128]
        image = torch.tensor(image, dtype=torch.float32)
        image = image / 128. - 1
        
        return image, path

    def __len__(self):
        return len(self.paths)

In [13]:
verify_dataset = Anime_verify_Dataset()

verify_loader = DataLoader(verify_dataset, batch_size=batch_size, shuffle=False)

In [14]:
model.train(False)
with torch.no_grad():
    colors = []
    paths = []
    for batch_images, batch_paths in verify_loader:
        output = model(batch_images.cuda())
        output = torch.argmax(output, dim=1)
        colors.append(output)
        paths.append(batch_paths)

In [15]:
colors[0],paths[0]

(tensor([3, 9, 9, 3, 9, 7, 3, 9, 3, 4, 4, 3, 7, 9, 7, 3, 9, 9, 9, 7, 9, 2, 9, 3,
         9, 7, 9, 6, 9, 9, 5, 9], device='cuda:0'),
 ('E:\\abag\\Course\\Third semester 2\\project\\anime_face\\0000003.jpg',
  'E:\\abag\\Course\\Third semester 2\\project\\anime_face\\0000018.jpg',
  'E:\\abag\\Course\\Third semester 2\\project\\anime_face\\0000021.jpg',
  'E:\\abag\\Course\\Third semester 2\\project\\anime_face\\0000024.jpg',
  'E:\\abag\\Course\\Third semester 2\\project\\anime_face\\0000027.jpg',
  'E:\\abag\\Course\\Third semester 2\\project\\anime_face\\0000033.jpg',
  'E:\\abag\\Course\\Third semester 2\\project\\anime_face\\0000034.jpg',
  'E:\\abag\\Course\\Third semester 2\\project\\anime_face\\0000041.jpg',
  'E:\\abag\\Course\\Third semester 2\\project\\anime_face\\0000050.jpg',
  'E:\\abag\\Course\\Third semester 2\\project\\anime_face\\0000054.jpg',
  'E:\\abag\\Course\\Third semester 2\\project\\anime_face\\0000058.jpg',
  'E:\\abag\\Course\\Third semester 2\\project\\anime

In [16]:
length = len(colors)
for i in range(length):
    color = colors[i].cpu().numpy()
    color = [LABEL2COLOR[c] for c in color]
    path = paths[i]
    for j in range(len(color)):
        img = cv2.imread(path[j])
        name = 'E:\\abag\\Course\\Third semester 2\\project\\anime_dataset\\' + color[j]  + '\\face'+ str(i*16+j) + '.jpg'
        cv2.imwrite(name,img)
        os.remove(path[j])
    per = i / length
    print('进度：%4.3f'%per,  end='\r')

进度：0.999