In [None]:
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image

dataset_path = Path("datasets/raw")

def default_loader(path):
    return Image.open(path).convert('RGB')

class SkillDataset(Dataset):
    def __init__(self, path: Path) -> None:
        super().__init__()
        all_img = list(dataset_path.glob("*.png"));
        self.c0 = [p for p in all_img if p.stem.endswith("_0")]
        self.c1 = [p for p in all_img if p.stem.endswith("_1")]
        self.c2 = [p for p in all_img if p.stem.endswith("_2")]
        self.c3 = [p for p in all_img if p.stem.endswith("_3")]
        print("len c0: ", len(self.c0), " c1: ", len(self.c1), " c2: ", len(self.c2), " c3: ", len(self.c3))

        self.loader = default_loader
        self.transform = transforms.Compose([
            transforms.GaussianBlur(3, sigma=(0.1, 2.0)),
            transforms.RandomPosterize(3),
            transforms.RandomAdjustSharpness(3),
            transforms.RandomAutocontrast(),
            transforms.ToTensor(),
        ])
        # 图片没多大，一次性全部载入内存算了
        self.data = [ self.get(i) for i in range(len(self))]
    
    def __len__(self):
        return len(self.c0) + len(self.c1) + len(self.c2) + len(self.c3)
    
    def get(self, index):
        if index % 1000 == 0:
            print("load ", index, " / ", len(self))
        if index < len(self.c0):
            return self.transform(self.loader(self.c0[index])), 0
        elif index < len(self.c0) + len(self.c1):
            return self.transform(self.loader(self.c1[index - len(self.c0)])), 1
        elif index < len(self.c0) + len(self.c1) + len(self.c2):
            return self.transform(self.loader(self.c2[index - len(self.c0) - len(self.c1)])), 2
        else:
            return self.transform(self.loader(self.c3[index - len(self.c0) - len(self.c1) - len(self.c2)])), 3
    
    def __getitem__(self, index):
        return self.data[index][0], self.data[index][1]
        

dataset = SkillDataset(dataset_path)

In [None]:

import torch

train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

# 显存不够可以把 batch size 改小点
train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=0)

In [None]:
import torch

import torch
import torch.nn as nn
import torch.nn.functional as F

class ModelM5(nn.Module):
    def __init__(self):
        super(ModelM5, self).__init__()
        self.channels = 3
        self.classes = 4

        self.conv1 = nn.Conv2d(self.channels, 32, 5, bias=False)
        self.conv1_bn = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, 5, bias=False)
        self.conv2_bn = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 96, 5, bias=False)
        self.conv3_bn = nn.BatchNorm2d(96)
        self.conv4 = nn.Conv2d(96, 128, 5, bias=False)
        self.conv4_bn = nn.BatchNorm2d(128)
        self.conv5 = nn.Conv2d(128, 160, 5, bias=False)
        self.conv5_bn = nn.BatchNorm2d(160)
        self.fc1 = nn.Linear(924160, self.classes, bias=False)
        self.fc1_bn = nn.BatchNorm1d(self.classes)
    def get_logits(self, x):
        x = (x - 0.5) * 2.0
        conv1 = F.relu(self.conv1_bn(self.conv1(x)))
        conv2 = F.relu(self.conv2_bn(self.conv2(conv1)))
        conv3 = F.relu(self.conv3_bn(self.conv3(conv2)))
        conv4 = F.relu(self.conv4_bn(self.conv4(conv3)))
        conv5 = F.relu(self.conv5_bn(self.conv5(conv4)))
        flat5 = torch.flatten(conv5.permute(0, 2, 3, 1), 1)
        logits = self.fc1_bn(self.fc1(flat5))
        return logits
    def forward(self, x):
        logits = self.get_logits(x)
        return F.log_softmax(logits, dim=1)


In [None]:

from torch.cuda.amp import autocast, GradScaler

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
if not use_cuda:
    print("WARNING: CPU will be used for training.")

model = ModelM5().to(device)

criterion = torch.nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scaler = GradScaler()

In [None]:
import time

start_time = time.time()
def train(epoch):
    global start_time
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()

        with autocast():
            output = model(data)
            loss = criterion(output, target)
            
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
    cur_time = time.time()
    cost = cur_time - start_time
    print(f'Train Epoch: {epoch}, Loss: {loss.item():.8f}, cost: {cost:.2f} s')
    start_time = cur_time
            
def test():
    model.eval()
    test_loss = 0.0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            with autocast():
                output = model(data)
                test_loss += criterion(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    acc = 100. * correct / len(test_loader.dataset)

    print(f'=== Test: Loss: {test_loss:.8f}, Acc: {acc:.4f} ===')
    return test_loss, acc

In [None]:
if use_cuda:
    torch.cuda.empty_cache()

output = Path('checkpoints')
output.mkdir(exist_ok=True, parents=True)

model_name = model.__class__.__name__
best_model_path = output / f'{model_name}_best.pt'

best_epoch = 0
best_loss = 100.0
best_acc = 0.0
default_interval = 1

def pipeline(start_epoch = 0, test_interval = default_interval):
    global best_epoch, best_loss, best_acc

    for epoch in range(start_epoch, 1000):
        train(epoch)
        if epoch % test_interval != 0:
            continue
        
        loss, acc = test()
        print(f'=== Pre best is {best_epoch}, Loss: {best_loss:.8f}, Acc: {best_acc:.4f} ===')
        torch.save(model, output / f'{model_name}_{epoch}.pt')
        if loss > best_loss:
            if epoch - best_epoch > 100:
                print('No improvement for a long time, Early stop!')
                break
            else:
                continue
        best_epoch = epoch
        best_loss = loss
        best_acc = acc
        print(f'====== New best is {best_epoch}, Loss: {best_loss:.8f}, Acc: {best_acc:.4f} ======')
        torch.save(model, best_model_path)

pipeline()

In [None]:
model = torch.load(best_model_path)
test()

In [None]:
# 导出 onnx

import torch.onnx
from pathlib import Path


def convert_onnx(path: Path):
    model = torch.load(path, map_location=torch.device("cpu"))
    model.eval()
    dummy_input = torch.randn(1, 3, 96, 96)
    torch.onnx.export(
        model,
        dummy_input,
        path.with_suffix(".onnx"),
        input_names=["input"],
        output_names=["output"],
    )


convert_onnx(best_model_path)


In [None]:
import onnx

onnx.checker.check_model(str(best_model_path.with_suffix(".onnx")))