In [None]:
import torch
import math
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

import torch.utils.data.dataset as Dataset

from torch.utils.data import DataLoader
import torchvision.transforms as transforms

from PIL import Image
import torchvision.models as models

import os
os.environ["KMP_DUPLICATE_LIB_OK"]  =  "TRUE"
import numpy as np
import matplotlib.pyplot as plt

from libtiff import TIFF

from CBAM_Attention_Module import AttachAttentionModule

In [None]:
class SRNet_dataset(Dataset.Dataset):
    def __init__(self,csv_dir):
        self.csv_dir = csv_dir
       
        self.names_list = []
        self.size = 0
        self.transform = transforms.ToTensor()
        #把csv文件中的路径读进来
        if not os.path.isfile(self.csv_dir):
            print(self.csv_dir + ':text file does not exist!')
        file = open(self.csv_dir)
        for f in file:
            self.names_list.append(f)
            self.size += 1
    
    def __len__(self):
        return self.size
    
    def __getitem__(self,idx):
        #读取图像路径并打开图像
        image_path = self.names_list[idx].split(',')[0]
        img = TIFF.open(image_path,mode='r')
        #print(type(img))  # <class 'libtiff.libtiff_ctypes.TIFF'>
        image = img.read_image() 
        
        #读取标签路径并打开标签图像
        label_path = self.names_list[idx].split(',')[1]
        label = Image.open(label_path)
        
        pic_path = self.names_list[idx].split(',')[2].strip('\n')
        pic = Image.open(pic_path)

        #函数返回一个字典类型的数据，里面包括了图像和标签，并将它们转为tensor形式
        sample = {'image':image,'label':label, 'pic':pic}
        sample['image'] = self.transform(sample['image'])
        sample['label'] = torch.from_numpy(np.array(sample['label']))
        sample['pic'] = self.transform(sample['pic'])

        return sample

In [None]:
class focal_loss(nn.Module):
    def __init__(self, alpha, gamma=0, num_classes = 2, size_average=True):
        super(focal_loss,self).__init__()
        self.tempalpha = alpha
        self.size_average = size_average
        if isinstance(alpha,list):
            assert len(alpha)==num_classes   # α可以以list方式输入,size:[num_classes] 用于对不同类别精细地赋予权重
            print(" --- Focal_loss alpha = {}, 将对每一类权重进行精细化赋值 --- ".format(alpha))
            self.alpha = torch.Tensor(alpha)
        else:
            assert alpha<1   #如果α为一个常数,则降低第一类的影响,在目标检测中为第一类
            print(" --- Focal_loss alpha = {} ,将对背景类进行衰减,请在目标检测任务中使用 --- ".format(alpha))
            self.alpha = torch.zeros(num_classes)
            self.alpha[0] += alpha
            self.alpha[1:] += (1-alpha) # α 最终为 [ α, 1-α, 1-α, 1-α, 1-α, ...] size:[num_classes]

        self.gamma = gamma

    def forward(self, preds, labels):
        # assert preds.dim()==2 and labels.dim()==1
        
        preds_logsoft = F.log_softmax(preds, dim=1) # log_softmax
        #print(preds_logsoft)
        
        preds_softmax = torch.exp(preds_logsoft)    # softmax
        #print(preds_softmax)

        preds_softmax = preds_softmax.gather(1,labels.view(-1,1))   # 这部分实现nll_loss ( crossempty = log_softmax + nll )
        #print(preds_softmax)
        
        preds_logsoft = preds_logsoft.gather(1,labels.view(-1,1))
        #print(preds_logsoft)
        self.alpha = self.alpha.cuda()
        self.alpha = self.alpha.gather(0,labels.view(-1).cuda()).cuda()
        #print(self.alpha)
        
        loss = -torch.mul(torch.pow((1-preds_softmax), self.gamma), preds_logsoft)  # torch.pow((1-preds_softmax), self.gamma) 为focal loss中 (1-pt)**γ

        #print(loss)
        loss = torch.mul(self.alpha, loss.t())
        #print(loss)
        if self.size_average:
            loss = loss.mean() * (1 / self.tempalpha)
        else:
            loss = loss.sum() * (1 / self.tempalpha)
        return loss

In [None]:
class Srnet(nn.Module):
    def __init__(self):
        super(Srnet, self).__init__()
        # Layer 1
        self.layer1 = nn.Conv2d(in_channels=12, out_channels=64,
            kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.attention1 = AttachAttentionModule(64)
        
        # Layer 2
        self.layer2 = nn.Conv2d(in_channels=64, out_channels=16,
            kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(16)
        self.attention2 = AttachAttentionModule(16)
        
        # Layer 3
        self.layer31 = nn.Conv2d(in_channels=16, out_channels=16,
            kernel_size=3, stride=1, padding=1, bias=False)
        self.bn31 = nn.BatchNorm2d(16)
        self.layer32 = nn.Conv2d(in_channels=16, out_channels=16,
            kernel_size=3, stride=1, padding=1, bias=False)
        self.bn32 = nn.BatchNorm2d(16)
        self.attention3 = AttachAttentionModule(16)
        
        # Layer 4
        self.layer41 = nn.Conv2d(in_channels=16, out_channels=16,
            kernel_size=3, stride=1, padding=1, bias=False)
        self.bn41 = nn.BatchNorm2d(16)
        self.layer42 = nn.Conv2d(in_channels=16, out_channels=16,
            kernel_size=3, stride=1, padding=1, bias=False)
        self.bn42 = nn.BatchNorm2d(16)
        self.attention4 = AttachAttentionModule(16)
        
        # Layer 5
        self.layer51 = nn.Conv2d(in_channels=16, out_channels=16,
            kernel_size=3, stride=1, padding=1, bias=False)
        self.bn51 = nn.BatchNorm2d(16)
        self.layer52 = nn.Conv2d(in_channels=16, out_channels=16,
            kernel_size=3, stride=1, padding=1, bias=False)
        self.bn52 = nn.BatchNorm2d(16)
        self.attention5 = AttachAttentionModule(16)
        
        # Layer 6
        self.layer61 = nn.Conv2d(in_channels=16, out_channels=16,
            kernel_size=3, stride=1, padding=1, bias=False)
        self.bn61 = nn.BatchNorm2d(16)
        self.layer62 = nn.Conv2d(in_channels=16, out_channels=16,
            kernel_size=3, stride=1, padding=1, bias=False)
        self.bn62 = nn.BatchNorm2d(16)
        self.attention6 = AttachAttentionModule(16)
        
        # Layer 7
        self.layer71 = nn.Conv2d(in_channels=16, out_channels=16,
            kernel_size=3, stride=1, padding=1, bias=False)
        self.bn71 = nn.BatchNorm2d(16)
        self.layer72 = nn.Conv2d(in_channels=16, out_channels=16,
            kernel_size=3, stride=1, padding=1, bias=False)
        self.bn72 = nn.BatchNorm2d(16)
        self.attention7 = AttachAttentionModule(16)
        
        # Layer 8
        self.layer81 = nn.Conv2d(in_channels=16, out_channels=16,
            kernel_size=1, stride=2, padding=0, bias=False)
        self.bn81 = nn.BatchNorm2d(16)
        self.layer82 = nn.Conv2d(in_channels=16, out_channels=16,
            kernel_size=3, stride=1, padding=1, bias=False)
        self.bn82 = nn.BatchNorm2d(16)
        self.layer83 = nn.Conv2d(in_channels=16, out_channels=16,
            kernel_size=3, stride=1, padding=1, bias=False)
        self.bn83 = nn.BatchNorm2d(16)
        self.pool1 = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
        
        # Layer 9
        self.layer91 = nn.Conv2d(in_channels=16, out_channels=64,
            kernel_size=1, stride=2, padding=0, bias=False)
        self.bn91 = nn.BatchNorm2d(64)
        self.layer92 = nn.Conv2d(in_channels=16, out_channels=64,
            kernel_size=3, stride=1, padding=1, bias=False)
        self.bn92 = nn.BatchNorm2d(64)
        self.layer93 = nn.Conv2d(in_channels=64, out_channels=64,
            kernel_size=3, stride=1, padding=1, bias=False)
        self.bn93 = nn.BatchNorm2d(64)
        self.pool2 = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
        
        # Layer 10
        self.layer101 = nn.Conv2d(in_channels=64, out_channels=128,
            kernel_size=1, stride=2, padding=0, bias=False)
        self.bn101 = nn.BatchNorm2d(128)
        self.layer102 = nn.Conv2d(in_channels=64, out_channels=128,
            kernel_size=3, stride=1, padding=1, bias=False)
        self.bn102 = nn.BatchNorm2d(128)
        self.layer103 = nn.Conv2d(in_channels=128, out_channels=128,
            kernel_size=3, stride=1, padding=1, bias=False)
        self.bn103 = nn.BatchNorm2d(128)
        self.pool3 = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
        
        # Layer 11
        self.layer111 = nn.Conv2d(in_channels=128, out_channels=256,
            kernel_size=1, stride=2, padding=0, bias=False)
        self.bn111 = nn.BatchNorm2d(256)
        self.layer112 = nn.Conv2d(in_channels=128, out_channels=256,
            kernel_size=3, stride=1, padding=1, bias=False)
        self.bn112 = nn.BatchNorm2d(256)
        self.layer113 = nn.Conv2d(in_channels=256, out_channels=256,
            kernel_size=3, stride=1, padding=1, bias=False)
        self.bn113 = nn.BatchNorm2d(256)
        self.pool3 = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)

        # Layer 12
        self.layer121 = nn.Conv2d(in_channels=256, out_channels=512,
            kernel_size=3, stride=2, padding=1, bias=False)
        self.bn121 = nn.BatchNorm2d(512)
        self.layer122 = nn.Conv2d(in_channels=512, out_channels=512,
            kernel_size=3, stride=1, padding=1, bias=False)
        self.bn122 = nn.BatchNorm2d(512)
        
        self.deconv2 = nn.ConvTranspose2d(in_channels=512, out_channels=512, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
        self.debn2   = nn.BatchNorm2d(512)
        #倒数2层的反卷积[1/16 --> 1/8]
        self.deconv1 = nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
        self.debn1   = nn.BatchNorm2d(256)
        #倒数3层的反卷积[1/8 --> 1/4]
        self.deconv0_1 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
        self.debn0_1   = nn.BatchNorm2d(128)
        #倒数4层的反卷积[1/4 --> 1/2]
        self.deconv0_2 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
        self.debn0_2   = nn.BatchNorm2d(64)
        #倒数5层的反卷积[1/2 --> 1/1]
        self.deconv0_3 = nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
        self.debn0_3   = nn.BatchNorm2d(32)

        self.classifier = nn.Conv2d(32, 2, kernel_size=1)

    def forward(self, inputs):
        # Layer 1
        conv = self.layer1(inputs)
        actv = F.relu(self.bn1(conv))
        actv = self.attention1(actv)
        
        # Layer 2
        conv = self.layer2(actv)
        actv = F.relu(self.bn2(conv))
        actv = self.attention2(actv)
        
        # Layer 3
        conv1 = self.layer31(actv)
        actv1 = F.relu(self.bn31(conv1))
        conv2 = self.layer32(actv1)
        bn = self.bn32(conv2)
        res = torch.add(actv, bn)
        res = self.attention3(res)
        
        # Layer 4
        conv1 = self.layer41(res)
        actv1 = F.relu(self.bn41(conv1))
        conv2 = self.layer42(actv1)
        bn = self.bn42(conv2)
        res = torch.add(res, bn)
        res = self.attention4(res)
        
        # Layer 5
        conv1 = self.layer51(res)
        actv1 = F.relu(self.bn51(conv1))
        conv2 = self.layer52(actv1)
        bn = self.bn52(conv2)
        res = torch.add(res, bn)
        res = self.attention5(res)
        
        # Layer 6
        conv1 = self.layer61(res)
        actv1 = F.relu(self.bn61(conv1))
        conv2 = self.layer62(actv1)
        bn = self.bn62(conv2)
        res = torch.add(res, bn)
        res = self.attention6(res)
        
        # Layer 7
        conv1 = self.layer71(res)
        actv1 = F.relu(self.bn71(conv1))
        conv2 = self.layer72(actv1)
        bn = self.bn72(conv2)
        res = torch.add(res, bn)
        res = self.attention7(res)
        
        # Layer 8
        convs = self.layer81(res)
        convs = self.bn81(convs)
        conv1 = self.layer82(res)
        actv1 = F.relu(self.bn82(conv1))
        conv2 = self.layer83(actv1)
        bn = self.bn83(conv2)
        pool = self.pool1(bn)
        res = torch.add(convs, pool)
        
        # Layer 9
        convs = self.layer91(res)
        convs = self.bn91(convs)
        conv1 = self.layer92(res)
        actv1 = F.relu(self.bn92(conv1))
        conv2 = self.layer93(actv1)
        bn = self.bn93(conv2)
        pool = self.pool2(bn)
        res = torch.add(convs, pool)
        
        # Layer 10
        convs = self.layer101(res)
        convs = self.bn101(convs)
        conv1 = self.layer102(res)
        actv1 = F.relu(self.bn102(conv1))
        conv2 = self.layer103(actv1)
        bn = self.bn103(conv2)
        pool = self.pool1(bn)
        res = torch.add(convs, pool)
        
        # Layer 11
        convs = self.layer111(res)
        convs = self.bn111(convs)
        conv1 = self.layer112(res)
        actv1 = F.relu(self.bn112(conv1))
        conv2 = self.layer113(actv1)
        bn = self.bn113(conv2)
        pool = self.pool1(bn)
        res = torch.add(convs, pool)
        
        # Layer 12
        conv1 = self.layer121(res)
        actv1 = F.relu(self.bn121(conv1))
        conv2 = self.layer122(actv1)
        bn = self.bn122(conv2)
        
        #最后一层特征图反卷积，使其大小与倒数第二层特征图一致(原图1/16)
        x2_1 = F.relu(self.deconv2(bn)) 
        #倒数第二层特征图反卷积，使其大小与倒数第三层特征图一致(原图1/8)
        x1_0 = F.relu(self.deconv1(x2_1))
        
        #倒数第三层特征图连续3次反卷积，使其大小与原始图像一致(由原图1/8到1/1)
        x0_image = self.debn0_1(F.relu(self.deconv0_1(x1_0)))
        x0_image = self.debn0_2(F.relu(self.deconv0_2(x0_image)))
        x0_image = self.debn0_3(F.relu(self.deconv0_3(x0_image)))
        outputs = self.classifier(x0_image)
        return outputs

In [None]:
data1 = SRNet_dataset('train_data1.csv')
SRNet_dataloader1 = DataLoader(data1,batch_size = 16,shuffle = True)

data2 = SRNet_dataset('train_data2.csv')
SRNet_dataloader2 = DataLoader(data2,batch_size = 16,shuffle = True)

data3 = SRNet_dataset('train_data3.csv')
SRNet_dataloader3 = DataLoader(data3,batch_size = 16,shuffle = True)

#model = Srnet().cuda()
model = torch.load('')

In [None]:
#网络损失函数定义为交叉熵函数
Loss_function1 = focal_loss(alpha = 0.05).cuda()
Loss_function2 = focal_loss(alpha = 0.2).cuda()
Loss_function3 = focal_loss(alpha = 0.5).cuda()
#网络优化方法为SGD
optimizer = torch.optim.SGD(model.parameters(),lr = 0.0001,momentum = 0.9)
#os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

In [None]:
Lossvalue = 0
Accvalue = 0
count = 0
for epoch in range(1):
    for i, sample in enumerate(SRNet_dataloader1, 0):
        #images代表图像，labels代表对应的标签
        images, labels = sample['image'], sample['label']
        loss  = 0
        
        outputs = model(images)

        _, predicts = torch.max(outputs, 1)

        correct = (predicts == labels).sum().item()
        total = (predicts != labels).sum().item() + correct
        optimizer.zero_grad()
        
        loss = 0
        for l in range(16):
            tempout0 = outputs[l][0]
            tempout1 = outputs[l][1]
            
            tempout0 = tempout0.reshape(-1,1)
            tempout1 = tempout1.reshape(-1,1)
            tempoutput = torch.cat([tempout0, tempout1], 1)
            
            templabel = labels[l]
            templabel = templabel.reshape(-1,1)
            
            loss = loss + Loss_function(tempoutput, templabel.long())
                   
        loss = loss / 16
        loss.backward()
        optimizer.step()
        
        Lossvalue = Lossvalue + loss.item()
        count = count + 1
        Acc = correct / total
        Accvalue = Accvalue + Acc

        print('epoch is %d , state is 1 , %d times , loss is %f , acc is %f'%(epoch,i,loss.item(),Acc))
        
        if i == 70 or i == 140 or i == 210:
            for j, sample in enumerate(SRNet_dataloader2, 0):
                #images代表图像，labels代表对应的标签
                images, labels = sample['image'], sample['label']
                loss  = 0
        
                outputs = model(images)

                _, predicts = torch.max(outputs, 1)

                correct = (predicts == labels).sum().item()
                total = (predicts != labels).sum().item() + correct
                optimizer.zero_grad()
        
                loss = 0
                for l in range(16):
                    tempout0 = outputs[l][0]
                    tempout1 = outputs[l][1]
            
                    tempout0 = tempout0.reshape(-1,1)
                    tempout1 = tempout1.reshape(-1,1)
                    tempoutput = torch.cat([tempout0, tempout1], 1)
            
                    templabel = labels[l]
                    templabel = templabel.reshape(-1,1)
            
                    loss = loss + Loss_function(tempoutput, templabel.long())
                   
                loss = loss / 16
                loss.backward()
                optimizer.step()
        
                Lossvalue = Lossvalue + loss.item()
                count = count + 1
                Acc = correct / total
                Accvalue = Accvalue + Acc

                print('epoch is %d , state is 2 , %d times , loss is %f , acc is %f'%(epoch,j,loss.item(),Acc))
Lossvalue = Lossvalue / count
Accvalue = Accvalue / count
print('epoch is %d , All_loss is %f , All_Acc is %f'%(epoch,Lossvalue,Accvalue))
torch.save(model,'')