In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from torch import optim
import torch.nn.functional as F
from glob import glob
import csv
import random
import re
import os
from PIL import Image
import numpy as np
import sklearn.metrics


def double_conv(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 3, padding=1),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, 3, padding=1),
        nn.ReLU(inplace=True)
    )   


class UNet(nn.Module):

    def __init__(self, n_class=2):
        super().__init__()
                
        self.dconv_down1 = double_conv(3, 64)
        self.dconv_down2 = double_conv(64, 128)
        self.dconv_down3 = double_conv(128, 256)
        self.dconv_down4 = double_conv(256, 512)        

        self.maxpool = nn.MaxPool2d(2)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)        
        
        self.dconv_up3 = double_conv(256 + 512, 256)
        self.dconv_up2 = double_conv(128 + 256, 128)
        self.dconv_up1 = double_conv(128 + 64, 64)
        
        self.conv_last = nn.Conv2d(64, n_class, 1)
        
        
    def forward(self, x):
        conv1 = self.dconv_down1(x)
        x = self.maxpool(conv1)

        conv2 = self.dconv_down2(x)
        x = self.maxpool(conv2)
        
        conv3 = self.dconv_down3(x)
        x = self.maxpool(conv3)   
        
        x = self.dconv_down4(x)
        
        x = self.upsample(x)        
        x = torch.cat([x, conv3], dim=1)
        
        x = self.dconv_up3(x)
        x = self.upsample(x)        
        x = torch.cat([x, conv2], dim=1)       

        x = self.dconv_up2(x)
        x = self.upsample(x)        
        x = torch.cat([x, conv1], dim=1)   
        
        x = self.dconv_up1(x)
        
        out = self.conv_last(x)
        
        return out

In [2]:
data_transform = transforms.Compose([
        #transforms.RandomSizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ])
class image_seg(Dataset):
    def __init__(self, data_dir='' ,transform=data_transform):
        self.transform = data_transform
        self.dirs =glob(data_dir) #read all images at onece
    def __len__(self):
        return len(self.dirs)
    def __getitem__(self, idx):
        input = Image.open(self.dirs[idx])
        filename = os.path.basename(self.dirs[idx])
        common_dir = os.path.dirname(os.path.dirname(self.dirs[idx]))
        #ignore the last 4 digits in the file name [:-4] tiff
        label_dir = common_dir +'/masks/'+filename[:-4] + 'png' 
        label = Image.open(label_dir).convert('L')
        input = input.resize((224,224), Image.LINEAR)
        input = np.array(input).transpose(2,0,1) #batch size first
        label = label.resize((224,224), Image.NEAREST)

        label = np.array(label)
        label[label==255] = 1
        label = torch.from_numpy(label).long()
        input = torch.from_numpy(input).float()
        return input, label

Dataset_obj_train = image_seg(data_dir ='drive/My Drive/Dataset/Amazon_Forest/Training/images/**.tiff' ,transform=data_transform)
trainloader = DataLoader(Dataset_obj_train, batch_size=4,shuffle=True, num_workers=2)

Dataset_obj_valid = image_seg(data_dir ='drive/My Drive/Dataset/Amazon_Forest/Validation/images/**.tiff',transform=data_transform)
validloader = DataLoader(Dataset_obj_valid, batch_size=4,shuffle=False, num_workers=2)



In [3]:
class image_seg_test(Dataset):
    def __init__(self, data_dir='' ,transform=data_transform):
        self.transform = data_transform
        self.dirs =glob(data_dir)
    def __len__(self):
        return len(self.dirs)
    def __getitem__(self, idx):
        input = Image.open(self.dirs[idx])
        filename = os.path.basename(self.dirs[idx])
        common_dir = os.path.dirname(os.path.dirname(self.dirs[idx]))
        input = input.resize((224,224), Image.LINEAR)
        input = np.array(input).transpose(2,0,1)
        input = torch.from_numpy(input).float()
        return input, filename[:-4]

Dataset_obj_test = image_seg_test(data_dir ='drive/My Drive/Dataset/Amazon_Forest/Test/**.tiff',transform=data_transform)
testloader = DataLoader(Dataset_obj_test, batch_size=2,shuffle=False, num_workers=2)
print(len(Dataset_obj_test))

15


In [4]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 
loss_function = torch.nn.CrossEntropyLoss()
net  = UNet(n_class=2)
optimizer = torch.optim.Adam(net.parameters())
net.to(device)

best_loss = float('inf')
best_epoch = 0
best_acc = 0
for epoch in range(50):
    totoal_loss = 0
    net.train()
    for inputs,labels in trainloader:
        inputs = inputs.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        logits = net(inputs)
        loss = loss_function(logits,labels)
        totoal_loss += loss.item()
        loss.backward()
        optimizer.step()
        

        

    print (f'Epoch: {epoch+1:02} |Loss:{totoal_loss/len(trainloader)}')

    net.eval()
    with torch.no_grad():
        for inputs,labels in validloader:
            inputs = inputs.to(device)
            labels = labels.to(device)  
            logits = net(inputs)
            loss = loss_function(logits,labels)
            

        if loss < best_loss:
            best_loss = loss
            best_epoch = epoch
            torch.save(net.state_dict(), 'best_model.pt')
            
        print("curr_val_loss:",loss.item(), "epoch:", epoch+1, "best_loss",best_loss.item(),"best_epoch",best_epoch)

Epoch: 01 |Loss:0.8186461701989174
curr_val_loss: 0.5211305022239685 epoch: 1 best_loss 0.5211305022239685 best_epoch 0
Epoch: 02 |Loss:0.4220226816833019
curr_val_loss: 0.2250683605670929 epoch: 2 best_loss 0.2250683605670929 best_epoch 1
Epoch: 03 |Loss:0.24862528964877129
curr_val_loss: 0.3878239095211029 epoch: 3 best_loss 0.2250683605670929 best_epoch 1
Epoch: 04 |Loss:0.2646574564278126
curr_val_loss: 0.2122773379087448 epoch: 4 best_loss 0.2122773379087448 best_epoch 3
Epoch: 05 |Loss:0.24206285551190376
curr_val_loss: 0.19166700541973114 epoch: 5 best_loss 0.19166700541973114 best_epoch 4
Epoch: 06 |Loss:0.20249907858669758
curr_val_loss: 0.1904958188533783 epoch: 6 best_loss 0.1904958188533783 best_epoch 5
Epoch: 07 |Loss:0.21933837048709393
curr_val_loss: 0.16270877420902252 epoch: 7 best_loss 0.16270877420902252 best_epoch 6
Epoch: 08 |Loss:0.18275585770606995
curr_val_loss: 0.18550977110862732 epoch: 8 best_loss 0.16270877420902252 best_epoch 6
Epoch: 09 |Loss:0.17152433330

In [5]:
def dice(pred, target):
    """This definition generalize to real valued pred and target vector.
This should be differentiable.
    pred: tensor with first dimension as batch
    target: tensor with first dimension as batch
    """

    smooth = 1.

    iflat = pred.contiguous().view(-1)
    tflat = target.contiguous().view(-1)
    intersection = (iflat * tflat).sum()

    A_sum = torch.sum(tflat * iflat)
    B_sum = torch.sum(tflat * tflat)
    
    return (2. * intersection + smooth) / (A_sum + B_sum + smooth) 


In [6]:
net.load_state_dict(torch.load('best_model.pt'))
dice_acc=0
for inputs,actual in validloader:
    inputs = inputs.to(device)
    actual = actual.to(device)
    #print(inputs.size())
    logit = net(inputs)
    pred = logit.argmax(dim=1)
    
    dice_ = dice(pred, actual)
    dice_acc += dice_
print (dice_acc/len(validloader))

tensor(0.9591, device='cuda:0')


In [None]:
dice_acc=0
base_dir = 'drive/My Drive/Dataset/Amazon_Forest/Test_masks'
for inputs,filename_list in testloader:
    inputs = inputs.to(device)
    logit = net(inputs)
    pred = logit.argmax(dim=1)
    for i in range (inputs.size()[0]):
        matplotlib.image.imsave(base_dir+'/'+filename_list[i]+'png', pred[i].cpu().numpy())
    print(filename_list)
    
    

('10.', '11.')
('1.', '8.')
('14.', '13.')
('9.', '3.')
('5.', '12.')
('2.', '4.')
('0.', '6.')
('7.',)


In [39]:
import matplotlib
test_img =glob('drive/My Drive/Dataset/Amazon_Forest/Test_masks/**.png')
for i in range (len(test_img)):
    test_im = Image.open(test_img[i])
    #print(test_im.size)
    re_testim = test_im.resize((512,512), Image.NEAREST)
    
    #print(re_testim.size)
    matplotlib.image.imsave('drive/My Drive/Dataset/Amazon_Forest/Test_masks_orig' + '/'+os.path.basename(test_img[i]), np.array(re_testim))