In [None]:
# import os
# os.kill(os.getpid(), 9)

In [1]:

import cv2
import torch
import random
from glob import glob
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from torchvision import models
import torch.nn as nn
import torch.optim as optim
from torchsummary import summary
from tqdm import tqdm

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [3]:
!wget http://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz
!wget http://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz
!tar -xf images.tar.gz
!tar -xf annotations.tar.gz

--2024-02-17 06:17:31--  http://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz
Resolving www.robots.ox.ac.uk (www.robots.ox.ac.uk)... 129.67.94.2
Connecting to www.robots.ox.ac.uk (www.robots.ox.ac.uk)|129.67.94.2|:80... connected.
HTTP request sent, awaiting response... 301 Moved Permanently
Location: https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz [following]
--2024-02-17 06:17:31--  https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz
Connecting to www.robots.ox.ac.uk (www.robots.ox.ac.uk)|129.67.94.2|:443... connected.
HTTP request sent, awaiting response... 301 Moved Permanently
Location: https://thor.robots.ox.ac.uk/~vgg/data/pets/images.tar.gz [following]
--2024-02-17 06:17:32--  https://thor.robots.ox.ac.uk/~vgg/data/pets/images.tar.gz
Resolving thor.robots.ox.ac.uk (thor.robots.ox.ac.uk)... 129.67.95.98
Connecting to thor.robots.ox.ac.uk (thor.robots.ox.ac.uk)|129.67.95.98|:443... connected.
HTTP request sent, awaiting response... 301 Moved Perman

In [4]:
input_dir = "images/"
target_dir = "annotations/trimaps/"

input_img_paths = sorted(glob(input_dir + "/*.jpg"))
target_paths = sorted(glob(target_dir + "/*.png"))

In [5]:
class SegmentDataset(Dataset):
    def __init__(self, image_dir, target_dir, img_size=(200, 200),
                 random_state=1337, train=True, transform=None):

        all_images_path = sorted(glob(image_dir + "/*.jpg"))
        all_targets_path = sorted(glob(target_dir + "/*.png"))

        random.Random(random_state).shuffle(all_images_path)
        random.Random(random_state).shuffle(all_targets_path)

        self.transform = transform
        self.img_size = img_size

        num_val_samples = 1000
        if train:
            self.images_path = all_images_path[num_val_samples:]
            self.targets_path = all_targets_path[num_val_samples:]
        else:
            self.images_path = all_images_path[:num_val_samples]
            self.targets_path = all_targets_path[:num_val_samples]

    def __len__(self):
        return len(self.images_path)

    def image_read(self, path):
        im = cv2.imread(path)
        im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
        im = cv2.resize(im, self.img_size)
        return im

    def __getitem__(self, idx):
        image_path = self.images_path[idx]
        target_path = self.targets_path[idx]
        image = self.image_read(image_path)
        target = self.image_read(target_path)[:, :, 0]
        if self.transform:
            image = self.transform(image)
        else:
            image = transforms.ToTensor()(image)
        target = torch.from_numpy(target.astype("int64")) - 1  # Change to int64
        return image.float(), target

In [11]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.dConvs = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, 1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.dConvs(x)

class DownConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.DownSampling = nn.Sequential(DoubleConv(in_channels, out_channels),
                                           nn.MaxPool2d(2))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.DownSampling(x)

class UpConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.UpSampling = nn.Sequential(nn.ConvTranspose2d(in_channels, out_channels, 2, 2),
                                         DoubleConv(out_channels, out_channels))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.UpSampling(x)

class MyUnet(nn.Module):
    def __init__(self, input_ch, n_classes):
        super().__init__()
        self.inputs = [input_ch, 64, 128, 256, 512, 1024]
        self.outputs = [1024, 512, 256, 128, 64, n_classes]
        self.encode_blocks = nn.ModuleList([DownConv(self.inputs[i], self.inputs[i + 1]) for i in range(len(self.inputs) - 1)])
        self.decode_blocks = nn.ModuleList([UpConv(self.outputs[i + 1], self.outputs[i]) for i in range(len(self.outputs) - 1)])

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        CopyCrop = []
        for encode_block in self.encode_blocks:
            x = encode_block(x)
            CopyCrop.append(x.clone())

        for i, decode_block in enumerate(self.decode_blocks):
            x = decode_block(x)
            x = torch.cat((x, CopyCrop[-i - 1]), dim=1)

        return x

In [12]:
model = MyUnet(input_ch=3, n_classes=3)
num_epochs = 10
model.to(device)

criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

train_dataset = SegmentDataset(input_dir, target_dir, train=True)
val_dataset = SegmentDataset(input_dir, target_dir, train=False)

train_dl = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_dl = DataLoader(val_dataset, batch_size=64)

In [13]:
model

MyUnet(
  (encode_blocks): ModuleList(
    (0): DownConv(
      (DownSampling): Sequential(
        (0): DoubleConv(
          (dConvs): Sequential(
            (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1))
            (1): ReLU(inplace=True)
            (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
            (3): ReLU(inplace=True)
          )
        )
        (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      )
    )
    (1): DownConv(
      (DownSampling): Sequential(
        (0): DoubleConv(
          (dConvs): Sequential(
            (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))
            (1): ReLU(inplace=True)
            (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
            (3): ReLU(inplace=True)
          )
        )
        (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      )
    )
    (2): DownConv(
      (DownSampling): Sequential(
        (0): DoubleConv(
 

In [None]:
train_history = []
validation_history = []

for epoch in range(num_epochs):

    model.train()
    losses = []

    with tqdm(train_dl, leave=False) as bar:
        bar.set_description(f"[Epoch: {epoch + 1}/{num_epochs}]")

        for batch_idx, (data, target) in enumerate(bar):

            data = data.to(device)
            target = target.to(device)

            output = model(data)

            loss = criterion(output, target)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            losses.append(loss.item())
            bar.set_postfix(avg_epoch_loss=f"{sum(losses)/len(losses):.4f}")
    train_history.append(sum(losses)/len(losses))

    model.eval()
    running_testloss = 0.0

    with torch.no_grad():
        for i, (test_data, test_label) in enumerate(val_dl):
            test_output = model(test_data)
            vloss = criterion(test_output, test_label)
            running_testloss += vloss.item()
        avg_vloss = running_testloss / (i + 1)
        validation_history.append(avg_vloss)
        if epoch % 10 == 9:
            print(f'Epoch [{epoch + 1}/{num_epochs}]: loss train: {sum(losses)/len(losses):.3f}, validation: {avg_vloss:.3f}')