In [33]:
import os
import glob
import random
import csv
import time

from PIL import Image
import torch
import torchvision.transforms as transforms
import visdom
from torch.utils.data import Dataset, DataLoader
from torchvision.models import resnet18
import torch.nn as nn
import torch.optim as optim

In [34]:
class Pokemon(Dataset):

    def __init__(self, root, resize=32, mode="train", csv_filename='images.csv'):
        super(Pokemon, self).__init__()

        self.root = root
        self.resize = resize
        self.images = []
        self.labels = []

        self.name2label = {}
        for name in sorted(os.listdir(os.path.join(root))):
            if not os.path.isdir(os.path.join(root, name)):
                continue
            self.name2label[name] = len(self.name2label.keys())
        self.load_csv(csv_filename)
        
        if mode == 'train':
            self.images = self.images[:int(0.6*len(self.images))]
            self.labels = self.labels[:int(0.6*len(self.labels))]
        elif mode == 'val':
            self.images = self.images[int(0.6*len(self.images)):int(0.8*len(self.images))]
            self.labels = self.labels[int(0.6*len(self.labels)):int(0.8*len(self.labels))]
        elif mode == 'test':
            self.images = self.images[int(0.8*len(self.images)):]
            self.labels = self.labels[int(0.8*len(self.labels)):]
        else:
            print('请重新输入mode')

    def load_csv(self, filename):
        # 保存在CSV文件中
        if not os.path.exists(os.path.join(self.root, filename)):
            for name in self.name2label.keys():
                self.images += glob.glob(os.path.join(self.root, name, '*.png'))
                self.images += glob.glob(os.path.join(self.root, name, '*.jpg'))
                self.images += glob.glob(os.path.join(self.root, name, '*.jpeg'))
            random.shuffle(self.images)
            with open(os.path.join(self.root, filename), mode='w', newline='') as f:
                writer = csv.writer(f)
                for img in self.images:
                    name = img.split(os.sep)[-2]
                    label = self.name2label[name]
                    writer.writerow([img, label])
                
        with open(os.path.join(self.root, filename)) as f:
            reader = csv.reader(f)
            for row in reader:
                img, label = row
                label = int(label)
                self.images.append(img)
                self.labels.append(label)
        assert len(self.images) == len(self.labels)
        
    def __len__(self):
        return len(self.images)
    
    def denormalize(self, x_hat):
        
        mean=[0.435, 0.456, 0.406]
        std=[0.229, 0.224, 0.225]
        mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)
        std = torch.tensor(std).unsqueeze(1).unsqueeze(1)
        
        x = x_hat * std + mean
        
        return x
    
    def __getitem__(self, idx):
        img, label = self.images[idx], self.labels[idx]
        tf = transforms.Compose([
            lambda x:Image.open(x).convert('RGB'),
            transforms.Resize((self.resize, self.resize)),
            transforms.RandomRotation(15),
            transforms.CenterCrop(self.resize),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.435, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        img = tf(img)
        label = torch.tensor(int(label))
        return img,label

In [4]:
db = Pokemon('D:\MyCode\dataset\pokeman', resize=224, mode="train")
viz = visdom.Visdom()
x, y = next(iter(db))
print('sample:', x.shape, y.shape, y)
viz.image(db.denormalize(x), win='sample_x', opts=dict(title='sample_x'))

# loader = DataLoader(db, batch_size=32, shuffle=True)
# for x, y in loader:
#     viz.images(db.denormalize(x), nrow=8, win='batch', opts=dict(title='batch'))
#     viz.text(str(y.numpy()), win='label', opts=dict(title='batch-y'))
#     time.sleep(10)

Setting up a new session...


sample: torch.Size([3, 224, 224]) torch.Size([]) tensor(1)


'sample_x'

In [35]:
def evalute(model, loader):
    correct = 0
    total = len(loader.dataset)
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        with torch.no_grad():
            logits = model(x)
            pred = logits.argmax(dim=1)
        correct += torch.eq(pred, y).sum().float().item()
    return correct/total

In [36]:
class Flatten(nn.Module):
    def __init__(self):
        super(Flatten, self).__init__()
    def forward(self, x):
        shape = torch.prod(torch.tensor(x.shape[1:])).item()
        return x.view(-1, shape)

In [None]:
batchsz = 32
lr = 1e-3
epochs = 10
# device = torch.device('cuda')
torch.manual_seed(1234)


def main():
    train_db = Pokemon('D:\MyCode\dataset\pokeman', 224, mode='train')
    val_db = Pokemon('D:\MyCode\dataset\pokeman', 224, mode='val')
    test_db = Pokemon('D:\MyCode\dataset\pokeman', 224, mode='test')
    
    train_loader = DataLoader(train_db, batch_size=batchsz, shuffle=True) # num_workers=2
    val_loader = DataLoader(val_db, batch_size=batchsz)
    test_loader = DataLoader(test_db, batch_size=batchsz)
    
    
    trained_model = resnet18(pretrained=True)#.to(device)
    model = nn.Sequential(*list(trained_model.children())[:-1],# 测试一下输出维度[b, 512, 1, 1]
                          Flatten(), #
                          nn.Linear(512, 5)
                         )
    
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criteon = nn.CrossEntropyLoss()
    
    best_acc = 0
    best_epoch = 0
    for epoch in range(epochs):
        
        for step, (x, y) in enumerate(train_loader):
            
            # x, y = x.to(device), y.to(device)
            logits = model(x)
            loss = criteon(logits, y)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        if epoch % 2 == 0:
            val_acc = evalute(model, val_loader)
            if val_acc > best_acc:
                best_epoch = epoch
                best_acc = val_acc
                # 保存参数
                torch.save(model.state_dict(), 'best.mdl')
                
                
    # 加载参数               
    model.load_state_dict(torch.load('best.mdl'))   
    test_acc = evalute(model, test_loader)
main()

In [20]:
        train(train_db)

if epoch%10 == 0:
    val_acc = evaluate(val_db)

    if val_ass is the best:
        save_ckpt()

    if out_of_pratience():
        break
load_ckpt()
test_acc = evaluate(test_db)

SyntaxError: invalid syntax (3642243474.py, line 6)