In [1]:
from torch import nn
import torch
from torch.utils.data import DataLoader
from DataModel import SegmentationData,base_image_path,base_imge_name,base_label_path,base_label_name,base_path

class Decoder(nn.Module):
    def __init__(self, in_channels, middle_channels, out_channels):
        super(Decoder, self).__init__()

        self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
        self.conv_relu = nn.Sequential(
            nn.Conv2d(middle_channels, out_channels, kernel_size=3,stride=1),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1),
            nn.ReLU(),
            )
  

    def forward(self, x1, x2):
        x1 = self.up(x1)
       
        x1 = torch.cat((x1, x2), dim=1)
        x1 = self.conv_relu(x1)
        return x1


    
    
class Encoder(nn.Module):
    def __init__(self, in_channels, out_channels,use_maxpool=True):
        super(Encoder, self).__init__()
        
        layers=[]
        
        if use_maxpool==True:
            layers.append(nn.MaxPool2d(kernel_size=2,stride=2))
        
        layers+=[
            nn.Conv2d(in_channels,out_channels,kernel_size=3,stride=1),
            nn.ReLU(),
            nn.Conv2d(out_channels,out_channels,kernel_size=3,stride=1),
            nn.ReLU(),
        ]
        self.net=nn.Sequential(*layers)
        
    def forward(self,X):
        
        return self.net(X)

    
encoder_params=((3,64,False),(64,128),(128,256),(256,512),(512,1024))    
decoder_params=((1024,1024,512),(512,512,256),(256,256,128),(128,128,64)) 
class U_net(nn.Module):
    def __init__(self,encoder_params,decoder_params):
        super().__init__()
        maxpools=[]
        encoders=[]
        for param in encoder_params:
            encoders.append(Encoder(*param))
            maxpools.append(nn.MaxPool2d(kernel_size=2,stride=2))
            
        decoders=[]
        for param in decoder_params:
            decoders.append(Decoder(*param))
        
        self.encoders=nn.Sequential(*encoders)
        self.decoders=nn.Sequential(*decoders)
        self.finalconv=nn.Conv2d(64,21,kernel_size=1)
       
        
        
    def forward(self,X):

        x0=self.encoders[0](X)
        #print(x0.shape)
        x1=self.encoders[1](x0)
        #print(x1.shape)
        x2=self.encoders[2](x1)
        #print(x2.shape)
        x3=self.encoders[3](x2)
        #print(x3.shape)
        x4=self.encoders[4](x3)
        #print(x4.shape)
        e3=x3[:,:,4:60,4:60]
        x5=self.decoders[0](x4,e3)
        #print(x5.shape)
        e2=x2[:,:,16:120,16:120]
        x6=self.decoders[1](x5,e2)
        #print(x6.shape)
        e1=x1[:,:,40:240,40:240]
        x7=self.decoders[2](x6,e1)
        e0=x0[:,:,98:490,98:490]
        #print(x7.shape)
        x8=self.decoders[3](x7,e0)
        return self.finalconv(x8)





In [2]:
data=SegmentationData(base_path+'train.txt',base_image_path,base_label_path)

In [3]:
train_iter=DataLoader(data,batch_size=1)

In [4]:
net=U_net(encoder_params,decoder_params)

In [5]:
def precise(test_iter,net,device):
    total=0
    correct=0
    with torch.no_grad():
        for X,y in test_iter:
            X,y=X.to(device),y.to(device)
            y_hat=net(X)
            y_hat=torch.argmax(y_hat,dim=-1).flatten()
            y=y.flatten()
            ans=(y_hat==y)
            total+=len(y)
            correct+=ans.sum().item()
        
    print(correct)
    print(f"accuracy :{correct/total*100:>3f}% ")

    
def train(data_iter,entroy_iter,net,optimizer,lr_scheduler,loss_fn,epochs,device,epoch_data_num):
    matrix_x,matrix_loss,entroy_loss,entroy_x=[0],[0],[],[]
    total_loss=0
    batchs=len(data_iter)
    for epoch in range(epochs):
        now_num=0
        for X,y in data_iter:
 
            now_num+=len(X)
    
            net.train()
            
            X,y=X.to(device),y.to(device)
            optimizer.zero_grad()
            y_hat=net(X)
            y_hat=y_hat.permute(0,2,3,1).flatten(start_dim=0,end_dim=-2)
            y=y.flatten()
            loss=loss_fn(y_hat,y)
            loss.sum().backward()
            optimizer.step()

            total_loss+=loss.item()

            matrix_x.append(matrix_x[-1]+1)
            matrix_loss.append(total_loss/(epoch*epoch_data_num+now_num))
            
            print(f"loss: {matrix_loss[-1]:>7f} now {matrix_x[-1]}/{batchs*epochs}",end='\r')
        '''
            lr_scheduler.step()
            with torch.no_grad():
                c_total_loss=0
                test_data_num=0
                for X,y in entroy_iter: 
                    net.eval()
                    test_data_num+=len(X)
                    X,y=X.to(device),y.to(device)
                    y_hat=net(X)
                    loss=loss_fn(y_hat,y)

                    c_total_loss+=loss.item()

                print(test_data_num)
                entroy_loss.append(c_total_loss/test_data_num)
                entroy_x.append((epoch+1)*batchs)
            print(f"cross entroy loss:{entroy_loss[-1]} now {epoch+1}/{epochs}")
            torch.save(net.state_dict(), f"Google_epoch{epoch+6}.bin")
            precise(test_iter,net,device)""

          '''
    
    return net,matrix_x,matrix_loss,entroy_x,entroy_loss



from torch.optim import Adam
from torch.optim.lr_scheduler import LambdaLR
device='cuda' if torch.cuda.is_available() else 'cpu'
net=U_net(encoder_params,decoder_params).to(device)
optimizer=Adam(net.parameters(),lr=0.0016)
lr_scheduler=LambdaLR(optimizer, lr_lambda=lambda epoch: 1/(2**epoch))
loss_fn=nn.CrossEntropyLoss()
'''for m in net.modules():
    if isinstance(m, (nn.Conv2d, nn.Linear)):
        nn.init.xavier_uniform_(m.weight)
'''



'for m in net.modules():\n    if isinstance(m, (nn.Conv2d, nn.Linear)):\n        nn.init.xavier_uniform_(m.weight)\n'

In [6]:
_,matrix_x,matrix_loss,entroy_x,entroy_loss=train(train_iter,train_iter,net,optimizer,lr_scheduler,loss_fn,4,device,len(data))



loss: 1.364065 now 532/5856

KeyboardInterrupt: 