In [2]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import os
import matplotlib.pyplot as plt
import cv2
from tqdm import tqdm


import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms
import torch.nn as nn
from torch.nn import functional as F
from torch.autograd import Variable


In [2]:
TRAIN_DIR = "data/raw/deepglobe-2018-dataset/train"
VALID_DIR = "data/raw/deepglobe-2018-dataset/valid"
TEST_DIR = "data/raw/deepglobe-2018-dataset/test"
COLOR_CODES = "data/raw/deepglobe-2018-dataset/class_dict.csv"

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((256, 256)),
])

In [3]:
df = pd.read_csv(COLOR_CODES)
label_map = {}
for index, row in df.iterrows():
    label_map[index] = [row["r"],row["g"],row["b"]]
label_map

{0: [0, 255, 255],
 1: [255, 255, 0],
 2: [255, 0, 255],
 3: [0, 255, 0],
 4: [0, 0, 255],
 5: [255, 255, 255],
 6: [0, 0, 0]}

In [4]:
class Segmentation_Dataset(Dataset):
    def __init__(self, image_dir, label_map, transform):
        self.image_dir = image_dir
        self.transform = transform
        self.label_map = label_map
        self.images_name = sorted([filename for filename in os.listdir(self.image_dir) if filename.endswith('_sat.jpg')])
        self.targets_name = sorted([filename for filename in os.listdir(self.image_dir) if filename.endswith('_mask.png')])
        
    def __len__(self):
        return len(self.images_name)

    def __getitem__(self, idx):
        image_path = os.path.join(self.image_dir, self.images_name[idx])
        mask_path = os.path.join(self.image_dir, self.targets_name[idx])
        
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # convert RGB
        
        mask = cv2.imread(mask_path)
        mask = cv2.cvtColor(mask, cv2.COLOR_BGR2RGB)
        mask = self.colormap_to_labelmap(mask)
        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)
            
        return image, mask
    
    def colormap_to_labelmap(self, mask):
        label_image = np.zeros_like(mask[:,:,0], dtype=np.uint8)

        for label, color in self.label_map.items():
            color_array = np.array(color)
            mask_condition = np.all(mask == color_array, axis=-1)
            label_image[mask_condition] = label

        return label_image.astype(np.float32)

In [5]:
"""img, mask, bla = Segmentation_Dataset(TRAIN_DIR, label_map).__getitem__(0)
print(bla)
sample = [img, mask]
for i in range(len(sample)):
    plt.subplot(1, 2, i+1)
    plt.imshow(sample[i])
    plt.axis('off')
plt.show()"""

"img, mask, bla = Segmentation_Dataset(TRAIN_DIR, label_map).__getitem__(0)\nprint(bla)\nsample = [img, mask]\nfor i in range(len(sample)):\n    plt.subplot(1, 2, i+1)\n    plt.imshow(sample[i])\n    plt.axis('off')\nplt.show()"

In [6]:
class ResidualBlock(nn.Module):
    def __init__(self, in_c=3, out_c=1):
        super(ResidualBlock, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_c, out_c, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_c),
            nn.LeakyReLU(negative_slope=0.1, inplace=True),
            nn.Conv2d(out_c, out_c, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_c),
            nn.LeakyReLU(negative_slope=0.1, inplace=True)
        )
        self.skip = nn.Sequential(
            nn.Conv2d(in_c, out_c, kernel_size=1, stride=1, padding=0),
            nn.BatchNorm2d(out_c),
            nn.LeakyReLU(negative_slope=0.1, inplace=True)
        )

    def forward(self, x):
        conv = self.conv(x)
        skip = self.skip(x)

        plus = conv + skip
        return plus

class UNet(nn.Module):

    def __init__(self, n_class):
        super().__init__()
                
        self.dconv_down1 = ResidualBlock(3, 64)
        self.dconv_down2 = ResidualBlock(64, 128)
        self.dconv_down3 = ResidualBlock(128, 256)
        self.dconv_down4 = ResidualBlock(256,512)
        self.bottleneck = ResidualBlock(512, 1024)        

        self.maxpool = nn.MaxPool2d(2)

        self.dconv1 =  nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)        
        self.dconv2 =  nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.dconv3 =  nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.dconv4 =  nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)

        self.dconv_up4 = ResidualBlock(1024, 512) 
        self.dconv_up3 = ResidualBlock(512, 256)
        self.dconv_up2 = ResidualBlock(256, 128)
        self.dconv_up1 = ResidualBlock(128, 64)
        self.conv_last = nn.Conv2d(64, n_class, 1)
        
        
    def forward(self, x):

        # encoder
        conv1 = self.dconv_down1(x)
        x1 = self.maxpool(conv1) 

        conv2 = self.dconv_down2(x1)
        x2 = self.maxpool(conv2) 
        
        conv3 = self.dconv_down3(x2)
        x3 = self.maxpool(conv3) 

        conv4 = self.dconv_down4(x3)
        x4 = self.maxpool(conv4)         

        x5 = self.bottleneck(x4) 


        x = self.dconv1(x5)        
        x = torch.cat([x, conv4], dim=1)
        x = self.dconv_up4(x)

        x = self.dconv2(x)        
        x = torch.cat([x, conv3], dim=1)       
        x = self.dconv_up3(x)

        x = self.dconv3(x)        
        x = torch.cat([x, conv2], dim=1)   
        x = self.dconv_up2(x)

        x = self.dconv4(x)        
        x = torch.cat([x, conv1], dim=1)   
        x = self.dconv_up1(x)
        
        seg = self.conv_last(x)

        return seg


In [7]:
class FocalLoss(nn.Module):
    def __init__(self, gamma=0, alpha=None, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.reduction = reduction
        self.gamma = gamma
        self.alpha = alpha
        if isinstance(alpha, (float, int, torch.LongTensor)): self.alpha = torch.Tensor([alpha, 1 - alpha])
        if isinstance(alpha, list): self.alpha = torch.Tensor(alpha)

    def forward(self, input, target):
        if input.dim() > 2:
            input = input.view(input.size(0), input.size(1), -1)  # N,C,H,W => N,C,H*W
            input = input.transpose(1, 2)  # N,C,H*W => N,H*W,C
            input = input.contiguous().view(-1, input.size(2))  # N,H*W,C => N*H*W,C
        target = target.view(-1, 1)

        logpt = F.log_softmax(input)
        logpt = logpt.gather(1, target)
        logpt = logpt.view(-1)
        pt = Variable(logpt.data.exp())

        if self.alpha is not None:
            if self.alpha.type() != input.data.type():
                self.alpha = self.alpha.type_as(input.data)
            at = self.alpha.gather(0, target.data.view(-1))
            logpt = logpt * Variable(at)

        loss = -1 * (1 - pt) ** self.gamma * logpt
        if self.reduction == "none":
            return loss
        elif self.reduction == "mean":
            return loss.mean()
        elif self.reduction == "sum":
            return loss.sum()

In [8]:
class Train():
    def __init__(self, model, dataloader, optimizer, num_epoch, device, loss):
        self.model = model
        self.dataloader = dataloader
        self.optimizer = optimizer
        self.num_epoch = num_epoch
        self.device = device
        self.loss = loss
    def train(self):
        self.model.to(self.device)
        self.model.train()
        for epoch in range(self.num_epoch):
            total_loss = 0.0
            total_samples = 0
            with tqdm(self.dataloader, unit="batch") as dl:
                for inputs, labels in dl:
                    inputs, labels = inputs.to(self.device), labels.to(self.device)
                    predictions = self.model(inputs)
                    labels = labels.squeeze().long()
                    loss = self.loss(predictions, labels)
                    self.optimizer.zero_grad()
                    loss.backward()
                    self.optimizer.step()
                    total_loss += loss.item() * inputs.size(0)
                    total_samples += inputs.size(0)
                    dl.set_postfix({"Epoch": epoch + 1, "Loss": total_loss / total_samples})
            epoch_loss = total_loss / total_samples
            print(f"Epoch [{epoch + 1}/{self.num_epoch}], Loss: {epoch_loss:.4f}")

In [9]:
num_classes = len(label_map.keys())
lr = 0.001
batch_size = 4
model = UNet(n_class=num_classes)
train_dataset = Segmentation_Dataset(TRAIN_DIR, label_map, transform)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
optimizer = torch.optim.AdamW(model.parameters(), lr)
num_epochs = 200
device =  torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
loss = FocalLoss()

In [10]:
trainer = Train(model, train_dataloader, optimizer, num_epochs, device, loss)
trainer.train()

  logpt = F.log_softmax(input)
100%|██████████| 201/201 [17:56<00:00,  5.35s/batch, Epoch=1, Loss=1.15]


Epoch [1/200], Loss: 1.1470


100%|██████████| 201/201 [15:52<00:00,  4.74s/batch, Epoch=2, Loss=0.999]


Epoch [2/200], Loss: 0.9989


100%|██████████| 201/201 [15:50<00:00,  4.73s/batch, Epoch=3, Loss=0.954]


Epoch [3/200], Loss: 0.9536


100%|██████████| 201/201 [24:01<00:00,  7.17s/batch, Epoch=4, Loss=0.896] 


Epoch [4/200], Loss: 0.8960


100%|██████████| 201/201 [41:41<00:00, 12.45s/batch, Epoch=5, Loss=0.861]


Epoch [5/200], Loss: 0.8606


100%|██████████| 201/201 [21:06<00:00,  6.30s/batch, Epoch=6, Loss=0.836]


Epoch [6/200], Loss: 0.8365


100%|██████████| 201/201 [17:53<00:00,  5.34s/batch, Epoch=7, Loss=0.806]


Epoch [7/200], Loss: 0.8064


100%|██████████| 201/201 [17:36<00:00,  5.25s/batch, Epoch=8, Loss=0.772]


Epoch [8/200], Loss: 0.7715


100%|██████████| 201/201 [17:03<00:00,  5.09s/batch, Epoch=9, Loss=0.755]


Epoch [9/200], Loss: 0.7546


100%|██████████| 201/201 [16:47<00:00,  5.01s/batch, Epoch=10, Loss=0.722]


Epoch [10/200], Loss: 0.7224


100%|██████████| 201/201 [16:41<00:00,  4.98s/batch, Epoch=11, Loss=0.723]


Epoch [11/200], Loss: 0.7231


100%|██████████| 201/201 [15:34<00:00,  4.65s/batch, Epoch=12, Loss=0.7]  


Epoch [12/200], Loss: 0.7000


100%|██████████| 201/201 [14:38<00:00,  4.37s/batch, Epoch=13, Loss=0.664]


Epoch [13/200], Loss: 0.6637


100%|██████████| 201/201 [14:31<00:00,  4.33s/batch, Epoch=14, Loss=0.672]


Epoch [14/200], Loss: 0.6720


100%|██████████| 201/201 [14:24<00:00,  4.30s/batch, Epoch=15, Loss=0.649]


Epoch [15/200], Loss: 0.6493


100%|██████████| 201/201 [14:40<00:00,  4.38s/batch, Epoch=16, Loss=0.659]


Epoch [16/200], Loss: 0.6586


100%|██████████| 201/201 [14:51<00:00,  4.43s/batch, Epoch=17, Loss=0.64] 


Epoch [17/200], Loss: 0.6401


100%|██████████| 201/201 [14:31<00:00,  4.34s/batch, Epoch=18, Loss=0.645]


Epoch [18/200], Loss: 0.6447


100%|██████████| 201/201 [14:33<00:00,  4.34s/batch, Epoch=19, Loss=0.62] 


Epoch [19/200], Loss: 0.6201


100%|██████████| 201/201 [14:30<00:00,  4.33s/batch, Epoch=20, Loss=0.615]


Epoch [20/200], Loss: 0.6146


100%|██████████| 201/201 [14:30<00:00,  4.33s/batch, Epoch=21, Loss=0.62] 


Epoch [21/200], Loss: 0.6203


100%|██████████| 201/201 [14:35<00:00,  4.35s/batch, Epoch=22, Loss=0.61] 


Epoch [22/200], Loss: 0.6101


100%|██████████| 201/201 [14:36<00:00,  4.36s/batch, Epoch=23, Loss=0.588]


Epoch [23/200], Loss: 0.5879


100%|██████████| 201/201 [14:37<00:00,  4.37s/batch, Epoch=24, Loss=0.612]


Epoch [24/200], Loss: 0.6123


100%|██████████| 201/201 [14:38<00:00,  4.37s/batch, Epoch=25, Loss=0.564]


Epoch [25/200], Loss: 0.5643


100%|██████████| 201/201 [14:37<00:00,  4.37s/batch, Epoch=26, Loss=0.587]


Epoch [26/200], Loss: 0.5868


100%|██████████| 201/201 [14:37<00:00,  4.37s/batch, Epoch=27, Loss=0.597]


Epoch [27/200], Loss: 0.5970


100%|██████████| 201/201 [14:54<00:00,  4.45s/batch, Epoch=28, Loss=0.584]


Epoch [28/200], Loss: 0.5842


100%|██████████| 201/201 [14:57<00:00,  4.47s/batch, Epoch=29, Loss=0.595]


Epoch [29/200], Loss: 0.5954


100%|██████████| 201/201 [14:55<00:00,  4.46s/batch, Epoch=30, Loss=0.554]


Epoch [30/200], Loss: 0.5543


100%|██████████| 201/201 [14:40<00:00,  4.38s/batch, Epoch=31, Loss=0.55] 


Epoch [31/200], Loss: 0.5504


100%|██████████| 201/201 [14:45<00:00,  4.41s/batch, Epoch=32, Loss=0.574]


Epoch [32/200], Loss: 0.5737


100%|██████████| 201/201 [14:40<00:00,  4.38s/batch, Epoch=33, Loss=0.54] 


Epoch [33/200], Loss: 0.5403


100%|██████████| 201/201 [14:37<00:00,  4.36s/batch, Epoch=34, Loss=0.552]


Epoch [34/200], Loss: 0.5524


100%|██████████| 201/201 [14:36<00:00,  4.36s/batch, Epoch=35, Loss=0.54] 


Epoch [35/200], Loss: 0.5395


100%|██████████| 201/201 [14:37<00:00,  4.37s/batch, Epoch=36, Loss=0.56] 


Epoch [36/200], Loss: 0.5603


100%|██████████| 201/201 [14:36<00:00,  4.36s/batch, Epoch=37, Loss=0.547]


Epoch [37/200], Loss: 0.5471


100%|██████████| 201/201 [14:40<00:00,  4.38s/batch, Epoch=38, Loss=0.52] 


Epoch [38/200], Loss: 0.5204


100%|██████████| 201/201 [14:32<00:00,  4.34s/batch, Epoch=39, Loss=0.54] 


Epoch [39/200], Loss: 0.5401


100%|██████████| 201/201 [14:51<00:00,  4.44s/batch, Epoch=40, Loss=0.528]


Epoch [40/200], Loss: 0.5281


100%|██████████| 201/201 [14:39<00:00,  4.37s/batch, Epoch=41, Loss=0.518]


Epoch [41/200], Loss: 0.5180


100%|██████████| 201/201 [14:41<00:00,  4.38s/batch, Epoch=42, Loss=0.512]


Epoch [42/200], Loss: 0.5117


100%|██████████| 201/201 [14:41<00:00,  4.39s/batch, Epoch=43, Loss=0.517]


Epoch [43/200], Loss: 0.5170


100%|██████████| 201/201 [14:40<00:00,  4.38s/batch, Epoch=44, Loss=0.516]


Epoch [44/200], Loss: 0.5161


100%|██████████| 201/201 [14:39<00:00,  4.38s/batch, Epoch=45, Loss=0.516]


Epoch [45/200], Loss: 0.5159


100%|██████████| 201/201 [14:41<00:00,  4.38s/batch, Epoch=46, Loss=0.528]


Epoch [46/200], Loss: 0.5278


100%|██████████| 201/201 [14:38<00:00,  4.37s/batch, Epoch=47, Loss=0.488]


Epoch [47/200], Loss: 0.4879


100%|██████████| 201/201 [14:41<00:00,  4.39s/batch, Epoch=48, Loss=0.503]


Epoch [48/200], Loss: 0.5034


100%|██████████| 201/201 [14:43<00:00,  4.40s/batch, Epoch=49, Loss=0.498]


Epoch [49/200], Loss: 0.4984


100%|██████████| 201/201 [14:52<00:00,  4.44s/batch, Epoch=50, Loss=0.5]  


Epoch [50/200], Loss: 0.5001


100%|██████████| 201/201 [15:17<00:00,  4.56s/batch, Epoch=51, Loss=0.482]


Epoch [51/200], Loss: 0.4821


100%|██████████| 201/201 [15:15<00:00,  4.56s/batch, Epoch=52, Loss=0.492]


Epoch [52/200], Loss: 0.4917


100%|██████████| 201/201 [15:19<00:00,  4.58s/batch, Epoch=53, Loss=0.488]


Epoch [53/200], Loss: 0.4879


100%|██████████| 201/201 [15:12<00:00,  4.54s/batch, Epoch=54, Loss=0.48] 


Epoch [54/200], Loss: 0.4796


100%|██████████| 201/201 [14:48<00:00,  4.42s/batch, Epoch=55, Loss=0.459]


Epoch [55/200], Loss: 0.4595


100%|██████████| 201/201 [14:42<00:00,  4.39s/batch, Epoch=56, Loss=0.471]


Epoch [56/200], Loss: 0.4714


100%|██████████| 201/201 [14:29<00:00,  4.33s/batch, Epoch=57, Loss=0.471]


Epoch [57/200], Loss: 0.4710


100%|██████████| 201/201 [14:30<00:00,  4.33s/batch, Epoch=58, Loss=0.467]


Epoch [58/200], Loss: 0.4666


100%|██████████| 201/201 [14:29<00:00,  4.32s/batch, Epoch=59, Loss=0.448]


Epoch [59/200], Loss: 0.4476


100%|██████████| 201/201 [14:30<00:00,  4.33s/batch, Epoch=60, Loss=0.468]


Epoch [60/200], Loss: 0.4678


100%|██████████| 201/201 [14:32<00:00,  4.34s/batch, Epoch=61, Loss=0.445]


Epoch [61/200], Loss: 0.4453


100%|██████████| 201/201 [14:31<00:00,  4.34s/batch, Epoch=62, Loss=0.493]


Epoch [62/200], Loss: 0.4930


100%|██████████| 201/201 [14:30<00:00,  4.33s/batch, Epoch=63, Loss=0.454]


Epoch [63/200], Loss: 0.4544


100%|██████████| 201/201 [14:32<00:00,  4.34s/batch, Epoch=64, Loss=0.435]


Epoch [64/200], Loss: 0.4352


100%|██████████| 201/201 [14:31<00:00,  4.34s/batch, Epoch=65, Loss=0.438]


Epoch [65/200], Loss: 0.4379


100%|██████████| 201/201 [14:35<00:00,  4.36s/batch, Epoch=66, Loss=0.473]


Epoch [66/200], Loss: 0.4730


100%|██████████| 201/201 [14:42<00:00,  4.39s/batch, Epoch=67, Loss=0.447]


Epoch [67/200], Loss: 0.4473


100%|██████████| 201/201 [14:36<00:00,  4.36s/batch, Epoch=68, Loss=0.426]


Epoch [68/200], Loss: 0.4264


100%|██████████| 201/201 [14:38<00:00,  4.37s/batch, Epoch=69, Loss=0.428]


Epoch [69/200], Loss: 0.4282


100%|██████████| 201/201 [14:37<00:00,  4.36s/batch, Epoch=70, Loss=0.425]


Epoch [70/200], Loss: 0.4254


100%|██████████| 201/201 [14:35<00:00,  4.35s/batch, Epoch=71, Loss=0.417]


Epoch [71/200], Loss: 0.4171


100%|██████████| 201/201 [14:38<00:00,  4.37s/batch, Epoch=72, Loss=0.406]


Epoch [72/200], Loss: 0.4061


100%|██████████| 201/201 [14:34<00:00,  4.35s/batch, Epoch=73, Loss=0.409]


Epoch [73/200], Loss: 0.4094


100%|██████████| 201/201 [14:36<00:00,  4.36s/batch, Epoch=74, Loss=0.412]


Epoch [74/200], Loss: 0.4125


100%|██████████| 201/201 [14:37<00:00,  4.36s/batch, Epoch=75, Loss=0.392]


Epoch [75/200], Loss: 0.3925


100%|██████████| 201/201 [14:37<00:00,  4.36s/batch, Epoch=76, Loss=0.419]


Epoch [76/200], Loss: 0.4186


100%|██████████| 201/201 [14:52<00:00,  4.44s/batch, Epoch=77, Loss=0.391]


Epoch [77/200], Loss: 0.3907


100%|██████████| 201/201 [14:42<00:00,  4.39s/batch, Epoch=78, Loss=0.382]


Epoch [78/200], Loss: 0.3822


100%|██████████| 201/201 [15:04<00:00,  4.50s/batch, Epoch=79, Loss=0.399]


Epoch [79/200], Loss: 0.3990


100%|██████████| 201/201 [14:55<00:00,  4.45s/batch, Epoch=80, Loss=0.396]


Epoch [80/200], Loss: 0.3958


100%|██████████| 201/201 [14:48<00:00,  4.42s/batch, Epoch=81, Loss=0.37] 


Epoch [81/200], Loss: 0.3704


100%|██████████| 201/201 [15:06<00:00,  4.51s/batch, Epoch=82, Loss=0.401]


Epoch [82/200], Loss: 0.4015


100%|██████████| 201/201 [14:56<00:00,  4.46s/batch, Epoch=83, Loss=0.434]


Epoch [83/200], Loss: 0.4344


100%|██████████| 201/201 [14:53<00:00,  4.45s/batch, Epoch=84, Loss=0.389]


Epoch [84/200], Loss: 0.3892


100%|██████████| 201/201 [14:40<00:00,  4.38s/batch, Epoch=85, Loss=0.39] 


Epoch [85/200], Loss: 0.3896


100%|██████████| 201/201 [14:45<00:00,  4.41s/batch, Epoch=86, Loss=0.359]


Epoch [86/200], Loss: 0.3594


100%|██████████| 201/201 [14:41<00:00,  4.38s/batch, Epoch=87, Loss=0.375]


Epoch [87/200], Loss: 0.3746


100%|██████████| 201/201 [14:46<00:00,  4.41s/batch, Epoch=88, Loss=0.35] 


Epoch [88/200], Loss: 0.3499


100%|██████████| 201/201 [14:48<00:00,  4.42s/batch, Epoch=89, Loss=0.349]


Epoch [89/200], Loss: 0.3486


100%|██████████| 201/201 [14:52<00:00,  4.44s/batch, Epoch=90, Loss=0.38] 


Epoch [90/200], Loss: 0.3803


100%|██████████| 201/201 [14:48<00:00,  4.42s/batch, Epoch=91, Loss=0.336]


Epoch [91/200], Loss: 0.3365


100%|██████████| 201/201 [14:47<00:00,  4.42s/batch, Epoch=92, Loss=0.342]


Epoch [92/200], Loss: 0.3421


100%|██████████| 201/201 [14:55<00:00,  4.45s/batch, Epoch=93, Loss=0.35] 


Epoch [93/200], Loss: 0.3502


100%|██████████| 201/201 [14:58<00:00,  4.47s/batch, Epoch=94, Loss=0.334]


Epoch [94/200], Loss: 0.3338


100%|██████████| 201/201 [14:49<00:00,  4.43s/batch, Epoch=95, Loss=0.339]


Epoch [95/200], Loss: 0.3387


100%|██████████| 201/201 [14:42<00:00,  4.39s/batch, Epoch=96, Loss=0.309]


Epoch [96/200], Loss: 0.3089


100%|██████████| 201/201 [14:41<00:00,  4.38s/batch, Epoch=97, Loss=0.311]


Epoch [97/200], Loss: 0.3113


100%|██████████| 201/201 [14:44<00:00,  4.40s/batch, Epoch=98, Loss=0.315]


Epoch [98/200], Loss: 0.3155


100%|██████████| 201/201 [14:44<00:00,  4.40s/batch, Epoch=99, Loss=0.332]


Epoch [99/200], Loss: 0.3319


100%|██████████| 201/201 [14:44<00:00,  4.40s/batch, Epoch=100, Loss=0.306]


Epoch [100/200], Loss: 0.3061


100%|██████████| 201/201 [14:45<00:00,  4.41s/batch, Epoch=101, Loss=0.308]


Epoch [101/200], Loss: 0.3076


100%|██████████| 201/201 [14:50<00:00,  4.43s/batch, Epoch=102, Loss=0.302]


Epoch [102/200], Loss: 0.3019


100%|██████████| 201/201 [14:51<00:00,  4.43s/batch, Epoch=103, Loss=0.301]


Epoch [103/200], Loss: 0.3013


100%|██████████| 201/201 [14:44<00:00,  4.40s/batch, Epoch=104, Loss=0.279]


Epoch [104/200], Loss: 0.2792


100%|██████████| 201/201 [14:44<00:00,  4.40s/batch, Epoch=105, Loss=0.299]


Epoch [105/200], Loss: 0.2993


100%|██████████| 201/201 [14:45<00:00,  4.40s/batch, Epoch=106, Loss=0.343]


Epoch [106/200], Loss: 0.3431


100%|██████████| 201/201 [14:40<00:00,  4.38s/batch, Epoch=107, Loss=0.348]


Epoch [107/200], Loss: 0.3480


100%|██████████| 201/201 [14:39<00:00,  4.38s/batch, Epoch=108, Loss=0.265]


Epoch [108/200], Loss: 0.2652


100%|██████████| 201/201 [14:41<00:00,  4.39s/batch, Epoch=109, Loss=0.278]


Epoch [109/200], Loss: 0.2784


100%|██████████| 201/201 [14:54<00:00,  4.45s/batch, Epoch=110, Loss=0.262]


Epoch [110/200], Loss: 0.2619


100%|██████████| 201/201 [14:56<00:00,  4.46s/batch, Epoch=111, Loss=0.257]


Epoch [111/200], Loss: 0.2566


100%|██████████| 201/201 [14:48<00:00,  4.42s/batch, Epoch=112, Loss=0.255]


Epoch [112/200], Loss: 0.2553


100%|██████████| 201/201 [14:43<00:00,  4.40s/batch, Epoch=113, Loss=0.265]


Epoch [113/200], Loss: 0.2653


100%|██████████| 201/201 [14:46<00:00,  4.41s/batch, Epoch=114, Loss=0.323]


Epoch [114/200], Loss: 0.3229


100%|██████████| 201/201 [14:42<00:00,  4.39s/batch, Epoch=115, Loss=0.295]


Epoch [115/200], Loss: 0.2952


100%|██████████| 201/201 [14:40<00:00,  4.38s/batch, Epoch=116, Loss=0.25] 


Epoch [116/200], Loss: 0.2504


100%|██████████| 201/201 [14:41<00:00,  4.38s/batch, Epoch=117, Loss=0.26] 


Epoch [117/200], Loss: 0.2603


100%|██████████| 201/201 [14:41<00:00,  4.38s/batch, Epoch=118, Loss=0.253]


Epoch [118/200], Loss: 0.2533


100%|██████████| 201/201 [14:42<00:00,  4.39s/batch, Epoch=119, Loss=0.274]


Epoch [119/200], Loss: 0.2739


100%|██████████| 201/201 [15:07<00:00,  4.51s/batch, Epoch=120, Loss=0.266]


Epoch [120/200], Loss: 0.2657


100%|██████████| 201/201 [14:40<00:00,  4.38s/batch, Epoch=121, Loss=0.23] 


Epoch [121/200], Loss: 0.2301


100%|██████████| 201/201 [14:44<00:00,  4.40s/batch, Epoch=122, Loss=0.226]


Epoch [122/200], Loss: 0.2260


100%|██████████| 201/201 [14:48<00:00,  4.42s/batch, Epoch=123, Loss=0.235]


Epoch [123/200], Loss: 0.2353


100%|██████████| 201/201 [14:54<00:00,  4.45s/batch, Epoch=124, Loss=0.241]


Epoch [124/200], Loss: 0.2412


100%|██████████| 201/201 [14:49<00:00,  4.43s/batch, Epoch=125, Loss=0.222]


Epoch [125/200], Loss: 0.2224


100%|██████████| 201/201 [14:41<00:00,  4.39s/batch, Epoch=126, Loss=0.312]


Epoch [126/200], Loss: 0.3124


100%|██████████| 201/201 [14:41<00:00,  4.39s/batch, Epoch=127, Loss=0.282]


Epoch [127/200], Loss: 0.2820


100%|██████████| 201/201 [14:39<00:00,  4.37s/batch, Epoch=128, Loss=0.226]


Epoch [128/200], Loss: 0.2263


100%|██████████| 201/201 [14:39<00:00,  4.37s/batch, Epoch=129, Loss=0.205]


Epoch [129/200], Loss: 0.2052


100%|██████████| 201/201 [14:55<00:00,  4.46s/batch, Epoch=130, Loss=0.211]


Epoch [130/200], Loss: 0.2112


100%|██████████| 201/201 [14:49<00:00,  4.43s/batch, Epoch=131, Loss=0.21] 


Epoch [131/200], Loss: 0.2101


100%|██████████| 201/201 [14:46<00:00,  4.41s/batch, Epoch=132, Loss=0.194]


Epoch [132/200], Loss: 0.1941


100%|██████████| 201/201 [14:41<00:00,  4.39s/batch, Epoch=133, Loss=0.19] 


Epoch [133/200], Loss: 0.1895


100%|██████████| 201/201 [14:41<00:00,  4.39s/batch, Epoch=134, Loss=0.183]


Epoch [134/200], Loss: 0.1830


100%|██████████| 201/201 [14:42<00:00,  4.39s/batch, Epoch=135, Loss=0.184]


Epoch [135/200], Loss: 0.1841


100%|██████████| 201/201 [15:02<00:00,  4.49s/batch, Epoch=136, Loss=0.191]


Epoch [136/200], Loss: 0.1908


100%|██████████| 201/201 [15:02<00:00,  4.49s/batch, Epoch=137, Loss=0.214]


Epoch [137/200], Loss: 0.2142


100%|██████████| 201/201 [14:51<00:00,  4.44s/batch, Epoch=138, Loss=0.208]


Epoch [138/200], Loss: 0.2081


100%|██████████| 201/201 [14:47<00:00,  4.41s/batch, Epoch=139, Loss=0.19] 


Epoch [139/200], Loss: 0.1905


100%|██████████| 201/201 [14:50<00:00,  4.43s/batch, Epoch=140, Loss=0.22] 


Epoch [140/200], Loss: 0.2203


100%|██████████| 201/201 [14:53<00:00,  4.44s/batch, Epoch=141, Loss=0.204]


Epoch [141/200], Loss: 0.2036


100%|██████████| 201/201 [14:50<00:00,  4.43s/batch, Epoch=142, Loss=0.168]


Epoch [142/200], Loss: 0.1677


100%|██████████| 201/201 [14:53<00:00,  4.45s/batch, Epoch=143, Loss=0.203]


Epoch [143/200], Loss: 0.2030


100%|██████████| 201/201 [14:53<00:00,  4.45s/batch, Epoch=144, Loss=0.172]


Epoch [144/200], Loss: 0.1724


100%|██████████| 201/201 [14:46<00:00,  4.41s/batch, Epoch=145, Loss=0.174]


Epoch [145/200], Loss: 0.1743


100%|██████████| 201/201 [14:45<00:00,  4.40s/batch, Epoch=146, Loss=0.151]


Epoch [146/200], Loss: 0.1513


100%|██████████| 201/201 [14:46<00:00,  4.41s/batch, Epoch=147, Loss=0.16] 


Epoch [147/200], Loss: 0.1605


100%|██████████| 201/201 [14:45<00:00,  4.41s/batch, Epoch=148, Loss=0.194]


Epoch [148/200], Loss: 0.1936


100%|██████████| 201/201 [14:46<00:00,  4.41s/batch, Epoch=149, Loss=0.222]


Epoch [149/200], Loss: 0.2222


100%|██████████| 201/201 [14:53<00:00,  4.45s/batch, Epoch=150, Loss=0.146]


Epoch [150/200], Loss: 0.1461


100%|██████████| 201/201 [14:51<00:00,  4.43s/batch, Epoch=151, Loss=0.143]


Epoch [151/200], Loss: 0.1432


100%|██████████| 201/201 [14:51<00:00,  4.43s/batch, Epoch=152, Loss=0.154]


Epoch [152/200], Loss: 0.1540


100%|██████████| 201/201 [14:57<00:00,  4.47s/batch, Epoch=153, Loss=0.141]


Epoch [153/200], Loss: 0.1408


100%|██████████| 201/201 [14:51<00:00,  4.44s/batch, Epoch=154, Loss=0.134]


Epoch [154/200], Loss: 0.1343


100%|██████████| 201/201 [14:50<00:00,  4.43s/batch, Epoch=155, Loss=0.237]


Epoch [155/200], Loss: 0.2366


100%|██████████| 201/201 [14:51<00:00,  4.44s/batch, Epoch=156, Loss=0.188]


Epoch [156/200], Loss: 0.1883


100%|██████████| 201/201 [14:52<00:00,  4.44s/batch, Epoch=157, Loss=0.143]


Epoch [157/200], Loss: 0.1432


100%|██████████| 201/201 [14:53<00:00,  4.45s/batch, Epoch=158, Loss=0.197]


Epoch [158/200], Loss: 0.1965


100%|██████████| 201/201 [14:50<00:00,  4.43s/batch, Epoch=159, Loss=0.226]


Epoch [159/200], Loss: 0.2259


100%|██████████| 201/201 [14:43<00:00,  4.40s/batch, Epoch=160, Loss=0.145]


Epoch [160/200], Loss: 0.1451


100%|██████████| 201/201 [14:54<00:00,  4.45s/batch, Epoch=161, Loss=0.145]


Epoch [161/200], Loss: 0.1451


100%|██████████| 201/201 [14:43<00:00,  4.40s/batch, Epoch=162, Loss=0.132]


Epoch [162/200], Loss: 0.1316


100%|██████████| 201/201 [14:46<00:00,  4.41s/batch, Epoch=163, Loss=0.144]


Epoch [163/200], Loss: 0.1440


100%|██████████| 201/201 [14:43<00:00,  4.39s/batch, Epoch=164, Loss=0.123]


Epoch [164/200], Loss: 0.1232


100%|██████████| 201/201 [14:40<00:00,  4.38s/batch, Epoch=165, Loss=0.119]


Epoch [165/200], Loss: 0.1189


100%|██████████| 201/201 [14:42<00:00,  4.39s/batch, Epoch=166, Loss=0.123]


Epoch [166/200], Loss: 0.1226


100%|██████████| 201/201 [14:46<00:00,  4.41s/batch, Epoch=167, Loss=0.115]


Epoch [167/200], Loss: 0.1154


100%|██████████| 201/201 [14:45<00:00,  4.40s/batch, Epoch=168, Loss=0.114]


Epoch [168/200], Loss: 0.1143


100%|██████████| 201/201 [14:45<00:00,  4.40s/batch, Epoch=169, Loss=0.116]


Epoch [169/200], Loss: 0.1155


100%|██████████| 201/201 [14:42<00:00,  4.39s/batch, Epoch=170, Loss=0.119]


Epoch [170/200], Loss: 0.1189


100%|██████████| 201/201 [14:43<00:00,  4.39s/batch, Epoch=171, Loss=0.116]


Epoch [171/200], Loss: 0.1163


100%|██████████| 201/201 [14:42<00:00,  4.39s/batch, Epoch=172, Loss=0.225]


Epoch [172/200], Loss: 0.2246


100%|██████████| 201/201 [14:44<00:00,  4.40s/batch, Epoch=173, Loss=0.287]


Epoch [173/200], Loss: 0.2868


100%|██████████| 201/201 [14:42<00:00,  4.39s/batch, Epoch=174, Loss=0.15] 


Epoch [174/200], Loss: 0.1495


100%|██████████| 201/201 [14:41<00:00,  4.39s/batch, Epoch=175, Loss=0.134]


Epoch [175/200], Loss: 0.1337


100%|██████████| 201/201 [14:42<00:00,  4.39s/batch, Epoch=176, Loss=0.175]


Epoch [176/200], Loss: 0.1752


100%|██████████| 201/201 [14:48<00:00,  4.42s/batch, Epoch=177, Loss=0.297]


Epoch [177/200], Loss: 0.2967


100%|██████████| 201/201 [14:41<00:00,  4.39s/batch, Epoch=178, Loss=0.128]


Epoch [178/200], Loss: 0.1281


100%|██████████| 201/201 [14:41<00:00,  4.39s/batch, Epoch=179, Loss=0.113]


Epoch [179/200], Loss: 0.1130


100%|██████████| 201/201 [14:42<00:00,  4.39s/batch, Epoch=180, Loss=0.111]


Epoch [180/200], Loss: 0.1108


100%|██████████| 201/201 [14:40<00:00,  4.38s/batch, Epoch=181, Loss=0.112]


Epoch [181/200], Loss: 0.1118


100%|██████████| 201/201 [14:41<00:00,  4.39s/batch, Epoch=182, Loss=0.112]


Epoch [182/200], Loss: 0.1119


100%|██████████| 201/201 [14:53<00:00,  4.44s/batch, Epoch=183, Loss=0.106]


Epoch [183/200], Loss: 0.1059


100%|██████████| 201/201 [15:37<00:00,  4.67s/batch, Epoch=184, Loss=0.103]


Epoch [184/200], Loss: 0.1033


100%|██████████| 201/201 [15:46<00:00,  4.71s/batch, Epoch=185, Loss=0.0972]


Epoch [185/200], Loss: 0.0972


100%|██████████| 201/201 [16:06<00:00,  4.81s/batch, Epoch=186, Loss=0.156]


Epoch [186/200], Loss: 0.1565


100%|██████████| 201/201 [14:54<00:00,  4.45s/batch, Epoch=187, Loss=0.121]


Epoch [187/200], Loss: 0.1206


100%|██████████| 201/201 [14:52<00:00,  4.44s/batch, Epoch=188, Loss=0.101] 


Epoch [188/200], Loss: 0.1009


100%|██████████| 201/201 [14:56<00:00,  4.46s/batch, Epoch=189, Loss=0.0968]


Epoch [189/200], Loss: 0.0968


100%|██████████| 201/201 [14:51<00:00,  4.44s/batch, Epoch=190, Loss=0.156]


Epoch [190/200], Loss: 0.1555


100%|██████████| 201/201 [14:42<00:00,  4.39s/batch, Epoch=191, Loss=0.127]


Epoch [191/200], Loss: 0.1266


100%|██████████| 201/201 [14:41<00:00,  4.39s/batch, Epoch=192, Loss=0.0975]


Epoch [192/200], Loss: 0.0975


100%|██████████| 201/201 [14:38<00:00,  4.37s/batch, Epoch=193, Loss=0.0997]


Epoch [193/200], Loss: 0.0997


100%|██████████| 201/201 [14:40<00:00,  4.38s/batch, Epoch=194, Loss=0.0911]


Epoch [194/200], Loss: 0.0911


100%|██████████| 201/201 [14:49<00:00,  4.42s/batch, Epoch=195, Loss=0.0943]


Epoch [195/200], Loss: 0.0943


100%|██████████| 201/201 [14:43<00:00,  4.40s/batch, Epoch=196, Loss=0.107]


Epoch [196/200], Loss: 0.1074


100%|██████████| 201/201 [14:43<00:00,  4.39s/batch, Epoch=197, Loss=0.0934]


Epoch [197/200], Loss: 0.0934


100%|██████████| 201/201 [14:41<00:00,  4.39s/batch, Epoch=198, Loss=0.0878]


Epoch [198/200], Loss: 0.0878


100%|██████████| 201/201 [14:45<00:00,  4.41s/batch, Epoch=199, Loss=0.0879]


Epoch [199/200], Loss: 0.0879


100%|██████████| 201/201 [14:43<00:00,  4.40s/batch, Epoch=200, Loss=0.11]  

Epoch [200/200], Loss: 0.1099





In [11]:
torch.save(model.state_dict(), 'full_model.pth')
model.load_state_dict(torch.load('full_model.pth'))
model.to(device)
model.eval()

  model.load_state_dict(torch.load('full_model.pth'))


UNet(
  (dconv_down1): ResidualBlock(
    (conv): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.1, inplace=True)
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): LeakyReLU(negative_slope=0.1, inplace=True)
    )
    (skip): Sequential(
      (0): Conv2d(3, 64, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.1, inplace=True)
    )
  )
  (dconv_down2): ResidualBlock(
    (conv): Sequential(
      (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(

In [1]:
img = cv2.imread("data/raw/deepglobe-2018-dataset/train/2334_sat.jpg")
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
input_tensor = transform(img)
input_batch = input_tensor.unsqueeze(0).to(device)

with torch.no_grad():
    output = model(input_batch).squeeze()

_, predicted_labels = torch.max(output, 0)
predicted_mask = predicted_labels.squeeze().cpu().numpy()
print(torch.unique(predicted_labels))

label_map_colors = {0: [0, 255, 255],
 1: [255, 255, 0],
 2: [255, 0, 255],
 3: [0, 255, 0],
 4: [0, 0, 255],
 5: [255, 255, 255],
 6: [0, 0, 0]}

colored_mask = np.zeros((predicted_mask.shape[0], predicted_mask.shape[1], 3), dtype=np.uint8)
for label, color in label_map_colors.items():
    colored_mask[predicted_mask == label] = color

plt.subplot(1, 2, 1)
plt.imshow(img)
plt.axis('off')

plt.subplot(1, 2, 2)
plt.imshow(colored_mask)
plt.axis('off')

plt.show()

NameError: name 'cv2' is not defined