In [6]:
import torch
import torch.nn as nn
import numpy as np
from torchvision import datasets
from torchvision.transforms import transforms
import matplotlib.pyplot as plt
from torch.utils.data import Dataset
from PIL import Image
import os

# Lege ein BirdDataset an

In [7]:
class BirdDataset(Dataset):
    def __init__(self, image_paths, image_dir, segmentation_dir, transform_image, transform_mask):
        super(BirdDataset, self).__init__()
        self.image_dir = image_dir
        self.segmentation_dir = segmentation_dir
        self.transform_image = transform_image
        self.transform_mask = transform_mask
        with open(image_paths, 'r') as f:
            self.images_paths = [line.split(" ")[-1] for line in f.readlines()]
    
    def __len__(self):
        return len(self.images_paths)
    
    def __getitem__(self, index):
        image_name = ".".join(self.images_paths[index].split('.')[:-1])

        image = Image.open(os.path.join(self.image_dir, f"{image_name}.jpg")).convert("RGB")
        seg = Image.open(os.path.join(self.segmentation_dir, f"{image_name}.png")).convert("L")

        image = self.transform_image(image)
        seg = self.transform_mask(seg)

        return image, seg

# Lege jetzt eine Funktion an, 
die das Datenset in Trainings- und Validierungsdaten aufteilt und die Daten in Batches liefert

In [8]:
from torch.utils.data import DataLoader
import torch

def load_data_set(image_paths, image_dir, segmentation_dir, transforms, batch_size=8, shuffle=True):
    dataset = BirdDataset(image_paths,
                          image_dir,
                          segmentation_dir,
                          transform_image=transforms[0],
                          transform_mask=transforms[1])
    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [11772, 16])

    return DataLoader( train_dataset, batch_size=batch_size, shuffle=shuffle), DataLoader( 
                         val_dataset, batch_size=batch_size, shuffle=shuffle)

# Aufbau der UNet-Architektur

In [9]:
class conv_block(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.conv1 = nn.Conv2d(in_c, out_c, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_c)
        self.conv2 = nn.Conv2d(out_c, out_c, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_c)
        self.relu = nn.ReLU()
            
    def forward(self, inputs):
        x = self.conv1(inputs)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        return x

In [10]:
class encoder_block(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.conv = conv_block(in_c, out_c)
        self.pool = nn.MaxPool2d((2, 2))
    def forward(self, inputs):
        x = self.conv(inputs)
        p = self.pool(x)
        #print("x.shape=", x.shape, " p.shape=", p.shape)
        return x, p

In [11]:
import torchvision.transforms.functional as TF
class decoder_block(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_c, out_c, kernel_size=2, stride=2, padding=0)
        self.conv = conv_block(out_c+out_c, out_c)
    def forward(self, inputs, skip):
        x = self.up(inputs)
        #print("x dim = ", x.shape, "  skip = ", skip.shape)
        if x.shape != skip.shape:
            x = TF.resize(x, size=skip.shape[2:])
        x = torch.cat([x, skip], axis=1)
        #print(x.shape)
        x = self.conv(x)
        return x

In [12]:
class UNet(nn.Module):
    
    def __init__(self):
        super().__init__()
        """ Encoder """
        self.e1 = encoder_block(3, 64)
        self.e2 = encoder_block(64, 128)
        self.e3 = encoder_block(128, 256)
        self.e4 = encoder_block(256, 512)
        """ Bottleneck """
        self.b = conv_block(512, 1024)
        """ Decoder """
        self.d1 = decoder_block(1024, 512)
        self.d2 = decoder_block(512, 256)
        self.d3 = decoder_block(256, 128)
        self.d4 = decoder_block(128, 64)
        """ Classifier """
        self.outputs = nn.Conv2d(64, 1, kernel_size=1, padding=0)
    
    def forward(self, inputs):
        """ Encoder """
        s1, p1 = self.e1(inputs)
        s2, p2 = self.e2(p1)
        s3, p3 = self.e3(p2)
        s4, p4 = self.e4(p3)
        """ Bottleneck """
        b = self.b(p4)
        #print("bottleneck dim = ", b.shape)
        """ Decoder """
        d1 = self.d1(b, s4)
        #print("d1 dim = ", d1.shape)
        d2 = self.d2(d1, s3)
        #print("d2 dim = ", d2.shape)
        d3 = self.d3(d2, s2)
        #print("d3 dim = ", d3.shape)
        #print("s1 dim = ", s1.shape)
        d4 = self.d4(d3, s1)
        #print("d4 dim = ", d4.shape)
        """ Classifier """
        outputs = self.outputs(d4)
        #print("outputs dim = ", outputs.shape)
        return outputs

In [13]:
def test():
    image = torch.randn((32, 3, 161, 161))
    model = UNet()
    out = model(image)
    print(image.shape, out.shape)
    assert out.shape == (32, 1, 161, 161)

In [14]:
test()



torch.Size([32, 3, 161, 161]) torch.Size([32, 1, 161, 161])


In [25]:
config = {
    "lr": 1e-3,
    "batch_size": 16,
    "image_dir": "CUB_200_2011/CUB_200_2011/images",
    "segmentation_dir": "CUB_200_2011/CUB_200_2011/segmentations",
    "image_paths": "CUB_200_2011/CUB_200_2011/images.txt",
    "epochs": 10,
    "checkpoint": "checkpoint/bird_segmentation_v1.pth",
    "optimiser": "checkpoint/bird_segmentation_v1_optim.pth",
    "continue_train": False,
    #"device": "cuda" if torch.cuda.is_available() else "cpu"
    "device" : "mps" if torch.backends.mps.is_available() else "cpu"
}

In [26]:

transforms_image = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize((0., 0., 0.), (1., 1., 1.))
])

transforms_mask = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize((0.,), (1.,))
])

In [27]:
train_dataset, val_dataset = load_data_set(
    config['image_paths'],
    config['image_dir'],
    config['segmentation_dir'],
    transforms=[transforms_image, transforms_mask],
    batch_size=config['batch_size']
)

print("loaded", len(train_dataset), "batches")

loaded 736 batches


In [28]:
model = UNet().to(config['device'])
optimiser = torch.optim.Adam(params=model.parameters(), lr=config['lr'])

In [29]:
if config['continue_train']:
    state_dict = torch.load(config['checkpoint'])
    optimiser_state = torch.load(config['optimiser'])
    model.load_state_dict(state_dict)
    optimiser.load_state_dict(optimiser_state)

In [30]:
loss_fn = torch.nn.BCEWithLogitsLoss()
#scaler = torch.cuda.amp.GradScaler()

model.train()

UNet(
  (e1): encoder_block(
    (conv): conv_block(
      (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU()
    )
    (pool): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
  )
  (e2): encoder_block(
    (conv): conv_block(
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU()
    )
    (pool): MaxPool2d(kernel_size=(2, 2), str

In [31]:
def check_accuracy_and_save(model, optimiser, epoch):
    torch.save(model.state_dict(), config['checkpoint'])
    torch.save(optimiser.state_dict(), config['optimiser'])

    num_correct = 0
    num_pixel = 0
    dice_score = 0

    model.eval()
    with torch.no_grad():
        for x, y in val_dataset:
            x = x.to(config['device'])
            y = y.to(config['device'])

            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()
            num_correct += (preds == y).sum()
            num_pixel += torch.numel(preds)
            dice_score += (2 * (preds * y).sum()) / (
                (preds + y).sum() + 1e-8
            )

            torchvision.utils.save_image(preds, f"test/pred/{epoch}.png")
            torchvision.utils.save_image(y, f"test/true/{epoch}.png")

    print(
        f"Dice Score = {dice_score/len(val_dataset)}"
    )
    model.train()

In [32]:
from tqdm import tqdm
import torchvision

def train():
    step = 0
    for epoch in range(config['epochs']):
        loop = tqdm(train_dataset)
        for image, seg in loop:
            image = image.to(config['device'])
            seg = seg.float().to(config['device'])

            #with torch.backends.mps.amp.autocast():
            #with torch.autocast(device_type="mps"):
            pred = model(image)
            loss = loss_fn(pred, seg)

            optimiser.zero_grad()
            loss.backward();
            optimiser.step();
            #scaler.scale(loss).backward()
            #scaler.step(optimiser)
            #scaler.update()

            loop.set_postfix(loss=loss.item())
            step += 1
        check_accuracy_and_save(model, optimiser, epoch)

In [None]:
train()

100%|████████████████████████████| 736/736 [12:55<00:00,  1.05s/it, loss=0.0337]


Dice Score = 0.923539400100708


100%|████████████████████████████| 736/736 [13:06<00:00,  1.07s/it, loss=0.0386]


Dice Score = 0.9237179756164551


100%|████████████████████████████| 736/736 [13:13<00:00,  1.08s/it, loss=0.0313]


Dice Score = 0.9237039089202881


 99%|███████████████████████████▊| 730/736 [12:56<00:06,  1.07s/it, loss=0.0304]

In [24]:
for i in range(1,13):
    image_name = 'Bild' + '{:02d}'.format(i)
    image = Image.open(os.path.join('./test', f"{image_name}.png")).convert("RGB")
    image = transforms_image(image)
    image = image.to(config['device'])
    image = image.reshape(1,3,256,256)
    preds = torch.sigmoid(model(image))
    torchvision.utils.save_image(preds, f"test/pred/{image_name}.png")