In [1]:
import os
import time
import numpy as np
from torch.optim.lr_scheduler import ReduceLROnPlateau,CosineAnnealingLR
import torch
import torch.nn as nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader, Dataset
from matplotlib import pyplot as plt
from sklearn.model_selection import train_test_split
from PIL import Image
from glob import glob
import sys
from torch.utils.data import DataLoader, Dataset
from albumentations import (Resize, RandomCrop,VerticalFlip, HorizontalFlip, Normalize, Compose, Crop, PadIfNeeded, RandomBrightness, Rotate)
from albumentations.pytorch import ToTensor
import cv2
from torch.nn import functional as F
from tqdm import tqdm
import segmentation_models_pytorch as smp


In [2]:
def provider(
    image_path,
    phase,
    mean=None,
    std=None,
    batch_size=8,
    num_workers=0,
):
    assert phase in ("train", "val", "test")

    image_list = glob(os.path.join(image_path, "*"))
    print("total images: {}".format(len(image_list)))

    index = range(len(image_list))

    dataset = CatDataset(index, image_list, phase=phase)

    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=False,
        shuffle=False,
    )

    return dataloader

def get_transforms(phase, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
    list_transforms = []
    list_transforms.extend(
        [
            Resize(256, 256),
            Normalize(mean=mean, std=std, p=1),
            ToTensor(),
        ]
    )
    list_trfms = Compose(list_transforms)
    return list_trfms

class CatDataset(Dataset):
    def __init__(self, idx, image_list, phase="train"):
        assert phase in ( "test")
        self.idx = idx
        self.image_list = image_list
        self.phase = phase

        self.transform = get_transforms(phase)

    def __getitem__(self, index):
        real_idx = self.idx[index]
        image_path = self.image_list[real_idx]

        image = cv2.imread(image_path)
        augmented = self.transform(image=image)

        return augmented["image"], image_path

    def __len__(self):
        return len(self.idx)
    
 

In [3]:
class Trainer(object):
    '''This class takes care of training and validation of our model'''
    def __init__(self, model):
        self.num_workers = 0
        self.batch_size = {"test":8}
        self.accumulation_steps = 32 // self.batch_size['test']
        self.phases = ["test"]
        self.device = torch.device("cuda:0")
        torch.set_default_tensor_type("torch.cuda.FloatTensor")
        self.net = model

        self.net = self.net.to(self.device)
        cudnn.benchmark = True
        self.dataloaders = {
            phase: provider(
                image_path=IMAGEPATH,
                phase=phase,
                mean=(0.485, 0.456, 0.406),
                std=(0.229, 0.224, 0.225),
                batch_size=self.batch_size[phase],
                num_workers=self.num_workers,
            )                                                   
            for phase in self.phases
        }
        self.losses = {phase: [] for phase in self.phases}
        
    def forward(self, images):
        images = images.to(self.device)
        outputs = self.net(images)
        
        return  outputs


    def iterate(self, phase):
        start = time.strftime("%H:%M:%S")
        print(f"Starting epoch: 0 | phase: {phase} | ⏰: {start}")
        self.net.train(phase == "train")
        dataloader = self.dataloaders[phase]
        for batch in tqdm(dataloader):
            images, pathes = batch
            with torch.no_grad():
                outputs = self.forward(images)
            
            batch_preds = torch.sigmoid(outputs)  
            # 预测结果以图片的形式存在输入图片的相同路径下，后面带有 label 后缀
            # 预测结果的分辨率统一为 256*256，若需恢复原分辨率需要调用albumentations 中的 Resize
            for i in range(batch_preds.shape[0]):
                numpy_output = batch_preds[i].squeeze(0).detach().cpu().numpy()
                r = np.where(numpy_output > 0.5, 255, 0).astype("uint8")
                cv2.imwrite(os.path.splitext(pathes[i])[0]+"_label.jpg", r)
            
                    
    def start(self):
        self.iterate("test")
        

In [4]:
MODELPATH = ""  # 模型路径
IMAGEPATH = r".\cat\200" # 图片路径

In [5]:
if os.path.exists(MODELPATH):
        
    model = smp.Unet('resnet50', classes=1, activation=None)
    state = torch.load(MODELPATH, map_location=lambda storage, loc: storage)

    model.load_state_dict(state["state_dict"])
else:
    model = smp.Unet('resnet50', classes=1, activation=None)

device = torch.device("cuda")
model.to(device)
model_trainer = Trainer(model)
model_trainer.start()

total images: 16
Starting epoch: 0 | phase: test | ⏰: 20:35:47


  0%|                                                                                            | 0/2 [00:00<?, ?it/s]

.\cat\200\cat-6.jpg
('.\\cat\\200\\cat-6', '.jpg')
.\cat\200\cat-6_label.jpg
('.\\cat\\200\\cat-6_label', '.jpg')
.\cat\200\cat_0.jpg
('.\\cat\\200\\cat_0', '.jpg')
.\cat\200\cat_0_label.jpg
('.\\cat\\200\\cat_0_label', '.jpg')
.\cat\200\cat_1.jpg
('.\\cat\\200\\cat_1', '.jpg')
.\cat\200\cat_1_label.jpg
('.\\cat\\200\\cat_1_label', '.jpg')
.\cat\200\cat_2.jpg
('.\\cat\\200\\cat_2', '.jpg')
.\cat\200\cat_2_label.jpg
('.\\cat\\200\\cat_2_label', '.jpg')


 50%|██████████████████████████████████████████                                          | 1/2 [00:02<00:02,  2.98s/it]

.\cat\200\cat_3.jpg
('.\\cat\\200\\cat_3', '.jpg')
.\cat\200\cat_3_label.jpg
('.\\cat\\200\\cat_3_label', '.jpg')
.\cat\200\cat_4.jpg
('.\\cat\\200\\cat_4', '.jpg')
.\cat\200\cat_4_label.jpg
('.\\cat\\200\\cat_4_label', '.jpg')
.\cat\200\cat_5.jpg
('.\\cat\\200\\cat_5', '.jpg')
.\cat\200\cat_5_label.jpg
('.\\cat\\200\\cat_5_label', '.jpg')
.\cat\200\cat_7.jpg
('.\\cat\\200\\cat_7', '.jpg')
.\cat\200\cat_7_label.jpg
('.\\cat\\200\\cat_7_label', '.jpg')


100%|████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:03<00:00,  1.54s/it]
