# 猫狗分类问题

## 关于实现的问题

按照pytorch实现的方式，之前有提到关于pytorch以及相关的视觉包的使用，现在综合之前写过的相关的代码进行实现  
  
一般需要实现下面几个东西：
1. 模型定义
2. 数据处理和预加载
3. 训练模型
4. 训练过程的可视化
5. 测试

In [20]:
import os
from PIL import Image
import torch as t
from torch.utils import data
import numpy as np
from torchvision import transforms as T

In [13]:
# 这是一个测试，验证下面的方法，这个函数的目的是提取没个文件名的数字，并且按照这个数字进行排序
x = 'data/train/001.jpg' 
key = x.split('.')[-2].split('/')[-1]
print(key)

001


# 定义猫狗数据类，获取数据集

In [16]:
class DogCat(data.Dataset):
    
    def __init__(self, root, transforms = None, train = None, test = False):
        self.test = test
        
        #这里得到的imgs所有图像的路径以及文件名，也就是说imgs是一个路径名的列表
        imgs = [os.path.join(root, img)for img in os.listdir(root)]
        
        if self.test:
            imgs = sorted(imgs, key = lambda x: int(x.split('.')[-2].split('/')[-1]))
        else:
            imgs = sorted(imgs, key = lambda x: int(s.split('.')[-2]))
        
        #查看总共有多少张图片
        imgs_num = len(imgs)
        
        #这个是测试集，直接按照测试集属性进行划分
        if self.test:
            self.imgs = imgs
        #这个是训练集
        elif train:
            self.imgs = imgs[:int(0.7 * imgs_num)]
        #这个是验证集
        else:
            self.imgs = imgs[int(0.7 * imgs_num):]
        
        if transforms is None:
            #固定写法，添加归一化操作
            normalize = T.normalize(mean = [0.485, .456, .406],
                                   std = [.229, .224, .225]
                                   )
        
            if self.test or not train:
                self.transforms = T.Compose([
                    T.Scale(224),
                    T.CenterCrop(224),
                    T.ToTensor(),
                    normalize
                ])
            
            else:
                self.transforms = T.Compose([
                    T.Scale(256),
                    T.RandomSizedCrop(224),
                    T.RandomHorizontalFlip(),
                    T.ToTensor(),
                    normalize
                    
                ])
                
    def __getitem__(self, index):
        
        img_path = self.imgs[index]
        #测试集不分配id,也就是说，测试集不给标记，直接给数据的id
        if self.test:
            label = int(self.imgs[index].split('.')[-2].split('/')[-1])#返回
        else:
            label = 1 if 'dog' in img_path.split('/'[-1]) else 0 
            
        data = Image.open(img_path)
        data = self.transforms(data)
        return data, label
    
    def __len__(self):
        return len(self.imgs)
        

# 配置文件
  
这个单元格就看作是一个配置文件

In [19]:

train_data_root = '../../data/dogs-vs-cats/train/'
test_data_root = '../../data/dogs-vs-cats/test1/'
batch_size = 16
num_workers = 2

# 加载数据集并训练

In [None]:
train_dataset = DogCat(train_data_root, train= True)

trainloader = data.DataLoader(train_dataset, batch_size= batch_size, shuffle= True, num_workers= num_workers)

for ii , (data, label) in enumerate(trainloader):
    train()
    

# 模型接口封装

In [21]:
class BasicModule(t.nn.Module):
    def __init__(self ):
        super(BasicModule, self).__init__()
        self.module_name = str(type(self))#模型默认的名称
        
    def load(self, path):
        self.load_state_dict(t.load(path))
        
        
    def save(self, name = None):
        if name is None:
            prefix = self.module_name+'_'
            name = time.strftime(prefix + '%m%d_%H:%M:%S.pth')
        t.save(self.state_dict(), name)
    
        return name

# 导入其他模型

In [23]:
from torchvision.models import AlexNet
from torchvision.models import resnet34



# 工具函数区

In [25]:
#coding: utf-8
import visdom
import time
import numpy as np

In [27]:
class Visualizer(object):
    # visdom 是一个视觉库，包含很多视觉相关的操作
    #这个类封装了visdom的基本操作，但是仍然可以通过self.vis.funtion调用原生的visdom的接口
    
#     self.text('hello visdom')
#     self.histogram(t.randn(1000))
#     self.line(t.arange(0,10), t.arange(1, 11))
    
    def __init__(self, env = 'default', **kwargs):
        self.vis = visdom.Visdom(env = env, **kwargs)
        
        self.index= {}
        self.log_text = ''
    
    
    def reinit(self, env='default', **kwargs):
        self.vis = visdom.Visdom(env = env, **kwargs)
        return self
    
    def plot_many(self, d):
        #@params d: dict(name, value) i.e. ('loss', 0.11)
        
        for k, v in d.items():
            self.plot(k, v)
            
    def img_many(self, d):
        for k,v in d.items():
            self.img(k, v)
        
    def plot(self, name, y, **kwargs):
        x = self.index.get(name, 0)
        
        self.vis.line(Y = np.array([y]), X = np.array([x]),win = (name), opt = dict(title = name),
                        update = None if x ==0 else 'append',
                      **kwargs
                     )
        self.index[name] = x+1
        
    def img(self, name, img_, **kwargs):
        self .vis.images(img_.cpu().numpy(),
                        win = (name),
                         opts = dict(title = name),
                         **kwargs
                        )
    def log(self, info, win = 'log_text'):
        
        self.log_text += ('[{time}]{info}<br>'.fomat(
            time = time.strftime('%m%d_%H%M%S'),
            info = info
        ))
        self.vis.text(self.log_text , win)
        
    def __getattr__(self, name):
        
        return getattr(self.vis, name)
    