In [3]:
import torch
from torch import nn
from torchvision.models import resnet101,ResNet101_Weights
from torch.utils.data import DataLoader,Dataset
import matplotlib.pyplot as plt
import albumentations as A
from albumentations.pytorch import ToTensorV2
import os 
import glob
from PIL import Image
import numpy as np
from torch.functional import F
from torchvision import datasets

In [4]:

class MyDataset(Dataset):
    def __init__(self, data_dir=r"data/VOCdevkit/VOC2012", image_size=(512,512), is_train=True, transform=None):
        super(MyDataset,self).__init__()
        self.transform = transform
        self.is_train = is_train
        self.image_size = image_size
        self.images_list, self.mask_list = self.read_data_list_from(data_dir, self.is_train)
        
        # 定义默认的Albumentations增强
        if self.transform is None:
            if self.is_train:
                self.aug = A.Compose([
                    A.Resize(512, 512),
                    A.HorizontalFlip(p=0.5),
                    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229,0.224,0.225]),
                    ToTensorV2()
                ], additional_targets={'mask': 'mask'})
            else:
                self.aug = A.Compose([
                    A.Resize(image_size[0], image_size[1]),
                    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229,0.224,0.225]),
                    ToTensorV2()
                ], additional_targets={'mask': 'mask'})
        else:
            self.aug = self.transform

    def __len__(self):
        return len(self.images_list)
    
    def __getitem__(self, index):
        image_path = self.images_list[index]
        mask_path = self.mask_list[index]
        
        # 读取图像和mask，转换mask中的255为-1
        image = np.array(Image.open(image_path).convert("RGB")).astype(np.float32)
        mask = np.array(Image.open(mask_path)).astype(np.int64)
        mask[mask==255]=-1  # 提前转换255为-1
        
        # 应用数据增强
        augmented = self.aug(image=image, mask=mask)
        image = augmented['image']
        mask = augmented['mask']
        
        return image, mask
    
    def read_data_list_from(self, root, is_train):
        if is_train ==True:
            data_dir=os.path.join(root,"ImageSets","Segmentation","train.txt")
        else:
            data_dir=os.path.join(root,"ImageSets","Segmentation","val.txt")
        fh=open(data_dir)
        images=[]
        masks=[]
        for line in fh:
            line=line.strip("\n")
            images.append(os.path.join(root,"JPEGImages",line+".jpg"))
            masks.append(os.path.join(root,"SegmentationClass",line+".png"))
                         
        return images, masks

In [5]:
train_ds=DataLoader(MyDataset(),batch_size=15,num_workers=0,shuffle=True)
val_ds=DataLoader(MyDataset(is_train=False),batch_size=15,num_workers=0)

In [6]:
def try_gpu(i=0):
    if torch.cuda.is_available():
        return torch.device(f'cuda:{i}')
    elif torch.backends.mps.is_available():
        return torch.device('mps')
    return torch.device('cpu')

In [7]:
class ConvBlock(nn.Module):
    def __init__(self,in_channels,out_channels,kernel_size,stride,padding):
        super().__init__()
        self.block=nn.Sequential(
            nn.Conv2d(in_channels,out_channels,kernel_size=kernel_size,stride=stride,padding=padding,padding_mode="reflect"),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )
    def forward(self,x):
        return self.block(x)

In [8]:
# define ASPP
class ASPP(nn.Module):
    def __init__(self,in_channels,out_channels):
        super().__init__()
        self.conv=nn.Conv2d(in_channels,out_channels,kernel_size=1,stride=1,padding=0)
        self.atrous6=nn.Conv2d(in_channels,out_channels,dilation=6,kernel_size=3,stride=1,padding=6,padding_mode="reflect")
        self.atrous12=nn.Conv2d(in_channels,out_channels,dilation=12,kernel_size=3,stride=1,padding=12,padding_mode="reflect")
        self.atrous18=nn.Conv2d(in_channels,out_channels,dilation=18,kernel_size=3,stride=1,padding=18,padding_mode="reflect")
        self.pool=nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_channels,out_channels,kernel_size=1,stride=1,padding=0),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    def forward(self,x):
        conv=self.conv(x)
        atrous6=self.atrous6(x)
        atrous12=self.atrous12(x)
        atrous18=self.atrous18(x)
        pooled=self.pool(x)
        pooled=F.interpolate(
            pooled,
            size=x.shape[2:],
            mode="bilinear",
            align_corners=True
        )
        return torch.cat((conv,atrous6,atrous12,atrous18,pooled),dim=1)

In [9]:
# define DeepLabV3+
class DeepLabV3P(nn.Module):
    def __init__(self,C=21):
        super().__init__()
        backbone=resnet101(weights=ResNet101_Weights.IMAGENET1K_V2)
        for name,param in backbone.named_parameters():
            param.requires_grad = False
        self.set_dilation(backbone)
        self.lower_f=nn.Sequential(*list(backbone.children())[:5])## 256 56 56
        self.up_f0=nn.Sequential(*list(backbone.children())[5:7])
        self.up_f1=nn.Sequential(*list(backbone.children())[7])## 2048 14 14

        self.aspp=ASPP(2048,256)
        self.up_conv=ConvBlock(256*5,256,kernel_size=1,stride=1,padding=0)
        self.down_conv=ConvBlock(256,128,kernel_size=1,stride=1,padding=0)
        self.cat_conv=ConvBlock(256+128,C,kernel_size=3,stride=1,padding=1)
    def set_dilation(self,layer,d=2):
        p=d
        s=1
        for n,m in layer.named_modules():
            if n in ["layer3.1.conv2","layer3.2.conv2","layer3.3.conv2","layer3.4.conv2","layer3.5.conv2",
                     "layer3.6.conv2","layer3.7.conv2","layer3.8.conv2","layer3.9.conv2","layer3.10.conv2",
                     "layer3.11.conv2","layer3.12.conv2","layer3.13.conv2","layer3.14.conv2","layer3.15.conv2",
                     "layer3.16.conv2","layer3.17.conv2","layer3.18.conv2","layer3.19.conv2","layer3.20.conv2",
                     "layer3.21.conv2","layer3.22.conv2","layer.4.0.conv2"]:
                m.dilation=(d,d)
                m.padding=(p,p)
            elif n in ["layer3.0.conv2"]:
                m.stride=(s,s)
            elif n in ["layer3.0.downsample.0"]:
                m.stride=(s,s)
            elif n in ["layer4.1.conv2","layer4.2.conv2"]:
                m.dilation=(2*d,2*d)
                m.padding=(2*p,2*p)
    def forward(self,x):
        lower_f=self.lower_f(x)
        up_f=self.up_f0(lower_f)
        up=self.up_f1(up_f)
        aspp=self.aspp(up)
        up_out=self.up_conv(aspp)
        up_out=F.interpolate(
            up_out,
            scale_factor=4,
            mode="bilinear",
            align_corners=True
        )

        down_out=self.down_conv(lower_f)
        cat=torch.cat((up_out,down_out),dim=1)
        out=self.cat_conv(cat)
        return F.interpolate(out,scale_factor=4,mode="bilinear",align_corners=True)

In [10]:
deeplab=DeepLabV3P()

In [11]:
## 损失函数
class FocalLoss(nn.Module):
    def __init__(self,alpha=0.25,gamma=2,ignore_dix=255, *args, **kwargs,):
        super().__init__(*args, **kwargs)
        self.alpha=alpha
        self.gamma=gamma
        self.ignore_dix=ignore_dix
    def forward(self,inputs,targets):
        predict=inputs.permute(0,2,3,1).contiguous()
        predict=torch.softmax(predict,dim=-1)
        b,c=predict.size(0),predict.size(3)
        mask=targets!=self.ignore_dix#(batch_size,h,w)
        predict=predict[mask].view(-1,c)
        targets=targets[mask].view(-1)
        one_hot=torch.eye(c,device=predict.device)
        targets=one_hot[targets].view(-1,c).float()#(predict_size,c)
        FL=((-self.alpha*((1-predict)**self.gamma))*targets*torch.log2(predict+1e-12)).sum(dim=-1)#(predict_size,)
        return FL.mean()

In [12]:
## 定义累加器
class Accumulator():
    def __init__(self,n):
        self.data=[0.0]*n
    def add(self,*args):
        self.data=[a+float(b) for a,b in zip(self.data,args)]
    def reset(self):
        self.data=[0.0]*len(self.data)
    def __getitem__(self, item):
        return self.data[item]

In [13]:
## 定义训练函数
def train(net,train_iter,val_iter,lr,num_epochs,device=None,patience=20):
    # def init_weights(m):
    #     if isinstance(m,nn.Conv2d):
    #         nn.init.normal_(m.weight.data,0,0.02)
    #         if m.bias is not None:
    #             nn.init.constant_(m.bias.data,0)
    #     elif isinstance(m,nn.BatchNorm2d):
    #         nn.init.normal_(m.weight.data,0,0.02)
    #         nn.init.constant_(m.bias.data,0)
    # net.apply(init_weights)
    history=[]
    best_val_loss=5.67e-1
    counter=0
    net.to(device)
    print('training on',device)
    loss=FocalLoss(ignore_dix=-1)
    optimizer=torch.optim.Adam(net.parameters(),lr=lr,weight_decay=5e-4,betas=(0.9,0.99))
    batch_size=len(train_iter)
    for epoch in range(num_epochs):
        net.train()
        metric=Accumulator(2)
        print(f"epoch{epoch+1}")
        for i,(X,y) in enumerate(train_iter):
            X,y=X.to(device),y.to(device)
            y=y.squeeze(1).long()
            y_hat=net(X)
            l=loss(y_hat,y)
            optimizer.zero_grad()
            l.sum().backward()
            optimizer.step()
            with torch.no_grad():
                metric.add(l,l.numel())
            if (i+1)%(batch_size//5)==0 or i==batch_size-1:
                print(f'\tloss {metric[0]/metric[1]:.5e}')
        ## 验证模式
        net.eval()
        metric2=Accumulator(2)
        with torch.no_grad():
            for X,y in val_iter:
                X,y=X.to(device),y.to(device)
                y=y.squeeze(1).long()
                y_hat=net(X)
                l2=loss(y_hat,y)
                metric2.add(l2,l2.numel())
        print(f'epoch {epoch+1} summary: loss {metric[0]/metric[1]:.5e}, val_loss {metric2[0]/metric2[1]:.5e}')
        if metric2[0]/metric2[1]<best_val_loss:
            best_val_loss=metric2[0]/metric2[1]
            torch.save(net.state_dict(),'model_best_deeplab_2.pth')
        else:
            counter+=1
            if counter>=patience:
                print('early stops')
                break
        print(f"best val_loss {best_val_loss:.5e}")
        history.append(metric[0]/metric[1])
    return history

In [None]:
## 开始训练！
# deeplab.load_state_dict(torch.load("model_best_deeplab.pth",weights_only=False))
# for name, param in deeplab.named_parameters():
#     if "lower_f" in name or "up_f0" in name or "up_f1" in name:
#         param.requires_grad = True
history=train(deeplab,train_ds,val_ds,lr=1e-5,num_epochs=40,device=try_gpu())

training on cuda:0
epoch1


In [None]:
plt.plot(history)
plt.xlabel('epoch')
plt.ylabel('loss')
plt.show()