In [5]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from torch import optim
import torch.nn.functional as F
from glob import glob
import csv
import random
import re
import os
from PIL import Image
import numpy as np
import sklearn.metrics


def double_conv(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 3, padding=1),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, 3, padding=1),
        nn.ReLU(inplace=True)
    )   


class UNet(nn.Module):

    def __init__(self, n_class=2):
        super().__init__()
                
        self.dconv_down1 = double_conv(3, 64)
        self.dconv_down2 = double_conv(64, 128)
        self.dconv_down3 = double_conv(128, 256)
        self.dconv_down4 = double_conv(256, 512)        

        self.maxpool = nn.MaxPool2d(2)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)        
        
        self.dconv_up3 = double_conv(256 + 512, 256)
        self.dconv_up2 = double_conv(128 + 256, 128)
        self.dconv_up1 = double_conv(128 + 64, 64)
        
        self.conv_last = nn.Conv2d(64, n_class, 1)
        
        
    def forward(self, x):
        conv1 = self.dconv_down1(x)
        x = self.maxpool(conv1)

        conv2 = self.dconv_down2(x)
        x = self.maxpool(conv2)
        
        conv3 = self.dconv_down3(x)
        x = self.maxpool(conv3)   
        
        x = self.dconv_down4(x)
        
        x = self.upsample(x)        
        x = torch.cat([x, conv3], dim=1)
        
        x = self.dconv_up3(x)
        x = self.upsample(x)        
        x = torch.cat([x, conv2], dim=1)       

        x = self.dconv_up2(x)
        x = self.upsample(x)        
        x = torch.cat([x, conv1], dim=1)   
        
        x = self.dconv_up1(x)
        
        out = self.conv_last(x)
        
        return out

In [2]:
data_transform = transforms.Compose([
        #transforms.RandomSizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ])
class image_seg(Dataset):
    def __init__(self, data_dir='' ,transform=data_transform):
        self.transform = data_transform
        self.dirs =glob(data_dir)
    def __len__(self):
        return len(self.dirs)
    def __getitem__(self, idx):
        input = Image.open(self.dirs[idx])
        filename = os.path.basename(self.dirs[idx])
        common_dir = os.path.dirname(os.path.dirname(self.dirs[idx]))
        label_dir = common_dir +'/masks/'+filename[:-4] + 'png'
        label = Image.open(label_dir).convert('L')
        input = self.transform(input)
        input = torch.from_numpy(np.array(input)[:,0:224,0:224]).float()
        label = np.array(label)[0:224,0:224]
        label[label==255] = 1
        label = torch.from_numpy(label).long()
        return input, label

Dataset_obj_train = image_seg(data_dir ='drive/My Drive/Dataset/Amazon_Forest/Training/images/**.tiff' ,transform=data_transform)
trainloader = DataLoader(Dataset_obj_train, batch_size=4,shuffle=True, num_workers=2)

Dataset_obj_valid = image_seg(data_dir ='drive/My Drive/Dataset/Amazon_Forest/Validation/images/**.tiff',transform=data_transform)
validloader = DataLoader(Dataset_obj_valid, batch_size=4,shuffle=False, num_workers=2)


In [3]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 
loss_function = torch.nn.CrossEntropyLoss()
net  = UNet(n_class=2)
optimizer = torch.optim.Adam(net.parameters())
net.to(device)

best_loss = float('inf')
best_epoch = 0
best_acc = 0
for epoch in range(50):
    totoal_loss = 0
    net.train()
    for inputs,labels in trainloader:
        inputs = inputs.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        logits = net(inputs)
        loss = loss_function(logits,labels)
        totoal_loss += loss.item()
        loss.backward()
        optimizer.step()
        

        

    print (f'Epoch: {epoch+1:02} |Loss:{totoal_loss/len(trainloader)}')

    net.eval()
    with torch.no_grad():
        for inputs,labels in validloader:
            inputs = inputs.to(device)
            labels = labels.to(device)  
            logits = net(inputs)
            loss = loss_function(logits,labels)
            

        if loss < best_loss:
            best_loss = loss
            best_epoch = epoch
            torch.save(net.state_dict(), 'best_model.pt')
            
        print("curr_val_loss:",loss.item(), "epoch:", epoch+1, "best_loss",best_loss.item(),"best_epoch",best_epoch)

Epoch: 01 |Loss:0.7110781520605087
curr_val_loss: 0.6905110478401184 epoch: 1 best_loss 0.6905110478401184 best_epoch 0
Epoch: 02 |Loss:0.6909127980470657
curr_val_loss: 0.6771504282951355 epoch: 2 best_loss 0.6771504282951355 best_epoch 1
Epoch: 03 |Loss:0.6782530024647713
curr_val_loss: 0.6047186851501465 epoch: 3 best_loss 0.6047186851501465 best_epoch 2
Epoch: 04 |Loss:0.5606506615877151
curr_val_loss: 0.47801125049591064 epoch: 4 best_loss 0.47801125049591064 best_epoch 3
Epoch: 05 |Loss:0.4991947151720524
curr_val_loss: 1.2758880853652954 epoch: 5 best_loss 0.47801125049591064 best_epoch 3
Epoch: 06 |Loss:0.483431413769722
curr_val_loss: 0.5682178139686584 epoch: 6 best_loss 0.47801125049591064 best_epoch 3
Epoch: 07 |Loss:0.46405326947569847
curr_val_loss: 0.40947771072387695 epoch: 7 best_loss 0.40947771072387695 best_epoch 6
Epoch: 08 |Loss:0.4857613258063793
curr_val_loss: 0.32758867740631104 epoch: 8 best_loss 0.32758867740631104 best_epoch 7
Epoch: 09 |Loss:0.47301811352372

In [18]:
net.load_state_dict(torch.load('best_model.pt'))

for inputs,actual in validloader:
    inputs = inputs.to(device)
    print(inputs.size())
    logit = net(inputs)
    pred = logit.argmax(dim=1)
    print(logit.size())
    print(actual.size())
    print(pred.size())
    break
    
    

torch.Size([4, 3, 224, 224])
torch.Size([4, 2, 224, 224])
torch.Size([4, 224, 224])
torch.Size([4, 224, 224])


In [17]:
actual[0].unique()

tensor([0, 1])

In [19]:
pred.unique()

tensor([0, 1], device='cuda:0')

In [20]:
A=pred[0]

torch.Size([224, 224])

In [36]:
import matplotlib
A = pred[1].cpu().numpy() 
B = actual[1].cpu().numpy()
C = inputs[1].cpu().numpy() 
C = np.transpose(C, (1, 2, 0))
print(C.shape) 
matplotlib.image.imsave('pred.png', A)
matplotlib.image.imsave('actual.png', B)
matplotlib.image.imsave('inputs.png', C)

(224, 224, 3)
